/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.branchratemodel;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.branchratemodel.BranchRateModel;
import dr.evomodel.branchratemodel.CountableBranchCategoryProvider;
import dr.evomodel.tree.TreeModel;
import dr.evomodel.tree.TreeParameterModel;
import dr.inference.markovjumps.MarkovReward;
import dr.inference.markovjumps.TwoStateOccupancyMarkovReward;
import dr.inference.model.AbstractModelLikelihood;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.Variable;
import java.util.ArrayList;
import java.util.List;

public class SericolaLatentStateBranchRateModel
extends AbstractModelLikelihood
implements BranchRateModel {
    public static final String LATENT_STATE_BRANCH_RATE_MODEL = "latentStateBranchRateModel";
    public static final boolean USE_CACHING = true;
    private final TreeModel tree;
    private final BranchRateModel nonLatentRateModel;
    private final Parameter latentTransitionRateParameter;
    private final Parameter latentTransitionFrequencyParameter;
    private final TreeParameterModel latentStateProportions;
    private final Parameter latentStateProportionParameter;
    private final CountableBranchCategoryProvider branchCategoryProvider;
    private MarkovReward series;
    private MarkovReward storedSeries;
    private boolean likelihoodKnown = false;
    private boolean storedLikelihoodKnown;
    private double logLikelihood;
    private double storedLogLikelihood;
    private double[] branchLikelihoods;
    private double[] storedbranchLikelihoods;
    private boolean[] updateBranch;
    private boolean[] storedUpdateBranch;
    private boolean[] updateCategory;
    private boolean[] storedUpdateCategory;

    public SericolaLatentStateBranchRateModel(String string, TreeModel treeModel, BranchRateModel branchRateModel, Parameter parameter, Parameter parameter2, Parameter parameter3, CountableBranchCategoryProvider countableBranchCategoryProvider) {
        super(string);
        this.tree = treeModel;
        this.addModel(this.tree);
        this.nonLatentRateModel = branchRateModel;
        this.addModel(branchRateModel);
        this.latentTransitionRateParameter = parameter;
        this.addVariable(parameter);
        this.latentTransitionFrequencyParameter = parameter2;
        this.addVariable(parameter2);
        if (countableBranchCategoryProvider == null) {
            this.latentStateProportions = new TreeParameterModel(this.tree, parameter3, false, TreeTrait.Intent.BRANCH);
            this.addModel(this.latentStateProportions);
            this.latentStateProportionParameter = null;
            this.branchCategoryProvider = null;
        } else {
            this.latentStateProportions = null;
            this.branchCategoryProvider = countableBranchCategoryProvider;
            this.latentStateProportionParameter = parameter3;
            this.latentStateProportionParameter.setDimension(countableBranchCategoryProvider.getCategoryCount());
            this.updateCategory = new boolean[countableBranchCategoryProvider.getCategoryCount()];
            this.storedUpdateCategory = new boolean[countableBranchCategoryProvider.getCategoryCount()];
            this.setUpdateAllCategories();
            this.addVariable(parameter3);
        }
        this.branchLikelihoods = new double[this.tree.getNodeCount()];
        this.updateBranch = new boolean[this.tree.getNodeCount()];
        this.storedUpdateBranch = new boolean[this.tree.getNodeCount()];
        this.storedbranchLikelihoods = new double[this.tree.getNodeCount()];
        this.setUpdateAllBranches();
    }

    public SericolaLatentStateBranchRateModel(Parameter parameter, Parameter parameter2) {
        super(LATENT_STATE_BRANCH_RATE_MODEL);
        this.tree = null;
        this.nonLatentRateModel = null;
        this.latentTransitionRateParameter = parameter;
        this.latentTransitionFrequencyParameter = parameter2;
        this.latentStateProportions = null;
        this.latentStateProportionParameter = null;
        this.branchCategoryProvider = null;
    }

    private double[] createLatentInfinitesimalMatrix() {
        double d = this.latentTransitionRateParameter.getParameterValue(0);
        double d2 = this.latentTransitionFrequencyParameter.getParameterValue(0);
        double[] dArray = new double[]{-d * d2, d * d2, d * (1.0 - d2), -d * (1.0 - d2)};
        return dArray;
    }

    private static double[] createReward() {
        return new double[]{0.0, 1.0};
    }

    private MarkovReward createSeries() {
        TwoStateOccupancyMarkovReward twoStateOccupancyMarkovReward = new TwoStateOccupancyMarkovReward(this.createLatentInfinitesimalMatrix());
        return twoStateOccupancyMarkovReward;
    }

    @Override
    public double getBranchRate(Tree tree, NodeRef nodeRef) {
        double d = this.nonLatentRateModel.getBranchRate(tree, nodeRef);
        double d2 = this.getLatentProportion(tree, nodeRef);
        return this.calculateBranchRate(d, d2);
    }

    public double getLatentProportion(Tree tree, NodeRef nodeRef) {
        if (this.latentStateProportions != null) {
            return this.latentStateProportions.getNodeValue(tree, nodeRef);
        }
        return this.latentStateProportionParameter.getParameterValue(this.branchCategoryProvider.getBranchCategory(tree, nodeRef));
    }

    private double calculateBranchRate(double d, double d2) {
        return d * (1.0 - d2);
    }

    @Override
    protected void handleModelChangedEvent(Model model, Object object, int n) {
        if (model == this.tree) {
            this.likelihoodKnown = false;
            if (n == -1) {
                this.setUpdateAllBranches();
            } else {
                this.setUpdateBranch(n);
            }
        } else if (model != this.nonLatentRateModel && model == this.latentStateProportions) {
            this.likelihoodKnown = false;
            if (n == -1) {
                this.setUpdateAllBranches();
            } else {
                this.setUpdateBranch(n);
            }
        }
        this.fireModelChanged();
    }

    @Override
    protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        if (variable == this.latentTransitionFrequencyParameter || variable == this.latentTransitionRateParameter) {
            this.series = null;
            this.setUpdateAllBranches();
            this.likelihoodKnown = false;
        } else if (variable == this.latentStateProportionParameter) {
            if (n == -1) {
                this.setUpdateAllBranches();
            } else {
                this.setUpdateBranchCategory(n);
            }
            this.likelihoodKnown = false;
            this.fireModelChanged();
        }
    }

    private void setUpdateBranch(int n) {
        this.updateBranch[n] = true;
    }

    private void setUpdateAllBranches() {
        for (int i = 0; i < this.updateBranch.length; ++i) {
            this.updateBranch[i] = true;
        }
    }

    private void clearUpdateAllBranches() {
        for (int i = 0; i < this.updateBranch.length; ++i) {
            this.updateBranch[i] = false;
        }
    }

    private void setUpdateBranchCategory(int n) {
        this.updateCategory[n] = true;
    }

    private void setUpdateAllCategories() {
        for (int i = 0; i < this.updateCategory.length; ++i) {
            this.updateCategory[i] = true;
        }
    }

    private void clearAllCategories() {
        if (this.updateCategory != null) {
            for (int i = 0; i < this.updateCategory.length; ++i) {
                this.updateCategory[i] = false;
            }
        }
    }

    @Override
    protected void storeState() {
        this.storedSeries = this.series;
        this.storedLogLikelihood = this.logLikelihood;
        this.storedLikelihoodKnown = this.likelihoodKnown;
        System.arraycopy(this.branchLikelihoods, 0, this.storedbranchLikelihoods, 0, this.branchLikelihoods.length);
        System.arraycopy(this.updateBranch, 0, this.storedUpdateBranch, 0, this.updateBranch.length);
        if (this.updateCategory != null) {
            System.arraycopy(this.updateCategory, 0, this.storedUpdateCategory, 0, this.updateCategory.length);
        }
    }

    @Override
    protected void restoreState() {
        this.series = this.storedSeries;
        this.logLikelihood = this.storedLogLikelihood;
        this.likelihoodKnown = this.storedLikelihoodKnown;
        double[] dArray = this.branchLikelihoods;
        this.branchLikelihoods = this.storedbranchLikelihoods;
        this.storedbranchLikelihoods = dArray;
        boolean[] blArray = this.updateBranch;
        this.updateBranch = this.storedUpdateBranch;
        this.storedUpdateBranch = blArray;
        boolean[] blArray2 = this.updateCategory;
        this.updateCategory = this.storedUpdateCategory;
        this.storedUpdateCategory = blArray2;
    }

    @Override
    protected void acceptState() {
    }

    @Override
    public Model getModel() {
        return this;
    }

    @Override
    public double getLogLikelihood() {
        if (!this.likelihoodKnown) {
            this.logLikelihood = this.calculateLogLikelihood();
            this.likelihoodKnown = true;
        }
        return this.logLikelihood;
    }

    private double calculateLogLikelihood() {
        double d = 0.0;
        for (int i = 0; i < this.tree.getInternalNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (nodeRef == this.tree.getRoot()) continue;
            if (this.updateNeededForNode(this.tree, nodeRef)) {
                double d2 = this.tree.getBranchLength(nodeRef);
                double d3 = this.getLatentProportion(this.tree, nodeRef);
                assert (d3 < 1.0);
                double d4 = d2 * d3;
                double d5 = this.getBranchRewardDensity(d4, d2);
                this.branchLikelihoods[nodeRef.getNumber()] = Math.log(d5);
            }
            d += this.branchLikelihoods[nodeRef.getNumber()];
        }
        this.clearUpdateAllBranches();
        this.clearAllCategories();
        return d;
    }

    private boolean updateNeededForNode(Tree tree, NodeRef nodeRef) {
        return this.updateCategory != null && this.updateCategory[this.branchCategoryProvider.getBranchCategory(tree, nodeRef)] || this.updateBranch[nodeRef.getNumber()];
    }

    public double getBranchRewardDensity(double d, double d2) {
        if (this.series == null) {
            this.series = this.createSeries();
        }
        boolean bl = false;
        double d3 = this.series.computePdf(d, d2, 0, 0);
        double d4 = this.series.computeConditionalProbability(d2, 0, 0);
        return d3 / d4;
    }

    @Override
    public void makeDirty() {
        this.likelihoodKnown = false;
        this.series = null;
        this.setUpdateAllBranches();
    }

    @Override
    public String getTraitName() {
        return "rate";
    }

    @Override
    public TreeTrait.Intent getIntent() {
        return TreeTrait.Intent.BRANCH;
    }

    @Override
    public TreeTrait getTreeTrait(String string) {
        if (string.equals("rate")) {
            return this;
        }
        if (this.latentStateProportions != null && string.equals(this.latentStateProportions.getTraitName())) {
            return this.latentStateProportions;
        }
        if (this.branchCategoryProvider != null && string.equals(this.branchCategoryProvider.getTraitName())) {
            return this.branchCategoryProvider;
        }
        throw new IllegalArgumentException("Unrecognised Tree Trait key, " + string);
    }

    @Override
    public TreeTrait[] getTreeTraits() {
        return new TreeTrait[]{this, this.latentStateProportions, this.branchCategoryProvider};
    }

    @Override
    public Class getTraitClass() {
        return Double.class;
    }

    @Override
    public boolean getLoggable() {
        return true;
    }

    @Override
    public Double getTrait(Tree tree, NodeRef nodeRef) {
        return this.getBranchRate(tree, nodeRef);
    }

    @Override
    public String getTraitString(Tree tree, NodeRef nodeRef) {
        return Double.toString(this.getBranchRate(tree, nodeRef));
    }

    static Mode findMode(List<Double> list, List<Double> list2) {
        Mode mode = new Mode(list.get(0), list2.get(0));
        for (int i = 1; i < list.size(); ++i) {
            if (!(list.get(i) > mode.pdf)) continue;
            mode.pdf = list.get(i);
            mode.reward = list2.get(i);
        }
        return mode;
    }

    static double calculateExpectation(List<Double> list, List<Double> list2) {
        double d = 0.0;
        double d2 = 0.0;
        for (int i = 0; i < list.size(); ++i) {
            d += list.get(i).doubleValue();
            d2 += list2.get(i) * list.get(i);
        }
        double d3 = d2 / d;
        return d2;
    }

    public static void main(String[] stringArray) {
        Parameter.Default default_ = new Parameter.Default(2.0);
        Parameter.Default default_2 = new Parameter.Default(0.5);
        SericolaLatentStateBranchRateModel sericolaLatentStateBranchRateModel = new SericolaLatentStateBranchRateModel(default_, default_2);
        for (double d = 0.1; d <= 10.0; d += 0.1) {
            ArrayList<Double> arrayList = new ArrayList<Double>();
            ArrayList<Double> arrayList2 = new ArrayList<Double>();
            for (double d2 = 0.0; d2 <= d; d2 += 0.01 * d) {
                double d3 = sericolaLatentStateBranchRateModel.getBranchRewardDensity(d2, d);
                System.out.println();
                arrayList2.add(d2);
                arrayList.add(d3);
            }
            Mode mode = SericolaLatentStateBranchRateModel.findMode(arrayList, arrayList2);
            System.out.println(d + " " + mode.reward / d + " " + SericolaLatentStateBranchRateModel.calculateExpectation(arrayList, arrayList2) / d);
        }
    }

    public MarkovReward getSeries() {
        if (this.series == null) {
            this.series = this.createSeries();
        }
        return this.series;
    }

    static class Mode {
        double pdf;
        double reward;

        Mode(double d, double d2) {
            this.pdf = d;
            this.reward = d2;
        }
    }
}

