/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module.layers;

import mikera.indexz.Index;
import mikera.indexz.Indexz;
import mikera.matrixx.AMatrix;
import mikera.matrixx.impl.VectorMatrixMN;
import mikera.vectorz.AVector;
import mikera.vectorz.ArrayVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import nuroko.module.AWeightLayer;

public final class FullWeightLayer
extends AWeightLayer {
    private final Vector bias;
    private final Vector biasGradient;
    private final Vector[] weights;
    private final Vector[] weightGradients;
    private final AVector parameters;
    private final AVector gradient;

    public FullWeightLayer(int inputLength, int outputLength) {
        super(inputLength, outputLength);
        this.bias = Vector.createLength((int)outputLength);
        this.weights = new Vector[outputLength];
        Vector params = this.bias;
        for (int j = 0; j < outputLength; ++j) {
            Vector wts;
            this.weights[j] = wts = Vector.createLength((int)inputLength);
            params = params.join((AVector)wts);
        }
        this.parameters = params;
        this.biasGradient = Vector.createLength((int)outputLength);
        this.weightGradients = new Vector[outputLength];
        Vector g = this.biasGradient;
        for (int j = 0; j < outputLength; ++j) {
            Vector grd;
            this.weightGradients[j] = grd = Vector.createLength((int)inputLength);
            g = g.join((AVector)grd);
        }
        this.gradient = g;
        assert (this.gradient.length() == this.parameters.length());
    }

    @Override
    public AVector getParameters() {
        return this.parameters;
    }

    @Override
    public int getParameterLength() {
        return this.parameters.length();
    }

    @Override
    public void think(AVector input, AVector output) {
        assert (this.inputLength == input.length());
        assert (this.outputLength == output.length());
        this.input.set(input);
        this.thinkInternal();
        output.set((AVector)this.getOutput());
    }

    @Override
    public void thinkInternal() {
        for (int j = 0; j < this.outputLength; ++j) {
            double val = this.bias.get(j);
            this.output.set(j, val += this.weights[j].dotProduct(this.input));
        }
    }

    @Override
    public AVector getGradient() {
        return this.gradient;
    }

    @Override
    public void trainGradientInternal(double factor) {
        this.inputGradient.fill(0.0);
        this.biasGradient.addMultiple((ArrayVector)this.outputGradient, factor *= this.getLearnFactor());
        for (int j = 0; j < this.outputLength; ++j) {
            double grad = this.outputGradient.get(j);
            this.weightGradients[j].addMultiple((ArrayVector)this.input, grad * factor);
            this.inputGradient.addMultiple((ArrayVector)this.weights[j], grad);
        }
    }

    @Override
    public FullWeightLayer clone() {
        FullWeightLayer wl = new FullWeightLayer(this.getInputLength(), this.getOutputLength());
        wl.getParameters().set(this.getParameters());
        return wl;
    }

    @Override
    public int getLinkCount(int outputIndex) {
        return this.inputLength;
    }

    @Override
    public double getLinkWeight(int outputIndex, int number) {
        return this.weights[outputIndex].data[number];
    }

    @Override
    public int getLinkSource(int outputIndex, int number) {
        return number;
    }

    @Override
    public void initRandom() {
        Vectorz.fillGaussian((AVector)this.bias, (double)0.0, (double)0.3);
        for (Vector v : this.weights) {
            Vectorz.fillGaussian((AVector)v, (double)0.0, (double)(1.0 / Math.sqrt(v.length())));
        }
    }

    @Override
    public FullWeightLayer getInverse() {
        FullWeightLayer wl = this.getInverseStructure();
        for (int j = 0; j < this.outputLength; ++j) {
            AVector owts = this.getSourceWeights(j);
            for (int i = 0; i < this.inputLength; ++i) {
                AVector wts = wl.getSourceWeights(i);
                wts.set(j, owts.get(i));
            }
        }
        return wl;
    }

    public FullWeightLayer getInverseStructure() {
        FullWeightLayer wl = new FullWeightLayer(this.getOutputLength(), this.getInputLength());
        return wl;
    }

    @Override
    public Index getSourceIndex(int outputIndex) {
        return Indexz.createSequence((int)this.inputLength);
    }

    @Override
    public AVector getSourceWeights(int outputIndex) {
        return this.weights[outputIndex];
    }

    @Override
    public AMatrix asMatrix() {
        return VectorMatrixMN.wrap((AVector[])this.weights);
    }

    @Override
    public boolean hasDifferentTrainingThinking() {
        return false;
    }

    public Vector getBias() {
        return this.bias;
    }
}

