/*
 * Decompiled with CFR 0.152.
 */
package it.uniroma2.exp.task;

import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.binary.C_SVC;
import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import edu.berkeley.compbio.jlibsvm.labelinverter.StringLabelInverter;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassProblem;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassProblemImpl;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import it.uniroma2.dtk.dt.GenericDT;
import it.uniroma2.dtk.op.convolution.ShuffledCircularConvolution;
import it.uniroma2.tk.TreeKernel;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.tree.Tree;
import java.io.BufferedReader;
import java.io.FileReader;
import java.util.HashMap;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class QCTaskTester {
    public static void main(String[] args) throws Exception {
        QCTaskTester att = new QCTaskTester();
        att.run();
    }

    public void run() throws Exception {
        System.out.println("Loading training and testing set...");
        HashMap<Tree, String> trainExamples = this.loadExamples("/home/lorenzo/esperimenti/dtk/qc/data/smallTrain");
        HashMap<Tree, String> testExamples = this.loadExamples("/home/lorenzo/esperimenti/dtk/qc/data/smallTest");
        System.out.println(trainExamples.size() + " training examples, " + testExamples.size() + " testing examples");
        this.runDistributed(trainExamples, testExamples);
        this.runOriginal(trainExamples, testExamples);
    }

    private void runOriginal(HashMap<Tree, String> trainExamples, HashMap<Tree, String> testExamples) {
        HashMap<Tree, Integer> exampleIds = new HashMap<Tree, Integer>(trainExamples.size());
        int i = 0;
        for (Tree t : trainExamples.keySet()) {
            exampleIds.put(t, i);
            ++i;
        }
        System.out.println("Building parameters...");
        MultiClassificationSVM<String, Tree> svm = new MultiClassificationSVM<String, Tree>(new C_SVC());
        MultiClassProblemImpl<String, Tree> artificialProblem = new MultiClassProblemImpl<String, Tree>(String.class, new StringLabelInverter(), trainExamples, exampleIds, new NoopScalingModel());
        ImmutableSvmParameterPoint.Builder svmParamBuilder = new ImmutableSvmParameterPoint.Builder();
        TreeKernel.lambda = 0.4;
        svmParamBuilder.kernel = new TreeKernel();
        svmParamBuilder.eps = 0.001f;
        svmParamBuilder.cache_size = 2048.0f;
        svmParamBuilder.allVsAllMode = MultiClassModel.AllVsAllMode.None;
        svmParamBuilder.oneVsAllMode = MultiClassModel.OneVsAllMode.Best;
        System.out.println("Training...");
        long time = System.currentTimeMillis();
        MultiClassModel<String, Tree> model = svm.train((MultiClassProblem<String, Tree>)artificialProblem, (ImmutableSvmParameter<String, Tree>)svmParamBuilder.build());
        System.out.println((double)(System.currentTimeMillis() - time) / 1000.0 + " seconds");
        System.out.println("Testing...");
        double acc = 0.0;
        for (Tree t : testExamples.keySet()) {
            if (!((String)model.predictLabel((Object)t)).equals(testExamples.get(t))) continue;
            acc += 1.0;
        }
        System.out.println("Accuracy: " + acc / (double)testExamples.size());
    }

    private void runDistributed(HashMap<Tree, String> originalTrainExamples, HashMap<Tree, String> originalTestExamples) throws Exception {
        System.out.println("Running distributed experiment...");
        GenericDT dt = new GenericDT(0, 4096, 0.4, ShuffledCircularConvolution.class);
        HashMap<double[], String> trainExamples = new HashMap<double[], String>(originalTrainExamples.size());
        HashMap<double[], String> testExamples = new HashMap<double[], String>(originalTestExamples.size());
        long time = System.currentTimeMillis();
        System.out.println("Distributing training trees...");
        for (Tree t : originalTrainExamples.keySet()) {
            trainExamples.put(dt.dt(t), originalTrainExamples.get(t));
        }
        System.out.println("Distributing testing trees...");
        for (Tree t : originalTestExamples.keySet()) {
            testExamples.put(dt.dt(t), originalTestExamples.get(t));
        }
        System.out.println((double)(System.currentTimeMillis() - time) / 1000.0 + " seconds");
        HashMap<double[], Integer> exampleIds = new HashMap<double[], Integer>(trainExamples.size());
        int i = 0;
        for (double[] t : trainExamples.keySet()) {
            exampleIds.put(t, i);
            ++i;
        }
        System.out.println("Building parameters...");
        MultiClassificationSVM svm = new MultiClassificationSVM(new C_SVC());
        MultiClassProblemImpl artificialProblem = new MultiClassProblemImpl(String.class, new StringLabelInverter(), trainExamples, exampleIds, new NoopScalingModel());
        ImmutableSvmParameterPoint.Builder svmParamBuilder = new ImmutableSvmParameterPoint.Builder();
        svmParamBuilder.kernel = new KernelFunction<double[]>(){

            @Override
            public double evaluate(double[] x, double[] y) {
                try {
                    return ArrayMath.dot(x, y);
                }
                catch (Exception e) {
                    e.printStackTrace();
                    return 0.0;
                }
            }
        };
        svmParamBuilder.eps = 0.001f;
        svmParamBuilder.cache_size = 2048.0f;
        svmParamBuilder.allVsAllMode = MultiClassModel.AllVsAllMode.None;
        svmParamBuilder.oneVsAllMode = MultiClassModel.OneVsAllMode.Best;
        svmParamBuilder.oneVsAllThreshold = 0.1;
        System.out.println("Training...");
        time = System.currentTimeMillis();
        MultiClassModel model = svm.train(artificialProblem, svmParamBuilder.build());
        System.out.println((double)(System.currentTimeMillis() - time) / 1000.0 + " seconds");
        System.out.println("Testing...");
        double acc = 0.0;
        for (double[] t : testExamples.keySet()) {
            try {
                if (!((String)model.predictLabel(t)).equals(testExamples.get(t))) continue;
                acc += 1.0;
            }
            catch (Exception e) {
                System.out.println("No good label found!");
            }
        }
        System.out.println("Accuracy: " + acc);
    }

    private HashMap<Tree, String> loadExamples(String file) throws Exception {
        String line;
        HashMap<Tree, String> examples = new HashMap<Tree, String>();
        BufferedReader in = new BufferedReader(new FileReader(file));
        while ((line = in.readLine()) != null) {
            String label = line.substring(0, line.indexOf("|BT|")).trim();
            Tree t = Tree.fromPennTree(line.substring(line.indexOf("|BT|") + 4, line.indexOf("|ET|")).trim());
            examples.put(t, label);
        }
        in.close();
        return examples;
    }
}

