/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import java.util.ArrayList;
import java.util.List;
import mikera.vectorz.AVector;
import nuroko.core.IComponent;
import nuroko.core.IInputState;
import nuroko.module.ACompoundComponent;
import nuroko.module.loss.LossFunction;

public class Stack
extends ACompoundComponent {
    public Stack(List<? extends IComponent> comps) {
        super(comps);
    }

    @Override
    public IComponent topComponent() {
        return (IComponent)this.components.get(this.componentCount - 1);
    }

    @Override
    public LossFunction getDefaultLossFunction() {
        return this.topComponent().getDefaultLossFunction();
    }

    @Override
    public AVector getInput() {
        return ((IComponent)this.components.get(0)).getInput();
    }

    @Override
    public AVector getInputGradient() {
        return ((IComponent)this.components.get(0)).getInputGradient();
    }

    @Override
    public int getInputLength() {
        return ((IComponent)this.components.get(0)).getInputLength();
    }

    @Override
    public void thinkInternal() {
        for (int i = 0; i < this.componentCount; ++i) {
            this.getComponent(i).thinkInternal();
            if (i >= this.componentCount - 1) continue;
            this.getComponent(i + 1).setInput(this.getComponent(i).getOutput());
        }
    }

    @Override
    public void thinkInternalTraining() {
        for (int i = 0; i < this.componentCount; ++i) {
            IComponent ci = this.getComponent(i);
            ci.thinkInternalTraining();
            if (i >= this.componentCount - 1) continue;
            this.getComponent(i + 1).setInput(ci.getOutput());
        }
    }

    @Override
    public Stack clone() {
        ArrayList<IComponent> al = new ArrayList<IComponent>();
        for (IComponent c : this.components) {
            al.add(c.clone());
        }
        return new Stack(al);
    }

    @Override
    public void setInput(AVector input) {
        ((IComponent)this.components.get(0)).setInput(input);
    }

    @Override
    public IInputState getInputState() {
        return ((IComponent)this.components.get(0)).getInputState();
    }

    @Override
    public void trainGradientInternal(double factor) {
        int n = this.componentCount;
        this.topComponent().trainGradientInternal(factor *= this.getLearnFactor());
        for (int i = n - 2; i >= 0; --i) {
            AVector gradient = this.getComponent(i + 1).getInputGradient();
            IComponent comp = this.getComponent(i);
            comp.getOutputGradient().set(gradient);
            comp.trainGradientInternal(factor);
        }
    }

    @Override
    public AVector getOutput() {
        return this.getComponent(this.componentCount - 1).getOutput();
    }

    @Override
    public AVector getOutputGradient() {
        return this.getComponent(this.componentCount - 1).getOutputGradient();
    }
}

