package edu.berkeley.nlp.discPCFG;

import edu.berkeley.nlp.PCFGLA.ArrayParser;
import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.ConditionalTrainer;
import edu.berkeley.nlp.PCFGLA.ConstrainedHierarchicalTwoChartParser;
import edu.berkeley.nlp.PCFGLA.ConstrainedTwoChartsParser;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.ParserConstrainer;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SpanPredictor;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.math.SloppyMath;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.zip.GZIPInputStream;

/* loaded from: input_file:edu/berkeley/nlp/discPCFG/ParsingObjectiveFunction.class */
public class ParsingObjectiveFunction implements ObjectiveFunction {
    public static final int NO_REGULARIZATION = 0;
    public static final int L1_REGULARIZATION = 1;
    public static final int L2_REGULARIZATION = 2;
    Grammar grammar;
    SimpleLexicon lexicon;
    SpanPredictor spanPredictor;
    Linearizer linearizer;
    int myRegularization;
    double sigma;
    double lastValue;
    double[] lastDerivative;
    double[] lastUnregularizedDerivative;
    double[] x;
    int dimension;
    int nGrammarWeights;
    int nLexiconWeights;
    int nSpanWeights;
    int nProcesses;
    String consBaseName;
    StateSetTreeList[] trainingTrees;
    ExecutorService pool;
    Calculator[] tasks;
    double bestObjectiveSoFar;
    String outFileName;
    double[] spanGoldCounts;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/berkeley/nlp/discPCFG/ParsingObjectiveFunction$Calculator.class */
    public class Calculator implements Callable {
        ArrayParser gParser;
        ConstrainedTwoChartsParser eParser;
        StateSetTreeList myTrees;
        String consName;
        int myID;
        int nCounts;
        Counts myCounts;
        boolean[][][][][] myConstraints;
        int unparsableTrees;
        int incorrectLLTrees;
        boolean doNotProjectConstraints;
        double[] myDerivatives;

        Calculator(StateSetTreeList stateSetTreeList, String str, int i, Grammar grammar, Lexicon lexicon, SpanPredictor spanPredictor, int i2, boolean z) {
            this.nCounts = i2;
            this.consName = str;
            this.myTrees = stateSetTreeList;
            this.doNotProjectConstraints = z;
            this.myID = i;
            this.gParser = new ArrayParser(grammar, lexicon);
            this.eParser = newEParser(grammar, lexicon, spanPredictor);
        }

        protected ConstrainedTwoChartsParser newEParser(Grammar grammar, Lexicon lexicon, SpanPredictor spanPredictor) {
            return !ConditionalTrainer.Options.hierarchicalChart ? new ConstrainedTwoChartsParser(grammar, lexicon, spanPredictor) : new ConstrainedHierarchicalTwoChartParser(grammar, lexicon, spanPredictor, grammar.finalLevel);
        }

        /* JADX WARN: Type inference failed for: r1v3, types: [boolean[][][][], boolean[][][][][]] */
        protected void loadConstraints() {
            this.myConstraints = new boolean[this.myTrees.size()][][];
            boolean[][][][][] zArr = (boolean[][][][][]) null;
            int i = 0;
            int i2 = 0;
            if (this.consName == null) {
                return;
            }
            for (int i3 = 0; i3 < this.myTrees.size(); i3++) {
                if (zArr == null || i2 >= zArr.length) {
                    zArr = loadData(this.consName + "-" + ((i * ParsingObjectiveFunction.this.nProcesses) + this.myID) + ".data");
                    i++;
                    i2 = 0;
                    System.out.print(".");
                }
                if (!this.doNotProjectConstraints) {
                    this.eParser.projectConstraints(zArr[i2]);
                }
                this.myConstraints[i3] = zArr[i2];
                i2++;
                if (this.myConstraints[i3].length != this.myTrees.get(i3).getYield().size()) {
                    System.out.println("My ID: " + this.myID + ", block: " + i + ", sentence: " + i2);
                    System.out.println("Sentence length and constraints length do not match!");
                    this.myConstraints[i3] = (boolean[][][][]) null;
                }
            }
        }

        @Override // java.util.concurrent.Callable
        public Counts call() {
            double d = 0.0d;
            this.myDerivatives = new double[ParsingObjectiveFunction.this.dimension];
            this.unparsableTrees = 0;
            this.incorrectLLTrees = 0;
            if (this.myConstraints == null) {
                loadConstraints();
            }
            int i = -1;
            Iterator<Tree<StateSet>> it = this.myTrees.iterator();
            while (it.hasNext()) {
                Tree<StateSet> next = it.next();
                i++;
                List<StateSet> yield = next.getYield();
                boolean[][][][] zArr = (boolean[][][][]) null;
                if (this.consName != null) {
                    zArr = this.myConstraints[i];
                    if (zArr.length != yield.size()) {
                        System.out.println("My ID: " + this.myID + ", block: 0, sentence: " + i);
                        System.out.println("Sentence length (" + yield.size() + ") and constraints length (" + zArr.length + ") do not match!");
                        System.exit(-1);
                    }
                }
                double doConstrainedInsideOutsideScores = this.eParser.doConstrainedInsideOutsideScores(yield, zArr, false, null, null, false);
                double doInsideOutsideScores = ConditionalTrainer.Options.hierarchicalChart ? this.eParser.doInsideOutsideScores(next, false, false, this.eParser.spanScores) : this.gParser.doInsideOutsideScores(next, false, false, this.eParser.spanScores);
                if (i % 500 == 0) {
                    System.out.print(".");
                }
                if (sanityCheckLLs(doInsideOutsideScores, doConstrainedInsideOutsideScores, next)) {
                    this.eParser.incrementExpectedCounts(ParsingObjectiveFunction.this.linearizer, this.myDerivatives, yield);
                    if (ConditionalTrainer.Options.hierarchicalChart) {
                        this.eParser.incrementExpectedGoldCounts(ParsingObjectiveFunction.this.linearizer, this.myDerivatives, next);
                    } else {
                        this.gParser.incrementExpectedGoldCounts(ParsingObjectiveFunction.this.linearizer, this.myDerivatives, next);
                    }
                    d += doInsideOutsideScores - doConstrainedInsideOutsideScores;
                } else {
                    d -= 1000.0d;
                }
            }
            this.myCounts = new Counts(d, this.myDerivatives, this.unparsableTrees, this.incorrectLLTrees);
            System.out.print(" " + this.myID + " ");
            return this.myCounts;
        }

        public boolean[][][][][] loadData(String str) {
            try {
                FileInputStream fileInputStream = new FileInputStream(str);
                GZIPInputStream gZIPInputStream = new GZIPInputStream(fileInputStream);
                ObjectInputStream objectInputStream = new ObjectInputStream(gZIPInputStream);
                boolean[][][][][] zArr = (boolean[][][][][]) objectInputStream.readObject();
                objectInputStream.close();
                gZIPInputStream.close();
                fileInputStream.close();
                return zArr;
            } catch (IOException e) {
                System.out.println("IOException\n" + e);
                return (boolean[][][][][]) null;
            } catch (ClassNotFoundException e2) {
                System.out.println("Class not found!");
                return (boolean[][][][][]) null;
            }
        }

        protected boolean sanityCheckLLs(double d, double d2, Tree<StateSet> tree) {
            if (SloppyMath.isVeryDangerous(d2) || SloppyMath.isVeryDangerous(d)) {
                this.unparsableTrees++;
                return false;
            }
            if (d - d2 <= 1.0E-4d) {
                return true;
            }
            System.out.println("Something is wrong! The gold LL is " + d + " and the all LL is " + d2);
            System.out.println(tree);
            this.incorrectLLTrees++;
            return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/berkeley/nlp/discPCFG/ParsingObjectiveFunction$Counts.class */
    public class Counts {
        double myObjective;
        double[] myDerivatives;
        int unparsableTrees;
        int incorrectLLTrees;

        public Counts(double d, double[] dArr, int i, int i2) {
            this.myObjective = d;
            this.myDerivatives = dArr;
            this.unparsableTrees = i;
            this.incorrectLLTrees = i2;
        }
    }

    @Override // edu.berkeley.nlp.math.Function
    public int dimension() {
        return this.dimension;
    }

    @Override // edu.berkeley.nlp.math.Function
    public double valueAt(double[] dArr) {
        ensureCache(dArr);
        return this.lastValue;
    }

    @Override // edu.berkeley.nlp.math.DifferentiableFunction
    public double[] derivativeAt(double[] dArr) {
        ensureCache(dArr);
        return this.lastDerivative;
    }

    @Override // edu.berkeley.nlp.math.DifferentiableRegularizableFunction
    public double[] unregularizedDerivativeAt(double[] dArr) {
        ensureCache(dArr);
        return this.lastUnregularizedDerivative;
    }

    private void ensureCache(double[] dArr) {
        boolean z;
        if (requiresUpdate(dArr)) {
            this.linearizer.delinearizeWeights(dArr);
            this.grammar = this.linearizer.getGrammar();
            this.lexicon = this.linearizer.getLexicon();
            this.spanPredictor = this.linearizer.getSpanPredictor();
            if (this.x == null) {
                this.x = (double[]) dArr.clone();
            } else {
                for (int i = 0; i < this.x.length; i++) {
                    this.x[i] = dArr[i];
                }
            }
            System.out.print("Task: ");
            Future[] futureArr = new Future[this.nProcesses];
            if (this.nProcesses > 1) {
                for (int i2 = 0; i2 < this.nProcesses; i2++) {
                    futureArr[i2] = this.pool.submit(this.tasks[i2]);
                }
                do {
                    z = true;
                    for (Future future : futureArr) {
                        z &= future.isDone();
                    }
                } while (!z);
            }
            double d = 0.0d;
            int i3 = 0;
            int i4 = 0;
            double[] dArr2 = new double[this.dimension];
            for (int i5 = 0; i5 < this.nProcesses; i5++) {
                Counts counts = null;
                if (this.nProcesses == 1) {
                    counts = this.tasks[0].call();
                } else {
                    try {
                        counts = (Counts) futureArr[i5].get();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    } catch (ExecutionException e2) {
                        e2.printStackTrace();
                        System.out.println(e2.getMessage());
                        System.out.println(e2.getLocalizedMessage());
                    }
                }
                d += counts.myObjective;
                for (int i6 = 0; i6 < this.dimension; i6++) {
                    int i7 = i6;
                    dArr2[i7] = dArr2[i7] + counts.myDerivatives[i6];
                }
                i3 += counts.unparsableTrees;
                i4 += counts.incorrectLLTrees;
            }
            if (this.spanPredictor != null) {
                int length = this.dimension - this.spanGoldCounts.length;
                double d2 = 0.0d;
                for (int i8 = 0; i8 < this.spanGoldCounts.length; i8++) {
                    d2 += dArr2[length + i8];
                    int i9 = length + i8;
                    dArr2[i9] = dArr2[i9] + this.spanGoldCounts[i8];
                    if (SloppyMath.isVeryDangerous(dArr2[length + i8])) {
                        System.out.print(dArr2[length + i8] + " ");
                    }
                }
                System.out.println(d2);
            }
            System.out.print(" done. ");
            if (i3 > 0) {
                System.out.println(i3 + " trees were not parsable.");
            }
            if (i4 > 0) {
                System.out.println(i4 + " trees had a higher gold LL than all LL.");
            }
            System.out.print("\nThe objective was " + d);
            this.lastUnregularizedDerivative = (double[]) dArr2.clone();
            switch (this.myRegularization) {
                case 1:
                    d = l1_regularize(d, dArr2);
                    System.out.print(" and is " + d + " after L1 regularization");
                    break;
                case L2_REGULARIZATION /* 2 */:
                    d = l2_regularize(d, dArr2);
                    System.out.print(" and is " + d + " after L2 regularization");
                    break;
            }
            System.out.print(".\n");
            double d3 = d * (-1.0d);
            for (int i10 = 0; i10 < dArr2.length; i10++) {
                int i11 = i10;
                dArr2[i11] = dArr2[i11] * (-1.0d);
                double[] dArr3 = this.lastUnregularizedDerivative;
                int i12 = i10;
                dArr3[i12] = dArr3[i12] * (-1.0d);
            }
            this.lastValue = d3;
            this.lastDerivative = dArr2;
            if (d3 >= this.bestObjectiveSoFar || ConditionalTrainer.Options.dontSaveGrammarsAfterEachIteration) {
                return;
            }
            this.bestObjectiveSoFar = d3;
            ParserData parserData = new ParserData(this.lexicon, this.grammar, this.spanPredictor, Numberer.getNumberers(), this.grammar.numSubStates, 1, 0, Binarization.RIGHT);
            int i13 = (int) d3;
            System.out.println("Saving grammar to " + this.outFileName + "-" + i13 + ".");
            if (parserData.Save(this.outFileName + "-" + i13)) {
                return;
            }
            System.out.println("Saving failed!");
        }
    }

    private boolean requiresUpdate(double[] dArr) {
        if (this.x == null) {
            return true;
        }
        for (int i = 0; i < this.x.length; i++) {
            if (dArr[i] == Double.NaN) {
                System.out.println("Optimizer proposed " + this.x[i]);
                dArr[i] = Double.NEGATIVE_INFINITY;
            }
            if (this.x[i] != dArr[i]) {
                return true;
            }
        }
        return false;
    }

    public double l2_regularize(double d, double[] dArr) {
        if (SloppyMath.isVeryDangerous(d)) {
            return d;
        }
        double d2 = this.sigma * this.sigma;
        double d3 = 0.0d;
        for (int i = 0; i < this.x.length; i++) {
            d3 += this.x[i] * this.x[i];
        }
        double d4 = d - (d3 / (2.0d * d2));
        for (int i2 = 0; i2 < this.x.length; i2++) {
            int i3 = i2;
            dArr[i3] = dArr[i3] - (this.x[i2] / d2);
            if (SloppyMath.isVeryDangerous(dArr[i2])) {
                System.out.println("Setting regularized derivative to zero because it is Inf.");
                dArr[i2] = 0.0d;
            }
        }
        return d4;
    }

    public double l1_regularize(double d, double[] dArr) {
        if (SloppyMath.isVeryDangerous(d)) {
            return d;
        }
        double d2 = this.sigma * this.sigma;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        int i4 = 0;
        for (int i5 = 0; i5 < this.nGrammarWeights; i5++) {
            int i6 = i;
            i++;
            i2 = (int) (i2 + Math.abs(this.x[i6]));
        }
        int i7 = (int) (i2 / (2.0d * d2));
        for (int i8 = 0; i8 < this.nLexiconWeights; i8++) {
            int i9 = i;
            i++;
            i3 = (int) (i3 + Math.abs(this.x[i9]));
        }
        int i10 = (int) (i3 / (2.0d * d2));
        for (int i11 = 0; i11 < this.nSpanWeights; i11++) {
            int i12 = i;
            i++;
            i4 = (int) (i4 + Math.abs(this.x[i12]));
        }
        double d3 = d - ((i7 + i10) + ((int) (i4 / (2.0d * 1.0d))));
        int i13 = 0;
        for (int i14 = 0; i14 < this.nGrammarWeights; i14++) {
            if (this.x[i13] < 0.0d) {
                int i15 = i13;
                dArr[i15] = dArr[i15] - ((-1.0d) / d2);
            } else if (this.x[i13] > 0.0d) {
                int i16 = i13;
                dArr[i16] = dArr[i16] - (1.0d / d2);
            } else if (dArr[i13] < (-1.0d) / d2) {
                int i17 = i13;
                dArr[i17] = dArr[i17] - (1.0d / d2);
            } else if (dArr[i13] > 1.0d / d2) {
                int i18 = i13;
                dArr[i18] = dArr[i18] - ((-1.0d) / d2);
            } else {
                dArr[i13] = 0.0d;
                this.lastUnregularizedDerivative[i13] = 0.0d;
            }
            if (SloppyMath.isVeryDangerous(dArr[i13]) || Math.abs(dArr[i13]) > 1.0E10d) {
                System.out.println("Setting regularized derivative to zero because it is " + dArr[i13]);
                dArr[i13] = 0.0d;
                this.lastUnregularizedDerivative[i13] = 0.0d;
            }
            i13++;
        }
        for (int i19 = 0; i19 < this.nLexiconWeights; i19++) {
            if (this.x[i13] < 0.0d) {
                int i20 = i13;
                dArr[i20] = dArr[i20] - ((-1.0d) / d2);
            } else if (this.x[i13] > 0.0d) {
                int i21 = i13;
                dArr[i21] = dArr[i21] - (1.0d / d2);
            } else if (dArr[i13] < (-1.0d) / d2) {
                int i22 = i13;
                dArr[i22] = dArr[i22] - (1.0d / d2);
            } else if (dArr[i13] > 1.0d / d2) {
                int i23 = i13;
                dArr[i23] = dArr[i23] - ((-1.0d) / d2);
            } else {
                dArr[i13] = 0.0d;
                this.lastUnregularizedDerivative[i13] = 0.0d;
            }
            if (SloppyMath.isVeryDangerous(dArr[i13]) || Math.abs(dArr[i13]) > 1.0E10d) {
                System.out.println("Setting regularized derivative to zero because it is " + dArr[i13]);
                dArr[i13] = 0.0d;
                this.lastUnregularizedDerivative[i13] = 0.0d;
            }
            i13++;
        }
        for (int i24 = 0; i24 < this.nSpanWeights; i24++) {
            if (this.x[i13] < 0.0d) {
                int i25 = i13;
                dArr[i25] = dArr[i25] - ((-1.0d) / 1.0d);
            } else if (this.x[i13] > 0.0d) {
                int i26 = i13;
                dArr[i26] = dArr[i26] - (1.0d / 1.0d);
            } else if (dArr[i13] < (-1.0d) / 1.0d) {
                int i27 = i13;
                dArr[i27] = dArr[i27] - (1.0d / 1.0d);
            } else if (dArr[i13] > 1.0d / 1.0d) {
                int i28 = i13;
                dArr[i28] = dArr[i28] - ((-1.0d) / 1.0d);
            } else {
                dArr[i13] = 0.0d;
                this.lastUnregularizedDerivative[i13] = 0.0d;
            }
            if (SloppyMath.isVeryDangerous(dArr[i13]) || Math.abs(dArr[i13]) > 1.0E10d) {
                System.out.println("Setting regularized derivative to zero because it is " + dArr[i13]);
                dArr[i13] = 0.0d;
                this.lastUnregularizedDerivative[i13] = 0.0d;
            }
            i13++;
        }
        return d3;
    }

    public ParsingObjectiveFunction() {
    }

    public ParsingObjectiveFunction(Linearizer linearizer, StateSetTreeList stateSetTreeList, double d, int i, String str, int i2, String str2, boolean z, boolean z2) {
        this.sigma = d;
        this.myRegularization = i;
        this.grammar = linearizer.getGrammar();
        this.lexicon = linearizer.getLexicon();
        this.spanPredictor = linearizer.getSpanPredictor();
        this.linearizer = linearizer;
        this.outFileName = str2;
        this.dimension = linearizer.dimension();
        this.nGrammarWeights = linearizer.getNGrammarWeights();
        this.nLexiconWeights = linearizer.getNLexiconWeights();
        this.nSpanWeights = linearizer.getNSpanWeights();
        if (this.spanPredictor != null) {
            this.spanGoldCounts = this.spanPredictor.countGoldSpanFeatures(stateSetTreeList);
        }
        int size = stateSetTreeList.size() / i2;
        this.consBaseName = str;
        boolean[][][][][] loadData = ParserConstrainer.loadData(str + "-0.data");
        size = loadData != null ? loadData.length : size;
        this.nProcesses = i2;
        this.trainingTrees = new StateSetTreeList[this.nProcesses];
        for (int i3 = 0; i3 < this.nProcesses; i3++) {
            this.trainingTrees[i3] = new StateSetTreeList();
        }
        int i4 = -1;
        int i5 = 0;
        for (int i6 = 0; i6 < stateSetTreeList.size(); i6++) {
            if (i6 % size == 0) {
                i4++;
                i5 = 0;
            }
            this.trainingTrees[i4 % this.nProcesses].add(stateSetTreeList.get(i6));
            i5++;
        }
        for (int i7 = 0; i7 < this.nProcesses; i7++) {
            System.out.println("Process " + i7 + " has " + this.trainingTrees[i7].size() + " trees.");
        }
        this.pool = Executors.newFixedThreadPool(this.nProcesses);
        this.tasks = new Calculator[this.nProcesses];
        for (int i8 = 0; i8 < this.nProcesses; i8++) {
            this.tasks[i8] = newCalculator(z, i8);
        }
        this.bestObjectiveSoFar = Double.POSITIVE_INFINITY;
    }

    @Override // edu.berkeley.nlp.discPCFG.ObjectiveFunction
    public void shutdown() {
        this.pool.shutdown();
    }

    protected Calculator newCalculator(boolean z, int i) {
        return new Calculator(this.trainingTrees[i], this.consBaseName, i, this.grammar, this.lexicon, this.spanPredictor, this.dimension, z);
    }

    public double[] getCurrentWeights() {
        return this.linearizer.getLinearizedWeights();
    }

    @Override // edu.berkeley.nlp.discPCFG.ObjectiveFunction
    public <F, L> double[] getLogProbabilities(EncodedDatum encodedDatum, double[] dArr, Encoding<F, L> encoding, IndexLinearizer indexLinearizer) {
        return null;
    }

    public void setSigma(double d) {
        this.sigma = d;
        this.x = null;
        this.bestObjectiveSoFar = Double.POSITIVE_INFINITY;
    }
}
