/*
 * Decompiled with CFR 0.152.
 */
package com.aliasi.lm;

import com.aliasi.lm.LanguageModel;
import com.aliasi.symbol.SymbolTable;
import com.aliasi.tokenizer.Tokenizer;
import com.aliasi.tokenizer.TokenizerFactory;
import com.aliasi.util.Strings;
import java.io.IOException;
import java.io.ObjectInput;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Iterator;

public class CompiledTokenizedLM
implements LanguageModel.Sequence,
LanguageModel.Tokenized {
    private final TokenizerFactory mTokenizerFactory;
    private final SymbolTable mSymbolTable;
    private final LanguageModel.Sequence mUnknownTokenModel;
    private final LanguageModel.Sequence mWhitespaceModel;
    private final int mMaxNGram;
    private final int[] mTokens;
    private final float[] mLogProbs;
    private final float[] mLogLambdas;
    private final int[] mFirstChild;

    CompiledTokenizedLM(ObjectInput in) throws IOException, ClassNotFoundException {
        String tokenizerClassName = in.readUTF();
        if (tokenizerClassName.equals("")) {
            this.mTokenizerFactory = (TokenizerFactory)in.readObject();
        } else {
            try {
                Class<?> tokenizerClass = Class.forName(tokenizerClassName);
                Constructor<?> tokCons = tokenizerClass.getConstructor(new Class[0]);
                this.mTokenizerFactory = (TokenizerFactory)tokCons.newInstance(new Object[0]);
            }
            catch (NoSuchMethodException e) {
                throw new ClassNotFoundException("Constructing " + tokenizerClassName, e);
            }
            catch (InstantiationException e) {
                throw new ClassNotFoundException("Constructing " + tokenizerClassName, e);
            }
            catch (IllegalAccessException e) {
                throw new ClassNotFoundException("Constructing " + tokenizerClassName, e);
            }
            catch (InvocationTargetException e) {
                throw new ClassNotFoundException("Constructing " + tokenizerClassName, e);
            }
        }
        this.mSymbolTable = (SymbolTable)in.readObject();
        this.mUnknownTokenModel = (LanguageModel.Sequence)in.readObject();
        this.mWhitespaceModel = (LanguageModel.Sequence)in.readObject();
        this.mMaxNGram = in.readInt();
        int numNodes = in.readInt();
        int lastInternalNodeIndex = in.readInt();
        this.mTokens = new int[numNodes];
        this.mLogProbs = new float[numNodes];
        this.mLogLambdas = new float[lastInternalNodeIndex + 1];
        this.mFirstChild = new int[lastInternalNodeIndex + 2];
        this.mFirstChild[this.mFirstChild.length - 1] = numNodes;
        for (int i = 0; i < numNodes; ++i) {
            this.mTokens[i] = in.readInt();
            this.mLogProbs[i] = in.readFloat();
            if (i > lastInternalNodeIndex) continue;
            this.mLogLambdas[i] = in.readFloat();
            this.mFirstChild[i] = in.readInt();
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder();
        sb.append("Tokenizer Class Name=" + this.mTokenizerFactory);
        sb.append('\n');
        sb.append("Symbol Table=" + this.mSymbolTable);
        sb.append('\n');
        sb.append("Unknown Token Model=" + this.mUnknownTokenModel);
        sb.append('\n');
        sb.append("Whitespace Model=" + this.mWhitespaceModel);
        sb.append('\n');
        sb.append("Token Trie");
        sb.append('\n');
        sb.append("Nodes=" + this.mTokens.length + " Internal=" + this.mLogLambdas.length);
        sb.append('\n');
        sb.append("Index Tok logP firstDtr log(1-L)");
        sb.append('\n');
        for (int i = 0; i < this.mTokens.length; ++i) {
            sb.append(i);
            sb.append('\t');
            sb.append(this.mTokens[i]);
            sb.append('\t');
            sb.append(this.mLogProbs[i]);
            if (i < this.mFirstChild.length) {
                sb.append('\t');
                sb.append(this.mFirstChild[i]);
                if (i < this.mLogLambdas.length) {
                    sb.append('\t');
                    sb.append(this.mLogLambdas[i]);
                }
            }
            sb.append('\n');
        }
        return sb.toString();
    }

    @Override
    public double log2Estimate(CharSequence cSeq) {
        char[] cs = Strings.toCharArray(cSeq);
        return this.log2Estimate(cs, 0, cs.length);
    }

    @Override
    public double log2Estimate(char[] cs, int start, int end) {
        Strings.checkArgsStartEnd(cs, start, end);
        double logEstimate = 0.0;
        Tokenizer tokenizer = this.mTokenizerFactory.tokenizer(cs, start, end - start);
        ArrayList<String> tokenList = new ArrayList<String>();
        while (true) {
            String whitespace = tokenizer.nextWhitespace();
            logEstimate += this.mWhitespaceModel.log2Estimate(whitespace);
            String token = tokenizer.nextToken();
            if (token == null) break;
            tokenList.add(token);
        }
        int[] tokIds = new int[tokenList.size() + 2];
        tokIds[0] = -2;
        tokIds[tokIds.length - 1] = -2;
        Iterator it = tokenList.iterator();
        int i = 1;
        while (it.hasNext()) {
            String token = (String)it.next();
            tokIds[i] = this.mSymbolTable.symbolToID(token);
            if (tokIds[i] < 0) {
                logEstimate += this.mUnknownTokenModel.log2Estimate(token);
            }
            ++i;
        }
        for (i = 2; i <= tokIds.length; ++i) {
            logEstimate += this.conditionalTokenEstimate(tokIds, 0, i);
        }
        return logEstimate;
    }

    private double conditionalTokenEstimate(int[] tokIds, int start, int end) {
        int maxContextLength;
        double estimate = 0.0;
        int contextEnd = end - 1;
        int tokId = tokIds[contextEnd];
        for (int contextLength = maxContextLength = Math.min(contextEnd - start, this.mMaxNGram - 1); contextLength >= 0; --contextLength) {
            int contextStart = contextEnd - contextLength;
            int contextIndex = this.getIndex(tokIds, contextStart, contextEnd);
            if (contextIndex == -1) continue;
            if (tokId == -1) {
                if (!this.hasDtrs(contextIndex)) continue;
                estimate += (double)this.mLogLambdas[contextIndex];
                continue;
            }
            int outcomeIndex = this.getIndex(contextIndex, tokId);
            if (outcomeIndex != -1) {
                return estimate + (double)this.mLogProbs[outcomeIndex];
            }
            if (!this.hasDtrs(contextIndex)) continue;
            estimate += (double)this.mLogLambdas[contextIndex];
        }
        return estimate;
    }

    @Override
    public double tokenLog2Probability(String[] tokens, int start, int end) {
        int[] tokIds = new int[tokens.length];
        for (int i = 0; i < tokens.length; ++i) {
            tokIds[i] = this.mSymbolTable.symbolToID(tokens[i]);
        }
        double sum = 0.0;
        for (int i = start + 1; i <= end; ++i) {
            sum += this.conditionalTokenEstimate(tokIds, start, i);
        }
        return sum;
    }

    @Override
    public double tokenProbability(String[] tokens, int start, int end) {
        return Math.pow(2.0, this.tokenLog2Probability(tokens, start, end));
    }

    boolean hasDtrs(int contextIndex) {
        return contextIndex < this.mLogLambdas.length && !Double.isNaN(this.mLogLambdas[contextIndex]);
    }

    private int getIndex(int fromIndex, int tokId) {
        if (fromIndex + 1 >= this.mFirstChild.length) {
            return -1;
        }
        int low = this.mFirstChild[fromIndex];
        int high = this.mFirstChild[fromIndex + 1] - 1;
        while (low <= high) {
            int mid = (high + low) / 2;
            if (this.mTokens[mid] == tokId) {
                return mid;
            }
            if (this.mTokens[mid] < tokId) {
                low = low == mid ? mid + 1 : mid;
                continue;
            }
            high = high == mid ? mid - 1 : mid;
        }
        return -1;
    }

    private int getIndex(int[] tokIds, int start, int end) {
        int index = 0;
        for (int currentStart = start; currentStart < end; ++currentStart) {
            if ((index = this.getIndex(index, tokIds[currentStart])) != -1) continue;
            return -1;
        }
        return index;
    }
}

