/*
 * Decompiled with CFR 0.152.
 */
package nuroko.module.loss;

import mikera.vectorz.AVector;
import nuroko.core.NurokoException;
import nuroko.module.loss.LossFunction;

public class CrossEntropyLoss
extends LossFunction {
    public static CrossEntropyLoss INSTANCE = new CrossEntropyLoss();
    public static final double BOUND = 1.0E-12;

    @Override
    public void calculateErrorDerivative(AVector output, AVector target, AVector gradientOut) {
        int tlen;
        int olen = output.length();
        if (olen != (tlen = target.length())) {
            throw new NurokoException("Target / output size mismtach: " + tlen + " vs. " + olen);
        }
        for (int i = 0; i < tlen; ++i) {
            double y = output.get(i);
            double t = target.get(i);
            double k = Math.max(1.0E-12, y * (1.0 - y));
            if (k != 0.0) {
                gradientOut.set(i, (t - y) / k);
                continue;
            }
            gradientOut.set(i, 0.0);
        }
    }

    @Override
    public double calculateError(AVector output, AVector target) {
        int len = output.length();
        double ce = 0.0;
        for (int i = 0; i < len; ++i) {
            double p = target.get(i);
            double q = output.get(i);
            ce -= Math.log(p * q + (1.0 - p) * (1.0 - q));
        }
        return ce;
    }
}

