/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.continuous.kraskov;

import infodynamics.measures.continuous.kraskov.MutualInfoCalculatorMultiVariateKraskov;
import infodynamics.utils.FirstIndexComparatorDouble;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import java.util.Arrays;

public class MutualInfoCalculatorMultiVariateKraskov2
extends MutualInfoCalculatorMultiVariateKraskov {
    protected static final int JOINT_NORM_VAL_COLUMN = 0;
    protected static final int JOINT_NORM_TIMESTEP_COLUMN = 1;
    protected static final double CUTOFF_MULTIPLIER = 1.5;

    @Override
    public double computeAverageLocalOfObservations(int[] nArray) throws Exception {
        if (!tryKeepAllPairsNorms || this.sourceObservations.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            double[][] dArray = this.destObservations;
            this.destObservations = MatrixUtils.extractSelectedTimePointsReusingArrays(dArray, nArray);
            double d = this.computeAverageLocalOfObservationsWhileComputingDistances();
            this.destObservations = dArray;
            return d;
        }
        if (this.xNorms == null) {
            this.computeNorms();
        }
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n3;
            int n4;
            int n5 = nArray[i];
            double[][] dArray = new double[n][2];
            for (int j = 0; j < n; ++j) {
                int n6 = nArray[j];
                dArray[j][0] = Math.max(this.xNorms[i][j], this.yNorms[n5][n6]);
                dArray[j][1] = j;
            }
            double d4 = 0.0;
            double d5 = 0.0;
            int[] nArray2 = null;
            if (this.k <= n2) {
                nArray2 = MatrixUtils.kMinIndices(dArray, 0, this.k);
            } else {
                Arrays.sort(dArray, FirstIndexComparatorDouble.getInstance());
                nArray2 = new int[this.k];
                for (n4 = 0; n4 < this.k; ++n4) {
                    nArray2[n4] = (int)dArray[n4][1];
                }
            }
            for (n4 = 0; n4 < this.k; ++n4) {
                n3 = nArray2[n4];
                if (this.xNorms[i][n3] > d4) {
                    d4 = this.xNorms[i][n3];
                }
                if (!(this.yNorms[n5][nArray[n3]] > d5)) continue;
                d5 = this.yNorms[n5][nArray[n3]];
            }
            n4 = 0;
            n3 = 0;
            for (int j = 0; j < n; ++j) {
                if (this.xNorms[i][j] <= d4) {
                    ++n4;
                }
                if (!(this.yNorms[n5][nArray[j]] <= d5)) continue;
                ++n3;
            }
            d2 += (double)n4;
            d3 += (double)n3;
            d += MathsUtils.digamma(n4) + MathsUtils.digamma(n3);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d2 /= (double)n, d3 /= (double)n));
        }
        double d6 = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - d + MathsUtils.digamma(n);
        this.miComputed = true;
        if (nArray == null) {
            this.lastAverage = d6;
        }
        return d6;
    }

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        if (!tryKeepAllPairsNorms || this.sourceObservations.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            return this.computeAverageLocalOfObservationsWhileComputingDistances();
        }
        if (this.xNorms == null) {
            this.computeNorms();
        }
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n3;
            int n4;
            double[][] dArray = new double[n][2];
            for (int j = 0; j < n; ++j) {
                dArray[j][0] = Math.max(this.xNorms[i][j], this.yNorms[i][j]);
                dArray[j][1] = j;
            }
            double d4 = 0.0;
            double d5 = 0.0;
            int[] nArray = null;
            if (this.k <= n2) {
                nArray = MatrixUtils.kMinIndices(dArray, 0, this.k);
            } else {
                Arrays.sort(dArray, FirstIndexComparatorDouble.getInstance());
                nArray = new int[this.k];
                for (n4 = 0; n4 < this.k; ++n4) {
                    nArray[n4] = (int)dArray[n4][1];
                }
            }
            for (n4 = 0; n4 < this.k; ++n4) {
                n3 = nArray[n4];
                if (this.xNorms[i][n3] > d4) {
                    d4 = this.xNorms[i][n3];
                }
                if (!(this.yNorms[i][n3] > d5)) continue;
                d5 = this.yNorms[i][n3];
            }
            n4 = 0;
            n3 = 0;
            for (int j = 0; j < n; ++j) {
                if (this.xNorms[i][j] <= d4) {
                    ++n4;
                }
                if (!(this.yNorms[i][j] <= d5)) continue;
                ++n3;
            }
            d2 += (double)n4;
            d3 += (double)n3;
            d += MathsUtils.digamma(n4) + MathsUtils.digamma(n3);
        }
        this.lastAverage = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - (d /= (double)n) + MathsUtils.digamma(n);
        this.miComputed = true;
        if (this.debug) {
            System.out.printf("Average n_x=%.3f, Average n_y=%.3f", d2 /= (double)n, d3 /= (double)n);
            System.out.printf("psi(k=%d)=%.4f - 1/k=%.4f - averageDiGammas=%.4f -psi(N)=%.4f => %.4f\n", this.k, MathsUtils.digamma(this.k), 1.0 / (double)this.k, d, MathsUtils.digamma(n), this.lastAverage);
        }
        return this.lastAverage;
    }

    public double computeAverageLocalOfObservationsWhileComputingDistances() throws Exception {
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n3;
            int n4;
            double[][] dArray = this.normCalculator.computeNorms(this.sourceObservations, this.destObservations, i);
            double[][] dArray2 = new double[n][2];
            for (int j = 0; j < n; ++j) {
                dArray2[j][0] = Math.max(dArray[j][0], dArray[j][1]);
                dArray2[j][1] = j;
            }
            double d4 = 0.0;
            double d5 = 0.0;
            int[] nArray = null;
            if (this.k <= n2) {
                nArray = MatrixUtils.kMinIndices(dArray2, 0, this.k);
            } else {
                Arrays.sort(dArray2, FirstIndexComparatorDouble.getInstance());
                nArray = new int[this.k];
                for (n4 = 0; n4 < this.k; ++n4) {
                    nArray[n4] = (int)dArray2[n4][1];
                }
            }
            for (n4 = 0; n4 < this.k; ++n4) {
                n3 = nArray[n4];
                if (dArray[n3][0] > d4) {
                    d4 = dArray[n3][0];
                }
                if (!(dArray[n3][1] > d5)) continue;
                d5 = dArray[n3][1];
            }
            n4 = 0;
            n3 = 0;
            for (int j = 0; j < n; ++j) {
                if (dArray[j][0] <= d4) {
                    ++n4;
                }
                if (!(dArray[j][1] <= d5)) continue;
                ++n3;
            }
            d2 += (double)n4;
            d3 += (double)n3;
            d += MathsUtils.digamma(n4) + MathsUtils.digamma(n3);
        }
        this.lastAverage = MathsUtils.digamma(this.k) - 1.0 / (double)this.k - (d /= (double)n) + MathsUtils.digamma(n);
        this.miComputed = true;
        if (this.debug) {
            System.out.printf("Average n_x=%.3f, Average n_y=%.3f\n", d2 /= (double)n, d3 /= (double)n);
            System.out.printf("psi(k=%d)=%.4f - 1/k=%.4f - averageDiGammas=%.4f + psi(N)=%.4f => %.4f\n", this.k, MathsUtils.digamma(this.k), 1.0 / (double)this.k, d, MathsUtils.digamma(n), this.lastAverage);
        }
        return this.lastAverage;
    }

    @Override
    public double[] computeLocalOfPreviousObservations() throws Exception {
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double[] dArray = new double[n];
        double d = MathsUtils.digamma(this.k);
        double d2 = 1.0 / (double)this.k;
        double d3 = MathsUtils.digamma(n);
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n3;
            int n4;
            double[][] dArray2 = this.normCalculator.computeNorms(this.sourceObservations, this.destObservations, i);
            double[][] dArray3 = new double[n][2];
            for (int j = 0; j < n; ++j) {
                dArray3[j][0] = Math.max(dArray2[j][0], dArray2[j][1]);
                dArray3[j][1] = j;
            }
            double d7 = 0.0;
            double d8 = 0.0;
            int[] nArray = null;
            if (this.k <= n2) {
                nArray = MatrixUtils.kMinIndices(dArray3, 0, this.k);
            } else {
                Arrays.sort(dArray3, FirstIndexComparatorDouble.getInstance());
                nArray = new int[this.k];
                for (n4 = 0; n4 < this.k; ++n4) {
                    nArray[n4] = (int)dArray3[n4][1];
                }
            }
            for (n4 = 0; n4 < this.k; ++n4) {
                n3 = nArray[n4];
                if (dArray2[n3][0] > d7) {
                    d7 = dArray2[n3][0];
                }
                if (!(dArray2[n3][1] > d8)) continue;
                d8 = dArray2[n3][1];
            }
            n4 = 0;
            n3 = 0;
            for (int j = 0; j < n; ++j) {
                if (dArray2[j][0] <= d7) {
                    ++n4;
                }
                if (!(dArray2[j][1] <= d8)) continue;
                ++n3;
            }
            d5 += (double)n4;
            d6 += (double)n3;
            double d9 = MathsUtils.digamma(n4);
            double d10 = MathsUtils.digamma(n3);
            dArray[i] = d - d2 - d9 - d10 + d3;
            d4 += d9 + d10;
        }
        d4 /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d5 /= (double)n, d6 /= (double)n));
        }
        this.lastAverage = d - d2 - d4 + d3;
        this.miComputed = true;
        return dArray;
    }

    @Override
    public String printConstants(int n) throws Exception {
        String string = String.format("digamma(k=%d)=%.3e - 1/k=%.3e + digamma(N=%d)=%.3e => %.3e", this.k, MathsUtils.digamma(this.k), 1.0 / (double)this.k, n, MathsUtils.digamma(n), MathsUtils.digamma(this.k) - 1.0 / (double)this.k + MathsUtils.digamma(n));
        return string;
    }
}

