/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.krystal.vajramexecutor.krystex;

import com.flipkart.krystal.annos.ExternallyInvocable;
import com.flipkart.krystal.core.VajramID;
import com.flipkart.krystal.facets.Dependency;
import com.flipkart.krystal.facets.Facet;
import com.flipkart.krystal.facets.FacetType;
import com.flipkart.krystal.facets.resolution.ResolverDefinition;
import com.flipkart.krystal.krystex.kryon.DefaultDependentChain;
import com.flipkart.krystal.krystex.kryon.DependentChain;
import com.flipkart.krystal.krystex.kryon.DependentChainStart;
import com.flipkart.krystal.krystex.kryon.KryonDefinitionRegistry;
import com.flipkart.krystal.krystex.logicdecoration.LogicExecutionContext;
import com.flipkart.krystal.krystex.logicdecoration.OutputLogicDecorator;
import com.flipkart.krystal.krystex.logicdecoration.OutputLogicDecoratorConfig;
import com.flipkart.krystal.vajram.IOVajramDef;
import com.flipkart.krystal.vajram.Vajram;
import com.flipkart.krystal.vajram.annos.VajramIdentifier;
import com.flipkart.krystal.vajram.batching.InputBatcher;
import com.flipkart.krystal.vajram.batching.InputBatcherImpl;
import com.flipkart.krystal.vajram.exec.VajramDefinition;
import com.flipkart.krystal.vajram.facets.resolution.InputResolver;
import com.flipkart.krystal.vajram.facets.specs.DependencySpec;
import com.flipkart.krystal.vajram.facets.specs.FacetSpec;
import com.flipkart.krystal.vajramexecutor.krystex.InputBatchingDecorator;
import com.flipkart.krystal.vajramexecutor.krystex.VajramKryonGraph;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.NoSuchElementException;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
import java.util.function.Supplier;
import org.checkerframework.checker.nullness.qual.Nullable;

public record InputBatcherConfig(Function<LogicExecutionContext, String> instanceIdGenerator, Predicate<BatcherContext> shouldBatch, Function<BatcherContext, OutputLogicDecorator> decoratorFactory) {
    public static InputBatcherConfig simple(Supplier<InputBatcher> inputBatcherSupplier) {
        return new InputBatcherConfig(logicExecutionContext -> InputBatcherConfig.generateInstanceId(logicExecutionContext.dependants(), logicExecutionContext.kryonDefinitionRegistry()).toString(), batcherContext -> true, batcherContext -> new InputBatchingDecorator(batcherContext.logicDecoratorContext().instanceId(), (InputBatcher)inputBatcherSupplier.get(), dependantChain -> batcherContext.logicDecoratorContext().logicExecutionContext().dependants().equals(dependantChain)));
    }

    public static InputBatcherConfig sharedBatcher(Supplier<InputBatcher> inputBatcherSupplier, String instanceId, DependentChain ... dependentChains) {
        return InputBatcherConfig.sharedBatcher(inputBatcherSupplier, instanceId, (ImmutableSet<DependentChain>)ImmutableSet.copyOf((Object[])dependentChains));
    }

    public static InputBatcherConfig sharedBatcher(Supplier<InputBatcher> inputBatcherSupplier, String instanceId, ImmutableSet<DependentChain> dependentChains) {
        return new InputBatcherConfig(logicExecutionContext -> instanceId, batcherContext -> dependentChains.contains((Object)batcherContext.logicDecoratorContext().logicExecutionContext().dependants()), batcherContext -> new InputBatchingDecorator(instanceId, (InputBatcher)inputBatcherSupplier.get(), arg_0 -> ((ImmutableSet)dependentChains).contains(arg_0)));
    }

    public static void autoRegisterSharedBatchers(VajramKryonGraph graph, BatchSizeSupplier batchSizeSupplier) {
        InputBatcherConfig.autoRegisterSharedBatchers(graph, batchSizeSupplier, (ImmutableSet<DependentChain>)ImmutableSet.of());
    }

    public static void autoRegisterSharedBatchers(VajramKryonGraph graph, BatchSizeSupplier batchSizeSupplier, ImmutableSet<DependentChain> disabledDependentChains) {
        Map<VajramID, Map<Integer, Set<DependentChain>>> ioNodes = InputBatcherConfig.getIoVajrams(graph, disabledDependentChains);
        ioNodes.forEach((vajramId, ioNodeMap) -> {
            int inputModulatorIndex = 0;
            if (InputBatcherConfig.isBatchingNeededForIoVajram(graph, vajramId)) {
                InputBatcherConfig[] inputModulatorConfigs = new InputBatcherConfig[ioNodeMap.size()];
                for (Map.Entry entry : ioNodeMap.entrySet()) {
                    Set depChains = (Set)entry.getValue();
                    inputModulatorConfigs[inputModulatorIndex++] = InputBatcherConfig.sharedBatcher(() -> new InputBatcherImpl(batchSizeSupplier.getBatchSize((VajramID)vajramId)), vajramId.id(), (DependentChain[])depChains.toArray(DependentChain[]::new));
                }
                graph.registerInputBatchers((VajramID)vajramId, inputModulatorConfigs);
            }
        });
    }

    private static StringBuilder generateInstanceId(DependentChain dependentChain, KryonDefinitionRegistry kryonDefinitionRegistry) {
        if (dependentChain instanceof DependentChainStart) {
            DependentChainStart dependantChainStart = (DependentChainStart)dependentChain;
            return new StringBuilder(dependantChainStart.toString());
        }
        if (dependentChain instanceof DefaultDependentChain) {
            DefaultDependentChain defaultDependantChain = (DefaultDependentChain)dependentChain;
            if (defaultDependantChain.incomingDependentChain() instanceof DependentChainStart) {
                Optional vajramIdAnno = kryonDefinitionRegistry.getOrThrow(defaultDependantChain.kryonId()).tags().getAnnotationByType(VajramIdentifier.class);
                if (vajramIdAnno.isPresent()) {
                    return InputBatcherConfig.generateInstanceId(defaultDependantChain.incomingDependentChain(), kryonDefinitionRegistry).append('>').append(((VajramIdentifier)vajramIdAnno.get()).value()).append(':').append(defaultDependantChain.latestDependency());
                }
                throw new NoSuchElementException("Could not find tag %s for kryon %s".formatted(Vajram.class, defaultDependantChain.kryonId()));
            }
            return InputBatcherConfig.generateInstanceId(defaultDependantChain.incomingDependentChain(), kryonDefinitionRegistry).append('>').append(defaultDependantChain.latestDependency());
        }
        throw new UnsupportedOperationException();
    }

    private static boolean isBatchingNeededForIoVajram(VajramKryonGraph graph, VajramID ioNode) {
        VajramDefinition ioNodeVajram = graph.getVajramDefinition(ioNode);
        for (FacetSpec facetSpec : ioNodeVajram.facetSpecs()) {
            if (!facetSpec.isBatched()) continue;
            return true;
        }
        return false;
    }

    private static Map<VajramID, Map<Integer, Set<DependentChain>>> getIoVajrams(VajramKryonGraph graph, ImmutableSet<DependentChain> disabledDependentChains) {
        HashMap<VajramID, Map<Integer, Set<DependentChain>>> ioNodes = new HashMap<VajramID, Map<Integer, Set<DependentChain>>>();
        for (VajramDefinition rootNode : InputBatcherConfig.externallyInvocableVajrams(graph)) {
            DependentChain dependentChain = graph.kryonDefinitionRegistry().getDependantChainsStart();
            HashMap<VajramID, Integer> ioNodeDepths = new HashMap<VajramID, Integer>();
            InputBatcherConfig.dfs(rootNode, graph, ioNodes, 0, dependentChain, ioNodeDepths, disabledDependentChains);
        }
        return ioNodes;
    }

    private static Iterable<VajramDefinition> externallyInvocableVajrams(VajramKryonGraph graph) {
        return graph.vajramDefinitions().values().stream().filter(v -> v.vajramTags().getAnnotationByType(ExternallyInvocable.class).isPresent()).toList();
    }

    private static void dfs(VajramDefinition rootNode, VajramKryonGraph graph, Map<VajramID, Map<Integer, Set<DependentChain>>> ioNodes, int depth, DependentChain incomingDepChain, Map<VajramID, Integer> ioNodeDepths, ImmutableSet<DependentChain> disabledDependentChains) {
        HashMap<Facet, List<Facet>> inputDefGraph = new HashMap<Facet, List<Facet>>();
        VajramID vajramId = rootNode.vajramId();
        graph.loadKryonSubGraphIfNeeded(vajramId);
        for (Facet inputDef : InputBatcherConfig.getOrderedInputDef(rootNode, inputDefGraph)) {
            DependentChain dependentChain;
            if (!(inputDef instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDef;
            List<ResolverDefinition> resolverDefinition = InputBatcherConfig.getInputResolverDefinition(rootNode, dependency);
            VajramDefinition childNode = graph.getVajramDefinition(dependency.onVajramId());
            if (inputDefGraph.get(inputDef) != null) {
                for (Facet inputDef1 : (List)inputDefGraph.get(inputDef)) {
                    VajramID prerequisiteVajramId = InputBatcherConfig.dependencyInputInChildNode(resolverDefinition, inputDef1);
                    if (prerequisiteVajramId == null) continue;
                    InputBatcherConfig.incrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(prerequisiteVajramId), graph, ioNodeDepths);
                }
            }
            if (disabledDependentChains.contains((Object)(dependentChain = incomingDepChain.extend(vajramId, (Dependency)dependency)))) continue;
            if (childNode.def() instanceof IOVajramDef) {
                depth = ioNodeDepths.computeIfAbsent(childNode.vajramId(), _v -> 0);
                ioNodes.computeIfAbsent(childNode.vajramId(), k -> new HashMap()).computeIfAbsent(depth, k -> new LinkedHashSet()).add(dependentChain);
            }
            InputBatcherConfig.dfs(childNode, graph, ioNodes, depth, dependentChain, ioNodeDepths, disabledDependentChains);
            if (inputDefGraph.get(inputDef) == null) continue;
            for (Facet inputDef1 : (List)inputDefGraph.get(inputDef)) {
                VajramID prerequisiteVajramId = InputBatcherConfig.dependencyInputInChildNode(resolverDefinition, inputDef1);
                if (prerequisiteVajramId == null) continue;
                graph.getVajramDefinition(prerequisiteVajramId);
                InputBatcherConfig.decrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(prerequisiteVajramId), graph, ioNodeDepths);
            }
        }
    }

    private static void incrementTheLeafIONodeOfTheVajram(VajramDefinition node, VajramKryonGraph graph, Map<VajramID, Integer> ioNodeDepth) {
        if (node.def() instanceof IOVajramDef) {
            ioNodeDepth.compute(node.vajramId(), (_vid, depth) -> depth == null ? 0 : depth + 1);
        }
        for (Facet inputDef : node.facetSpecs()) {
            if (!(inputDef instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDef;
            InputBatcherConfig.incrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(dependency.onVajramId()), graph, ioNodeDepth);
        }
    }

    private static void decrementTheLeafIONodeOfTheVajram(VajramDefinition node, VajramKryonGraph graph, Map<VajramID, Integer> ioNodeDepth) {
        if (node.def() instanceof IOVajramDef) {
            ioNodeDepth.compute(node.vajramId(), (_vid, depth) -> depth == null ? 0 : depth - 1);
        }
        for (Facet inputDef : node.facetSpecs()) {
            if (!(inputDef instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDef;
            InputBatcherConfig.decrementTheLeafIONodeOfTheVajram(graph.getVajramDefinition(dependency.onVajramId()), graph, ioNodeDepth);
        }
    }

    private static @Nullable VajramID dependencyInputInChildNode(List<ResolverDefinition> depInputs, Facet inputDefinition) {
        for (ResolverDefinition depInput : depInputs) {
            if (!(inputDefinition instanceof DependencySpec)) continue;
            DependencySpec dependency = (DependencySpec)inputDefinition;
            if (!depInput.sources().contains((Object)inputDefinition)) continue;
            return dependency.onVajramId();
        }
        return null;
    }

    private static List<ResolverDefinition> getInputResolverDefinition(VajramDefinition rootNode, DependencySpec<?, ?, ?> dependency) {
        return rootNode.inputResolvers().values().stream().filter(inputResolver -> inputResolver.definition().target().dependency().id() == dependency.id()).map(InputResolver::definition).collect(ArrayList::new, ArrayList::add, ArrayList::addAll);
    }

    private static Collection<Facet> getOrderedInputDef(VajramDefinition rootNode, Map<Facet, List<Facet>> graph) {
        ImmutableCollection resolvers = rootNode.inputResolvers().values();
        ImmutableSet inputDefinitions = rootNode.facetSpecs();
        for (InputResolver resolver : resolvers) {
            ResolverDefinition resolverDefinition = resolver.definition();
            for (Facet facet : inputDefinitions) {
                Facet dependingVID;
                if (!facet.facetTypes().contains((Object)FacetType.DEPENDENCY) || !resolverDefinition.sources().contains((Object)facet) || (dependingVID = InputBatcherConfig.getInputDefinitionDep((Facet)resolverDefinition.target().dependency(), (ImmutableCollection<? extends Facet>)inputDefinitions)) == null) continue;
                graph.putIfAbsent(dependingVID, new ArrayList());
                graph.get(dependingVID).add(facet);
            }
        }
        HashSet<Facet> visited = new HashSet<Facet>();
        ArrayDeque<Facet> queue = new ArrayDeque<Facet>();
        for (Facet vid : inputDefinitions) {
            if (!vid.facetTypes().contains((Object)FacetType.DEPENDENCY) || visited.contains(vid)) continue;
            InputBatcherConfig.topologicalSortUtil(vid, visited, graph, queue);
        }
        return queue;
    }

    private static @Nullable Facet getInputDefinitionDep(Facet dep, ImmutableCollection<? extends Facet> inputDefinitions) {
        for (Facet facet : inputDefinitions) {
            if (!facet.facetTypes().contains((Object)FacetType.DEPENDENCY) || facet.id() != dep.id()) continue;
            return facet;
        }
        return null;
    }

    static void topologicalSortUtil(Facet vid, Set<Facet> visited, Map<Facet, List<Facet>> graph, Queue<Facet> stack) {
        visited.add(vid);
        for (Facet i : (List)graph.getOrDefault(vid, new ArrayList())) {
            if (visited.contains(i)) continue;
            InputBatcherConfig.topologicalSortUtil(i, visited, graph, stack);
        }
        if (vid.facetTypes().contains((Object)FacetType.DEPENDENCY)) {
            stack.add(vid);
        }
    }

    @FunctionalInterface
    public static interface BatchSizeSupplier {
        public int getBatchSize(VajramID var1);
    }

    public record BatcherContext(OutputLogicDecoratorConfig.LogicDecoratorContext logicDecoratorContext) {
    }
}

