package edu.berkeley.nlp.sequence.stationary;

import edu.berkeley.nlp.math.DoubleArrays;
import java.util.Arrays;

/* loaded from: input_file:edu/berkeley/nlp/sequence/stationary/StationaryForwardBackward.class */
public class StationaryForwardBackward {
    boolean DEBUG = true;
    int MAX_LENGTH;
    int numStates;
    double[][] alphas;
    double[][] betas;
    StationarySequenceModel seqModel;
    int obsLength;
    double normConstant;
    double[][] nodePotentials;
    double[] alphaScalingFactors;
    double[] betaScalingFactors;
    double[][] edgePosteriors;
    double[][] nodePosteriors;
    int[][] allowableForwardTransitions;
    int[][] allowableBackwardTransitions;
    private final double[][] edgeForwardPotentials;
    private final double[][] edgeBackwardPotentials;
    public static final double SCALE;
    public static final double LOG_SCALE;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* JADX WARN: Type inference failed for: r1v34, types: [double[], double[][]] */
    public StationaryForwardBackward(StationarySequenceModel stationarySequenceModel) {
        this.seqModel = stationarySequenceModel;
        this.numStates = stationarySequenceModel.getNumStates();
        this.MAX_LENGTH = stationarySequenceModel.getMaximumSequenceLength();
        this.alphas = new double[this.MAX_LENGTH][this.numStates];
        this.betas = new double[this.MAX_LENGTH][this.numStates];
        this.alphaScalingFactors = new double[this.MAX_LENGTH];
        this.betaScalingFactors = new double[this.MAX_LENGTH];
        this.nodePotentials = new double[this.MAX_LENGTH][this.numStates];
        this.nodePosteriors = new double[this.MAX_LENGTH][this.numStates];
        this.allowableForwardTransitions = stationarySequenceModel.getAllowableForwardTransitions();
        this.allowableBackwardTransitions = stationarySequenceModel.getAllowableBackwardTransitions();
        this.edgeForwardPotentials = stationarySequenceModel.getForwardEdgePotentials();
        this.edgeBackwardPotentials = stationarySequenceModel.getBackwardEdgePotentials();
        this.edgePosteriors = new double[this.numStates];
        for (int i = 0; i < this.numStates; i++) {
            this.edgePosteriors[i] = new double[this.allowableForwardTransitions[i].length];
        }
    }

    public void setInput(StationarySequenceInstance stationarySequenceInstance) {
        setInput(stationarySequenceInstance, false);
    }

    public void setInput(StationarySequenceInstance stationarySequenceInstance, boolean z) {
        if (!$assertionsDisabled && stationarySequenceInstance.getSequenceLength() >= this.seqModel.getMaximumSequenceLength()) {
            throw new AssertionError();
        }
        this.obsLength = stationarySequenceInstance.getSequenceLength();
        this.normConstant = 0.0d;
        clearArrays();
        stationarySequenceInstance.fillNodePotentials(this.nodePotentials);
        forwardPass();
        if (z) {
            return;
        }
        backwardPass();
        computePosteriors();
    }

    private void clearArrays() {
        for (int i = 0; i < this.obsLength; i++) {
            Arrays.fill(this.nodePotentials[i], 0.0d);
            Arrays.fill(this.nodePosteriors[i], 0.0d);
        }
        for (int i2 = 0; i2 < this.numStates; i2++) {
            Arrays.fill(this.edgePosteriors[i2], 0.0d);
        }
        Arrays.fill(this.alphaScalingFactors, 0.0d);
        Arrays.fill(this.betaScalingFactors, 0.0d);
    }

    private static double getScaleFactor(double d) {
        if (d == 0.0d) {
            return 1.0d;
        }
        return d == 1.0d ? SCALE : d == 2.0d ? SCALE * SCALE : d == 3.0d ? SCALE * SCALE * SCALE : d == -1.0d ? 1.0d / SCALE : d == -2.0d ? (1.0d / SCALE) / SCALE : d == -3.0d ? ((1.0d / SCALE) / SCALE) / SCALE : Math.pow(SCALE, d);
    }

    public double getLogNormalizationConstant() {
        return this.normConstant;
    }

    public double[][] getNodeMarginals() {
        return this.nodePosteriors;
    }

    public double[][] getEdgeMarginalSums() {
        return this.edgePosteriors;
    }

    public int[] viterbiDecode() {
        int[][] iArr = new int[this.obsLength][this.numStates];
        for (int i = 1; i < this.obsLength; i++) {
            for (int i2 = 0; i2 < this.numStates; i2++) {
                int i3 = -1;
                double d = Double.NEGATIVE_INFINITY;
                int[] iArr2 = this.allowableBackwardTransitions[i2];
                for (int i4 = 0; i4 < iArr2.length; i4++) {
                    int i5 = iArr2[i4];
                    double d2 = this.alphas[i - 1][i5] * this.edgeBackwardPotentials[i2][i4] * this.nodePotentials[i][i2];
                    if (d2 > d) {
                        i3 = i5;
                        d = d2;
                    }
                }
                iArr[i][i2] = i3;
            }
        }
        int[] iArr3 = new int[this.obsLength];
        iArr3[this.obsLength - 1] = DoubleArrays.argMax(this.alphas[this.obsLength - 1]);
        for (int i6 = this.obsLength - 2; i6 >= 0; i6--) {
            iArr3[i6] = iArr[i6 + 1][iArr3[i6 + 1]];
        }
        return iArr3;
    }

    public int[] nodePosteriorDecode() {
        int[] iArr = new int[this.obsLength];
        for (int i = 0; i < this.obsLength; i++) {
            iArr[i] = DoubleArrays.argMax(this.nodePosteriors[i]);
        }
        return iArr;
    }

    private void forwardPass() {
        for (int i = 0; i < this.obsLength; i++) {
            double d = Double.NEGATIVE_INFINITY;
            if (i == 0) {
                for (int i2 = 0; i2 < this.numStates; i2++) {
                    this.alphas[i][i2] = this.nodePotentials[i][i2];
                    if (this.alphas[i][i2] > d) {
                        d = this.alphas[i][i2];
                    }
                }
            } else {
                for (int i3 = 0; i3 < this.numStates; i3++) {
                    double d2 = 0.0d;
                    if (this.nodePotentials[i][i3] > 0.0d) {
                        int[] iArr = this.allowableBackwardTransitions[i3];
                        for (int i4 = 0; i4 < iArr.length; i4++) {
                            d2 += this.alphas[i - 1][iArr[i4]] * this.edgeBackwardPotentials[i3][i4];
                        }
                        d2 *= this.nodePotentials[i][i3];
                    }
                    if (d2 > d) {
                        d = d2;
                    }
                    this.alphas[i][i3] = d2;
                }
            }
            if (d == 0.0d || Double.isInfinite(d)) {
                throw new RuntimeException(String.format("The alphas[%d] has max=%.3f", Integer.valueOf(i), Double.valueOf(d)));
            }
            int i5 = 0;
            double d3 = 1.0d;
            while (d > SCALE) {
                d /= SCALE;
                d3 *= SCALE;
                i5++;
            }
            while (d > 0.0d && d < 1.0d / SCALE) {
                d *= SCALE;
                d3 /= SCALE;
                i5--;
            }
            if (i5 != 0) {
                for (int i6 = 0; i6 < this.numStates; i6++) {
                    double[] dArr = this.alphas[i];
                    int i7 = i6;
                    dArr[i7] = dArr[i7] / d3;
                }
            }
            if (i == 0) {
                this.alphaScalingFactors[i] = i5;
            } else {
                this.alphaScalingFactors[i] = this.alphaScalingFactors[i - 1] + i5;
            }
        }
    }

    private void backwardPass() {
        for (int i = this.obsLength - 1; i >= 0; i--) {
            double d = 0.0d;
            if (i == this.obsLength - 1) {
                for (int i2 = 0; i2 < this.numStates; i2++) {
                    this.betas[i][i2] = this.nodePotentials[i][i2];
                    if (this.betas[i][i2] > d) {
                        d = this.betas[i][i2];
                    }
                }
            } else {
                for (int i3 = 0; i3 < this.numStates; i3++) {
                    double d2 = 0.0d;
                    if (this.nodePotentials[i][i3] > 0.0d) {
                        int[] iArr = this.allowableForwardTransitions[i3];
                        for (int i4 = 0; i4 < iArr.length; i4++) {
                            d2 += this.edgeForwardPotentials[i3][i4] * this.betas[i + 1][iArr[i4]];
                        }
                        d2 *= this.nodePotentials[i][i3];
                    }
                    if (d2 > d) {
                        d = d2;
                    }
                    this.betas[i][i3] = d2;
                }
            }
            int i5 = 0;
            double d3 = 1.0d;
            while (d > SCALE) {
                d /= SCALE;
                d3 *= SCALE;
                i5++;
            }
            while (d > 0.0d && d < 1.0d / SCALE) {
                d *= SCALE;
                d3 /= SCALE;
                i5--;
            }
            if (i5 != 0) {
                for (int i6 = 0; i6 < this.numStates; i6++) {
                    double[] dArr = this.betas[i];
                    int i7 = i6;
                    dArr[i7] = dArr[i7] / d3;
                }
            }
            if (i == this.obsLength - 1) {
                this.betaScalingFactors[i] = i5;
            } else {
                this.betaScalingFactors[i] = this.betaScalingFactors[i + 1] + i5;
            }
        }
    }

    public double[][] getAlphas() {
        return this.alphas;
    }

    public double[][] getBetas() {
        return this.betas;
    }

    private void computePosteriors() {
        double add = DoubleArrays.add(this.alphas[this.obsLength - 1]);
        double d = this.alphaScalingFactors[this.obsLength - 1];
        if (add == 0.0d) {
            throw new RuntimeException("Forward-Backward: No non-zero label sequences");
        }
        for (int i = 0; i + 1 < this.obsLength; i++) {
            double[] dArr = this.betas[i + 1];
            double d2 = (this.alphaScalingFactors[i] + this.betaScalingFactors[i + 1]) - d;
            double scaleFactor = getScaleFactor(d2);
            if (!$assertionsDisabled && Math.abs(d2) > 3.0d) {
                throw new AssertionError("Exp scale is " + d2);
            }
            for (int i2 = 0; i2 < this.numStates; i2++) {
                int[] iArr = this.allowableForwardTransitions[i2];
                double d3 = this.alphas[i][i2];
                if (d3 != 0.0d) {
                    double d4 = 0.0d;
                    double d5 = d3 / add;
                    for (int i3 = 0; i3 < iArr.length; i3++) {
                        double d6 = dArr[iArr[i3]];
                        if (d6 != 0.0d) {
                            double d7 = d5 * this.edgeForwardPotentials[i2][i3] * d6 * scaleFactor;
                            double[] dArr2 = this.edgePosteriors[i2];
                            int i4 = i3;
                            dArr2[i4] = dArr2[i4] + d7;
                            d4 += d7;
                        }
                    }
                    this.nodePosteriors[i][i2] = d4;
                }
            }
        }
        for (int i5 = 0; i5 < this.numStates; i5++) {
            double d8 = this.alphas[this.obsLength - 1][i5];
            if (d8 != 0.0d) {
                this.nodePosteriors[this.obsLength - 1][i5] = ((d8 * this.betas[this.obsLength - 1][i5]) / (add * this.nodePotentials[this.obsLength - 1][i5])) * getScaleFactor((this.alphaScalingFactors[this.obsLength - 1] + this.betaScalingFactors[this.obsLength - 1]) - d);
            }
        }
        this.normConstant = (d * LOG_SCALE) + Math.log(add);
        if (this.DEBUG) {
            probCheck();
        }
    }

    private double relativeDiff(double d, double d2) {
        return Math.abs(d - d2) / Math.max(d, d2);
    }

    private void probCheck() {
        for (int i = 0; i < this.obsLength; i++) {
            double d = 0.0d;
            for (int i2 = 0; i2 < this.nodePosteriors[i].length; i2++) {
                d += this.nodePosteriors[i][i2];
            }
            if (relativeDiff(1.0d, d) > 0.01d) {
                throw new RuntimeException("Node Sum: " + d + " not 1.0 for " + i);
            }
        }
        double d2 = 0.0d;
        for (int i3 = 0; i3 < this.numStates; i3++) {
            int[] iArr = this.allowableForwardTransitions[i3];
            for (int i4 = 0; i4 < this.edgePosteriors[i3].length; i4++) {
                d2 += this.edgePosteriors[i3][i4];
            }
        }
        if (relativeDiff(this.obsLength - 1.0d, d2) > 0.01d) {
            throw new RuntimeException("Failed ProbCheck: Edge Sum: " + d2 + " for " + this.obsLength);
        }
    }

    static {
        $assertionsDisabled = !StationaryForwardBackward.class.desiredAssertionStatus();
        SCALE = Math.exp(100.0d);
        LOG_SCALE = Math.log(SCALE);
    }
}
