/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.jlibsvm.regression;

import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameter;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterGrid;
import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint;
import edu.berkeley.compbio.jlibsvm.SolutionVector;
import edu.berkeley.compbio.jlibsvm.SvmException;
import edu.berkeley.compbio.jlibsvm.qmatrix.BooleanInvertingKernelQMatrix;
import edu.berkeley.compbio.jlibsvm.regression.RegressionModel;
import edu.berkeley.compbio.jlibsvm.regression.RegressionProblem;
import edu.berkeley.compbio.jlibsvm.regression.RegressionSVM;
import edu.berkeley.compbio.jlibsvm.regression.RegressionSolverNu;
import java.util.ArrayList;
import java.util.Map;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class Nu_SVR<P, R extends RegressionProblem<P, R>>
extends RegressionSVM<P, R> {
    private static final Logger logger = Logger.getLogger(Nu_SVR.class);

    @Override
    public RegressionModel<P> train(R problem, @NotNull ImmutableSvmParameter<Float, P> param) {
        ((RegressionSVM)this).validateParam(param);
        if (param instanceof ImmutableSvmParameterGrid && param.gridsearchBinaryMachinesIndependently) {
            throw new SvmException("Can't do grid search without cross-validation, which is not implemented for regression SVMs.");
        }
        RegressionModel<P> result = this.trainScaled(problem, (ImmutableSvmParameterPoint)param);
        return result;
    }

    private RegressionModel<P> trainScaled(R problem, @NotNull ImmutableSvmParameterPoint<Float, P> param) {
        if (param.scalingModelLearner != null && param.scaleBinaryMachinesIndependently) {
            problem = (RegressionProblem)problem.getScaledCopy(param.scalingModelLearner);
        }
        float laplaceParameter = -1.0f;
        if (param.probability) {
            laplaceParameter = this.laplaceParameter(problem, param);
        }
        float sum = param.C * param.nu * (float)problem.getNumExamples() / 2.0f;
        ArrayList solutionVectors = new ArrayList();
        for (Map.Entry example : problem.getExamples().entrySet()) {
            float initAlpha = Math.min(sum, param.C);
            sum -= initAlpha;
            SolutionVector sv = new SolutionVector(example.getKey(), true, -((Float)example.getValue()).floatValue(), initAlpha);
            solutionVectors.add(sv);
            sv.id = problem.getId(example.getKey());
            sv = new SolutionVector(example.getKey(), false, ((Float)example.getValue()).floatValue(), initAlpha);
            solutionVectors.add(sv);
            sv.id = -problem.getId(example.getKey());
        }
        BooleanInvertingKernelQMatrix qMatrix = new BooleanInvertingKernelQMatrix(param.kernel, problem.getNumExamples(), param.getCacheRows());
        RegressionSolverNu s = new RegressionSolverNu(solutionVectors, qMatrix, param.C, param.eps, param.shrinking);
        RegressionModel model = s.solve();
        model.param = param;
        model.setSvmType(this.getSvmType());
        model.laplaceParameter = laplaceParameter;
        logger.info("epsilon = " + -model.r);
        model.compact();
        return model;
    }

    @Override
    public String getSvmType() {
        return "nu_svr";
    }

    @Override
    public void validateParam(@NotNull ImmutableSvmParameterPoint<Float, P> param) {
        super.validateParam(param);
        if (param.nu <= 0.0f || param.nu > 1.0f) {
            throw new SvmException("nu <= 0 or nu > 1");
        }
        if (param.C <= 0.0f) {
            throw new SvmException("C <= 0");
        }
    }
}

