/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import mikera.vectorz.AVector;
import mikera.vectorz.Op;
import mikera.vectorz.Ops;
import mikera.vectorz.Vector;
import nuroko.core.Components;
import nuroko.core.IComponent;
import nuroko.core.IModule;
import nuroko.core.Util;
import nuroko.module.ALayerStack;
import nuroko.module.AWeightLayer;
import nuroko.module.loss.LossFunction;

public class NeuralNet
extends ALayerStack {
    private final int layerCount;
    private final AWeightLayer[] layers;
    private final Vector[] data;
    private final Vector[] grad;
    private final Vector outputGradient;
    private final AVector parameters;
    private final AVector gradient;
    private final Op[] layerOps;

    public NeuralNet(AWeightLayer ... layers) {
        this(layers, Ops.LOGISTIC);
    }

    public NeuralNet(AWeightLayer[] layers, Op outputOp) {
        this(layers, Ops.TANH, outputOp);
    }

    public NeuralNet(AWeightLayer layer, Op outputOp) {
        this(new AWeightLayer[]{layer}, null, outputOp);
    }

    public NeuralNet(AWeightLayer[] layers, Op hiddenOp, Op outputOp) {
        int i;
        this.layerCount = layers.length;
        this.layers = (AWeightLayer[])layers.clone();
        this.layerOps = new Op[this.layerCount];
        for (i = 0; i < this.layerCount - 1; ++i) {
            this.layerOps[i] = hiddenOp;
        }
        this.layerOps[this.layerCount - 1] = outputOp;
        this.data = new Vector[this.layerCount + 1];
        this.grad = new Vector[this.layerCount + 1];
        this.data[0] = Vector.createLength((int)layers[0].getInputLength());
        this.grad[0] = Vector.createLength((int)layers[0].getInputLength());
        for (i = 0; i < this.layerCount; ++i) {
            this.data[i + 1] = Vector.createLength((int)layers[i].getOutputLength());
            this.grad[i + 1] = Vector.createLength((int)layers[i].getOutputLength());
        }
        this.outputGradient = Vector.createLength((int)layers[this.layerCount - 1].getOutputLength());
        AVector params = layers[0].getParameters();
        for (int i2 = 1; i2 < this.layerCount; ++i2) {
            params = params.join(layers[i2].getParameters());
        }
        this.parameters = params;
        AVector g = layers[0].getGradient();
        for (int i3 = 1; i3 < this.layerCount; ++i3) {
            g = g.join(layers[i3].getGradient());
        }
        this.gradient = g;
    }

    @Override
    public List<IModule> getModules() {
        ArrayList<IModule> al = new ArrayList<IModule>();
        for (AWeightLayer m : this.layers) {
            al.add(m);
        }
        return al;
    }

    @Override
    public LossFunction getDefaultLossFunction() {
        Op topOp = this.layerOps[this.layerCount - 1];
        return Components.defaultLossFunction(topOp);
    }

    @Override
    public List<IComponent> getComponents() {
        return Collections.EMPTY_LIST;
    }

    public NeuralNet getInverse() {
        AWeightLayer[] newLayers = new AWeightLayer[this.layerCount];
        for (int i = 0; i < this.layerCount; ++i) {
            newLayers[i] = this.layers[this.layerCount - 1 - i].getInverse();
        }
        return new NeuralNet(newLayers);
    }

    @Override
    public void initRandom() {
        for (AWeightLayer wl : this.layers) {
            wl.initRandom();
        }
    }

    @Override
    public void trainGradientInternal(double factor) {
        this.backpropGradient(factor *= this.getLearnFactor(), false);
    }

    private void backpropGradient(double factor, boolean skipTopDerivative) {
        this.grad[this.layerCount].set((AVector)this.outputGradient);
        for (int i = this.layerCount - 1; i >= 0; --i) {
            this.grad[i].fill(0.0);
            Op op = skipTopDerivative && i == this.layerCount - 1 ? Ops.LINEAR : this.getLayerOp(i);
            Util.scaleByDerivative(op, this.data[i + 1], this.grad[i + 1]);
            this.layers[i].trainGradient((AVector)this.data[i], (AVector)this.grad[i + 1], (AVector)this.grad[i], factor);
        }
    }

    public Op getLayerOp(int i) {
        return this.layerOps[i];
    }

    @Override
    public void thinkInternal() {
        for (int i = 0; i < this.layerCount; ++i) {
            this.layers[i].think((AVector)this.data[i], (AVector)this.data[i + 1]);
            this.getLayerOp(i).applyTo(this.data[i + 1].getArray());
        }
    }

    @Override
    public AWeightLayer getLayer(int i) {
        return this.layers[i];
    }

    @Override
    public int getInputLength() {
        return this.data[0].length();
    }

    @Override
    public int getOutputLength() {
        return this.data[this.layerCount].length();
    }

    @Override
    public int getParameterLength() {
        return this.parameters.length();
    }

    @Override
    public AVector getParameters() {
        return this.parameters;
    }

    @Override
    public AVector getGradient() {
        return this.gradient;
    }

    @Override
    public List<AWeightLayer> getLayers() {
        return Arrays.asList(this.layers);
    }

    @Override
    public void applyConstraintsInternal() {
        for (AWeightLayer c : this.getLayers()) {
            c.applyConstraints();
        }
    }

    @Override
    public AVector getOutput() {
        return this.data[this.layerCount];
    }

    @Override
    public AVector getInput() {
        return this.data[0];
    }

    @Override
    public AVector getData(int i) {
        return this.data[i];
    }

    @Override
    public NeuralNet clone() {
        AWeightLayer[] newlayers = new AWeightLayer[this.layerCount];
        for (int i = 0; i < this.layerCount; ++i) {
            newlayers[i] = this.layers[i].clone();
        }
        NeuralNet ns = new NeuralNet(newlayers);
        return ns;
    }

    @Override
    public int getLayerCount() {
        return this.layerCount;
    }

    @Override
    public AVector getInputGradient() {
        return this.grad[0];
    }

    @Override
    public AVector getOutputGradient() {
        return this.outputGradient;
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return false;
    }
}

