/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import mikera.vectorz.AVector;
import mikera.vectorz.ArrayVector;
import mikera.vectorz.Vector;
import nuroko.core.ITrainable;
import nuroko.module.ALayerStack;
import nuroko.module.AWeightLayer;
import nuroko.module.CompoundStack;

public class NetworkStack
extends CompoundStack<ALayerStack>
implements ITrainable {
    private final Vector[] data;
    private final Vector[] grad;

    public NetworkStack(ALayerStack ... components) {
        this(Arrays.asList(components));
    }

    public NetworkStack(List<ALayerStack> components) {
        super(components);
        this.data = new Vector[this.componentCount + 1];
        this.grad = new Vector[this.componentCount + 1];
        int il = ((ALayerStack)this.getComponent(0)).getInputLength();
        this.data[0] = Vector.createLength((int)il);
        this.grad[0] = Vector.createLength((int)il);
        for (int i = 0; i < this.componentCount; ++i) {
            this.data[i + 1] = Vector.createLength((int)((ALayerStack)this.getComponent(i)).getOutputLength());
            this.grad[i + 1] = Vector.createLength((int)((ALayerStack)this.getComponent(i)).getOutputLength());
        }
    }

    @Override
    public int getInputLength() {
        return this.data[0].length();
    }

    @Override
    public int getOutputLength() {
        return this.data[this.componentCount].length();
    }

    @Override
    public void think(AVector input, AVector output) {
        this.data[0].set(input);
        for (int i = 0; i < this.componentCount; ++i) {
            ((ALayerStack)this.getComponent(i)).think((AVector)this.data[i], (AVector)this.data[i + 1]);
        }
        if (output != null) {
            output.set((AVector)this.data[this.componentCount]);
        }
    }

    @Override
    public int getLayerCount() {
        int result = 0;
        for (ALayerStack c : this.components) {
            result += c.getLayerCount();
        }
        return result;
    }

    @Override
    public AWeightLayer getLayer(int i) {
        for (ALayerStack c : this.components) {
            int lc = c.getLayerCount();
            if (i < lc) {
                return c.getLayer(i);
            }
            i -= lc;
        }
        throw new IndexOutOfBoundsException("Invalid index:" + i);
    }

    @Override
    public AVector getData(int i) {
        for (ALayerStack c : this.components) {
            int lc = c.getLayerCount();
            if (i < lc) {
                return c.getData(i);
            }
            i -= lc;
        }
        throw new IndexOutOfBoundsException("Invalid index:" + i);
    }

    @Override
    public NetworkStack clone() {
        ArrayList<ALayerStack> al = new ArrayList<ALayerStack>();
        for (ALayerStack c : this.components) {
            al.add(c.clone());
        }
        return new NetworkStack((List<ALayerStack>)al);
    }

    @Override
    public void train(AVector input, AVector target) {
        assert (this.getOutputLength() == target.length());
        this.think(input, null);
        this.grad[this.componentCount].set(target);
        this.grad[this.componentCount].sub((ArrayVector)this.data[this.componentCount]);
        this.trainGradient(input, (AVector)this.grad[this.componentCount], null, 1.0, true);
    }

    @Override
    public void trainGradient(AVector input, AVector outputGradient, AVector inputGradient, double factor, boolean skipTopDerivative) {
        assert (this.getInputLength() == input.length());
        this.grad[this.componentCount].set(outputGradient);
        for (int i = this.componentCount - 1; i >= 0; --i) {
            this.grad[i].fill(0.0);
            ((ALayerStack)this.getComponent(i)).trainGradient((AVector)this.data[i], (AVector)this.grad[i + 1], (AVector)this.grad[i], factor, skipTopDerivative && i == this.componentCount - 1);
        }
        if (inputGradient != null) {
            inputGradient.add((AVector)this.grad[0]);
        }
    }
}

