/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module;

import mikera.util.Rand;
import mikera.vectorz.AVector;
import mikera.vectorz.impl.Vector0;
import nuroko.module.AStateComponent;

public class Dropout
extends AStateComponent {
    private double dropoutRate = 0.5;
    private final boolean[] dropped;

    public Dropout(int length) {
        super(length, length);
        this.dropped = new boolean[length];
    }

    public Dropout(int length, double dropoutRate) {
        this(length);
        this.dropoutRate = dropoutRate;
    }

    @Override
    public void thinkInternal() {
        this.output.set((AVector)this.input);
    }

    @Override
    public void thinkInternalTraining() {
        this.output.set((AVector)this.input);
        if (this.dropoutRate > 0.0) {
            double scaleFactor = 1.0 / this.dropoutRate;
            double[] dt = this.output.getArray();
            for (int i = 0; i < dt.length; ++i) {
                boolean drop = Rand.chance((double)this.dropoutRate);
                if (Rand.chance((double)this.dropoutRate)) {
                    dt[i] = 0.0;
                } else {
                    int n = i;
                    dt[n] = dt[n] * scaleFactor;
                }
                this.dropped[i] = drop;
            }
        }
    }

    @Override
    public AVector getParameters() {
        return Vector0.INSTANCE;
    }

    @Override
    public AVector getGradient() {
        return Vector0.INSTANCE;
    }

    @Override
    public void trainGradientInternal(double factor) {
        this.inputGradient.set((AVector)this.outputGradient);
        int len = this.getInputLength();
        double[] ig = this.inputGradient.getArray();
        for (int i = 0; i < len; ++i) {
            if (!this.dropped[i]) continue;
            ig[i] = 0.0;
        }
    }

    @Override
    public Dropout clone() {
        return new Dropout(this.getInputLength(), this.dropoutRate);
    }
}

