package edu.berkeley.nlp.bp;

import edu.berkeley.nlp.math.SloppyMath;

/* loaded from: input_file:edu/berkeley/nlp/bp/TripletFactorPotential.class */
public class TripletFactorPotential implements FactorPotential {
    private double[][][] potentials;

    public TripletFactorPotential(double[][][] dArr) {
        this.potentials = dArr;
    }

    @Override // edu.berkeley.nlp.bp.FactorPotential
    public void computeLogMessages(double[][] dArr, double[][] dArr2) {
        int length = dArr[0].length;
        int length2 = dArr[1].length;
        int length3 = dArr[2].length;
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                for (int i3 = 0; i3 < length3; i3++) {
                    int[] iArr = {i, i2, i3};
                    double d = this.potentials[i][i2][i3];
                    double d2 = 0.0d;
                    for (int i4 = 0; i4 < 3; i4++) {
                        d2 += dArr[i4][iArr[i4]];
                    }
                    for (int i5 = 0; i5 < 3; i5++) {
                        int i6 = iArr[i5];
                        dArr2[i5][i6] = SloppyMath.logAdd(dArr2[i5][i6], d + (d2 > Double.NEGATIVE_INFINITY ? d2 - dArr[i5][i6] : d2));
                    }
                }
            }
            for (int i7 = 0; i7 < 3; i7++) {
                SloppyMath.logNormalize(dArr2[i7]);
            }
        }
    }

    @Override // edu.berkeley.nlp.bp.FactorPotential
    public Object computeMarginal(double[][] dArr) {
        int length = dArr[0].length;
        int length2 = dArr[1].length;
        int length3 = dArr[2].length;
        double[][][] dArr2 = new double[length][length2][length3];
        double[] dArr3 = new double[length * length2 * length3];
        int i = 0;
        for (int i2 = 0; i2 < length; i2++) {
            for (int i3 = 0; i3 < length2; i3++) {
                for (int i4 = 0; i4 < length3; i4++) {
                    int[] iArr = {i2, i3, i4};
                    double d = this.potentials[i2][i3][i4];
                    for (int i5 = 0; i5 < 3; i5++) {
                        d += dArr[i5][iArr[i5]];
                    }
                    int i6 = i;
                    i++;
                    dArr3[i6] = d;
                    dArr2[i2][i3][i4] = d;
                }
            }
        }
        double logAdd = SloppyMath.logAdd(dArr3);
        for (int i7 = 0; i7 < length; i7++) {
            for (int i8 = 0; i8 < length2; i8++) {
                for (int i9 = 0; i9 < length3; i9++) {
                    dArr2[i7][i8][i9] = Math.exp(dArr2[i7][i8][i9] - logAdd);
                }
            }
        }
        return dArr2;
    }
}
