#include <windows.h>
#include <stdio.h>
#include <vector>
#include <set>
#include <map>

#pragma warning(disable: 4087)

static const char* var_desc[] =
{
    "cap shape (bell=b,conical=c,convex=x,flat=f)",
    "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",
    "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",
    "bruises? (bruises=t,no=f)",
    "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",
    "gill attachment (attached=a,descending=d,free=f,notched=n)",
    "gill spacing (close=c,crowded=w,distant=d)",
    "gill size (broad=b,narrow=n)",
    "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",
    "stalk shape (enlarging=e,tapering=t)",
    "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",
    "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    "veil type (partial=p,universal=u)",
    "veil color (brown=n,orange=o,white=w,yellow=y)",
    "ring number (none=n,one=o,two=t)",
    "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",
    "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",
    "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",
    "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",
    0
};

 
struct MushroomDesc
{
	char Desc[23];
};

struct MushroomFeature
{
	bool isPoison;
	int index;
	char ch;
	float prob;
};

struct MushroomFeatureComp
{
	bool operator()(const MushroomFeature& left, const MushroomFeature& right) const
	{
		if (left.index < right.index)
		{
			return true;
		}
		else
		{
			if (left.index > right.index)
			{
				return false;
			}
			else
			{
				return left.ch < right.ch;
			}
		}
	}
};



typedef std::vector<MushroomDesc> MushroomVect;
typedef std::vector<char> FeatureVect;

typedef std::vector<MushroomFeature> MushroomFeatureVect;

typedef std::set<MushroomFeature, MushroomFeatureComp> MushroomFeatureSet;

struct MushroomFeatureCorrelationMapComp
{
	bool operator()(const MushroomFeatureSet& left, const MushroomFeatureSet& right)const
	{
		MushroomFeatureComp comp;
		return comp(*left.begin(), *right.begin());
	}
};


typedef std::map<MushroomFeatureSet, float, MushroomFeatureCorrelationMapComp> MushroomFeatureCorrelationMap;

//typedef std::pair<MushroomFeatureSet, float> MushroomFeaturePair;


void readFile(char* szFileName, MushroomVect& vect)
{
	FILE* stream = NULL;
	char buffer[256];
	MushroomDesc desc;
	if ((stream = fopen(szFileName, "rt")) != NULL)
	{
		while (fgets(buffer, 256, stream))
		{
			for (int i = 0; i < 23; i ++)
			{
				desc.Desc[i] = buffer[i*2];
			}
			vect.push_back(desc);
		}
		fclose(stream);
	}
}

void printFile(MushroomVect& vect)
{
	char buffer[256];
	int i, j;
	for ( i = 0; i < vect.size(); i ++)
	{
		for ( j = 0; j < 23; j ++)
		{
			buffer[j*2] = vect[i].Desc[j];
			buffer[j*2+1] = ',';
		}
		buffer[j*2] = '\0';
		printf("%s\n", buffer);
	}
}

bool findInMushroom(MushroomVect& vect, int index, char ch)
{
	for (int i = 0; i < vect.size(); i ++)
	{
		if (vect[i].Desc[index+1] == ch)
		{
			return true;
		}
	}
	return false;

}

bool findInFeature(FeatureVect& vect, char ch)
{
	for (int i = 0; i < vect.size(); i ++)
	{
		if (vect[i] == ch)
		{
			return true;
		}
	}
	return false;
}

float checkProb(MushroomVect& vect, int index, char ch)
{
	int counter = 0;
	for (int i = 0; i < vect.size(); i ++)
	{
		if (vect[i].Desc[index + 1] == ch)
		{
			counter ++;
		}
	}
	return (float)counter/(float)vect.size();
}

void checkCorrelation(MushroomVect& vect, MushroomFeatureVect& feature, MushroomFeatureCorrelationMap& correlation)
{
	
	MushroomFeatureSet featureSet;
	int nTotal = 0, nCorrelation; 
	//MushroomFeaturePair myPair;
	
	for (int i = 0; i < feature.size(); i ++)
	{
		for (int j = i + 1; j < feature.size(); j ++)
		{
			// make sure they are agree
			if (feature[i].isPoison == feature[j].isPoison)
			{
				nTotal = nCorrelation = 0;
				for (int k = 0; k < vect.size(); k ++)
				{
					if (vect[k].Desc[feature[i].index + 1] == feature[i].ch || vect[k].Desc[feature[j].index + 1] == feature[j].ch)
					{
						nTotal ++;
						if (vect[k].Desc[feature[i].index + 1] == feature[i].ch && vect[k].Desc[feature[j].index + 1] == feature[j].ch)
						{
							nCorrelation ++;
						}
					}
				}
			}
		}
		//finish one pair
		if (nCorrelation > 0 )
		{
			
			featureSet.clear();
			featureSet.insert(feature[i]);
			featureSet.insert(feature[j]);
			
			//myPair = std::make_pair(featureSet, (float)nCorrelation/(float)nTotal);

			printf("the correlation is between %d,%d with %d over %d\n", i, j, nCorrelation, nTotal);
			
			correlation.insert(MushroomFeatureCorrelationMap::value_type(featureSet, (float)nCorrelation/(float)nTotal));
			
			
		}
	}
	

}

void analysisFile(MushroomVect& vect, MushroomFeatureVect& mushroomFeature)
{
	MushroomVect poison, edible;

	MushroomFeature myFeature;
	

	FeatureVect feature;
	int i;
	int row;
	char ch;
	for (i = 0; i < vect.size(); i ++)
	{
		if (vect[i].Desc[0] == 'p')
		{
			poison.push_back(vect[i]);
		}
		else
		{
			if (vect[i].Desc[0] == 'e')
			{
				edible.push_back(vect[i]);
			}
			else
			{
				printf("corrupted file");
				return;
			}
		}
	}
	for (i = 0; i < 22; i ++)
	{
		// new column
		feature.clear();
		for (row = 0; row < poison.size(); row ++)
		{
			ch = poison[row].Desc[i+1];
			if (ch != '?')
			{
				if (!findInFeature(feature, ch))
				{
					// a new feature
					feature.push_back(ch);
					if (!findInMushroom(edible, i, ch))
					{
						printf("poison mushroom is %s[%d]=%c\n", var_desc[i], i, ch);
						//later we check the possibility 
						myFeature.isPoison = true;
						myFeature.index = i;
						myFeature.ch = ch;
						myFeature.prob = checkProb(poison, i, ch);
						mushroomFeature.push_back(myFeature);
					}
				}
			}
		}
	}

	// swap
	for (i = 0; i < 22; i ++)
	{
		// new column
		feature.clear();
		for (row = 0; row < edible.size(); row ++)
		{
			ch = edible[row].Desc[i+1];
			if (ch != '?')
			{
				if (!findInFeature(feature, ch))
				{
					// a new feature
					feature.push_back(ch);
					if (!findInMushroom(poison, i, ch))
					{
						printf("edible mushroom is %s[%d]=%c\n", var_desc[i], i, ch);
						//later we check the possibility 
						myFeature.isPoison = false;
						myFeature.index = i;
						myFeature.ch = ch;
						myFeature.prob = checkProb(edible, i, ch);
						mushroomFeature.push_back(myFeature);
					}
				}
			}
		}
	}
	for (i = 0; i < mushroomFeature.size(); i ++)
	{
		printf("%s: %s[%d]=%c with %f\n", mushroomFeature[i].isPoison?"poison":"edible", var_desc[mushroomFeature[i].index], mushroomFeature[i].index, mushroomFeature[i].ch, mushroomFeature[i].prob);
	}
}


void predictFile(MushroomVect& vect, MushroomFeatureVect& feature)
{
	float prob;
	int counter = 0;
	
	for (int i = 0; i < vect.size(); i ++)
	{
		prob = 0;
		for (int j = 0; j < feature.size(); j ++)
		{		
			if (vect[i].Desc[feature[j].index + 1] == feature[j].ch)
			{
				if (feature[j].isPoison)
				{
					prob -= feature[j].prob;
				}
				else
				{
					prob += feature[j].prob;
				}	
				printf("%s: feature:%s favor %f\n", feature[j].isPoison?"poison":"edible", var_desc[feature[j].index], feature[j].prob);
			}
			
		}
		if (vect[i].Desc[0] == 'p' && prob  < 0 || vect[i].Desc[0] == 'e' && prob > 0)
		{
			counter ++;
			printf("total prob: %f and fact is: %c\n", prob, vect[i].Desc[0]);
		}
		
	}
	printf("prediction rate: %d over %d\n", counter, vect.size());
}




int main()
{
	MushroomVect vect;
	MushroomFeatureVect feature;

	readFile("agaricus-lepiota.data", vect);
	//printFile(vect);
	analysisFile(vect, feature);

	predictFile(vect, feature);

	return 0;

}

int APIENTRY WinMain(HINSTANCE hInstance, HINSTANCE hPrevInstance, LPSTR lpCmdLine, int nCmdShow)
{


}