package ac.uk.susx.jack.tag.runner;

import ac.uk.susx.jack.tag.cluster.KMeansClustering;
import ac.uk.susx.jack.tag.data.TFIDF;
import com.beust.jcommander.Parameter;
import com.clearspring.analytics.util.Lists;
import com.google.common.collect.Maps;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;

/* loaded from: input_file:ac/uk/susx/jack/tag/runner/KMeansRunner.class */
public class KMeansRunner {

    /* loaded from: input_file:ac/uk/susx/jack/tag/runner/KMeansRunner$Inputs.class */
    public class Inputs {

        @Parameter(names = {"-ts", "--test"}, description = "Directories of test documents")
        private List<String> testDocs = Lists.newArrayList();

        @Parameter(names = {"-tr", "--train"}, description = "Training documents directory", required = true)
        private String trainDir = null;

        @Parameter(names = {"-k", "--centres"}, description = "Number of clusters centres")
        private int k = 3;

        @Parameter(names = {"-f", "--features"}, description = "Number of tfidf features to use")
        private int feats = 0;

        public Inputs() {
        }
    }

    public static void main(String[] strArr) {
        File[] fileArr = {new File("/Volumes/External/Phd-LDA/Experiment-5/cancer_DOID_162_D2_cleaned_3000"), new File("/Volumes/External/Phd-LDA/Experiment-5/immune_system_disease_DOID_2914_D2_cleaned"), new File("/Volumes/External/Phd-LDA/Experiment-5/gastrointestinal_system_disease_DOID_77_D2_cleaned")};
        File file = new File("/Volumes/External/Phd-LDA/Experiment-5/combined-cleaned");
        new TFIDF();
        KMeansClustering kMeansClustering = new KMeansClustering();
        try {
            List newArrayList = Lists.newArrayList();
            DataFrame dataFrame = TFIDF.tokenise(TFIDF.basicDataFrame(TFIDF.readData(file, "txt"), TFIDF.basicSchema()), "document", "tokens");
            HashingTF tf = TFIDF.tf("tokens", "features", 10000);
            DataFrame transform = tf.transform(dataFrame);
            for (File file2 : fileArr) {
                newArrayList.add(tf.transform(TFIDF.tokenise(TFIDF.basicDataFrame(TFIDF.readData(file2, "txt"), TFIDF.basicSchema()), "document", "tokens")));
            }
            System.out.println(transform.toString());
            KMeansModel train = kMeansClustering.train(transform, 3);
            Vector[] clusterCenters = train.clusterCenters();
            System.out.println("Cluster Centers: ");
            for (Vector vector : clusterCenters) {
                System.out.println(vector);
            }
            Iterator it = newArrayList.iterator();
            while (it.hasNext()) {
                predict(train, ((DataFrame) it.next()).collectAsList(), 3);
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public static void predict(KMeansModel kMeansModel, List<Row> list, int i) {
        HashMap newHashMap = Maps.newHashMap();
        Iterator<Row> it = list.iterator();
        while (it.hasNext()) {
            int predict = kMeansModel.predict((Vector) it.next().get(i));
            if (newHashMap.containsKey(String.valueOf(predict))) {
                newHashMap.put(String.valueOf(predict), Integer.valueOf(((Integer) newHashMap.get(String.valueOf(predict))).intValue() + 1));
            } else {
                newHashMap.put(String.valueOf(predict), 1);
            }
        }
        int i2 = 0;
        for (String str : newHashMap.keySet()) {
            System.out.println("Class: " + str + " Count: " + newHashMap.get(str));
            i2 += ((Integer) newHashMap.get(str)).intValue();
        }
        Iterator it2 = newHashMap.keySet().iterator();
        while (it2.hasNext()) {
            System.out.println("Class: " + ((String) it2.next()) + " Percentage: " + ((((Integer) newHashMap.get(r0)).intValue() / i2) * 100.0d));
        }
    }
}
