/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.ml.mcmc;

import com.davidsoergel.dsutils.GenericFactory;
import com.davidsoergel.dsutils.GenericFactoryException;
import com.davidsoergel.runutils.Property;
import com.davidsoergel.runutils.PropertyConsumer;
import edu.berkeley.compbio.ml.mcmc.DataCollector;
import edu.berkeley.compbio.ml.mcmc.EnergyMove;
import edu.berkeley.compbio.ml.mcmc.MonteCarloState;
import edu.berkeley.compbio.ml.mcmc.Move;
import edu.berkeley.compbio.ml.mcmc.MoveTypeSet;
import edu.berkeley.compbio.ml.mcmc.ProbabilityMove;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.log4j.Logger;

@PropertyConsumer
public abstract class MonteCarlo {
    private static final Logger logger = Logger.getLogger(MonteCarlo.class);
    @Property(helpmessage="", defaultvalue="0")
    public int burnIn;
    @Property(helpmessage="", defaultvalue="100000")
    public int numSteps;
    @Property(helpmessage="Write status to the console every n samples", defaultvalue="1000")
    public int writeToConsoleInterval;
    @Property(helpmessage="Collect data to disk every n samples", defaultvalue="100")
    public int collectDataToDiskInterval;
    @Property(defaultvalue="edu.berkeley.compbio.ml.mcmc.MoveTypeSet")
    public MoveTypeSet movetypes;
    @Property(helpmessage="= kT.  Must be >= 1.  1 is cold chain", defaultvalue="1")
    public double heatFactor = 1.0;
    protected DataCollector dataCollector;
    protected int acceptedCount;
    protected String id;
    protected final Map<Class<Move>, Integer> accepted = new HashMap<Class<Move>, Integer>();
    protected final Map<Class<Move>, Integer> proposed = new HashMap<Class<Move>, Integer>();
    protected boolean isColdest = true;
    private int proposedCount;
    private int step = 0;

    public DataCollector getDataCollector() {
        return this.dataCollector;
    }

    public void setDataCollector(DataCollector dc) {
        this.dataCollector = dc;
    }

    public double getHeatFactor() {
        return this.heatFactor;
    }

    public void setHeatFactor(double t) {
        this.heatFactor = t;
    }

    public String getId() {
        return this.id;
    }

    public void setId(String id) {
        this.id = id;
    }

    public MoveTypeSet getMovetypes() {
        return this.movetypes;
    }

    public void setMovetypes(MoveTypeSet movetypes) {
        this.movetypes = movetypes;
    }

    public void init() {
        this.resetCounts();
    }

    public void run() throws IOException, GenericFactoryException {
        this.burnIn();
        this.runNoBurnIn();
    }

    public void burnIn() throws IOException, GenericFactoryException {
        for (int i = 0; i < this.burnIn; ++i) {
            this.doStep();
        }
        this.resetCounts();
        this.step = 0;
    }

    public void doStep() throws IOException, GenericFactoryException {
        Class c;
        ++this.step;
        boolean writeToConsole = this.writeToConsoleInterval != 0 && this.step % this.writeToConsoleInterval == 0;
        boolean collectDataToDisk = this.collectDataToDiskInterval != 0 && this.step % this.collectDataToDiskInterval == 0;
        MonteCarloState currentState = this.getCurrentState();
        Move m = this.movetypes.newMove(currentState);
        MonteCarloState newState = m instanceof EnergyMove ? ((EnergyMove)((Object)m)).doMove(this.heatFactor) : ((ProbabilityMove)((Object)m)).doMove(this.heatFactor);
        Class<?> movetype = m.getClass();
        ++this.proposedCount;
        this.proposed.put(movetype, this.proposed.get(movetype) + 1);
        if (currentState != newState) {
            ++this.acceptedCount;
            this.accepted.put(movetype, this.accepted.get(movetype) + 1);
        }
        this.setCurrentState(newState);
        if (collectDataToDisk) {
            if (this.isColdest()) {
                currentState.writeToDataCollector(this.step, this.dataCollector);
            }
            for (GenericFactory<Move> f : this.movetypes.getFactories()) {
                c = f.getCreatesClass();
                this.dataCollector.setTimecourseValue(this.id + "." + c.getSimpleName() + ".proposed", this.proposed.get(c).intValue());
                this.dataCollector.setTimecourseValue(this.id + "." + c.getSimpleName() + ".accepted", this.accepted.get(c).intValue());
            }
        }
        if (writeToConsole && logger.isInfoEnabled()) {
            logger.debug("Step " + this.step);
            logger.debug("[ " + this.id + " ] Accepted " + this.acceptedCount + " out of " + this.proposedCount + " proposed total moves.");
            for (GenericFactory<Move> f : this.movetypes.getFactories()) {
                c = f.getCreatesClass();
                logger.debug("[ " + this.id + " ] Accepted " + this.accepted.get(c) + " out of " + this.proposed.get(c) + " proposed " + c + " moves.");
            }
            this.resetCounts();
            System.out.println(currentState);
            if (this.dataCollector != null) {
                System.out.println(this.dataCollector.toString());
            }
        }
    }

    public abstract MonteCarloState getCurrentState();

    public abstract void setCurrentState(MonteCarloState var1);

    public boolean isColdest() {
        return this.isColdest;
    }

    public void resetCounts() {
        this.proposedCount = 0;
        this.acceptedCount = 0;
        for (Class c : this.movetypes.pluginMap.getAvailablePlugins()) {
            this.proposed.put(c, 0);
            this.accepted.put(c, 0);
        }
    }

    public void runNoBurnIn() throws IOException, GenericFactoryException {
        for (int i = 0; i < this.numSteps; ++i) {
            this.doStep();
        }
    }

    public void setColdest(boolean coldest) {
        this.isColdest = coldest;
    }

    public double unnormalizedLogLikelihood(MonteCarloState mcs) {
        logger.debug(String.format("unnormalizedLogLikelihood: %f, heatFactor = %f, product = %f", mcs.unnormalizedLogLikelihood(), this.heatFactor, mcs.unnormalizedLogLikelihood() * (1.0 / this.heatFactor)));
        return mcs.unnormalizedLogLikelihood() * (1.0 / this.heatFactor);
    }
}

