/*
 * Decompiled with CFR 0.152.
 */
package infodynamics.measures.mixed.kraskov;

import infodynamics.measures.mixed.ConditionalMutualInfoCalculatorMultiVariateWithDiscreteSourceCommon;
import infodynamics.utils.EuclideanUtils;
import infodynamics.utils.MathsUtils;
import infodynamics.utils.MatrixUtils;

public class ConditionalMutualInfoCalculatorMultiVariateWithDiscreteKraskov
extends ConditionalMutualInfoCalculatorMultiVariateWithDiscreteSourceCommon
implements Cloneable {
    protected static final double CUTOFF_MULTIPLIER = 1.5;
    protected int k = 4;
    protected EuclideanUtils normCalculator = new EuclideanUtils(2);
    protected double[][] xNorms;
    protected double[][] zNorms;
    protected double[][] xzNorms;
    public static boolean tryKeepAllPairsNorms = true;
    public static int MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM = 2000;
    public static final String PROP_K = "k";
    public static final String PROP_NORM_TYPE = "NORM_TYPE";

    @Override
    public void initialise(int n, int n2, int n3) {
        super.initialise(n, n2, n3);
        this.xNorms = null;
        this.zNorms = null;
        this.xzNorms = null;
    }

    @Override
    public void setProperty(String string, String string2) {
        if (string.equalsIgnoreCase(PROP_K)) {
            this.k = Integer.parseInt(string2);
        } else if (string.equalsIgnoreCase(PROP_NORM_TYPE)) {
            this.normCalculator.setNormToUse(string2);
        } else {
            super.setProperty(string, string2);
        }
    }

    @Override
    public void finaliseAddObservations() throws Exception {
        super.finaliseAddObservations();
        for (int i = 0; i < this.counts.length; ++i) {
            if (this.counts[i] >= this.k) continue;
            throw new RuntimeException("This implementation assumes there are at least k items in each discrete bin");
        }
    }

    protected void computeNorms() {
        int n = this.continuousDataX.length;
        this.xNorms = new double[n][n];
        this.zNorms = new double[n][n];
        this.xzNorms = new double[n][n];
        for (int i = 0; i < n; ++i) {
            for (int j = 0; j < n; ++j) {
                if (j == i) {
                    this.xNorms[i][j] = Double.POSITIVE_INFINITY;
                    this.zNorms[i][j] = Double.POSITIVE_INFINITY;
                    this.xzNorms[i][j] = Double.POSITIVE_INFINITY;
                    continue;
                }
                this.xNorms[i][j] = this.normCalculator.norm(this.continuousDataX[i], this.continuousDataX[j]);
                this.zNorms[i][j] = this.normCalculator.norm(this.conditionedDataZ[i], this.conditionedDataZ[j]);
                this.xzNorms[i][j] = Math.max(this.xNorms[i][j], this.zNorms[i][j]);
            }
        }
    }

    public double computeAverageLocalOfObservations(int[] nArray) throws Exception {
        int n = this.continuousDataX.length;
        if (!tryKeepAllPairsNorms || n > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            int[] nArray2 = this.discreteData;
            this.discreteData = MatrixUtils.extractSelectedTimePoints(this.discreteData, nArray);
            double d = this.computeAverageLocalOfObservationsWhileComputingDistances();
            this.discreteData = nArray2;
            return d;
        }
        int[] nArray3 = MatrixUtils.extractSelectedTimePoints(this.discreteData, nArray);
        if (this.xNorms == null) {
            this.computeNorms();
        }
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            int n3;
            double[][] dArray = new double[n][2];
            for (int j = 0; j < n; ++j) {
                dArray[j][0] = Math.max(this.xNorms[i][j], this.zNorms[i][j]);
                dArray[j][1] = j;
            }
            double d7 = 0.0;
            double d8 = 0.0;
            int[] nArray4 = null;
            nArray4 = MatrixUtils.kMinIndicesSubjectTo(dArray, 0, this.k, nArray3, nArray3[i]);
            for (n3 = 0; n3 < this.k; ++n3) {
                n2 = nArray4[n3];
                if (this.xNorms[i][n2] > d7) {
                    d7 = this.xNorms[i][n2];
                }
                if (!(this.zNorms[i][n2] > d8)) continue;
                d8 = this.zNorms[i][n2];
            }
            n3 = 0;
            n2 = 0;
            int n4 = 0;
            for (int j = 0; j < n; ++j) {
                if (!(this.zNorms[i][j] <= d8)) continue;
                ++n4;
                if (this.xNorms[i][j] <= d7) {
                    ++n3;
                }
                if (nArray3[i] != nArray3[j]) continue;
                ++n2;
            }
            d2 += (double)n3;
            d3 += (double)n2;
            d4 += (double)n4;
            d += MathsUtils.digamma(n4) - MathsUtils.digamma(n3) - MathsUtils.digamma(n2);
            double d9 = 1.0 / (double)n3;
            d5 += d9;
            double d10 = 1.0 / (double)n2;
            d6 += d10;
        }
        this.condMi = MathsUtils.digamma(this.k) - 2.0 / (double)this.k + (d /= (double)n) + (d5 /= (double)n) + (d6 /= (double)n);
        this.miComputed = true;
        if (this.debug) {
            System.out.println(String.format("Average n_xz=%.3f, Average n_yz=%.3f, Average n_z=%.3f", d2 /= (double)n, d3 /= (double)n, d4 /= (double)n));
            System.out.printf("Av = digamma(k)=%.3f + <digammas>=%.3f +<inverses>=%.3f - 2/k=%.3f = %.3f (<1/n_yz>=%.3f, <1/n_xz>=%.3f)\n", MathsUtils.digamma(this.k), d, d5 + d6, 2.0 / (double)this.k, this.condMi, d6, d5);
        }
        return this.condMi;
    }

    @Override
    public double computeAverageLocalOfObservations() throws Exception {
        if (!tryKeepAllPairsNorms || this.continuousDataX.length > MAX_DATA_SIZE_FOR_KEEP_ALL_PAIRS_NORM) {
            return this.computeAverageLocalOfObservationsWhileComputingDistances();
        }
        if (this.xNorms == null) {
            this.computeNorms();
        }
        int n = this.continuousDataX.length;
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            int n3;
            double[][] dArray = new double[n][2];
            for (int j = 0; j < n; ++j) {
                dArray[j][0] = Math.max(this.xNorms[i][j], this.zNorms[i][j]);
                dArray[j][1] = j;
            }
            double d7 = 0.0;
            double d8 = 0.0;
            int[] nArray = null;
            nArray = MatrixUtils.kMinIndicesSubjectTo(dArray, 0, this.k, this.discreteData, this.discreteData[i]);
            for (n3 = 0; n3 < this.k; ++n3) {
                n2 = nArray[n3];
                if (this.xNorms[i][n2] > d7) {
                    d7 = this.xNorms[i][n2];
                }
                if (!(this.zNorms[i][n2] > d8)) continue;
                d8 = this.zNorms[i][n2];
            }
            n3 = 0;
            n2 = 0;
            int n4 = 0;
            for (int j = 0; j < n; ++j) {
                if (!(this.zNorms[i][j] <= d8)) continue;
                ++n4;
                if (this.xNorms[i][j] <= d7) {
                    ++n3;
                }
                if (this.discreteData[i] != this.discreteData[j]) continue;
                ++n2;
            }
            d4 += (double)n3;
            d5 += (double)n2;
            d6 += (double)n4;
            d += MathsUtils.digamma(n4) - MathsUtils.digamma(n3) - MathsUtils.digamma(n2);
            double d9 = 1.0 / (double)n3;
            d2 += d9;
            double d10 = 1.0 / (double)n2;
            d3 += d10;
        }
        this.condMi = MathsUtils.digamma(this.k) - 2.0 / (double)this.k + (d /= (double)n) + (d3 /= (double)n) + (d2 /= (double)n);
        this.miComputed = true;
        if (this.debug) {
            System.out.printf("Average n_xz=%.3f (-> digam=%.3f %.3f), Average n_yz=%.3f (-> digam=%.3f)", d4 /= (double)n, MathsUtils.digamma((int)d4), MathsUtils.digamma((int)d4 - 1), d5 /= (double)n, MathsUtils.digamma((int)d5));
            System.out.printf(", Average n_z=%.3f (-> digam=%.3f)\n", d6 /= (double)n, MathsUtils.digamma((int)d6));
            System.out.printf("Independent average num in joint box is %.3f\n", d4 * d5 / (double)n);
            System.out.printf("Av = digamma(k)=%.3f + <digammas>=%.3f + <avInverses>=%.3f - 2/k=%.3f = %.3f (<1/n_yz>=%.3f, <1/n_xz>=%.3f)\n", MathsUtils.digamma(this.k), d, d3 + d2, 2.0 / (double)this.k, this.condMi, d3, d2);
        }
        return this.condMi;
    }

    public double computeAverageLocalOfObservationsWhileComputingDistances() throws Exception {
        int n = this.continuousDataX.length;
        double d = 0.0;
        double d2 = 0.0;
        double d3 = 0.0;
        double d4 = 0.0;
        double d5 = 0.0;
        double d6 = 0.0;
        for (int i = 0; i < n; ++i) {
            int n2;
            int n3;
            double[][] dArray = this.normCalculator.computeNorms(this.continuousDataX, this.conditionedDataZ, i);
            double[][] dArray2 = new double[n][2];
            for (int j = 0; j < n; ++j) {
                dArray2[j][0] = Math.max(dArray[j][0], dArray[j][1]);
                dArray2[j][1] = j;
            }
            double d7 = 0.0;
            double d8 = 0.0;
            int[] nArray = null;
            nArray = MatrixUtils.kMinIndicesSubjectTo(dArray2, 0, this.k, this.discreteData, this.discreteData[i]);
            for (n3 = 0; n3 < this.k; ++n3) {
                n2 = nArray[n3];
                if (dArray[n2][0] > d7) {
                    d7 = dArray[n2][0];
                }
                if (!(dArray[n2][1] > d8)) continue;
                d8 = dArray[n2][1];
            }
            n3 = 0;
            n2 = 0;
            int n4 = 0;
            for (int j = 0; j < n; ++j) {
                if (!(dArray[j][1] <= d8)) continue;
                ++n4;
                if (dArray[j][0] <= d7) {
                    ++n3;
                }
                if (this.discreteData[i] != this.discreteData[j]) continue;
                ++n2;
            }
            d4 += (double)n3;
            d5 += (double)n2;
            d6 += (double)n4;
            d += MathsUtils.digamma(n4) - MathsUtils.digamma(n3) - MathsUtils.digamma(n2);
            double d9 = 1.0 / (double)n3;
            d3 += d9;
            double d10 = 1.0 / (double)n2;
            d2 += d10;
        }
        this.condMi = MathsUtils.digamma(this.k) - 2.0 / (double)this.k + (d /= (double)n) + (d2 /= (double)n) + (d3 /= (double)n);
        this.miComputed = true;
        if (this.debug) {
            System.out.printf("Average n_xz=%.3f, Average n_yz=%.3f, Average n_z=%.3f\n", d4 /= (double)n, d5 /= (double)n, d6 /= (double)n);
            System.out.printf("Av = digamma(k)=%.3f + <digammas>=%.3f + <inverses>=%.3f - 2/k=%.3f = %.3f (<1/n_yz>=%.3f, <1/n_xz>=%.3f)\n", MathsUtils.digamma(this.k), d, d2 + d3, 2.0 / (double)this.k, this.condMi, d2, d3);
        }
        return this.condMi;
    }

    @Override
    public double[] computeLocalOfPreviousObservations() throws Exception {
        throw new Exception("Not implemented yet");
    }

    @Override
    public double[] computeLocalUsingPreviousObservations(double[][] dArray, int[] nArray, double[][] dArray2) throws Exception {
        if (this.normalise) {
            dArray = MatrixUtils.normaliseIntoNewArray(dArray, this.meansX, this.stdsX);
            dArray2 = MatrixUtils.normaliseIntoNewArray(dArray2, this.meansZ, this.stdsZ);
        }
        throw new Exception("Not implemented yet");
    }
}

