/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.jlibsvm.binary;

import edu.berkeley.compbio.jlibsvm.ContinuousModel;
import edu.berkeley.compbio.jlibsvm.DiscreteModel;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.LabelParser;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.binary.AlphaModel;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationProblem;
import edu.berkeley.compbio.jlibsvm.binary.SvmBinaryCrossValidationResults;
import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import edu.berkeley.compbio.jlibsvm.scaler.NoopScalingModel;
import edu.berkeley.compbio.jlibsvm.scaler.ScalingModel;
import edu.berkeley.compbio.ml.CrossValidationResults;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Collection;
import java.util.Properties;
import java.util.StringTokenizer;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class BinaryModel<L extends Comparable, P>
extends AlphaModel<L, P>
implements DiscreteModel<L, P>,
ContinuousModel<P> {
    public ImmutableSvmParameterPoint<L, P> param;
    private static final Logger logger = Logger.getLogger(BinaryModel.class);
    public float obj;
    public float upperBoundPositive;
    public float upperBoundNegative;
    public ScalingModel<P> scalingModel = new NoopScalingModel();
    public float r;
    public SvmBinaryCrossValidationResults<L, P> crossValidationResults;
    L trueLabel;
    L falseLabel;

    @Override
    public CrossValidationResults getCrossValidationResults() {
        return this.crossValidationResults;
    }

    @Override
    public Collection<L> getLabels() {
        return this.param.getLabels();
    }

    public BinaryModel() {
    }

    public BinaryModel(Properties props, LabelParser<L> labelParser) {
        ImmutableSvmParameterPoint.Builder builder = new ImmutableSvmParameterPoint.Builder();
        try {
            builder.kernel = (KernelFunction)Class.forName(props.getProperty("kernel_type")).getConstructor(Properties.class).newInstance(props);
        }
        catch (Throwable e) {
            throw new SvmException(e);
        }
        StringTokenizer st = new StringTokenizer(props.getProperty("label"));
        while (st.hasMoreTokens()) {
            builder.putWeight((Comparable)labelParser.parse(st.nextToken()), null);
        }
        this.param = builder.build();
    }

    public BinaryModel(ImmutableSvmParameterPoint<L, P> param) {
        this.param = param;
    }

    public L getFalseLabel() {
        return this.falseLabel;
    }

    @NotNull
    public ScalingModel<P> getScalingModel() {
        return this.scalingModel;
    }

    public void setScalingModel(@NotNull ScalingModel<P> scalingModel) {
        this.scalingModel = scalingModel;
    }

    public L getTrueLabel() {
        return this.trueLabel;
    }

    @Override
    public L predictLabel(P x) {
        return this.predictValue(x).floatValue() > 0.0f ? this.trueLabel : this.falseLabel;
    }

    public float getSumAlpha() {
        float result = 0.0f;
        for (Double aFloat : this.supportVectors.values()) {
            result = (float)((double)result + aFloat);
        }
        return result;
    }

    public float getTrueProbability(P x) {
        return this.crossValidationResults.sigmoid.predict(this.predictValue(x).floatValue());
    }

    public float getProbability(P x, L l) {
        if (l.equals(this.trueLabel)) {
            return this.getTrueProbability(x);
        }
        if (l.equals(this.falseLabel)) {
            return 1.0f - this.getTrueProbability(x);
        }
        throw new SvmException("Can't compute probability: " + l + " is not one of the classes in this binary model (" + this.trueLabel + ", " + this.falseLabel + ")");
    }

    @Override
    public Float predictValue(P x) {
        float sum = 0.0f;
        P scaledX = this.scalingModel.scaledCopy(x);
        for (int i = 0; i < this.numSVs; ++i) {
            float kvalue = (float)this.param.kernel.evaluate(scaledX, this.SVs[i]);
            sum = (float)((double)sum + this.alphas[i] * (double)kvalue);
        }
        return Float.valueOf(sum -= this.rho);
    }

    public float getTrueProbability(float[] kvalues, int[] svIndexMap) {
        float pv = this.predictValue(kvalues, svIndexMap).floatValue();
        if (this.crossValidationResults == null) {
            logger.error("Can't compute probability in binary model without crossvalidationresults");
            return (double)pv > 0.0 ? 1.0f : 0.0f;
        }
        if (this.crossValidationResults.sigmoid == null) {
            logger.error("Can't compute probability in binary model without sigmoid");
            return (double)pv > 0.0 ? 1.0f : 0.0f;
        }
        return this.crossValidationResults.sigmoid.predict(pv);
    }

    public Float predictValue(float[] kvalues, int[] svIndexMap) {
        float sum = 0.0f;
        for (int i = 0; i < this.numSVs; ++i) {
            sum = (float)((double)sum + this.alphas[i] * (double)kvalues[svIndexMap[i]]);
        }
        return Float.valueOf(sum -= this.rho);
    }

    public L predictLabel(float[] kvalues, int[] svIndexMap) {
        return this.predictValue(kvalues, svIndexMap).floatValue() > 0.0f ? this.trueLabel : this.falseLabel;
    }

    public void printSolutionInfo(BinaryClassificationProblem<L, P> problem) {
        if (logger.isDebugEnabled()) {
            logger.debug("obj = " + this.obj + ", rho = " + this.rho);
            int nBSV = 0;
            for (int i = 0; i < this.numSVs; ++i) {
                Double alpha = this.alphas[i];
                Object point2 = this.SVs[i];
                if (!(Math.abs(alpha) > 0.0)) continue;
                if (problem.getTargetValue(point2).equals(this.trueLabel)) {
                    if (!(Math.abs(alpha) >= (double)this.upperBoundPositive)) continue;
                    ++nBSV;
                    continue;
                }
                if (!(Math.abs(alpha) >= (double)this.upperBoundNegative)) continue;
                ++nBSV;
            }
            logger.debug("nSV = " + this.SVs.length + ", nBSV = " + nBSV);
        }
    }

    @Override
    public void writeToStream(DataOutputStream fp) throws IOException {
        super.writeToStream(fp);
        fp.writeBytes("nr_class 2\n");
        this.writeSupportVectors(fp);
        fp.close();
    }
}

