/*
 * Decompiled with CFR 0.152.
 */
package it.uniroma2.tk;

import edu.berkeley.compbio.jlibsvm.kernel.KernelFunction;
import it.uniroma2.util.tree.Tree;
import java.util.HashMap;
import java.util.Vector;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class RouteTreeKernel
implements KernelFunction<Tree> {
    public static double lambda = 1.0;
    public static boolean useProductions = false;
    private static int nodeCount = 0;
    private static HashMap<Integer, HashMap<String, Double>> deltaMatrixAll;
    private static HashMap<Tree, Integer> nodeIndices;

    public static double value(Tree a, Tree b) {
        deltaMatrixAll = new HashMap();
        nodeIndices = new HashMap();
        nodeCount = 0;
        double sum = 0.0;
        a.initializeParents();
        b.initializeParents();
        Vector<Tree> aNodes = RouteTreeKernel.allNodes(a);
        Vector<Tree> bNodes = RouteTreeKernel.allNodes(b);
        for (int i = 1; i <= Math.min(RouteTreeKernel.maxDepth(a), RouteTreeKernel.maxDepth(b)); ++i) {
            for (Tree aa : aNodes) {
                for (Tree bb : bNodes) {
                    sum += RouteTreeKernel.delta(aa, bb, i);
                }
            }
        }
        return sum;
    }

    private static double delta(Tree a, Tree b, int index) {
        if (!deltaMatrixAll.containsKey(index)) {
            deltaMatrixAll.put(index, new HashMap());
        }
        HashMap<String, Double> deltaMatrix = deltaMatrixAll.get(index);
        double k = 0.0;
        if (!nodeIndices.containsKey(a)) {
            nodeIndices.put(a, nodeCount);
            ++nodeCount;
        }
        if (!nodeIndices.containsKey(b)) {
            nodeIndices.put(b, nodeCount);
            ++nodeCount;
        }
        if (deltaMatrix.containsKey(nodeIndices.get(a) + ":" + nodeIndices.get(b))) {
            return deltaMatrix.get(nodeIndices.get(a) + ":" + nodeIndices.get(b));
        }
        if (RouteTreeKernel.deltaCompare(a, b)) {
            if (index == 1) {
                k = 1.0;
            } else {
                int chposA = RouteTreeKernel.chpos(a, index - 1);
                int chposB = RouteTreeKernel.chpos(b, index - 1);
                if (chposA >= 0 && chposB >= 0 && chposA == chposB) {
                    k = lambda * RouteTreeKernel.delta(a, b, index - 1);
                }
            }
        }
        deltaMatrix.put(nodeIndices.get(a) + ":" + nodeIndices.get(b), k);
        return k;
    }

    private static int chpos(Tree a, int level) {
        Tree child = a;
        Tree parent = a.getParent();
        if (parent == null) {
            return -1;
        }
        while (level > 1) {
            if (parent.getParent() == null) {
                return -1;
            }
            child = parent;
            parent = parent.getParent();
            --level;
        }
        if (parent == null) {
            return -1;
        }
        return parent.getChildren().indexOf(child);
    }

    private static boolean deltaCompare(Tree a, Tree b) {
        if (useProductions) {
            return RouteTreeKernel.productionCompare(a, b);
        }
        return a.getRootLabel().equals(b.getRootLabel());
    }

    private static boolean productionCompare(Tree a, Tree b) {
        if (!a.getRootLabel().equals(b.getRootLabel())) {
            return false;
        }
        if (a.getChildren().size() != b.getChildren().size() || a.getChildren().size() == 0) {
            return false;
        }
        for (int i = 0; i < a.getChildren().size(); ++i) {
            if (a.getChildren().get(i).getRootLabel().equals(b.getChildren().get(i).getRootLabel())) continue;
            return false;
        }
        return true;
    }

    private static Vector<Tree> allNodes(Tree node) {
        Vector<Tree> all = new Vector<Tree>();
        all.add(node);
        for (Tree child : node.getChildren()) {
            all.addAll(RouteTreeKernel.allNodes(child));
        }
        return all;
    }

    private static int maxDepth(Tree root) {
        int d = 0;
        for (Tree child : root.getChildren()) {
            int cd = RouteTreeKernel.maxDepth(child);
            if (cd <= d) continue;
            d = cd;
        }
        return d + 1;
    }

    @Override
    public double evaluate(Tree arg0, Tree arg1) {
        return RouteTreeKernel.value(arg0, arg1);
    }
}

