package edu.berkeley.nlp.bp;

import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.CollectionUtils;
import edu.berkeley.nlp.util.functional.Function;
import edu.berkeley.nlp.util.functional.FunctionalUtils;
import fig.basic.Pair;
import java.util.ArrayList;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:edu/berkeley/nlp/bp/TreeFactorGraph.class */
public class TreeFactorGraph {
    public static <L> Pair<List<NodeMarginal>, List<EdgeMarginal>> runBP(Tree<L> tree, Function<Tree<L>, Variable> function, Function<Variable, double[]> function2, Function<Pair<Variable, Variable>, double[][]> function3) {
        Map mapPairs = FunctionalUtils.mapPairs(tree, function, new IdentityHashMap());
        FactorGraph factorGraph = new FactorGraph(mapPairs.values());
        Iterator<Tree<L>> it = tree.iterator();
        while (it.hasNext()) {
            Tree<L> next = it.next();
            Variable variable = (Variable) mapPairs.get(next);
            factorGraph.addFactor(Collections.singletonList(variable), new NodeFactorPotential(function2.apply(variable)));
            Iterator<Tree<L>> it2 = next.getChildren().iterator();
            while (it2.hasNext()) {
                Variable variable2 = (Variable) mapPairs.get(it2.next());
                double[][] apply = function3.apply(Pair.newPair(variable, variable2));
                if (apply != null) {
                    factorGraph.addFactor(CollectionUtils.makeList(variable, variable2), new EdgeFactorPotential(apply));
                }
            }
        }
        new BeliefPropogation().run(factorGraph);
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        for (Factor factor : factorGraph.factors) {
            List<Variable> list = factor.vars;
            if (factor.potential instanceof NodeFactorPotential) {
                arrayList.add(new NodeMarginal(list.get(0), (double[]) factor.marginals));
            } else {
                if (!(factor.potential instanceof EdgeFactorPotential)) {
                    throw new RuntimeException("Unrecognied Factor Potential");
                }
                arrayList2.add(new EdgeMarginal(list.get(0), list.get(1), (double[][]) factor.marginals));
            }
        }
        return Pair.newPair(arrayList, arrayList2);
    }
}
