package edu.berkeley.nlp.PCFGLA;

import edu.berkeley.nlp.discPCFG.ParsingObjectiveFunction;
import edu.berkeley.nlp.io.PennTreebankReader;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.syntax.Trees;
import edu.berkeley.nlp.treebank.ChineseTreebankLanguagePack;
import edu.berkeley.nlp.util.Counter;
import java.io.FileInputStream;
import java.io.InputStreamReader;
import java.io.StringReader;
import java.nio.charset.Charset;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/berkeley/nlp/PCFGLA/Corpus.class */
public class Corpus {
    public static TreeBankType myTreebank = TreeBankType.WSJ;
    ArrayList<Tree<String>> trainTrees;
    ArrayList<Tree<String>> validationTrees;
    ArrayList<Tree<String>> devTestTrees;
    ArrayList<Tree<String>> finalTestTrees;

    /* loaded from: input_file:edu/berkeley/nlp/PCFGLA/Corpus$TreeBankType.class */
    public enum TreeBankType {
        BROWN,
        WSJ,
        CHINESE,
        GERMAN,
        SPANISH,
        FRENCH,
        CONLL,
        SINGLEFILE
    }

    public Corpus(String str, TreeBankType treeBankType, double d, boolean z) {
        this(str, treeBankType, d, z, -1);
    }

    public Corpus(String str, TreeBankType treeBankType, double d, boolean z, int i) {
        this(str, treeBankType, z, i);
        int size = this.trainTrees.size();
        this.trainTrees = new ArrayList<>(this.trainTrees.subList(0, (int) Math.ceil(size * d)));
        int i2 = 0;
        Iterator<Tree<String>> it = this.trainTrees.iterator();
        while (it.hasNext()) {
            i2 += it.next().getYield().size();
        }
        System.out.println("In training set we have # of words: " + i2);
        System.out.println("reducing number of training trees from " + size + " to " + this.trainTrees.size());
    }

    private Corpus(String str, TreeBankType treeBankType, boolean z, int i) {
        this.trainTrees = new ArrayList<>();
        this.validationTrees = new ArrayList<>();
        this.devTestTrees = new ArrayList<>();
        this.finalTestTrees = new ArrayList<>();
        myTreebank = treeBankType;
        if (!(str == null)) {
            try {
                if (myTreebank == TreeBankType.CHINESE) {
                    System.out.println("Loading CHINESE data!");
                    loadChinese(str);
                } else if (myTreebank == TreeBankType.WSJ) {
                    System.out.println("Loading ENGLISH WSJ data!");
                    loadWSJ(str, z, i);
                } else if (myTreebank == TreeBankType.GERMAN) {
                    System.out.println("Loading GERMAN data!");
                    loadGerman(str);
                } else if (myTreebank == TreeBankType.BROWN) {
                    System.out.println("Loading BROWN data!");
                    loadBrown(str);
                } else if (myTreebank == TreeBankType.SPANISH) {
                    System.out.println("Loading SPANISH data!");
                    loadSpanish(str);
                } else if (myTreebank == TreeBankType.FRENCH) {
                    System.out.println("Loading FRENCH data!");
                    loadCONLL(str, true);
                } else if (myTreebank == TreeBankType.CONLL) {
                    System.out.println("Loading CoNLL converted data!");
                    loadCONLL(str, false);
                } else if (myTreebank == TreeBankType.SINGLEFILE) {
                    System.out.println("Loading data from single file!");
                    loadSingleFile(str);
                }
                return;
            } catch (Exception e) {
                System.out.println("Error loading trees!");
                System.out.println(e.getStackTrace().toString());
                throw new Error(e.getMessage(), e);
            }
        }
        System.out.println("Loading one dummy sentence into training set only.");
        ArrayList arrayList = new ArrayList();
        switch (4) {
            case 0:
                arrayList.add("((S (A x) (C x)))");
                arrayList.add("((S (E x) (B x)))");
                break;
            case 1:
                arrayList.add("((S (NP (NP (DT The) (JJ complicated) (NN language)) (PP (IN in) (NP (DT the) (JJ huge) (JJ new) (NN law)))) (VP (VBZ has) (VP (VBD muddied) (NP (DT the) (NN fight)))) (. .)))");
                break;
            case ParsingObjectiveFunction.L2_REGULARIZATION /* 2 */:
                arrayList.add("((S (Z1 (Z2 x) (NNPS x)) (U3 (Uu (A1 (NNP x1) (NNPS x2))))))");
                arrayList.add("((S (K (U2 (Z1 (Z2 x) (NNP x)))) (U7 (NNS x))))");
                arrayList.add("((S (Z1 (NNPS x) (NN x)) (F (CC y) (ZZ z))))");
                break;
            case 3:
                arrayList.add("((X (C (B b) (B b)) (F (E (D d)))))");
                arrayList.add("((Y (C (B a) (B a)) (E (D d))))");
                arrayList.add("((X (C (B b) (B b)) (E (D d))))");
                break;
            case 4:
                arrayList.add("( (S (SBAR (IN In) (NN order) (S (VP (TO to) (VP (VB strengthen) (NP (NP (JJ cultural) (NN exchange) (CC and) (NN contact)) (PP (IN between) (NP (NP (NP (DT the) (NNS descendents)) (PP (IN of) (NP (DT the) (NNPS Emperors)))) (UCP (PP (IN at) (NP (NN home))) (CC and) (ADVP (RB abroad)))))))))) (, ,) (NP (NNP China)) (VP (MD will) (VP (VB hold) (NP (DT the) (JJ \") (NNP China) (NNP Art) (NNP Festival) (NN \")) (PP (IN in) (NP (NP (NNP Beijing)) (CC and) (NNP Shenzhen))) (ADVP (RB simultaneously)) (PP (IN from) (NP (DT the) (NN 8th))) (PP (TO to) (NP (NP (DT the) (JJ 18th)) (PP (IN of) (NP (NNP December))))) (NP (DT this) (NN year)))) (. .)) )");
                arrayList.add("( (S (PP (IN In) (NP (NP (NN order) (S (VP (TO to) (VP (VB strengthen) (NP (NP (JJ cultural) (NN exchange) (CC and) (NN contact)) (PP (IN between) (NP (NP (DT the) (NNS descendents)) (PP (IN of) (NP (DT the) (NNPS Emperors))) (PP (IN at) (NP (NN home)))))))))) (CC and) (ADVP (RB abroad)))) (, ,) (NP (NNP China)) (VP (MD will) (VP (VB hold) (NP (DT the) (JJ \") (NNP China) (NNP Art) (NNP Festival) (NN \")) (PP (IN in) (NP (NP (NNP Beijing)) (CC and) (NNP Shenzhen))) (ADVP (RB simultaneously)) (PP (IN from) (NP (DT the) (NN 8th))) (PP (TO to) (NP (NP (DT the) (JJ 18th)) (PP (IN of) (NP (NNP December))))) (NP (DT this) (NN year)))) (. .)) )");
                arrayList.add("( (S (PP (IN In) (NP (NN order) (S (VP (TO to) (VP (VB strengthen) (NP (NP (JJ cultural) (NN exchange) (CC and) (NN contact)) (PP (IN between) (NP (NP (DT the) (NNS descendents)) (PP (IN of) (NP (DT the) (NNPS Emperors)))))) (UCP (PP (IN at) (ADVP (RB home))) (CC and) (ADVP (RB abroad)))))))) (, ,) (NP (NNP China)) (VP (MD will) (VP (VB hold) (NP (DT the) (`` \") (NNP China) (NNP Art) (NNP Festival) (NN \")) (PP (IN in) (NP (NNP Beijing) (CC and) (NNP Shenzhen))) (ADVP (RB simultaneously)) (PP (PP (IN from) (NP (DT the) (NN 8th))) (PP (IN to) (NP (DT the) (NN 18th))) (PP (IN of) (NP (NNP December)))) (NP (DT this) (NN year)))) (. .)) )");
                break;
            case 5:
                arrayList.add("((X (C (B a) (B a)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (E (D d) (D d))))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (E (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                break;
            case 6:
                arrayList.add("((Y (C (B b) (B b)) (E (D d) (D d))))");
                arrayList.add("((Y (C (B b) (D b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (U (C (B b) (B b))) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                break;
            case 7:
                arrayList.add("((X (S (NP (X (PRP I))) (VP like))))");
                arrayList.add("((X (C (U (V (W (B a) (B a))))) (D d)))");
                arrayList.add("((X (Y (Z (V (C (B a) (B a))) (D d)))))");
                arrayList.add("((X (C (B a) (B a)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (E (D d) (D d))))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (U (C (B b) (B b))) (D d)))");
                arrayList.add("((Y (E (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                arrayList.add("((Y (C (B b) (B b)) (D d)))");
                break;
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            Tree<String> transformTree = new Trees.StandardTreeNormalizer().transformTree(new Trees.PennTreeReader(new StringReader((String) it.next())).next());
            this.trainTrees.add(transformTree);
            this.devTestTrees.add(transformTree);
            this.validationTrees.add(transformTree);
        }
    }

    private void loadChinese(String str) throws Exception {
        System.out.print("Loading Chinese treebank trees...");
        if (0 == 0) {
            this.trainTrees.addAll(readTrees(str, 1, 25, Charset.forName(ChineseTreebankLanguagePack.ENCODING)));
            this.trainTrees.addAll(readTrees(str, 26, 270, Charset.forName(ChineseTreebankLanguagePack.ENCODING)));
        }
        this.trainTrees.addAll(readTrees(str, 400, 1151, Charset.forName(ChineseTreebankLanguagePack.ENCODING)));
        this.devTestTrees.addAll(readTrees(str, 301, 325, Charset.forName(ChineseTreebankLanguagePack.ENCODING)));
        this.validationTrees.addAll(readTrees(str, 301, 325, Charset.forName(ChineseTreebankLanguagePack.ENCODING)));
        this.finalTestTrees.addAll(readTrees(str, 271, 300, Charset.forName(ChineseTreebankLanguagePack.ENCODING)));
        System.out.print("" + this.trainTrees.size() + " " + this.validationTrees.size() + " " + this.devTestTrees.size() + " " + this.finalTestTrees.size() + " trees...");
        System.out.println("done");
    }

    private void loadBrown(String str) throws Exception {
        String[] strArr = {"cf", "cg", "ck", "cl", "cm", "cn", "cp", "cr"};
        int[] iArr = new int[strArr.length];
        for (int i = 0; i < strArr.length; i++) {
            ArrayList arrayList = new ArrayList();
            ArrayList arrayList2 = new ArrayList();
            ArrayList arrayList3 = new ArrayList();
            ArrayList arrayList4 = new ArrayList();
            splitTrainValidTest(readTrees(str + "/" + strArr[i], 0, 1000, Charset.defaultCharset()), arrayList, arrayList2, arrayList3, arrayList4);
            this.trainTrees.addAll(arrayList);
            this.validationTrees.addAll(arrayList2);
            this.devTestTrees.addAll(arrayList3);
            this.finalTestTrees.addAll(arrayList4);
            iArr[i] = arrayList.size();
            System.out.println("I read " + iArr[i] + " training trees from section " + strArr[i]);
        }
    }

    private void loadSpanish(String str) throws Exception {
        System.out.print("Loading Spanish trees...");
        this.trainTrees.addAll(readTrees(str, 1, 1, Charset.defaultCharset()));
        this.validationTrees.addAll(readTrees(str, 2, 279, Charset.defaultCharset()));
        this.devTestTrees.addAll(readTrees(str, 2, 279, Charset.defaultCharset()));
        this.finalTestTrees.addAll(readTrees(str, 2, 279, Charset.defaultCharset()));
        System.out.println("done");
    }

    private void loadSingleFile(String str) throws Exception {
        System.out.print("Loading trees from single file...");
        Trees.PennTreeReader pennTreeReader = new Trees.PennTreeReader(new InputStreamReader(new FileInputStream(str), "UTF-8"));
        while (pennTreeReader.hasNext()) {
            this.trainTrees.add(pennTreeReader.next());
        }
        Trees.StandardTreeNormalizer standardTreeNormalizer = new Trees.StandardTreeNormalizer();
        ArrayList<Tree<String>> arrayList = new ArrayList<>();
        Iterator<Tree<String>> it = this.trainTrees.iterator();
        while (it.hasNext()) {
            arrayList.add(standardTreeNormalizer.transformTree(it.next()));
        }
        if (arrayList.size() == 0) {
            throw new Exception("failed to load any trees at " + str);
        }
        this.trainTrees = arrayList;
        this.devTestTrees = this.trainTrees;
        System.out.println("done");
    }

    private void loadCONLL(String str, boolean z) throws Exception {
        Charset forName = z ? Charset.forName("ISO8859_1") : Charset.forName("UTF-8");
        System.out.print("Loading CoNLL trees...");
        this.trainTrees = readAndPreprocessTrees(str, 1, 1, forName);
        this.validationTrees = readAndPreprocessTrees(str, 2, 2, forName);
        this.devTestTrees = readAndPreprocessTrees(str, 2, 2, forName);
        this.finalTestTrees = readAndPreprocessTrees(str, 3, 3, forName);
        Iterator<Tree<String>> it = this.trainTrees.iterator();
        while (it.hasNext()) {
            Tree<String> next = it.next();
            if (next.getChildren().size() != 1) {
                System.out.println("Malformed v: " + next);
            }
        }
        Iterator<Tree<String>> it2 = this.devTestTrees.iterator();
        while (it2.hasNext()) {
            Tree<String> next2 = it2.next();
            if (next2.getChildren().size() != 1) {
                System.out.println("Malformed v: " + next2);
            }
        }
        Iterator<Tree<String>> it3 = this.finalTestTrees.iterator();
        while (it3.hasNext()) {
            Tree<String> next3 = it3.next();
            if (next3.getChildren().size() != 1) {
                System.out.println("Malformed t: " + next3);
            }
        }
        System.out.println("done");
    }

    private ArrayList<Tree<String>> readAndPreprocessTrees(String str, int i, int i2, Charset charset) throws Exception {
        ArrayList<Tree> arrayList = new ArrayList();
        ArrayList<Tree<String>> arrayList2 = new ArrayList<>();
        arrayList.addAll(readTrees(str, i, i2, charset));
        for (Tree tree : arrayList) {
            ArrayList arrayList3 = new ArrayList(1);
            arrayList3.add(tree);
            arrayList2.add(new Tree<>("ROOT", arrayList3));
        }
        return arrayList2;
    }

    private void loadWSJ(String str, boolean z, int i) throws Exception {
        System.out.print("Loading WSJ trees...");
        if (!z) {
            if (i == -1) {
                this.trainTrees.addAll(readTrees(str, 200, 2199, Charset.defaultCharset()));
            } else {
                System.out.println("Skipping section " + i + ".");
                if (i == 2) {
                    this.trainTrees.addAll(readTrees(str, 300, 2199, Charset.defaultCharset()));
                } else if (i == 21) {
                    this.trainTrees.addAll(readTrees(str, 200, 2099, Charset.defaultCharset()));
                } else {
                    int i2 = i * 100;
                    this.trainTrees.addAll(readTrees(str, 200, i2 - 1, Charset.defaultCharset()));
                    this.trainTrees.addAll(readTrees(str, i2 + 100, 2199, Charset.defaultCharset()));
                }
            }
            this.validationTrees.addAll(readTrees(str, 2100, 2199, Charset.defaultCharset()));
        }
        this.devTestTrees.addAll(readTrees(str, 2200, 2299, Charset.defaultCharset()));
        this.finalTestTrees.addAll(readTrees(str, 2300, 2399, Charset.defaultCharset()));
        System.out.println("done");
    }

    private void loadGerman(String str) throws Exception {
        System.out.print("Loading German trees...");
        int i = 0;
        for (Tree<String> tree : readTrees(str, 1, 3, Charset.forName("UTF-8"))) {
            ArrayList arrayList = new ArrayList(1);
            tree.setLabel("PSEUDO");
            arrayList.add(tree);
            Tree<String> tree2 = new Tree<>("ROOT", arrayList);
            if (i < 18602) {
                this.trainTrees.add(tree2);
            } else if (i > 19601) {
                this.finalTestTrees.add(tree2);
            } else {
                this.validationTrees.add(tree2);
                this.devTestTrees.add(tree2);
            }
            i++;
        }
        System.out.println("done.\nThere are " + this.trainTrees.size() + " " + this.devTestTrees.size() + " " + this.finalTestTrees.size() + " trees.");
    }

    public static List<Tree<String>> readTrees(String str, int i, int i2, Charset charset) throws Exception {
        Collection<Tree<String>> readTrees = PennTreebankReader.readTrees(str, i, i2, charset);
        Trees.StandardTreeNormalizer standardTreeNormalizer = new Trees.StandardTreeNormalizer();
        ArrayList arrayList = new ArrayList();
        Iterator<Tree<String>> it = readTrees.iterator();
        while (it.hasNext()) {
            arrayList.add(standardTreeNormalizer.transformTree(it.next()));
        }
        if (arrayList.size() == 0) {
            throw new Exception("failed to load any trees at " + str + " from " + i + " to " + i2);
        }
        return arrayList;
    }

    public static void splitTrainValidTest(List<Tree<String>> list, List<Tree<String>> list2, List<Tree<String>> list3, List<Tree<String>> list4, List<Tree<String>> list5) {
        for (int i = 0; i < list.size(); i++) {
            if (i % 10 < 7) {
                list2.add(list.get(i));
            } else if (i % 10 == 7) {
                list3.add(list.get(i));
            } else if (i % 10 == 8) {
                list4.add(list.get(i));
            } else if (i % 10 == 9) {
                list5.add(list.get(i));
            }
        }
    }

    public static List<Tree<String>> filterTreesForConditional(List<Tree<String>> list, boolean z, boolean z2, boolean z3) {
        ArrayList arrayList = new ArrayList(list.size());
        for (Tree<String> tree : list) {
            if (tree.getYield().size() != 1) {
                if (tree.hasUnaryChain()) {
                    if (z3) {
                        tree.removeUnaryChains();
                    }
                }
                if (z2) {
                    Iterator<Tree<String>> it = tree.getNonTerminals().iterator();
                    while (it.hasNext()) {
                        if (it.next().getLabel().contains("WHNP")) {
                            break;
                        }
                    }
                }
                if (!z || !tree.hasUnariesOtherThanRoot()) {
                    arrayList.add(tree);
                }
            }
        }
        return arrayList;
    }

    public static List<Tree<String>> binarizeAndFilterTrees(List<Tree<String>> list, int i, int i2, int i3, Binarization binarization, boolean z, boolean z2) {
        return binarizeAndFilterTrees(list, i, i2, i3, binarization, z, z2, false);
    }

    public static List<Tree<String>> binarizeAndFilterTrees(List<Tree<String>> list, int i, int i2, int i3, Binarization binarization, boolean z, boolean z2, boolean z3) {
        ArrayList arrayList = new ArrayList();
        System.out.print("Binarizing and annotating trees...");
        if (z2) {
            System.out.println("annotation levels: vertical=" + i + " horizontal=" + i2);
        }
        int i4 = 0;
        for (Tree<String> tree : list) {
            i4++;
            if (tree.getYield().size() <= i3) {
                arrayList.add(TreeAnnotations.processTree(tree, i, i2, binarization, z, z3, true));
            }
        }
        System.out.print("done.\n");
        return arrayList;
    }

    public List<Tree<String>> getTrainTrees() {
        return this.trainTrees;
    }

    public List<Tree<String>> getValidationTrees() {
        return this.validationTrees;
    }

    public List<Tree<String>> getDevTestingTrees() {
        return this.devTestTrees;
    }

    public List<Tree<String>> getFinalTestingTrees() {
        return this.finalTestTrees;
    }

    public static List<Tree<String>> makePosTrees(List<Tree<String>> list) {
        System.out.print("Making POS-trees...");
        ArrayList arrayList = new ArrayList();
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            arrayList.add(makePosTree(it.next()));
        }
        System.out.print(" done.\n");
        return arrayList;
    }

    public static Tree<String> makePosTree(Tree<String> tree) {
        List<Tree<String>> terminals = tree.getTerminals();
        List<String> preTerminalYield = tree.getPreTerminalYield();
        int size = preTerminalYield.size();
        ArrayList arrayList = new ArrayList();
        arrayList.add(new Tree("STOP"));
        Tree tree2 = new Tree("STOP", arrayList);
        for (int i = size - 1; i >= 0; i--) {
            String str = preTerminalYield.get(i);
            ArrayList arrayList2 = new ArrayList();
            arrayList2.add(terminals.get(i));
            Tree tree3 = new Tree(str, arrayList2);
            ArrayList arrayList3 = new ArrayList();
            arrayList3.add(tree3);
            arrayList3.add(tree2);
            tree2 = new Tree(str, arrayList3);
        }
        ArrayList arrayList4 = new ArrayList();
        arrayList4.add(tree2);
        return new Tree<>(tree.getLabel(), arrayList4);
    }

    public static void replaceRareWords(StateSetTreeList stateSetTreeList, SimpleLexicon simpleLexicon, int i) {
        Counter counter = new Counter();
        Iterator<Tree<StateSet>> it = stateSetTreeList.iterator();
        while (it.hasNext()) {
            Iterator<StateSet> it2 = it.next().getYield().iterator();
            while (it2.hasNext()) {
                String word = it2.next().getWord();
                counter.incrementCount(word, 1.0d);
                simpleLexicon.wordIndexer.add(word);
            }
        }
        Iterator<Tree<StateSet>> it3 = stateSetTreeList.iterator();
        while (it3.hasNext()) {
            int i2 = 0;
            for (StateSet stateSet : it3.next().getYield()) {
                if (counter.getCount(stateSet.getWord()) <= i) {
                    stateSet.setWord(simpleLexicon.getSignature(stateSet.getWord(), i2));
                }
                i2++;
            }
        }
    }

    public static void replaceRareWords(List<Tree<String>> list, SimpleLexicon simpleLexicon, int i) {
        Counter counter = new Counter();
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            for (String str : it.next().getYield()) {
                counter.incrementCount(str, 1.0d);
                simpleLexicon.wordIndexer.add(str);
            }
        }
        Iterator<Tree<String>> it2 = list.iterator();
        while (it2.hasNext()) {
            int i2 = 0;
            for (Tree<String> tree : it2.next().getTerminals()) {
                if (counter.getCount(tree.getLabel()) <= i) {
                    tree.setLabel(simpleLexicon.getSignature(tree.getLabel(), i2));
                }
                i2++;
            }
        }
    }

    public static void lowercaseWords(List<Tree<String>> list) {
        Iterator<Tree<String>> it = list.iterator();
        while (it.hasNext()) {
            for (Tree<String> tree : it.next().getTerminals()) {
                tree.setLabel(tree.getLabel().toLowerCase());
            }
        }
    }
}
