package edu.berkeley.nlp.math;

import edu.berkeley.nlp.util.CallbackFunction;
import edu.berkeley.nlp.util.Logger;
import java.io.Serializable;
import java.util.LinkedList;

/* loaded from: input_file:edu/berkeley/nlp/math/LBFGSMinimizer.class */
public class LBFGSMinimizer implements GradientMinimizer, Serializable {
    private static final long serialVersionUID = 36473897808840226L;
    double EPS;
    int maxIterations;
    int maxHistorySize;
    LinkedList<double[]> inputDifferenceVectorList;
    LinkedList<double[]> derivativeDifferenceVectorList;
    transient CallbackFunction iterCallbackFunction;
    int minIterations;
    double initialStepSizeMultiplier;
    double stepSizeMultiplier;

    /* loaded from: input_file:edu/berkeley/nlp/math/LBFGSMinimizer$IterationCallbackFunction.class */
    public interface IterationCallbackFunction {
        void iterationDone(double[] dArr, int i);
    }

    public void setMinIteratons(int i) {
        this.minIterations = i;
    }

    public void setMaxIterations(int i) {
        this.maxIterations = i;
    }

    public void setInitialStepSizeMultiplier(double d) {
        this.initialStepSizeMultiplier = d;
    }

    public void setStepSizeMultiplier(double d) {
        this.stepSizeMultiplier = d;
    }

    public double[] getSearchDirection(int i, double[] dArr) {
        return implicitMultiply(getInitialInverseHessianDiagonal(i), dArr);
    }

    protected double[] getInitialInverseHessianDiagonal(int i) {
        double d = 1.0d;
        if (this.derivativeDifferenceVectorList.size() >= 1) {
            double[] lastDerivativeDifference = getLastDerivativeDifference();
            d = DoubleArrays.innerProduct(lastDerivativeDifference, getLastInputDifference()) / DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
        }
        return DoubleArrays.constantArray(d, i);
    }

    @Override // edu.berkeley.nlp.math.GradientMinimizer
    public double[] minimize(DifferentiableFunction differentiableFunction, double[] dArr, double d) {
        return minimize(differentiableFunction, dArr, d, false);
    }

    @Override // edu.berkeley.nlp.math.GradientMinimizer
    public double[] minimize(DifferentiableFunction differentiableFunction, double[] dArr, double d, boolean z) {
        BacktrackingLineSearcher backtrackingLineSearcher = new BacktrackingLineSearcher();
        double[] clone = DoubleArrays.clone(dArr);
        for (int i = 0; i < this.maxIterations; i++) {
            double valueAt = differentiableFunction.valueAt(clone);
            double[] derivativeAt = differentiableFunction.derivativeAt(clone);
            double[] implicitMultiply = implicitMultiply(getInitialInverseHessianDiagonal(differentiableFunction), derivativeAt);
            DoubleArrays.scale(implicitMultiply, -1.0d);
            if (i == 0) {
                backtrackingLineSearcher.stepSizeMultiplier = this.initialStepSizeMultiplier;
            } else {
                backtrackingLineSearcher.stepSizeMultiplier = this.stepSizeMultiplier;
            }
            double[] minimize = backtrackingLineSearcher.minimize(differentiableFunction, clone, implicitMultiply);
            double valueAt2 = differentiableFunction.valueAt(minimize);
            double[] derivativeAt2 = differentiableFunction.derivativeAt(minimize);
            if (z) {
                printProgress(i, valueAt2);
            }
            if (i >= this.minIterations && converged(valueAt, valueAt2, d)) {
                return minimize;
            }
            updateHistories(clone, minimize, derivativeAt, derivativeAt2);
            clone = minimize;
            if (this.iterCallbackFunction != null) {
                this.iterCallbackFunction.callback(clone, Integer.valueOf(i), Double.valueOf(valueAt2), derivativeAt2);
            }
        }
        return clone;
    }

    private void printProgress(int i, double d) {
        Logger.logs(String.format("[LBFGSMinimizer.minimize] Iteration %d ended with value %.6f\n", Integer.valueOf(i), Double.valueOf(d)), new Object[0]);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean converged(double d, double d2, double d3) {
        return d == d2 || Math.abs(d2 - d) / (Math.abs((d2 + d) + this.EPS) / 2.0d) < d3;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void updateHistories(double[] dArr, double[] dArr2, double[] dArr3, double[] dArr4) {
        double[] addMultiples = DoubleArrays.addMultiples(dArr2, 1.0d, dArr, -1.0d);
        double[] addMultiples2 = DoubleArrays.addMultiples(dArr4, 1.0d, dArr3, -1.0d);
        pushOntoList(addMultiples, this.inputDifferenceVectorList);
        pushOntoList(addMultiples2, this.derivativeDifferenceVectorList);
    }

    private void pushOntoList(double[] dArr, LinkedList<double[]> linkedList) {
        linkedList.addFirst(dArr);
        if (linkedList.size() > this.maxHistorySize) {
            linkedList.removeLast();
        }
    }

    private int historySize() {
        return this.inputDifferenceVectorList.size();
    }

    public void setMaxHistorySize(int i) {
        this.maxHistorySize = i;
    }

    private double[] getInputDifference(int i) {
        return this.inputDifferenceVectorList.get(i);
    }

    private double[] getDerivativeDifference(int i) {
        return this.derivativeDifferenceVectorList.get(i);
    }

    private double[] getLastDerivativeDifference() {
        return this.derivativeDifferenceVectorList.getFirst();
    }

    private double[] getLastInputDifference() {
        return this.inputDifferenceVectorList.getFirst();
    }

    private double[] implicitMultiply(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[historySize()];
        double[] dArr4 = new double[historySize()];
        double[] clone = DoubleArrays.clone(dArr2);
        for (int historySize = historySize() - 1; historySize >= 0; historySize--) {
            double[] inputDifference = getInputDifference(historySize);
            double[] derivativeDifference = getDerivativeDifference(historySize);
            dArr3[historySize] = DoubleArrays.innerProduct(inputDifference, derivativeDifference);
            if (dArr3[historySize] == 0.0d) {
                throw new RuntimeException("LBFGSMinimizer.implicitMultiply: Curvature problem.");
            }
            dArr4[historySize] = DoubleArrays.innerProduct(inputDifference, clone) / dArr3[historySize];
            clone = DoubleArrays.addMultiples(clone, 1.0d, derivativeDifference, (-1.0d) * dArr4[historySize]);
        }
        double[] pointwiseMultiply = DoubleArrays.pointwiseMultiply(dArr, clone);
        for (int i = 0; i < historySize(); i++) {
            pointwiseMultiply = DoubleArrays.addMultiples(pointwiseMultiply, 1.0d, getInputDifference(i), dArr4[i] - (DoubleArrays.innerProduct(getDerivativeDifference(i), pointwiseMultiply) / dArr3[i]));
        }
        return pointwiseMultiply;
    }

    private double[] getInitialInverseHessianDiagonal(DifferentiableFunction differentiableFunction) {
        double d = 1.0d;
        if (this.derivativeDifferenceVectorList.size() >= 1) {
            double[] lastDerivativeDifference = getLastDerivativeDifference();
            d = DoubleArrays.innerProduct(lastDerivativeDifference, getLastInputDifference()) / DoubleArrays.innerProduct(lastDerivativeDifference, lastDerivativeDifference);
        }
        return DoubleArrays.constantArray(d, differentiableFunction.dimension());
    }

    public void setIterationCallbackFunction(CallbackFunction callbackFunction) {
        this.iterCallbackFunction = callbackFunction;
    }

    public LBFGSMinimizer() {
        this.EPS = 1.0E-10d;
        this.maxIterations = 20;
        this.maxHistorySize = 5;
        this.inputDifferenceVectorList = new LinkedList<>();
        this.derivativeDifferenceVectorList = new LinkedList<>();
        this.iterCallbackFunction = null;
        this.minIterations = -1;
        this.initialStepSizeMultiplier = 0.01d;
        this.stepSizeMultiplier = 0.5d;
    }

    public LBFGSMinimizer(int i) {
        this.EPS = 1.0E-10d;
        this.maxIterations = 20;
        this.maxHistorySize = 5;
        this.inputDifferenceVectorList = new LinkedList<>();
        this.derivativeDifferenceVectorList = new LinkedList<>();
        this.iterCallbackFunction = null;
        this.minIterations = -1;
        this.initialStepSizeMultiplier = 0.01d;
        this.stepSizeMultiplier = 0.5d;
        this.maxIterations = i;
    }
}
