package rs.ac.bg.fon.ai.dataPreparation;

import java.util.Enumeration;
import java.util.Random;

import weka.attributeSelection.BestFirst;
import weka.attributeSelection.WrapperSubsetEval;
import weka.classifiers.Evaluation;
import weka.classifiers.bayes.NaiveBayes;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.supervised.attribute.AttributeSelection;
import weka.filters.unsupervised.attribute.Remove;

public class PercentSplitAttributeSelection {

		private static String fileName = "data/census90-income.arff";

		public static void main(String[] args) throws Exception {

			// reading dataset from the file
			DataSource dataSrc = new DataSource(fileName);
			Instances data = dataSrc.getDataSet();
			data.setClassIndex(data.numAttributes()-1);
			
			// printing all attributes from the original dataset
			printAttributeData(data);			
			
			//split the full dataset into the training and test sets
			// so that we use 80% for training and 20% for testing
			// assure that data are randomly placed in the training and test sets 
			int seed = 9;
			Random randomizer = new Random(seed);
			// copy of the original dataset that will be randomized
			Instances randomData = new Instances(data);
			randomData.stratify(2);				// important: the order of stratify() and randomize() is important
			randomData.randomize(randomizer);	// first stratify, then randomize
			//now do the 80-20 split
			int trainSize = (int) Math.round(randomData.size() * 0.80);
			int testSize = randomData.size() - trainSize;
			Instances trainData = new Instances(randomData, 0, trainSize);
			Instances testData = new Instances(randomData, trainSize, testSize);
			
			
			// Do attribute selection using the wrapper method, by taking the following steps:
			// 1) creating a classifier instance to be used for wrapped selection
			NaiveBayes nbClsf = new NaiveBayes();
			nbClsf.setUseSupervisedDiscretization(true);
			nbClsf.buildClassifier(trainData);  // note that only train portion of the data is used
												// the classifier must not get access to the data (before evaluation)
			
			// 2) creating evaluator wrapped around the classifier
			WrapperSubsetEval wrappedEval = new WrapperSubsetEval();
			wrappedEval.setClassifier(nbClsf);
			
			// 3) creating AttributeSelection filter and setting evaluator and search method
			AttributeSelection attSelector = new AttributeSelection();
			attSelector.setInputFormat(trainData);  // again, use train data only
			attSelector.setEvaluator(wrappedEval);
			attSelector.setSearch(new BestFirst());
			
			// 4) applying the AttributeSection instance and getting a new data set with reduced attribute set
			Instances reducedTrainData = Filter.useFilter(trainData, attSelector);
			
			// printing all attributes from the reduced train data set
			System.out.println("\n\n--------------------------------------\n\n");
			System.out.println("Attributes in the train data set after attribute selection");
			printAttributeData(reducedTrainData);
			
			//now, train the classifier on the train data with the reduced attribute set
			nbClsf.setUseSupervisedDiscretization(true);
			nbClsf.buildClassifier(reducedTrainData);
			
			//to test the classifier we need to have the test set with the same
			//attributes as the train set, that is, we have to remove the attributes that were
			//removed from the train set; the following method finds indexes of those attributes
			int[] attToRemove = removedAttributesIndices(testData, reducedTrainData);
			
			// use Remove filter to remove those attributes
			Remove removeFilter = new Remove();
			removeFilter.setAttributeIndicesArray(attToRemove);
			removeFilter.setInputFormat(testData);
			Instances reducedTestData = Filter.useFilter(testData, removeFilter);
			
//			System.out.println("\n\n--------------------------------------\n\n");
//			System.out.println("Attributes in the test data set after attribute selection");
//			printAttributeData(reducedTestData);
			
			// now, test the classifier
			Evaluation eval = new Evaluation(reducedTestData);
			eval.evaluateModel(nbClsf, reducedTestData);
			
			System.out.println( eval.toMatrixString() );
			System.out.println( eval.toSummaryString() );
			
		}

		private static void printAttributeData(Instances dataset) {	
			Enumeration<Attribute> attributes = dataset.enumerateAttributes();
			while (attributes.hasMoreElements()) {
				Attribute a = (Attribute) attributes.nextElement();
				System.out.println("- " + a.name() + ": " + Attribute.typeToString(a));
			}
		}
		
		private static int[] removedAttributesIndices(Instances fullDataset, Instances reducedDataset) {
			int[] attToRemove = new int[fullDataset.numAttributes() - reducedDataset.numAttributes()];
			
			int k = 0, m = 0;
			Enumeration<Attribute> allAttributesEnum = fullDataset.enumerateAttributes();
			while (allAttributesEnum.hasMoreElements()) {
				Attribute attribute = (Attribute) allAttributesEnum.nextElement();
				boolean removed = true;
				Enumeration<Attribute> selectedAttrEnum = reducedDataset.enumerateAttributes();
				while (selectedAttrEnum.hasMoreElements()) {
					Attribute selectedAttr = (Attribute) selectedAttrEnum.nextElement();
					if ( attribute.equals(selectedAttr) ) {
						removed = false;
						break;
					}
				}
				if ( removed ) attToRemove[k++] = m;
				m++;
			}
			
			return attToRemove;
		}
}
