package rs.fon.is.weka.clustering;

import java.util.Enumeration;

import weka.clusterers.ClusterEvaluation;
import weka.clusterers.SimpleKMeans;
import weka.core.Attribute;
import weka.core.Instances;
import weka.core.SelectedTag;
import weka.core.converters.ConverterUtils.DataSource;
import weka.filters.Filter;
import weka.filters.MultiFilter;
import weka.filters.unsupervised.attribute.Normalize;
import weka.filters.unsupervised.attribute.Remove;

public class KMeansClusteringExample {

	private static String fileName = "data/wholesale_horeca_customers.csv";

	public static void main(String[] args) throws Exception {

		// read data from CSV file
		DataSource loader = new DataSource(fileName);
		Instances data = loader.getDataSet();
		
		// check the attributes
		System.out.println("Attributes in the dataset:");
		printAttributeData(data);
		
		// create Remove filter to remove nominal attributes
		Remove removeFilter = new Remove();
		removeFilter.setAttributeIndicesArray(new int[] {0,1});;
		removeFilter.setInputFormat(data);
		
		// create Normalize filter to normalize the data
		Normalize normalizeFilter = new Normalize();
		normalizeFilter.setScale(1);
		normalizeFilter.setInputFormat(data);
		
		// create MultiFilter so that both removal and normalization can be done
		MultiFilter multiFilter = new MultiFilter();
		multiFilter.setInputFormat(data);
		multiFilter.setFilters(new Filter[]{removeFilter, normalizeFilter});
		Instances clusteringData = Filter.useFilter(data, multiFilter);
				
		// check the correct attributes were removed
		System.out.println("\nAttributes to be used for clustering:");
		printAttributeData(clusteringData);
		
		// use the Elbow method to find the optimal number of clusters
		int[] clusterNum = {2,3,4,5,6,7,8};
		double[] sumOfSquredErrors = new double[clusterNum.length];
		
		SelectedTag kMeansPlusPlusTag = new SelectedTag(SimpleKMeans.KMEANS_PLUS_PLUS, SimpleKMeans.TAGS_SELECTION);
		int seed = 11;
		
		for (int i = 0; i < clusterNum.length; i++) {	
			SimpleKMeans kMeansCLusterer = new SimpleKMeans();
			kMeansCLusterer.setSeed(seed);
			kMeansCLusterer.setNumClusters( clusterNum[i] );
			kMeansCLusterer.setMaxIterations(20);
			kMeansCLusterer.setInitializationMethod( kMeansPlusPlusTag );
			kMeansCLusterer.buildClusterer(clusteringData);
			
			sumOfSquredErrors[i] = kMeansCLusterer.getSquaredError();
		}
		
		// print sum of squared errors for different numbers of clusters
		System.out.println("\nSum of squared errors for different k values");
		System.out.println("K\t\tSSE");
		for (int i = 0; i < sumOfSquredErrors.length; i++) {
			System.out.println(clusterNum[i] + "\t\t" + sumOfSquredErrors[i]);
		}
		
		// draw the Elbow plot
//		ElbowMethodPlot plot = new ElbowMethodPlot(clusterNum, sumOfSquredErrors);
//		plot.drawPlot();
		
		// choose the best K value and build a clustering model with that value
		int k = 5;
		SimpleKMeans bestKMeans = new SimpleKMeans();
		bestKMeans.setSeed(seed);
		bestKMeans.setNumClusters( k );
		bestKMeans.setMaxIterations(20);
		bestKMeans.setInitializationMethod( kMeansPlusPlusTag );
		bestKMeans.setDisplayStdDevs(true);
		bestKMeans.buildClusterer(clusteringData);
		
		//evaluate the clustering model		
		ClusterEvaluation eval = new ClusterEvaluation();
		eval.setClusterer(bestKMeans);
		eval.evaluateClusterer(data);
		
		System.out.println(eval.clusterResultsToString());
		
		
	}
	
	private static void printAttributeData(Instances dataset) {	
		Enumeration<Attribute> attributes = dataset.enumerateAttributes();
		while (attributes.hasMoreElements()) {
			Attribute a = (Attribute) attributes.nextElement();
			System.out.println("- " + a.name() + ": " + Attribute.typeToString(a));
		}
	}
}
