/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.jlibsvm.multi;

import com.davidsoergel.conja.Function;
import com.davidsoergel.conja.Parallel;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.DistributionException;
import com.davidsoergel.trees.htpn.HierarchicalTypedPropertyNode;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM;
import edu.berkeley.compbio.jlibsvm.multi.BatchClusterLabelInverter;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassModel;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassProblem;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassProblemImpl;
import edu.berkeley.compbio.jlibsvm.multi.MultiClassificationSVM;
import edu.berkeley.compbio.jlibsvm.multi.VotingResult;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.ml.cluster.AbstractClusteringMethod;
import edu.berkeley.compbio.ml.cluster.BatchCluster;
import edu.berkeley.compbio.ml.cluster.BatchClusteringMethod;
import edu.berkeley.compbio.ml.cluster.ClusterException;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.Clusterable;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.ClusteringTestResults;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.PointClusterFilter;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import edu.berkeley.compbio.ml.cluster.SupervisedClusteringMethod;
import java.io.Serializable;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class MultiClassificationSVMAdapter<T extends Clusterable<T>>
extends AbstractClusteringMethod<T, BatchCluster<T>>
implements BatchClusteringMethod<T>,
SupervisedClusteringMethod<T> {
    private static final Logger logger = Logger.getLogger(MultiClassificationSVMAdapter.class);
    final ImmutableSvmParameter<BatchCluster<T>, T> param;
    final Map<T, BatchCluster<T>> examples = new HashMap<T, BatchCluster<T>>();
    final Map<T, Integer> exampleIds = new HashMap<T, Integer>();
    Map<String, BatchCluster<T>> theClusterMap;
    private MultiClassModel<BatchCluster<T>, T> model;
    private BinaryClassificationSVM<BatchCluster<T>, T> binarySvm;
    final AtomicInteger trainingCount = new AtomicInteger(0);

    public MultiClassificationSVMAdapter(Set<String> potentialTrainingBins, Map<String, Set<String>> predictLabelSets, ProhibitionModel<T> prohibitionModel, Set<String> testLabels, @NotNull ImmutableSvmParameter<BatchCluster<T>, T> param) {
        super(null, potentialTrainingBins, predictLabelSets, prohibitionModel, testLabels);
        this.param = param;
    }

    public void setBinarySvm(BinaryClassificationSVM<BatchCluster<T>, T> binarySvm) {
        this.binarySvm = binarySvm;
    }

    public void addAll(ClusterableIterator<T> trainingIterator) {
        Parallel.forEach(trainingIterator, new Function<T, Void>(){

            @Override
            public Void apply(@Nullable T sample) {
                MultiClassificationSVMAdapter.this.add(sample);
                return null;
            }
        });
        logger.info("Prepared " + this.trainingCount + " training samples");
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void add(T sample) {
        String label = (String)sample.getImmutableWeightedLabels().getDominantKeyInSet(this.potentialTrainingBins);
        BatchCluster<T> cluster = this.theClusterMap.get(label);
        cluster.add(sample);
        AtomicInteger atomicInteger = this.trainingCount;
        synchronized (atomicInteger) {
            this.examples.put(sample, cluster);
            this.exampleIds.put(sample, this.trainingCount.intValue());
            this.trainingCount.incrementAndGet();
        }
    }

    @Override
    public void createClusters() {
        this.theClusterMap = new HashMap<String, BatchCluster<T>>(this.potentialTrainingBins.size());
        int i = 0;
        for (String label : this.potentialTrainingBins) {
            BatchCluster<T> cluster = this.theClusterMap.get(label);
            if (cluster != null) continue;
            BatchCluster cluster1 = new BatchCluster(i++);
            this.theClusterMap.put(label, cluster1);
            this.addCluster(cluster1);
        }
    }

    @Override
    public void train() {
        MultiClassificationSVM<BatchCluster<T>, T> svm = new MultiClassificationSVM<BatchCluster<T>, T>(this.binarySvm);
        MultiClassProblemImpl<BatchCluster<T>, T> problem = new MultiClassProblemImpl<BatchCluster<T>, T>(BatchCluster.class, new BatchClusterLabelInverter(), this.examples, this.exampleIds, new NoopScalingModel());
        logger.debug("Performing multiclass training");
        this.model = svm.train((MultiClassProblem<BatchCluster<T>, T>)problem, this.param);
        this.removeEmptyClusters();
        this.normalizeClusterLabelProbabilities();
    }

    public void putResults(HierarchicalTypedPropertyNode<String, Serializable, ?> innerResults) {
    }

    private MultiClassModel<BatchCluster<T>, T> makeMultiClassModelWithProhibition(@Nullable T p) {
        HashSet<BatchCluster<T>> disallowedClusters = new HashSet<BatchCluster<T>>();
        PointClusterFilter<T> clusterFilter = this.prohibitionModel == null ? null : this.prohibitionModel.getFilter(p);
        for (BatchCluster<T> cluster : this.model.getLabels()) {
            if (clusterFilter == null || !clusterFilter.isProhibited(cluster)) continue;
            disallowedClusters.add(cluster);
        }
        return new MultiClassModel<BatchCluster<T>, T>(this.model, disallowedClusters);
    }

    @Override
    public synchronized ClusteringTestResults test(ClusterableIterator<T> theTestIterator, DissimilarityMeasure<String> intraLabelDistances) throws DistributionException, ClusterException {
        ClusteringTestResults result = super.test(theTestIterator, intraLabelDistances);
        result.setInfo(this.model.getInfo());
        return result;
    }

    @Override
    public ClusterMove<T, BatchCluster<T>> bestClusterMove(T p) throws NoGoodClusterException {
        MultiClassModel<BatchCluster<T>, T> leaveOneOutModel = this.model;
        if (this.prohibitionModel != null) {
            try {
                leaveOneOutModel = this.makeMultiClassModelWithProhibition(p);
            }
            catch (NoSuchElementException e) {
                // empty catch block
            }
        }
        VotingResult<BatchCluster<T>> r = leaveOneOutModel.predictLabelWithQuality(p);
        ClusterMove result = new ClusterMove();
        result.bestCluster = r.getBestLabel();
        result.voteProportion = r.getBestVoteProportion();
        result.secondBestVoteProportion = r.getSecondBestVoteProportion();
        result.bestDistance = r.getBestOneVsAllProbability();
        result.secondBestDistance = r.getSecondBestOneVsAllProbability();
        if (result.bestCluster == null) {
            throw new NoGoodClusterException();
        }
        return result;
    }
}

