/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import mikera.vectorz.AVector;
import mikera.vectorz.ArrayVector;
import mikera.vectorz.Op;
import mikera.vectorz.Vector;
import nuroko.core.IModule;
import nuroko.core.ITrainable;
import nuroko.core.Util;
import nuroko.module.ALayerStack;
import nuroko.module.AWeightLayer;

public class NeuralNet
extends ALayerStack
implements ITrainable {
    private final int layerCount;
    private final AWeightLayer[] layers;
    private final Vector[] data;
    private final Vector[] grad;
    private final AVector parameters;
    private final AVector gradient;
    private final Op[] layerOps;

    public NeuralNet(AWeightLayer ... layers) {
        this(layers, Op.LOGISTIC);
    }

    public NeuralNet(AWeightLayer[] layers, Op outputOp) {
        this(layers, Op.TANH, Op.LOGISTIC);
    }

    public NeuralNet(AWeightLayer[] layers, Op hiddenOp, Op outputOp) {
        int i;
        this.layerCount = layers.length;
        this.layers = (AWeightLayer[])layers.clone();
        this.layerOps = new Op[this.layerCount];
        for (i = 0; i < this.layerCount - 1; ++i) {
            this.layerOps[i] = hiddenOp;
        }
        this.layerOps[this.layerCount - 1] = outputOp;
        this.data = new Vector[this.layerCount + 1];
        this.grad = new Vector[this.layerCount + 1];
        this.data[0] = Vector.createLength((int)layers[0].getInputLength());
        this.grad[0] = Vector.createLength((int)layers[0].getInputLength());
        for (i = 0; i < this.layerCount; ++i) {
            this.data[i + 1] = Vector.createLength((int)layers[i].getOutputLength());
            this.grad[i + 1] = Vector.createLength((int)layers[i].getOutputLength());
        }
        AVector params = layers[0].getParameters();
        for (int i2 = 1; i2 < this.layerCount; ++i2) {
            params = params.join(layers[i2].getParameters());
        }
        this.parameters = params;
        AVector g = layers[0].getGradient();
        for (int i3 = 1; i3 < this.layerCount; ++i3) {
            g = g.join(layers[i3].getGradient());
        }
        this.gradient = g;
    }

    public List<IModule> getComponents() {
        ArrayList<IModule> al = new ArrayList<IModule>();
        for (AWeightLayer m : this.layers) {
            al.add(m);
        }
        return al;
    }

    public NeuralNet getInverse() {
        AWeightLayer[] newLayers = new AWeightLayer[this.layerCount];
        for (int i = 0; i < this.layerCount; ++i) {
            newLayers[i] = this.layers[this.layerCount - 1 - i].getInverse();
        }
        return new NeuralNet(newLayers);
    }

    public void initRandom() {
        for (AWeightLayer wl : this.layers) {
            wl.initRandom();
        }
    }

    @Override
    public void train(AVector input, AVector target) {
        assert (this.getOutputLength() == target.length());
        this.think(input, null);
        this.grad[this.layerCount].set(target);
        this.grad[this.layerCount].sub((ArrayVector)this.data[this.layerCount]);
        this.backpropGradient(1.0, true);
    }

    @Override
    public void trainGradient(AVector input, AVector outputGradient, AVector inputGradient, double factor, boolean skipTopDerivative) {
        assert (this.getInputLength() == input.length());
        this.think(input, null);
        this.grad[this.layerCount].set(outputGradient);
        this.backpropGradient(factor, skipTopDerivative);
        if (inputGradient != null) {
            inputGradient.add((AVector)this.grad[0]);
        }
    }

    private void backpropGradient(double factor, boolean skipTopDerivative) {
        for (int i = this.layerCount - 1; i >= 0; --i) {
            this.grad[i].fill(0.0);
            Op op = skipTopDerivative && i == this.layerCount - 1 ? Op.LINEAR : this.getLayerOp(i);
            Util.scaleByDerivative(op, this.data[i + 1], this.grad[i + 1]);
            this.layers[i].trainGradient((AVector)this.data[i], (AVector)this.grad[i + 1], (AVector)this.grad[i], factor);
        }
    }

    public Op getLayerOp(int i) {
        return this.layerOps[i];
    }

    @Override
    public void think(AVector input, AVector output) {
        this.data[0].set(input);
        for (int i = 0; i < this.layerCount; ++i) {
            this.layers[i].think((AVector)this.data[i], (AVector)this.data[i + 1]);
            this.getLayerOp(i).applyTo(this.data[i + 1].getArray());
        }
        if (output != null) {
            output.set((AVector)this.data[this.layerCount]);
        }
    }

    @Override
    public AWeightLayer getLayer(int i) {
        return this.layers[i];
    }

    @Override
    public int getInputLength() {
        return this.data[0].length();
    }

    @Override
    public int getOutputLength() {
        return this.data[this.layerCount].length();
    }

    @Override
    public int getParameterLength() {
        return this.parameters.length();
    }

    @Override
    public AVector getParameters() {
        return this.parameters;
    }

    @Override
    public AVector getGradient() {
        return this.gradient;
    }

    @Override
    public List<AWeightLayer> getLayers() {
        return Arrays.asList(this.layers);
    }

    public AVector getInputSignal() {
        return this.grad[0];
    }

    public AVector getOutputSignal() {
        return this.grad[this.layerCount];
    }

    @Override
    public AVector getOutput() {
        return this.data[this.layerCount];
    }

    @Override
    public AVector getInput() {
        return this.data[0];
    }

    @Override
    public AVector getData(int i) {
        return this.data[i];
    }

    @Override
    public NeuralNet clone() {
        AWeightLayer[] newlayers = new AWeightLayer[this.layerCount];
        for (int i = 0; i < this.layerCount; ++i) {
            newlayers[i] = this.layers[i].clone();
        }
        NeuralNet ns = new NeuralNet(newlayers);
        return ns;
    }

    @Override
    public int getLayerCount() {
        return this.layerCount;
    }
}

