/*
 * Decompiled with CFR 0.152.
 */
package nuroko.testing;

import java.util.List;
import mikera.util.Tools;
import mikera.vectorz.AVector;
import mikera.vectorz.Vector;
import mikera.vectorz.Vectorz;
import mikera.vectorz.impl.Vector0;
import nuroko.core.Components;
import nuroko.core.IComponent;
import nuroko.core.IInputState;
import nuroko.core.IModule;
import nuroko.core.IParameterised;
import nuroko.core.IThinker;
import org.junit.Assert;

public class GenericModuleTests {
    private static void testDecomposeParams(IComponent c) {
        List<IComponent> cs = (c = c.clone()).getComponents();
        int ccount = cs.size();
        if (ccount == 0) {
            return;
        }
        Vector0 cp = Vector0.INSTANCE;
        Vector0 cg = Vector0.INSTANCE;
        for (int i = 0; i < ccount; ++i) {
            IComponent ch = cs.get(i);
            cp = cp.join(ch.getParameters());
            cg = cg.join(ch.getGradient());
        }
        Vectorz.fillGaussian((AVector)cp);
        Vectorz.fillGaussian((AVector)cg);
        Assert.assertEquals((long)c.getParameterLength(), (long)cp.length());
        Assert.assertEquals((long)c.getParameterLength(), (long)cg.length());
        Assert.assertEquals((Object)cp, (Object)c.getParameters());
        Assert.assertEquals((Object)cg, (Object)c.getGradient());
    }

    private static void testFill(IParameterised p) {
        p = p.clone();
        AVector param = p.getParameters();
        AVector grad = p.getGradient();
        param.fill(1.0);
        grad.fill(1.0);
        Assert.assertTrue((boolean)param.epsilonEquals(grad));
        for (int i = 0; i < param.length(); ++i) {
            Assert.assertEquals((double)1.0, (double)param.get(i), (double)0.0);
            param.set(i, 2.0);
        }
        grad.add(grad);
        Assert.assertTrue((boolean)param.epsilonEquals(grad));
    }

    private static void testCloneNotLinked(IParameterised p) {
        if (p.getParameterLength() == 0) {
            return;
        }
        p = p.clone();
        AVector param = p.getParameters();
        AVector grad = p.getGradient();
        param.fill(1.0);
        grad.fill(1.0);
        IParameterised p2 = p.clone();
        p2.getParameters().fill(2.0);
        p2.getGradient().fill(2.0);
        Assert.assertEquals((double)1.0, (double)Vectorz.maxValue((AVector)param), (double)0.0);
        Assert.assertEquals((double)1.0, (double)Vectorz.maxValue((AVector)grad), (double)0.0);
    }

    private static void testCloneCopyParameters(IParameterised p) {
        if (p instanceof IThinker) {
            p = p.clone();
            AVector param = p.getParameters();
            Vectorz.fillGaussian((AVector)param);
            IParameterised p2 = p.clone();
            p2.getParameters().fill(2.0);
            p2.getParameters().set(param);
            AVector input = Vectorz.newVector((int)((IThinker)((Object)p)).getInputLength());
            AVector output = Vectorz.newVector((int)((IThinker)((Object)p)).getOutputLength());
            AVector output2 = Vectorz.newVector((int)((IThinker)((Object)p)).getOutputLength());
            Vectorz.fillRandom((AVector)input);
            ((IThinker)((Object)p)).think(input, output);
            ((IThinker)((Object)p2)).think(input, output2);
            Assert.assertTrue((boolean)output.epsilonEquals(output2));
        }
    }

    private static void testParameterVectors(IParameterised p) {
        int pl = p.getParameterLength();
        AVector parameters = p.getParameters();
        AVector gradient = p.getGradient();
        Assert.assertEquals((long)pl, (long)parameters.length());
        Assert.assertEquals((long)pl, (long)gradient.length());
        if (pl > 0) {
            Assert.assertTrue((boolean)Tools.distinctObjects((Object[])new Object[]{p.getParameters(), p.getGradient()}));
        }
    }

    private static void testParameterized(IParameterised p) {
        p = p.clone();
        GenericModuleTests.testParameterVectors(p);
        GenericModuleTests.testFill(p);
        GenericModuleTests.testCloneNotLinked(p);
        GenericModuleTests.testCloneCopyParameters(p);
    }

    private static void testOverwriteOutput(IThinker p) {
        p = p.clone();
        AVector input = Vectorz.newVector((int)p.getInputLength());
        AVector output = Vectorz.newVector((int)p.getOutputLength());
        output.fill(Double.NaN);
        p.think(input, output);
        for (int i = 0; i < output.length(); ++i) {
            Assert.assertTrue((output.get(i) != Double.NaN ? 1 : 0) != 0);
        }
    }

    private static void testOverwriteInputGradient(IComponent p) {
        p = p.clone();
        AVector og = Vectorz.newVector((int)p.getOutputLength());
        Vectorz.fillGaussian((AVector)og);
        p.getInputGradient().fill(Double.NaN);
        p.getOutputGradient().set(og);
        p.trainGradientInternal(1.0);
        AVector ig = p.getInputGradient();
        for (int i = 0; i < ig.length(); ++i) {
            Assert.assertTrue((ig.get(i) != Double.NaN ? 1 : 0) != 0);
        }
    }

    private static void testJoinedGradient(IComponent p) {
        int i;
        if (p.isStochastic()) {
            return;
        }
        p = p.clone();
        int il = p.getInputLength();
        int pl = p.getParameterLength();
        AVector og = Vectorz.newVector((int)p.getOutputLength());
        Vectorz.fillGaussian((AVector)og);
        og = og.join(og);
        p = Components.join(p, p.clone());
        p.getInputGradient().fill(Double.NaN);
        p.getOutputGradient().set(og);
        p.trainGradientInternal(1.0);
        AVector ig = p.getInputGradient();
        for (i = 0; i < ig.length(); ++i) {
            Assert.assertTrue((ig.get(i) != Double.NaN ? 1 : 0) != 0);
        }
        for (i = 0; i < il; ++i) {
            Assert.assertTrue((ig.get(i) == ig.get(i + il) ? 1 : 0) != 0);
        }
        AVector params = p.getParameters();
        Assert.assertEquals((long)(pl * 2), (long)params.length());
        params.addMultiple(p.getGradient(), 1.0E-4);
        for (int i2 = 0; i2 < pl; ++i2) {
            Assert.assertTrue((params.get(i2) == params.get(i2 + pl) ? 1 : 0) != 0);
        }
    }

    private static void testThinker(IThinker p) {
        p = p.clone();
        GenericModuleTests.testOverwriteOutput(p);
    }

    private static void testInput(IInputState o) {
        AVector input = o.getInput();
        Assert.assertEquals((long)o.getInputLength(), (long)input.length());
        Assert.assertEquals((long)input.length(), (long)o.getInputGradient().length());
    }

    private static void testModule(IModule o) {
        for (IModule iModule : o.getModules()) {
            GenericModuleTests.test(iModule);
        }
    }

    private static void testGeneralThinking(IComponent p) {
        p = p.clone();
        AVector input = Vectorz.createUniformRandomVector((int)p.getInputLength());
        AVector output = Vectorz.createUniformRandomVector((int)p.getOutputLength());
        p.think(input, output);
        for (int i = 0; i < output.length(); ++i) {
            Assert.assertTrue((output.get(i) != Double.NaN ? 1 : 0) != 0);
        }
        AVector res = p.think(input);
        if (!p.isStochastic()) {
            Assert.assertEquals((Object)res, (Object)output);
        }
        Assert.assertEquals((Object)output, (Object)p.getOutput());
    }

    private static void testStates(IComponent p) {
        Assert.assertTrue((boolean)Tools.distinctObjects((Object[])new Object[]{p.getInput(), p.getOutput(), p.getInputGradient(), p.getOutputGradient()}));
    }

    private static void testParameterUpdates(IComponent p) {
        p = p.clone();
        int ol = p.getOutputLength();
        int il = p.getInputLength();
        AVector grad = p.getGradient();
        grad.fill(0.0);
        Vector output = Vector.createLength((int)ol);
        Vector target = Vector.createLength((int)ol);
        Vector input = Vector.createLength((int)il);
        Vectorz.fillGaussian((AVector)target);
        Vectorz.fillGaussian((AVector)input);
        p.think((AVector)input, (AVector)output);
        Assert.assertTrue((boolean)grad.isZero());
        p.train((AVector)input, (AVector)target);
        AVector tg = grad.clone();
        if (!p.isStochastic() && !p.isSynthesiser()) {
            p.train((AVector)input, (AVector)output);
            Assert.assertEquals((Object)tg, (Object)grad);
        }
    }

    private static void testSubComponents(IComponent o) {
        for (IComponent m : o.getComponents()) {
            GenericModuleTests.test(m);
        }
    }

    private static void testComponent(IComponent o) {
        GenericModuleTests.testGeneralThinking(o);
        GenericModuleTests.testStates(o);
        GenericModuleTests.testSubComponents(o);
        GenericModuleTests.testDecomposeParams(o);
        GenericModuleTests.testParameterUpdates(o);
        Assert.assertTrue((o.getInputState().getInput() == o.getInput() ? 1 : 0) != 0);
    }

    public static void test(Object o) {
        if (o instanceof IParameterised) {
            GenericModuleTests.testParameterized((IParameterised)o);
        }
        if (o instanceof IInputState) {
            GenericModuleTests.testInput((IInputState)o);
        }
        if (o instanceof IThinker) {
            GenericModuleTests.testThinker((IThinker)o);
        }
        if (o instanceof IModule) {
            GenericModuleTests.testModule((IModule)o);
        }
        if (o instanceof IComponent) {
            IComponent c = (IComponent)o;
            GenericModuleTests.testComponent(c);
            GenericModuleTests.testOverwriteInputGradient(c);
            GenericModuleTests.testJoinedGradient(c);
        }
    }
}

