/*
 * Decompiled with CFR 0.152.
 */
package opennlp.ccg.synsem;

import java.io.File;
import java.io.IOException;
import java.net.URL;
import java.util.ArrayList;
import java.util.List;
import opennlp.ccg.grammar.Grammar;
import opennlp.ccg.lexicon.ListPairWord;
import opennlp.ccg.lexicon.Word;
import opennlp.ccg.ngrams.ConditionalProbabilityTable;
import opennlp.ccg.ngrams.NgramScorer;
import opennlp.ccg.perceptron.Alphabet;
import opennlp.ccg.perceptron.FeatureExtractor;
import opennlp.ccg.perceptron.FeatureList;
import opennlp.ccg.perceptron.FeatureVector;
import opennlp.ccg.synsem.DerivationHandler;
import opennlp.ccg.synsem.Sign;
import opennlp.ccg.synsem.SignScorer;
import opennlp.ccg.test.Regression;
import opennlp.ccg.test.RegressionInfo;
import opennlp.ccg.util.Pair;

public class GenerativeSyntacticModel
implements FeatureExtractor,
SignScorer {
    public static String genlogprobkey = "genlogprob";
    public static final String EXPANSION = "E";
    public static final String LEFT = "left";
    public static final String RIGHT = "right";
    public static final String UNARY = "unary";
    public static final String LEAF = "leaf";
    public static final String PARENT = "P";
    public static final String HEAD = "H";
    public static final String SIBLING = "S";
    public static final String LEXCAT_PARENT = "CP";
    public static final String POS_PARENT = "T";
    public static final String WORD_PARENT = "W";
    public static final String LEXCAT_SIBLING = "CS";
    public static final String POS_SIBLING = "TS";
    public static final String WORD_SIBLING = "WS";
    public static final String LEXCAT_TOP = "CT";
    public static final String POS_TOP = "TT";
    public static final String WORD_TOP = "WT";
    public static final String TOP = "<top>";
    protected ConditionalProbabilityTable topModel;
    protected ConditionalProbabilityTable leafModel;
    protected ConditionalProbabilityTable unaryModel;
    protected ConditionalProbabilityTable binaryModel;
    protected boolean debugScore = false;
    protected Alphabet alphabet = null;
    protected Alphabet.Feature genlogprobFeature = null;

    public GenerativeSyntacticModel(String topModelFN, String leafModelFN, String unaryModelFN, String binaryModelFN) throws IOException {
        this.topModel = new ConditionalProbabilityTable(topModelFN);
        this.leafModel = new ConditionalProbabilityTable(leafModelFN);
        this.unaryModel = new ConditionalProbabilityTable(unaryModelFN);
        this.binaryModel = new ConditionalProbabilityTable(binaryModelFN);
    }

    public void setDebug(boolean debugScore) {
        this.debugScore = debugScore;
        this.topModel.setDebug(debugScore);
        this.leafModel.setDebug(debugScore);
        this.unaryModel.setDebug(debugScore);
        this.binaryModel.setDebug(debugScore);
    }

    @Override
    public void setAlphabet(Alphabet alphabet) {
        this.alphabet = alphabet;
        ArrayList<String> keys = new ArrayList<String>(1);
        keys.add(genlogprobkey);
        this.genlogprobFeature = alphabet.closed() ? alphabet.index(keys) : alphabet.add(keys);
    }

    @Override
    public FeatureVector extractFeatures(Sign sign, boolean complete) {
        return this.genLogProbVector((float)this.logprob(sign, complete));
    }

    protected FeatureVector genLogProbVector(float logprob) {
        FeatureList retval = new FeatureList(1);
        if (this.genlogprobFeature != null) {
            retval.add(this.genlogprobFeature, Float.valueOf(logprob));
        }
        return retval;
    }

    @Override
    public double score(Sign sign, boolean complete) {
        return NgramScorer.convertToProb(this.logprob(sign, complete));
    }

    public double logprob(Sign sign, boolean complete) {
        LogProbGetter lpgetter = new LogProbGetter();
        if (complete) {
            return (Double)lpgetter.handleCompleteDerivation(sign);
        }
        return (Double)lpgetter.handleDerivation(sign);
    }

    public static List<Word> getFactors(Sign sign) {
        FactorsGetter fgetter = new FactorsGetter();
        fgetter.handleCompleteDerivation(sign);
        return fgetter.factors;
    }

    public static void addTopFactors(Sign sign, List<Pair<String, String>> pairs) {
        pairs.add(new Pair<String, String>(EXPANSION, TOP));
        pairs.add(new Pair<String, String>(PARENT, TOP));
        pairs.add(new Pair<String, String>(LEXCAT_PARENT, TOP));
        pairs.add(new Pair<String, String>(WORD_PARENT, TOP));
        pairs.add(new Pair<String, String>(HEAD, sign.getSupertag()));
        Sign lexHead = sign.getLexHead();
        pairs.add(new Pair<String, String>(LEXCAT_TOP, lexHead.getSupertag()));
        pairs.add(new Pair<String, String>(POS_TOP, lexHead.getPOS()));
        pairs.add(new Pair<String, String>(WORD_TOP, lexHead.getWordForm()));
    }

    public static void addLexFactors(Sign sign, List<Pair<String, String>> pairs) {
        pairs.add(new Pair<String, String>(EXPANSION, LEAF));
        GenerativeSyntacticModel.addParentFactors(sign, pairs);
    }

    public static void addParentFactors(Sign sign, List<Pair<String, String>> pairs) {
        pairs.add(new Pair<String, String>(PARENT, sign.getSupertag()));
        Sign lexHead = sign.getLexHead();
        pairs.add(new Pair<String, String>(LEXCAT_PARENT, lexHead.getSupertag()));
        pairs.add(new Pair<String, String>(POS_PARENT, lexHead.getPOS()));
        pairs.add(new Pair<String, String>(WORD_PARENT, lexHead.getWordForm()));
    }

    public static void addUnaryFactors(Sign sign, List<Pair<String, String>> pairs, Sign headChild) {
        pairs.add(new Pair<String, String>(EXPANSION, UNARY));
        GenerativeSyntacticModel.addParentFactors(sign, pairs);
        pairs.add(new Pair<String, String>(HEAD, headChild.getSupertag()));
    }

    public static void addBinaryFactors(Sign sign, List<Pair<String, String>> pairs, boolean left, Sign headChild, Sign siblingChild) {
        pairs.add(new Pair<String, String>(EXPANSION, left ? LEFT : RIGHT));
        GenerativeSyntacticModel.addParentFactors(sign, pairs);
        pairs.add(new Pair<String, String>(HEAD, headChild.getSupertag()));
        pairs.add(new Pair<String, String>(SIBLING, siblingChild.getSupertag()));
        Sign siblingLexHead = siblingChild.getLexHead();
        pairs.add(new Pair<String, String>(LEXCAT_SIBLING, siblingLexHead.getSupertag()));
        pairs.add(new Pair<String, String>(POS_SIBLING, siblingLexHead.getPOS()));
        pairs.add(new Pair<String, String>(WORD_SIBLING, siblingLexHead.getWordForm()));
    }

    public static void main(String[] args) throws IOException {
        String argstr = "(-dir <modeldir>) (-g <grammarfile>) (-t <testbedfile>) (-verbose)";
        String usage = "Usage: java opennlp.ccg.synsem.GenerativeSyntacticModel " + argstr;
        if (args.length > 0 && args[0].equals("-h")) {
            System.out.println(usage);
            System.exit(0);
        }
        String dir = ".";
        String topfn = "top.flm";
        String leaffn = "leaf.flm";
        String unaryfn = "unary.flm";
        String binaryfn = "binary.flm";
        String grammarfn = "grammar.xml";
        String tbfn = "testbed.xml";
        boolean verbose = false;
        for (int i = 0; i < args.length; ++i) {
            if (args[i].equals("-dir")) {
                dir = args[++i];
                continue;
            }
            if (args[i].equals("-g")) {
                grammarfn = args[++i];
                continue;
            }
            if (args[i].equals("-t")) {
                tbfn = args[++i];
                continue;
            }
            if (args[i].equals("-v") || args[i].equals("-verbose")) {
                verbose = true;
                continue;
            }
            System.out.println("Unrecognized option: " + args[i]);
        }
        URL grammarURL = new File(grammarfn).toURI().toURL();
        System.out.println("Loading grammar from URL: " + grammarURL);
        Grammar grammar = new Grammar(grammarURL);
        System.out.println("Loading syntactic model from: " + dir);
        topfn = dir + "/" + topfn;
        leaffn = dir + "/" + leaffn;
        unaryfn = dir + "/" + unaryfn;
        binaryfn = dir + "/" + binaryfn;
        GenerativeSyntacticModel model = new GenerativeSyntacticModel(topfn, leaffn, unaryfn, binaryfn);
        if (verbose) {
            model.setDebug(true);
        }
        double logprobttotal = 0.0;
        int numsents = 0;
        for (File f : Regression.getXMLFiles(new File(tbfn))) {
            System.out.println("Loading: " + f.getName());
            RegressionInfo rinfo = new RegressionInfo(grammar, f);
            for (int i = 0; i < rinfo.numberOfItems(); ++i) {
                RegressionInfo.TestItem item = rinfo.getItem(i);
                if (item.numOfParses == 0) continue;
                ++numsents;
                if (verbose) {
                    System.out.println("scoring: " + item.sentence);
                } else {
                    System.out.print(".");
                }
                Sign sign = item.sign;
                double logprob = model.logprob(sign, true);
                logprobttotal += logprob;
                if (!verbose) continue;
                System.out.println(sign.getDerivationHistory().toString());
                System.out.println("logprob: " + logprob);
            }
            System.out.println();
        }
        System.out.println("total logprob: " + logprobttotal);
        System.out.println("logprob per sentence: " + logprobttotal / (double)numsents);
    }

    public static class FactorsGetter
    extends DerivationHandler<Void> {
        public List<Word> factors = new ArrayList<Word>();
        private List<Pair<String, String>> pairs = null;

        private void newPairs() {
            this.pairs = new ArrayList<Pair<String, String>>();
        }

        private void addPairs() {
            this.factors.add(new ListPairWord(this.pairs));
        }

        @Override
        public Void topStep(Sign sign) {
            this.newPairs();
            GenerativeSyntacticModel.addTopFactors(sign, this.pairs);
            this.addPairs();
            this.handleDerivation(sign);
            return null;
        }

        @Override
        public Void lexStep(Sign sign) {
            this.newPairs();
            GenerativeSyntacticModel.addLexFactors(sign, this.pairs);
            this.addPairs();
            return null;
        }

        @Override
        public Void unaryStep(Sign sign, Sign headChild) {
            this.newPairs();
            GenerativeSyntacticModel.addUnaryFactors(sign, this.pairs, headChild);
            this.addPairs();
            this.handleDerivation(headChild);
            return null;
        }

        @Override
        public Void binaryStep(Sign sign, boolean left, Sign headChild, Sign siblingChild) {
            this.newPairs();
            GenerativeSyntacticModel.addBinaryFactors(sign, this.pairs, left, headChild, siblingChild);
            this.addPairs();
            this.handleDerivation(headChild);
            this.handleDerivation(siblingChild);
            return null;
        }
    }

    public class LogProbGetter
    extends DerivationHandler<Double> {
        private List<Pair<String, String>> pairs = new ArrayList<Pair<String, String>>();

        private String listPairs() {
            StringBuffer sb = new StringBuffer();
            for (Pair<String, String> pair : this.pairs) {
                sb.append((String)pair.a).append('-').append((String)pair.b).append(' ');
            }
            return sb.toString();
        }

        @Override
        public Double checkCache(Sign sign) {
            GenLogProb glp = (GenLogProb)sign.getData(GenLogProb.class);
            return glp == null ? null : Double.valueOf(glp.logprob);
        }

        @Override
        public void cache(Sign sign, Double total) {
            sign.addData(new GenLogProb(total));
        }

        @Override
        public Double topStep(Sign sign) {
            this.pairs.clear();
            GenerativeSyntacticModel.addTopFactors(sign, this.pairs);
            if (GenerativeSyntacticModel.this.debugScore) {
                System.out.println("[topStep] " + this.listPairs());
            }
            return GenerativeSyntacticModel.this.topModel.logprob(this.pairs) + (Double)this.handleDerivation(sign);
        }

        @Override
        public Double lexStep(Sign sign) {
            this.pairs.clear();
            GenerativeSyntacticModel.addLexFactors(sign, this.pairs);
            if (GenerativeSyntacticModel.this.debugScore) {
                System.out.println("[lexStep] " + this.listPairs());
            }
            return GenerativeSyntacticModel.this.leafModel.logprob(this.pairs);
        }

        @Override
        public Double unaryStep(Sign sign, Sign headChild) {
            this.pairs.clear();
            GenerativeSyntacticModel.addUnaryFactors(sign, this.pairs, headChild);
            if (GenerativeSyntacticModel.this.debugScore) {
                System.out.println("[unaryStep] " + this.listPairs());
            }
            return GenerativeSyntacticModel.this.unaryModel.logprob(this.pairs) + (Double)this.handleDerivation(headChild);
        }

        @Override
        public Double binaryStep(Sign sign, boolean left, Sign headChild, Sign siblingChild) {
            this.pairs.clear();
            GenerativeSyntacticModel.addBinaryFactors(sign, this.pairs, left, headChild, siblingChild);
            if (GenerativeSyntacticModel.this.debugScore) {
                System.out.println("[binaryStep] " + this.listPairs());
            }
            return GenerativeSyntacticModel.this.binaryModel.logprob(this.pairs) + (Double)this.handleDerivation(headChild) + (Double)this.handleDerivation(siblingChild);
        }
    }

    public static class GenLogProb {
        public final double logprob;

        public GenLogProb(double logprob) {
            this.logprob = logprob;
        }
    }
}

