/*
 * Decompiled with CFR 0.152.
 */
package edu.berkeley.compbio.ml.cluster.kohonen;

import com.davidsoergel.dsutils.DSArrayUtils;
import com.davidsoergel.dsutils.GenericFactory;
import com.davidsoergel.dsutils.GenericFactoryException;
import com.davidsoergel.dsutils.Labellable;
import com.davidsoergel.stats.DissimilarityMeasure;
import com.davidsoergel.stats.SimpleFunction;
import edu.berkeley.compbio.ml.cluster.AbstractUnsupervisedOnlineClusteringMethod;
import edu.berkeley.compbio.ml.cluster.AdditiveClusterable;
import edu.berkeley.compbio.ml.cluster.CentroidClusteringUtils;
import edu.berkeley.compbio.ml.cluster.ClusterException;
import edu.berkeley.compbio.ml.cluster.ClusterMove;
import edu.berkeley.compbio.ml.cluster.ClusterRuntimeException;
import edu.berkeley.compbio.ml.cluster.ClusterableIterator;
import edu.berkeley.compbio.ml.cluster.ClusterableIteratorFactory;
import edu.berkeley.compbio.ml.cluster.NoGoodClusterException;
import edu.berkeley.compbio.ml.cluster.ProhibitionModel;
import edu.berkeley.compbio.ml.cluster.kohonen.KohonenSOM;
import edu.berkeley.compbio.ml.cluster.kohonen.KohonenSOM2DSearchStrategy;
import edu.berkeley.compbio.ml.cluster.kohonen.KohonenSOMCell;
import edu.berkeley.compbio.ml.cluster.kohonen.LabelDiffuser;
import java.io.ByteArrayOutputStream;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.NotImplementedException;
import org.apache.log4j.Logger;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class KohonenSOM2D<T extends AdditiveClusterable<T>>
extends AbstractUnsupervisedOnlineClusteringMethod<T, KohonenSOMCell<T>>
implements KohonenSOM<T> {
    private static final Logger logger = Logger.getLogger(KohonenSOM2D.class);
    final int[] cellsPerDimension;
    double maxRadius;
    final double minRadius;
    int time = 0;
    int changed = 0;
    private final int dimensions;
    private final boolean edgesWrap;
    private final boolean decrementLosingNeighborhood;
    private final SimpleFunction moveFactorFunction;
    private final SimpleFunction radiusFunction;
    private final SimpleFunction weightFunction;
    private final Map<Integer, WeightedMask> weightedMasks = new HashMap<Integer, WeightedMask>();
    private final Map<Integer, WeightedMask> shellMasks = new HashMap<Integer, WeightedMask>();
    private final KohonenSOM2DSearchStrategy<T> searchStrategy;
    private LabelDiffuser<T, KohonenSOMCell<T>> labeler;

    public KohonenSOM2D(DissimilarityMeasure<T> dm, Set<String> potentialTrainingBins, Map<String, Set<String>> predictLabelSets, ProhibitionModel<T> prohibitionModel, Set<String> testLabels, @NotNull Integer[] cellsPerDimension, SimpleFunction moveFactorFunction, SimpleFunction radiusFunction, SimpleFunction weightFunction, boolean decrementLosingNeighborhood, boolean edgesWrap, double minRadius, KohonenSOM2DSearchStrategy<T> searchStrategy) {
        super(dm, potentialTrainingBins, predictLabelSets, prohibitionModel, testLabels);
        this.cellsPerDimension = DSArrayUtils.toPrimitive(cellsPerDimension);
        this.dimensions = cellsPerDimension.length;
        this.moveFactorFunction = moveFactorFunction;
        this.radiusFunction = radiusFunction;
        this.weightFunction = weightFunction;
        this.decrementLosingNeighborhood = decrementLosingNeighborhood;
        this.edgesWrap = edgesWrap;
        this.minRadius = minRadius;
        this.searchStrategy = searchStrategy;
        if (this.dimensions != 2) {
            throw new ClusterRuntimeException("KohonenSOM2D accepts only two-dimensional grid.");
        }
        int totalCells = cellsPerDimension[0] * cellsPerDimension[1];
        this.setNumClusters(totalCells);
        int[] zeroCell = new int[this.dimensions];
        Arrays.fill(zeroCell, 0);
        this.maxRadius = DSArrayUtils.norm(this.cellsPerDimension) / 2.0;
        searchStrategy.setDistanceMeasure(this.measure);
    }

    public int getChanged() {
        return this.changed;
    }

    public void setLabeler(LabelDiffuser<T, KohonenSOMCell<T>> labeler) {
        this.labeler = labeler;
    }

    @Override
    protected void removeEmptyClusters() {
    }

    @Override
    public String shortClusteringStats() {
        return CentroidClusteringUtils.shortClusteringStats(this.getClusters(), this.measure);
    }

    @Override
    public void computeClusterStdDevs(ClusterableIterator<T> theDataPointProvider) {
        CentroidClusteringUtils.computeClusterStdDevs(this.getClusters(), this.measure, this.getAssignments(), theDataPointProvider);
    }

    @Override
    public String clusteringStats() {
        ByteArrayOutputStream b = new ByteArrayOutputStream();
        CentroidClusteringUtils.writeClusteringStatsToStream(this.getClusters(), this.measure, b);
        return b.toString();
    }

    @Override
    public void writeClusteringStatsToStream(OutputStream outf) {
        CentroidClusteringUtils.writeClusteringStatsToStream(this.getClusters(), this.measure, outf);
    }

    @Override
    public Iterator<Set<KohonenSOMCell<T>>> getNeighborhoodShellIterator(KohonenSOMCell<T> cell) {
        return new NeighborhoodShellIterator(cell);
    }

    @Override
    public boolean add(T p) throws NoGoodClusterException {
        double motionFactor;
        KohonenSOMCell<T> neighbor;
        WeightedCell v;
        Iterator<WeightedCell> i;
        ClusterMove<T, KohonenSOMCell<T>> cm = this.bestClusterMove(p);
        if (cm.isChanged()) {
            ++this.changed;
            this.putAssignment(p.getId(), cm.bestCluster);
        }
        KohonenSOMCell loser = (KohonenSOMCell)cm.oldCluster;
        KohonenSOMCell winner = (KohonenSOMCell)cm.bestCluster;
        double moveFactor = this.moveFactorFunction.f(this.time);
        moveFactor = Math.min(moveFactor, 1.0);
        moveFactor = Math.max(moveFactor, 0.0);
        double radius = this.getCurrentRadius();
        logger.trace("Adding point with neighborhood radius " + radius + ", moveFactor " + moveFactor);
        if (this.decrementLosingNeighborhood && loser != null) {
            winner.getMutableWeightedLabels().removeAll(p.getMutableWeightedLabels());
            i = this.getWeightedMask((int)radius).iterator(loser);
            while (i.hasNext()) {
                v = i.next();
                neighbor = v.theCell;
                motionFactor = moveFactor * v.weight;
                neighbor.recenterByRemovingWeighted(p, motionFactor);
            }
        }
        winner.getMutableWeightedLabels().addAll(p.getMutableWeightedLabels());
        i = this.getWeightedMask((int)radius).iterator(winner);
        while (i.hasNext()) {
            v = i.next();
            neighbor = v.theCell;
            motionFactor = moveFactor * v.weight;
            neighbor.recenterByAddingWeighted(p, motionFactor);
        }
        ++this.time;
        return true;
    }

    @Override
    public void train(ClusterableIteratorFactory<T> trainingCollectionIteratorFactory, int iterations) throws ClusterException {
        super.train(trainingCollectionIteratorFactory, iterations);
        this.labeler.propagateLabels(this);
    }

    @Override
    public void setPrototypeFactory(GenericFactory<T> prototypeFactory) throws GenericFactoryException {
        int totalCells = this.cellsPerDimension[0] * this.cellsPerDimension[1];
        this.createClusters(totalCells, prototypeFactory);
        this.searchStrategy.setSOM(this);
    }

    @Override
    public void initializeWithSamples(ClusterableIterator<T> initIterator, int initSamples) {
        for (int i = 0; i < initSamples; ++i) {
            this.addToRandomCell((AdditiveClusterable)initIterator.nextFullyLabelled());
            if (i % 100 != 0) continue;
            logger.debug("Initialized with " + i + " samples.");
        }
    }

    public void addToRandomCell(T p) {
        KohonenSOMCell winner = (KohonenSOMCell)this.chooseRandomCluster();
        double moveFactor = 0.5;
        double radius = this.maxRadius;
        logger.trace("Adding point with neighborhood radius " + radius + ", moveFactor " + 0.5);
        Iterator<WeightedCell> i = this.getWeightedMask((int)radius).iterator(winner);
        while (i.hasNext()) {
            WeightedCell v = i.next();
            KohonenSOMCell<T> neighbor = v.theCell;
            double motionFactor = 0.5 * v.weight;
            neighbor.recenterByAddingWeighted(p, motionFactor);
        }
    }

    @Override
    public ClusterMove<T, KohonenSOMCell<T>> bestClusterMove(T p) throws NoGoodClusterException {
        return this.searchStrategy.bestClusterMove(p);
    }

    public double[] computeCellAverageNeighborDistances() {
        double[] result = new double[this.getNumClusters()];
        int width = this.cellsPerDimension[0];
        int height = this.cellsPerDimension[1];
        for (int x = 0; x < width - 1; ++x) {
            for (int y = 0; y < height - 1; ++y) {
                KohonenSOMCell<T> here = this.clusterAt(x, y);
                KohonenSOMCell<T> right = this.clusterAt(x + 1, y);
                double d = this.measure.distanceFromTo(here.getCentroid(), right.getCentroid());
                int n = this.listIndexFor(x, y);
                result[n] = result[n] + d;
                int n2 = this.listIndexFor(x + 1, y);
                result[n2] = result[n2] + d;
                KohonenSOMCell<T> down = this.clusterAt(x, y + 1);
                double d1 = this.measure.distanceFromTo(here.getCentroid(), down.getCentroid());
                int n3 = this.listIndexFor(x, y);
                result[n3] = result[n3] + d1;
                int n4 = this.listIndexFor(x, y + 1);
                result[n4] = result[n4] + d1;
            }
        }
        int i = 0;
        while (i < result.length) {
            int n = i++;
            result[n] = result[n] / 4.0;
        }
        return result;
    }

    public KohonenSOMCell<T> clusterAt(int x, int y) {
        return (KohonenSOMCell)this.getCluster(this.listIndexFor(x, y));
    }

    private int listIndexFor(int x, int y) {
        if (this.edgesWrap) {
            x %= this.cellsPerDimension[0];
            y %= this.cellsPerDimension[1];
        }
        return y * this.cellsPerDimension[0] + x;
    }

    private int[] cellPositionFor(int listIndex) {
        int x = listIndex % this.cellsPerDimension[0];
        int y = listIndex / this.cellsPerDimension[0];
        return new int[]{x, y};
    }

    private void createClusters(int totalCells, GenericFactory<T> prototypeFactory) throws GenericFactoryException {
        for (int i = 0; i < totalCells; ++i) {
            Labellable centroid = prototypeFactory == null ? null : (AdditiveClusterable)prototypeFactory.create(String.valueOf(i));
            centroid.doneLabelling();
            KohonenSOMCell<Labellable> c = new KohonenSOMCell<Labellable>(i, centroid);
            this.addCluster(c);
        }
    }

    public double getCurrentRadius() {
        double radius = this.radiusFunction.f(this.time);
        radius = Math.min(radius, this.maxRadius);
        radius = Math.max(radius, this.minRadius);
        return radius;
    }

    WeightedMask getShellMask(int radius) {
        WeightedMask result = this.shellMasks.get(radius);
        if (result == null) {
            if (radius < 1) {
                result = this.getWeightedMask(0);
            } else {
                WeightedMask outerMask = this.getWeightedMask(radius);
                WeightedMask innerMask = this.getWeightedMask(radius - 1);
                ArrayList<Integer> xList = new ArrayList<Integer>();
                ArrayList<Integer> yList = new ArrayList<Integer>();
                for (int i = 0; i < outerMask.deltaX.length; ++i) {
                    int x = outerMask.deltaX[i];
                    int y = outerMask.deltaY[i];
                    if (innerMask.containsPoint(x, y)) continue;
                    xList.add(x);
                    yList.add(y);
                }
                result = new WeightedMask();
                result.deltaX = DSArrayUtils.toPrimitive(xList.toArray(new Integer[xList.size()]));
                result.deltaY = DSArrayUtils.toPrimitive(yList.toArray(new Integer[yList.size()]));
                result.weight = new double[result.deltaX.length];
                Arrays.fill(result.weight, 1.0);
                result.numCells = result.deltaX.length;
            }
            this.shellMasks.put(radius, result);
        }
        return result;
    }

    WeightedMask getWeightedMask(int radius) {
        WeightedMask result = this.weightedMasks.get(radius);
        if (result == null) {
            result = new WeightedMask(radius);
            this.weightedMasks.put(radius, result);
        }
        return result;
    }

    public void resetChanged() {
        this.changed = 0;
    }

    public void train(ClusterableIteratorFactory<T> trainingCollectionIteratorFactory, GenericFactory<T> prototypeFactory, int trainingEpochs) throws ClusterException {
        this.train(trainingCollectionIteratorFactory, trainingEpochs);
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class NeighborhoodShellIterator
    implements Iterator<Set<KohonenSOMCell<T>>> {
        int radius = 0;
        private final KohonenSOMCell<T> center;

        public NeighborhoodShellIterator(KohonenSOMCell<T> center) {
            this.center = center;
        }

        @Override
        public boolean hasNext() {
            return true;
        }

        @Override
        public Set<KohonenSOMCell<T>> next() {
            WeightedMask mask = KohonenSOM2D.this.getShellMask(this.radius);
            HashSet result = new HashSet();
            Iterator<WeightedCell> i = mask.iterator(this.center);
            while (i.hasNext()) {
                result.add(i.next().theCell);
            }
            ++this.radius;
            return result;
        }

        @Override
        public void remove() {
            throw new NotImplementedException();
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class WeightedCell {
        final KohonenSOMCell<T> theCell;
        final double weight;

        private WeightedCell(KohonenSOMCell<T> theCell, double weight) {
            this.theCell = theCell;
            this.weight = weight;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    class WeightedMask {
        int[] deltaX;
        int[] deltaY;
        double[] weight;
        int numCells;

        private WeightedMask() {
        }

        private WeightedMask(int radius) {
            if (radius == 0) {
                this.deltaX = new int[1];
                this.deltaY = new int[1];
                this.weight = new double[1];
                this.deltaX[0] = 0;
                this.deltaY[0] = 0;
                this.weight[0] = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(0.0);
                this.numCells = 1;
            } else {
                double theWeight;
                int x = radius;
                int xChange = 1 - 2 * radius;
                int overestimateNumCells = (int)(3.2 * (double)((radius + 1) * (radius + 1)));
                this.deltaX = new int[overestimateNumCells];
                this.deltaY = new int[overestimateNumCells];
                this.weight = new double[overestimateNumCells];
                int i = 0;
                this.deltaX[i] = 0;
                this.deltaY[i] = 0;
                double d = this.weight[i] = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(0.0);
                assert (this.weight[i] > 0.0);
                ++i;
                int radiusError = 0;
                int yChange = 1;
                for (int y = 0; x >= y; ++y) {
                    i = this.plot8CirclePoints(i, x, y, radius);
                    if (2 * (radiusError += (yChange += 2)) + xChange <= 0) continue;
                    --x;
                    radiusError += xChange;
                    xChange += 2;
                }
                double d2 = theWeight = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(1.0);
                assert (theWeight > 0.0);
                this.deltaX[i] = 0;
                this.deltaY[i] = radius;
                this.weight[i] = theWeight;
                this.deltaX[++i] = 0;
                this.deltaY[i] = -radius;
                this.weight[i] = theWeight;
                this.numCells = ++i;
            }
        }

        private int plot8CirclePoints(int i, int x, int y, int radius) {
            if (x != 0 && y != 0) {
                double theWeight1;
                double theWeight;
                double dist;
                while (x > y) {
                    dist = Math.sqrt(x * x + y * y);
                    double d = theWeight = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(dist / (double)radius);
                    assert (theWeight > 0.0);
                    this.deltaX[i] = x;
                    this.deltaY[i] = y;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = x;
                    this.deltaY[i] = -y;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = -x;
                    this.deltaY[i] = y;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = -x;
                    this.deltaY[i] = -y;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = y;
                    this.deltaY[i] = x;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = y;
                    this.deltaY[i] = -x;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = -y;
                    this.deltaY[i] = x;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = -y;
                    this.deltaY[i] = -x;
                    this.weight[i] = theWeight;
                    ++i;
                    --x;
                }
                dist = Math.sqrt(y * y + y * y);
                double d = theWeight = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(dist / (double)radius);
                assert (theWeight > 0.0);
                this.deltaX[i] = y;
                this.deltaY[i] = y;
                this.weight[i] = theWeight;
                this.deltaX[++i] = y;
                this.deltaY[i] = -y;
                this.weight[i] = theWeight;
                this.deltaX[++i] = -y;
                this.deltaY[i] = y;
                this.weight[i] = theWeight;
                this.deltaX[++i] = -y;
                this.deltaY[i] = -y;
                this.weight[i] = theWeight;
                ++i;
                double d2 = theWeight1 = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f((double)y / (double)radius);
                assert (theWeight1 > 0.0);
                this.deltaX[i] = 0;
                this.deltaY[i] = y;
                this.weight[i] = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(theWeight1);
                this.deltaX[++i] = 0;
                this.deltaY[i] = -y;
                this.weight[i] = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f(theWeight1);
                ++i;
            } else if (y == 0 && x != 0) {
                while (x > 0) {
                    double theWeight;
                    double d = theWeight = KohonenSOM2D.this.weightFunction == null ? 1.0 : KohonenSOM2D.this.weightFunction.f((double)x / (double)radius);
                    assert (theWeight > 0.0);
                    this.deltaX[i] = x;
                    this.deltaY[i] = 0;
                    this.weight[i] = theWeight;
                    this.deltaX[++i] = -x;
                    this.deltaY[i] = 0;
                    this.weight[i] = theWeight;
                    ++i;
                    --x;
                }
            }
            return i;
        }

        public boolean containsPoint(int x, int y) {
            for (int i = 0; i < this.deltaX.length; ++i) {
                if (this.deltaX[i] != x || this.deltaY[i] != y) continue;
                return true;
            }
            return false;
        }

        public Iterator<WeightedCell> iterator(KohonenSOMCell<T> center) {
            return new MaskIterator(center);
        }

        /*
         * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
         */
        private class MaskIterator
        implements Iterator<WeightedCell> {
            WeightedCell currentCell;
            WeightedCell nextCell;
            final int xCenter;
            final int yCenter;
            int trav = -1;

            public MaskIterator(KohonenSOMCell<T> center) {
                int[] c = KohonenSOM2D.this.cellPositionFor(KohonenSOM2D.this.getClusterIndexOf(center));
                this.xCenter = c[0];
                this.yCenter = c[1];
                this.nextCell = this.findNextCell();
            }

            @Nullable
            private WeightedCell findNextCell() {
                ++this.trav;
                boolean foundCell = false;
                int realX = -1;
                int realY = -1;
                while (!foundCell) {
                    if (this.trav >= WeightedMask.this.numCells) {
                        return null;
                    }
                    realX = this.xCenter + WeightedMask.this.deltaX[this.trav];
                    realY = this.yCenter + WeightedMask.this.deltaY[this.trav];
                    if (!(KohonenSOM2D.this.edgesWrap || realX >= 0 && realX < KohonenSOM2D.this.cellsPerDimension[0] && realY >= 0 && realY < KohonenSOM2D.this.cellsPerDimension[1])) {
                        ++this.trav;
                        continue;
                    }
                    if ((realX %= KohonenSOM2D.this.cellsPerDimension[0]) < 0) {
                        realX += KohonenSOM2D.this.cellsPerDimension[0];
                    }
                    if ((realY %= KohonenSOM2D.this.cellsPerDimension[1]) < 0) {
                        realY += KohonenSOM2D.this.cellsPerDimension[1];
                    }
                    foundCell = true;
                }
                return new WeightedCell((KohonenSOMCell)KohonenSOM2D.this.getCluster(KohonenSOM2D.this.listIndexFor(realX, realY)), WeightedMask.this.weight[this.trav]);
            }

            @Override
            public boolean hasNext() {
                return this.nextCell != null;
            }

            @Override
            public WeightedCell next() {
                this.currentCell = this.nextCell;
                this.nextCell = this.findNextCell();
                return this.currentCell;
            }

            @Override
            public void remove() {
                throw new NotImplementedException();
            }
        }
    }
}

