/*
 * Decompiled with CFR 0.152.
 */
package opennlp.ccg.parse.tagger.sequencescoring;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.logging.Level;
import java.util.logging.Logger;
import opennlp.ccg.lexicon.Word;
import opennlp.ccg.ngrams.StandardNgramModel;
import opennlp.ccg.parse.tagger.Constants;
import opennlp.ccg.parse.tagger.ProbIndexPair;
import opennlp.ccg.parse.tagger.sequencescoring.Backpointer;
import opennlp.ccg.parse.tagger.sequencescoring.Trellis;
import opennlp.ccg.util.Interner;
import opennlp.ccg.util.Pair;

public class SequenceScorer
extends StandardNgramModel {
    private Trellis<Word> seqLabs;
    private Trellis<Double> initScores;
    private Trellis<Double> fbScores;
    private Trellis<Backpointer> backPointers;
    private int searchBeam = 200;
    private List<List<Double>> tmpInitScores = new ArrayList<List<Double>>(500);
    private List<List<Double>> tmpFwdScores = new ArrayList<List<Double>>(500);
    private List<List<Word>> tmpSeqLabs = new ArrayList<List<Word>>(500);
    private List<List<Backpointer>> tmpBkpointers = new ArrayList<List<Backpointer>>(500);
    private Interner<Word> words = new Interner();
    private Constants.TaggingAlgorithm alg = Constants.TaggingAlgorithm.FORWARDBACKWARD;

    public SequenceScorer(int order, String lmFile) throws IOException {
        super(order, lmFile);
    }

    public static int findOrder(String tagSequenceModel) {
        BufferedReader reader = null;
        String ln = null;
        int ord = 0;
        try {
            reader = new BufferedReader(new FileReader(new File(tagSequenceModel)));
            ln = reader.readLine();
            reader = new BufferedReader(new FileReader(new File(tagSequenceModel)));
            while (ln != null && !ln.startsWith("\\data\\")) {
                ln = reader.readLine();
            }
            ln = reader.readLine();
            while (ln != null & ln.startsWith("ngram ")) {
                ord = Integer.parseInt(ln.split(" ")[1].split("=")[0]);
                ln = reader.readLine();
            }
            reader.close();
        }
        catch (FileNotFoundException fnfe) {
            Logger.getLogger(SequenceScorer.class.getName()).log(Level.SEVERE, null, fnfe);
        }
        catch (IOException ioe) {
            Logger.getLogger(SequenceScorer.class.getName()).log(Level.SEVERE, null, ioe);
        }
        return ord;
    }

    public void setAlgorithm(Constants.TaggingAlgorithm newAlg) {
        this.alg = newAlg;
    }

    public void setSearchBeam(int newBeam) {
        this.searchBeam = newBeam;
    }

    public List<List<Pair<Double, String>>> rescoreSequence(List<List<Pair<Double, String>>> observationSequence) {
        List<Word> bestHist;
        this.tmpInitScores.clear();
        this.tmpFwdScores.clear();
        this.tmpSeqLabs.clear();
        this.tmpBkpointers.clear();
        for (List<Pair<Double, String>> tw : observationSequence) {
            ArrayList<Double> scrs = new ArrayList<Double>(tw.size());
            ArrayList<Object> fscs = new ArrayList<Object>(tw.size());
            ArrayList<Word> sLabs = new ArrayList<Word>(tw.size());
            ArrayList<Object> bpts = new ArrayList<Object>(tw.size());
            for (Pair<Double, String> tagging : tw) {
                scrs.add((Double)tagging.a > 0.0 ? Math.log((Double)tagging.a) : (Double)tagging.a);
                fscs.add(null);
                sLabs.add(this.words.intern(Word.createWord((String)tagging.b, null, null, null, null, null, null)));
                bpts.add(null);
            }
            this.tmpInitScores.add(scrs);
            this.tmpSeqLabs.add(sLabs);
            this.tmpFwdScores.add(fscs);
            this.tmpBkpointers.add(bpts);
        }
        this.initScores = new Trellis(this.tmpInitScores);
        this.fbScores = new Trellis(this.tmpFwdScores);
        this.backPointers = new Trellis(this.tmpBkpointers);
        this.seqLabs = new Trellis(this.tmpSeqLabs);
        for (int u = 0; u < observationSequence.size(); ++u) {
            int v;
            List<Pair<Double, String>> tw;
            tw = observationSequence.get(u);
            double normTot = 0.0;
            for (v = 0; v < tw.size(); ++v) {
                Word currTag = this.seqLabs.getCoord(u, v);
                bestHist = null;
                Double seqScore = null;
                Double obsScore = this.initScores.getCoord(u, v);
                if (u == 0) {
                    bestHist = this.getBestHist(u, v, this.order);
                    bestHist.add(currTag);
                    seqScore = this.lmScore(bestHist);
                    double fs = seqScore + obsScore;
                    normTot += Math.exp(fs);
                    this.fbScores.setCoord(u, v, fs);
                    continue;
                }
                List<Pair<Double, String>> prevTaggedWord = observationSequence.get(u - 1);
                Object[] bestPrevScores = new ProbIndexPair[Math.min(prevTaggedWord.size(), this.searchBeam)];
                for (int z = 0; z < Math.min(prevTaggedWord.size(), this.searchBeam); ++z) {
                    bestHist = this.getBestHist(u - 1, z, this.order - 1);
                    bestHist.add(currTag);
                    seqScore = this.lmScore(bestHist);
                    double fs = this.fbScores.getCoord(u - 1, z) + seqScore;
                    bestPrevScores[z] = new ProbIndexPair(fs += obsScore.doubleValue(), z);
                }
                Arrays.sort(bestPrevScores);
                double fsum = 0.0;
                for (int q = 0; q < bestPrevScores.length; ++q) {
                    fsum += Math.exp(((ProbIndexPair)bestPrevScores[q]).a);
                }
                normTot += fsum;
                this.fbScores.setCoord(u, v, Math.log(fsum));
                ArrayList<Integer> bks = new ArrayList<Integer>(bestPrevScores.length);
                for (int q = 0; q < bestPrevScores.length; ++q) {
                    bks.add(((ProbIndexPair)bestPrevScores[q]).b);
                }
                this.backPointers.setCoord(u, v, new Backpointer(bks));
            }
            for (v = 0; v < tw.size(); ++v) {
                this.fbScores.setCoord(u, v, Math.log(Math.exp(this.fbScores.getCoord(u, v)) / normTot));
            }
        }
        int size = observationSequence.size();
        if (this.alg == Constants.TaggingAlgorithm.FORWARDBACKWARD) {
            for (int u = size - 1; u >= 0; --u) {
                int v;
                List<Pair<Double, String>> tw = observationSequence.get(u);
                double normTot = 0.0;
                for (v = 0; v < tw.size(); ++v) {
                    bestHist = null;
                    Double obsScore = this.initScores.getCoord(u, v);
                    if (u == size - 1) {
                        bestHist = this.getBestHist(u, v, this.order - 1);
                        bestHist.add(this.words.intern(Word.createWord("</s>", null, null, null, null, null, null)));
                        double bsc = this.fbScores.getCoord(u, v) + obsScore;
                        normTot += Math.exp(bsc);
                        this.fbScores.setCoord(u, v, bsc);
                        continue;
                    }
                    bestHist = this.getBestHist(u, v, this.order - 1);
                    List<Pair<Double, String>> followingTaggedWd = observationSequence.get(u + 1);
                    double backwardSum = 0.0;
                    for (int z = 0; z < followingTaggedWd.size(); ++z) {
                        Word followingTag = this.words.intern(Word.createWord(((String)followingTaggedWd.get((int)z).b).intern(), null, null, null, null, null, null));
                        if (z > 0) {
                            bestHist.remove(bestHist.size() - 1);
                        }
                        bestHist.add(followingTag);
                        backwardSum += Math.exp(this.lmScore(bestHist) + this.fbScores.getCoord(u + 1, z));
                    }
                    double newSc = Math.log(backwardSum) + obsScore;
                    normTot += Math.exp(newSc);
                    this.fbScores.setCoord(u, v, newSc);
                }
                for (v = 0; v < tw.size(); ++v) {
                    this.fbScores.setCoord(u, v, Math.log(Math.exp(this.fbScores.getCoord(u, v)) / normTot));
                }
            }
        }
        for (int i = 0; i < observationSequence.size(); ++i) {
            Object[] fwdScrs = new ProbIndexPair[observationSequence.get(i).size()];
            List<Pair<Double, String>> tagging = observationSequence.get(i);
            for (int j = 0; j < tagging.size(); ++j) {
                double probP = Math.exp(this.fbScores.getCoord(i, j));
                fwdScrs[j] = new ProbIndexPair(probP, new Integer(j));
            }
            Arrays.sort(fwdScrs);
            ArrayList newTagging = new ArrayList(fwdScrs.length);
            for (int z = 0; z < fwdScrs.length; ++z) {
                Double renorm = new Double(((ProbIndexPair)fwdScrs[z]).a);
                if (renorm.equals(Constants.one)) {
                    renorm = Constants.one;
                }
                newTagging.add(new Pair(renorm, tagging.get((int)((ProbIndexPair)fwdScrs[z]).b.intValue()).b));
            }
            observationSequence.set(i, newTagging);
        }
        return observationSequence;
    }

    private double lmScore(List<Word> seq) {
        this.setWordsToScore(seq, false);
        this.prepareToScoreWords();
        return this.logprob();
    }

    private List<Word> getBestHist(int i, int j, int order) {
        int size = Math.max(order, 0);
        List<Word> retVal = null;
        Backpointer bp = this.backPointers.getCoord(i, j);
        if (i == -1) {
            retVal = new ArrayList<Word>(size);
            retVal.add(this.words.intern(Word.createWord("<s>", null, null, null, null, null, null)));
            return retVal;
        }
        if (i == 0) {
            retVal = this.getBestHist(i - 1, 0, order - 1);
            retVal.add(this.seqLabs.getCoord(i, j));
            return retVal;
        }
        if (order == 0) {
            retVal = new ArrayList(size);
            return retVal;
        }
        retVal = this.getBestHist(i - 1, bp.get(0), order - 1);
        retVal.add(this.seqLabs.getCoord(i, j));
        return retVal;
    }
}

