/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.continuous.kraskov;

import infodynamics.measures.continuous.kraskov.MutualInfoCalculatorMultiVariateKraskov;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;
import java.util.Arrays;

public class MutualInfoCalculatorMultiVariateKraskov1
extends MutualInfoCalculatorMultiVariateKraskov {
    protected static final double CUTOFF_MULTIPLIER = 1.5;

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        return this.computeAverageLocalOfObservations(null);
    }

    @Override
    public double computeAverageLocalOfObservations(int[] nArray) throws Exception {
        if (!tryKeepAllPairsNorms || this.sourceObservations.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            double[][] dArray = this.destObservations;
            if (nArray != null) {
                this.destObservations = MatrixUtils.extractSelectedTimePointsReusingArrays(dArray, nArray);
            }
            double d = this.computeAverageLocalOfObservationsWhileComputingDistances();
            this.destObservations = dArray;
            return d;
        }
        if (this.xNorms == null) {
            this.computeNorms();
        }
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n3 = nArray == null ? i : nArray[i];
            double[] dArray = new double[n];
            for (int j = 0; j < n; ++j) {
                int n4 = nArray == null ? j : nArray[j];
                dArray[j] = Math.max(this.xNorms[i][j], this.yNorms[n3][n4]);
            }
            double d4 = 0.0;
            if (this.k <= n2) {
                d4 = MatrixUtils.kthMin(dArray, this.k);
            } else {
                Arrays.sort(dArray);
                d4 = dArray[this.k - 1];
            }
            int n5 = 0;
            int n6 = 0;
            for (int j = 0; j < n; ++j) {
                int n7;
                if (this.xNorms[i][j] < d4) {
                    ++n5;
                }
                int n8 = n7 = nArray == null ? j : nArray[j];
                if (!(this.yNorms[n3][n7] < d4)) continue;
                ++n6;
            }
            d2 += (double)n5;
            d3 += (double)n6;
            d += MathsUtils.digamma(n5 + 1) + MathsUtils.digamma(n6 + 1);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d2 /= (double)n, d3 /= (double)n));
        }
        double d5 = MathsUtils.digamma(this.k) - d + MathsUtils.digamma(n);
        this.miComputed = true;
        if (nArray == null) {
            this.lastAverage = d5;
        }
        return d5;
    }

    public double computeAverageLocalOfObservationsWhileComputingDistances() throws Exception {
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        for (int i = 0; i < n; ++i) {
            double[][] dArray = this.normCalculator.computeNorms(this.sourceObservations, this.destObservations, i);
            double[] dArray2 = new double[n];
            for (int j = 0; j < n; ++j) {
                dArray2[j] = Math.max(dArray[j][0], dArray[j][1]);
            }
            double d4 = 0.0;
            if (this.k <= n2) {
                d4 = MatrixUtils.kthMin(dArray2, this.k);
            } else {
                Arrays.sort(dArray2);
                d4 = dArray2[this.k - 1];
            }
            int n3 = 0;
            int n4 = 0;
            for (int j = 0; j < n; ++j) {
                if (dArray[j][0] < d4) {
                    ++n3;
                }
                if (!(dArray[j][1] < d4)) continue;
                ++n4;
            }
            d2 += (double)n3;
            d3 += (double)n4;
            d += MathsUtils.digamma(n3 + 1) + MathsUtils.digamma(n4 + 1);
        }
        d /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d2 /= (double)n, d3 /= (double)n));
        }
        this.lastAverage = MathsUtils.digamma(this.k) - d + MathsUtils.digamma(n);
        this.miComputed = true;
        return this.lastAverage;
    }

    @Override
    public double[] computeLocalOfPreviousObservations() throws Exception {
        int n = this.sourceObservations.length;
        int n2 = (int)(1.5 * Math.log(n) / Math.log(2.0));
        double[] dArray = new double[n];
        double d = MathsUtils.digamma(this.k);
        double d2 = MathsUtils.digamma(n);
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        for (int i = 0; i < n; ++i) {
            double[][] dArray2 = this.normCalculator.computeNorms(this.sourceObservations, this.destObservations, i);
            double[] dArray3 = new double[n];
            for (int j = 0; j < n; ++j) {
                dArray3[j] = Math.max(dArray2[j][0], dArray2[j][1]);
            }
            double d6 = 0.0;
            if (this.k <= n2) {
                d6 = MatrixUtils.kthMin(dArray3, this.k);
            } else {
                Arrays.sort(dArray3);
                d6 = dArray3[this.k - 1];
            }
            int n3 = 0;
            int n4 = 0;
            for (int j = 0; j < n; ++j) {
                if (dArray2[j][0] < d6) {
                    ++n3;
                }
                if (!(dArray2[j][1] < d6)) continue;
                ++n4;
            }
            double d7 = MathsUtils.digamma(n3 + 1);
            double d8 = MathsUtils.digamma(n4 + 1);
            dArray[i] = d - d7 - d8 + d2;
            d4 += (double)n3;
            d5 += (double)n4;
            d3 += d7 + d8;
        }
        d3 /= (double)n;
        if (this.debug) {
            System.out.println(String.format("Average n_x=%.3f, Average n_y=%.3f", d4 /= (double)n, d5 /= (double)n));
        }
        this.lastAverage = d - d3 + d2;
        this.miComputed = true;
        return dArray;
    }

    @Override
    public String printConstants(int n) throws Exception {
        String string = String.format("digamma(k=%d)=%.3e + digamma(N=%d)=%.3e => %.3e", this.k, MathsUtils.digamma(this.k), n, MathsUtils.digamma(n), MathsUtils.digamma(this.k) + MathsUtils.digamma(n));
        return string;
    }
}

