/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous.cdi;

import dr.evomodel.treedatalikelihood.continuous.cdi.PrecisionType;
import dr.evomodel.treedatalikelihood.continuous.cdi.SafeMultivariateWithDriftIntegrator;
import dr.math.matrixAlgebra.missingData.MissingOps;
import java.util.Arrays;
import org.ejml.data.D1Matrix64F;
import org.ejml.data.DenseMatrix64F;
import org.ejml.ops.CommonOps;

public class SafeMultivariateDiagonalActualizedWithDriftIntegrator
extends SafeMultivariateWithDriftIntegrator {
    private static boolean DEBUG = false;
    private static final boolean TIMING = false;
    private double[] diagonal1mActualizations;
    double[] stationaryVariances;
    private double[] vectorDiagQdi;
    private double[] vectorDiagQdj;

    public SafeMultivariateDiagonalActualizedWithDriftIntegrator(PrecisionType precisionType, int n, int n2, int n3, int n4, int n5) {
        super(precisionType, n, n2, n3, n4, n5);
        this.allocateStorage();
        System.err.println("Trying SafeMultivariateDiagonalActualizedWithDriftIntegrator");
    }

    @Override
    public void getBranch1mActualization(int n, double[] dArray) {
        if (n == -1) {
            throw new RuntimeException("Not yet implemented");
        }
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait);
        System.arraycopy(this.diagonal1mActualizations, n * this.dimTrait, dArray, 0, this.dimTrait);
    }

    @Override
    public void getBranchActualization(int n, double[] dArray) {
        this.getBranch1mActualization(n, dArray);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.oneMinus(dArray);
    }

    @Override
    public void getBranchExpectation(double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4) {
        assert (dArray4 != null);
        assert (dArray4.length >= this.dimTrait);
        assert (dArray != null);
        assert (dArray.length >= this.dimTrait);
        assert (dArray2 != null);
        assert (dArray2.length >= this.dimTrait);
        assert (dArray3 != null);
        assert (dArray3.length >= this.dimTrait);
        for (int i = 0; i < this.dimTrait; ++i) {
            dArray4[i] = dArray[i] * dArray2[i] + dArray3[i];
        }
    }

    private void allocateStorage() {
        this.diagonal1mActualizations = new double[this.dimTrait * this.bufferCount];
        this.stationaryVariances = new double[this.dimProcess * this.dimProcess * this.diffusionCount];
        this.vectorDiagQdi = new double[this.dimTrait];
        this.vectorDiagQdj = new double[this.dimTrait];
    }

    @Override
    public void setDiffusionStationaryVariance(int n, double[] dArray, double[] dArray2) {
        assert (this.stationaryVariances != null);
        assert (this.dimProcess == dArray.length);
        int n2 = this.dimProcess * this.dimProcess;
        int n3 = n2 * n;
        double[] dArray3 = new double[n2];
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.scalingMatrix(dArray, dArray3);
        this.setStationaryVariance(n3, dArray3, n2, dArray2);
    }

    void setStationaryVariance(int n, double[] dArray, int n2, double[] dArray2) {
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.scaleInv(this.inverseDiffusions, n, dArray, this.stationaryVariances, n, n2);
    }

    static void scaleInv(double[] dArray, int n, double[] dArray2, double[] dArray3, int n2, int n3) {
        for (int i = 0; i < n3; ++i) {
            dArray3[n2 + i] = dArray[n + i] / dArray2[i];
        }
    }

    private static void scalingMatrix(double[] dArray, double[] dArray2) {
        int n = dArray.length;
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                dArray2[i * n + j] = dArray[i] + dArray[j];
            }
        }
    }

    @Override
    public void updateOrnsteinUhlenbeckDiffusionMatrices(int n, int[] nArray, double[] dArray, double[] dArray2, double[] dArray3, double[] dArray4, int n2) {
        int n3;
        int n4;
        double d;
        int n5;
        assert (this.diffusions != null);
        assert (nArray.length >= n2);
        assert (dArray.length >= n2);
        super.updateOrnsteinUhlenbeckDiffusionMatrices(n, nArray, dArray, dArray2, dArray3, dArray4, n2);
        if (DEBUG) {
            System.err.println("Matrices (safe with actualized drift):");
        }
        int n6 = this.dimTrait * this.dimTrait;
        int n7 = this.dimProcess * this.dimProcess;
        int n8 = n7 * n;
        for (n5 = 0; n5 < n2; ++n5) {
            d = dArray[n5];
            n4 = this.dimTrait * nArray[n5];
            n3 = this.dimTrait * n4;
            this.computeOUActualization(dArray3, dArray4, d, n4, n3);
        }
        for (n5 = 0; n5 < n2; ++n5) {
            d = dArray[n5];
            n4 = n6 * nArray[n5];
            n3 = this.dimTrait * nArray[n5];
            this.computeOUVarianceBranch(n8, n4, n3, d);
            SafeMultivariateDiagonalActualizedWithDriftIntegrator.invertVectorSymmPosDef(this.variances, this.precisions, n4, this.dimProcess);
        }
        assert (dArray2 != null);
        assert (this.displacements != null);
        assert (dArray2.length >= n2 * this.dimProcess);
        n5 = 0;
        for (int i = 0; i < n2; ++i) {
            int n9 = this.dimTrait * nArray[i];
            n4 = n6 * nArray[i];
            this.computeOUActualizedDisplacement(dArray2, n5, n4, n9);
            n5 += this.dimProcess;
        }
    }

    void computeOUActualization(double[] dArray, double[] dArray2, double d, int n, int n2) {
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.computeOUDiagonal1mActualization(dArray, d, this.dimTrait, this.diagonal1mActualizations, n);
    }

    static void computeOUDiagonal1mActualization(double[] dArray, double d, int n, double[] dArray2, int n2) {
        for (int i = 0; i < n; ++i) {
            dArray2[n2 + i] = -Math.expm1(-dArray[i] * d);
        }
    }

    void computeOUVarianceBranch(int n, int n2, int n3, double d) {
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.scalingActualizationMatrix(this.diagonal1mActualizations, n3, this.stationaryVariances, n, this.variances, n2, this.dimTrait, d, this.inverseDiffusions, n);
    }

    private static void scalingActualizationMatrix(double[] dArray, int n, double[] dArray2, int n2, double[] dArray3, int n3, int n4, double d, double[] dArray4, int n5) {
        for (int i = 0; i < n4; ++i) {
            for (int j = 0; j < n4; ++j) {
                double d2 = dArray2[n2 + i * n4 + j];
                dArray3[n3 + i * n4 + j] = Double.isInfinite(d2) || dArray[n + i] + dArray[n + j] == 0.0 ? d * dArray4[n5 + i * n4 + j] : d2 * (-dArray[n + i] * dArray[n + j] + dArray[n + i] + dArray[n + j]);
            }
        }
    }

    private static void invertVectorSymmPosDef(double[] dArray, double[] dArray2, int n, int n2) {
        DenseMatrix64F denseMatrix64F = MissingOps.wrap(dArray, n, n2, n2);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(n2, n2);
        MissingOps.symmPosDefInvert(denseMatrix64F, denseMatrix64F2);
        MissingOps.unwrap(denseMatrix64F2, dArray2, n);
    }

    void computeOUActualizedDisplacement(double[] dArray, int n, int n2, int n3) {
        for (int i = 0; i < this.dimTrait; ++i) {
            this.displacements[n3 + i] = dArray[n + i] * this.diagonal1mActualizations[n3 + i];
        }
    }

    @Override
    void actualizePrecision(DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, int n, int n2, int n3) {
        double[] dArray = this.vectorDiagQdj;
        System.arraycopy(this.diagonal1mActualizations, n3, dArray, 0, this.dimTrait);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.oneMinus(dArray);
        MissingOps.diagMult(dArray, denseMatrix64F, denseMatrix64F2);
        MissingOps.diagMult(denseMatrix64F2, dArray, denseMatrix64F);
    }

    @Override
    void actualizeVariance(DenseMatrix64F denseMatrix64F, int n, int n2, int n3) {
        double[] dArray = this.vectorDiagQdi;
        System.arraycopy(this.diagonal1mActualizations, n3, dArray, 0, this.dimTrait);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.oneMinus(dArray);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.diagonalDoubleProduct(denseMatrix64F, dArray, denseMatrix64F);
    }

    @Override
    void scaleAndDriftMean(int n, int n2, int n3) {
        for (int i = 0; i < this.dimTrait; ++i) {
            this.preOrderPartials[n + i] = (1.0 - this.diagonal1mActualizations[n3 + i]) * this.preOrderPartials[n + i] + this.displacements[n3 + i];
        }
    }

    public double[] getStationaryVariance(int n) {
        assert (this.stationaryVariances != null);
        return this.getMatrixProcess(n, this.stationaryVariances);
    }

    @Override
    void computePartialPrecision(int n, int n2, int n3, int n4, DenseMatrix64F denseMatrix64F, DenseMatrix64F denseMatrix64F2, DenseMatrix64F denseMatrix64F3) {
        double[] dArray = this.vectorDiagQdi;
        System.arraycopy(this.diagonal1mActualizations, n, dArray, 0, this.dimTrait);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.oneMinus(dArray);
        double[] dArray2 = this.vectorDiagQdj;
        System.arraycopy(this.diagonal1mActualizations, n2, dArray2, 0, this.dimTrait);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.oneMinus(dArray2);
        DenseMatrix64F denseMatrix64F4 = this.matrix0;
        DenseMatrix64F denseMatrix64F5 = this.matrix1;
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.diagonalDoubleProduct(denseMatrix64F, dArray, denseMatrix64F4);
        SafeMultivariateDiagonalActualizedWithDriftIntegrator.diagonalDoubleProduct(denseMatrix64F2, dArray2, denseMatrix64F5);
        CommonOps.add((D1Matrix64F)denseMatrix64F4, denseMatrix64F5, (D1Matrix64F)denseMatrix64F3);
        if (DEBUG) {
            System.err.println("Qdi: " + Arrays.toString(dArray));
            System.err.println("\tQdiPipQdi: " + denseMatrix64F4);
            System.err.println("\tQdj: " + Arrays.toString(dArray2));
            System.err.println("\tQdjPjpQdj: " + denseMatrix64F5);
        }
    }

    @Override
    void computeWeightedSum(double[] dArray, double[] dArray2, int n, double[] dArray3) {
        MissingOps.weightedSumActualized(dArray, 0, this.matrixPip, this.vectorDiagQdi, 0, dArray2, 0, this.matrixPjp, this.vectorDiagQdj, 0, n, dArray3);
    }

    private static void diagonalDoubleProduct(DenseMatrix64F denseMatrix64F, double[] dArray, DenseMatrix64F denseMatrix64F2) {
        MissingOps.diagMult(denseMatrix64F, dArray, denseMatrix64F2);
        MissingOps.diagMult(dArray, denseMatrix64F2);
    }

    static void oneMinus(double[] dArray) {
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = 1.0 - dArray[i];
        }
    }
}

