/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.krystal.krystex.node;

import com.flipkart.krystal.data.InputValue;
import com.flipkart.krystal.data.Inputs;
import com.flipkart.krystal.data.Results;
import com.flipkart.krystal.data.ValueOrError;
import com.flipkart.krystal.krystex.LogicDefinition;
import com.flipkart.krystal.krystex.MainLogic;
import com.flipkart.krystal.krystex.MainLogicDefinition;
import com.flipkart.krystal.krystex.commands.BatchCommand;
import com.flipkart.krystal.krystex.commands.CallbackBatch;
import com.flipkart.krystal.krystex.commands.Flush;
import com.flipkart.krystal.krystex.commands.ForwardBatch;
import com.flipkart.krystal.krystex.decoration.FlushCommand;
import com.flipkart.krystal.krystex.decoration.LogicDecorationOrdering;
import com.flipkart.krystal.krystex.decoration.LogicExecutionContext;
import com.flipkart.krystal.krystex.decoration.MainLogicDecorator;
import com.flipkart.krystal.krystex.node.AbstractNode;
import com.flipkart.krystal.krystex.node.BatchResponse;
import com.flipkart.krystal.krystex.node.DependantChain;
import com.flipkart.krystal.krystex.node.DuplicateRequestException;
import com.flipkart.krystal.krystex.node.KrystalNodeExecutor;
import com.flipkart.krystal.krystex.node.MainLogicInputs;
import com.flipkart.krystal.krystex.node.NodeDefinition;
import com.flipkart.krystal.krystex.node.NodeId;
import com.flipkart.krystal.krystex.node.NodeLogicId;
import com.flipkart.krystal.krystex.node.NodeUtils;
import com.flipkart.krystal.krystex.request.RequestId;
import com.flipkart.krystal.krystex.request.RequestIdGenerator;
import com.flipkart.krystal.krystex.resolution.DependencyResolutionRequest;
import com.flipkart.krystal.krystex.resolution.MultiResolverDefinition;
import com.flipkart.krystal.krystex.resolution.ResolverCommand;
import com.flipkart.krystal.krystex.resolution.ResolverDefinition;
import com.flipkart.krystal.utils.Futures;
import com.flipkart.krystal.utils.ImmutableMapView;
import com.flipkart.krystal.utils.SkippedExecutionException;
import com.google.common.base.Functions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;

final class BatchNode
extends AbstractNode<BatchCommand, BatchResponse> {
    private final Map<DependantChain, Set<String>> availableInputsByDepChain = new LinkedHashMap<DependantChain, Set<String>>();
    private final Map<DependantChain, ForwardBatch> inputsValueCollector = new LinkedHashMap<DependantChain, ForwardBatch>();
    private final Map<DependantChain, Map<String, CallbackBatch>> dependencyValuesCollector = new LinkedHashMap<DependantChain, Map<String, CallbackBatch>>();
    private final Map<DependantChain, CompletableFuture<BatchResponse>> resultsByDepChain = new LinkedHashMap<DependantChain, CompletableFuture<BatchResponse>>();
    private final Map<Inputs, CompletableFuture<Object>> resultsCache = new LinkedHashMap<Inputs, CompletableFuture<Object>>();
    private final Map<DependantChain, Set<String>> executedDependencies = new LinkedHashMap<DependantChain, Set<String>>();
    private final Map<DependantChain, Set<RequestId>> requestsByDependantChain = new LinkedHashMap<DependantChain, Set<RequestId>>();
    private final Set<DependantChain> flushedDependantChain = new LinkedHashSet<DependantChain>();
    private final Map<DependantChain, Boolean> mainLogicExecuted = new LinkedHashMap<DependantChain, Boolean>();

    BatchNode(NodeDefinition nodeDefinition, KrystalNodeExecutor krystalNodeExecutor, Function<LogicExecutionContext, ImmutableMap<String, MainLogicDecorator>> requestScopedDecoratorsSupplier, LogicDecorationOrdering logicDecorationOrdering, KrystalNodeExecutor.ResolverExecStrategy resolverExecStrategy, RequestIdGenerator requestIdGenerator) {
        super(nodeDefinition, krystalNodeExecutor, requestScopedDecoratorsSupplier, logicDecorationOrdering, resolverExecStrategy, requestIdGenerator);
    }

    @Override
    public void executeCommand(Flush flushCommand) {
        this.flushedDependantChain.add(flushCommand.dependantChain());
        this.flushAllDependenciesIfNeeded(flushCommand.dependantChain());
        this.flushDecoratorsIfNeeded(flushCommand.dependantChain());
    }

    @Override
    public CompletableFuture<BatchResponse> executeCommand(BatchCommand nodeCommand) {
        DependantChain dependantChain = nodeCommand.dependantChain();
        CompletableFuture resultForDepChain = this.resultsByDepChain.computeIfAbsent(dependantChain, r -> new CompletableFuture());
        try {
            if (nodeCommand instanceof ForwardBatch) {
                ForwardBatch forwardBatch = (ForwardBatch)nodeCommand;
                this.collectInputValues(forwardBatch);
            } else if (nodeCommand instanceof CallbackBatch) {
                CallbackBatch callbackBatch = (CallbackBatch)nodeCommand;
                this.collectDependencyValues(callbackBatch);
            }
            this.triggerDependencies(dependantChain, this.getTriggerableDependencies(dependantChain, nodeCommand.inputNames()));
            Optional<CompletableFuture<BatchResponse>> mainLogicFuture = this.executeMainLogicIfPossible(dependantChain);
            mainLogicFuture.ifPresent(f -> Futures.linkFutures((CompletableFuture)f, (CompletableFuture)resultForDepChain));
        }
        catch (Throwable e) {
            resultForDepChain.completeExceptionally(e);
        }
        return resultForDepChain;
    }

    private Map<String, Set<ResolverDefinition>> getTriggerableDependencies(DependantChain dependantChain, Set<String> newInputNames) {
        Set availableInputs = this.availableInputsByDepChain.getOrDefault(dependantChain, Set.of());
        Set executedDeps = this.executedDependencies.getOrDefault(dependantChain, Set.of());
        return Stream.concat(Stream.concat(Stream.of(Optional.empty()), newInputNames.stream().map(Optional::of)).map(arg_0 -> ((ImmutableMapView)this.resolverDefinitionsByInput).get(arg_0)).filter(Objects::nonNull).flatMap(Collection::stream).map(ResolverDefinition::dependencyName), this.dependenciesWithNoResolvers.stream()).distinct().filter(depName -> !executedDeps.contains(depName)).filter(depName -> ((ImmutableSet)this.resolverDefinitionsByDependencies.getOrDefault(depName, (Object)ImmutableSet.of())).stream().map(ResolverDefinition::boundFrom).flatMap(Collection::stream).allMatch(availableInputs::contains)).collect(Collectors.toMap(Functions.identity(), depName -> (Set)this.resolverDefinitionsByDependencies.getOrDefault(depName, (Object)ImmutableSet.of())));
    }

    private void triggerDependencies(DependantChain dependantChain, Map<String, Set<ResolverDefinition>> triggerableDependencies) {
        ForwardBatch forwardBatch = this.getForwardCommand(dependantChain);
        Optional<MultiResolverDefinition> multiResolverOpt = this.nodeDefinition.multiResolverLogicId().map(nodeLogicId -> this.nodeDefinition.nodeDefinitionRegistry().logicDefinitionRegistry().getMultiResolver((NodeLogicId)nodeLogicId));
        ImmutableMap<RequestId, String> skippedRequests = forwardBatch.skippedRequests();
        ImmutableSet executableRequests = forwardBatch.executableRequests().keySet();
        LinkedHashMap<String, Map> commandsByDependency = new LinkedHashMap<String, Map>();
        if (!skippedRequests.isEmpty()) {
            ResolverCommand.SkipDependency skip = ResolverCommand.skip(String.join((CharSequence)", ", (Iterable<? extends CharSequence>)skippedRequests.values()));
            for (String string : triggerableDependencies.keySet()) {
                commandsByDependency.computeIfAbsent(string, _k -> new LinkedHashMap()).put(skippedRequests.keySet(), skip);
            }
        }
        Set<String> dependenciesWithNoResolvers = triggerableDependencies.entrySet().stream().filter(e -> ((Set)e.getValue()).isEmpty()).map(Map.Entry::getKey).collect(Collectors.toSet());
        for (RequestId requestId : executableRequests) {
            dependenciesWithNoResolvers.forEach(depName -> commandsByDependency.computeIfAbsent((String)depName, _k -> new LinkedHashMap()).put(Set.of(requestId), ResolverCommand.multiExecuteWith((ImmutableList<Inputs>)ImmutableList.of((Object)Inputs.empty()))));
            Inputs inputs = this.getInputsFor(dependantChain, requestId, triggerableDependencies.values().stream().flatMap(Collection::stream).map(ResolverDefinition::boundFrom).flatMap(Collection::stream).collect(Collectors.toSet()));
            multiResolverOpt.map(LogicDefinition::logic).map(logic -> logic.resolve(triggerableDependencies.entrySet().stream().filter(e -> !((Set)e.getValue()).isEmpty()).map(e -> new DependencyResolutionRequest((String)e.getKey(), (Set)e.getValue())).toList(), inputs)).orElse(ImmutableMap.of()).forEach((depName, resolverCommand) -> commandsByDependency.computeIfAbsent((String)depName, _k -> new LinkedHashMap()).put(Set.of(requestId), resolverCommand));
        }
        for (Map.Entry entry : commandsByDependency.entrySet()) {
            String depName3 = (String)entry.getKey();
            Map resolverCommandsForDep = (Map)entry.getValue();
            this.triggerDependency(depName3, dependantChain, resolverCommandsForDep, triggerableDependencies.get(depName3));
        }
    }

    private ForwardBatch getForwardCommand(DependantChain dependantChain) {
        ForwardBatch forwardBatch = this.inputsValueCollector.get(dependantChain);
        if (forwardBatch == null) {
            throw new IllegalArgumentException("Missing Forward command. This should not be possible.");
        }
        return forwardBatch;
    }

    private void triggerDependency(String depName, DependantChain dependantChain, Map<Set<RequestId>, ResolverCommand> resolverCommandsByReq, Set<ResolverDefinition> resolverDefinitions) {
        NodeId depNodeId = (NodeId)this.nodeDefinition.dependencyNodes().get((Object)depName);
        LinkedHashMap<RequestId, Inputs> inputsByDepReq = new LinkedHashMap<RequestId, Inputs>();
        LinkedHashMap<RequestId, String> skipReasonsByReq = new LinkedHashMap<RequestId, String>();
        LinkedHashMap<RequestId, Set> depReqsByIncomingReq = new LinkedHashMap<RequestId, Set>();
        for (Map.Entry<Set<RequestId>, ResolverCommand> entry : resolverCommandsByReq.entrySet()) {
            Set<RequestId> incomingReqIds = entry.getKey();
            ResolverCommand resolverCommand = entry.getValue();
            if (resolverCommand instanceof ResolverCommand.SkipDependency) {
                ResolverCommand.SkipDependency skipDependency = (ResolverCommand.SkipDependency)resolverCommand;
                RequestId depReqId = this.requestIdGenerator.newSubRequest(incomingReqIds.iterator().next(), () -> "%s[skip]".formatted(depName));
                incomingReqIds.forEach(incomingReqId -> depReqsByIncomingReq.computeIfAbsent((RequestId)incomingReqId, _k -> new LinkedHashSet()).add(depReqId));
                skipReasonsByReq.put(depReqId, skipDependency.reason());
                continue;
            }
            int count = 0;
            for (RequestId incomingReqId2 : incomingReqIds) {
                if (resolverCommand.getInputs().isEmpty()) {
                    RequestId depReqId = this.requestIdGenerator.newSubRequest(incomingReqId2, () -> "%s[skip]".formatted(depName));
                    skipReasonsByReq.put(depReqId, "Resolvers for dependency %s resolved to empty list".formatted(depName));
                    continue;
                }
                for (Inputs inputs : resolverCommand.getInputs()) {
                    int currentCount = count++;
                    RequestId depReqId = this.requestIdGenerator.newSubRequest(incomingReqId2, () -> "%s[%s]".formatted(depName, currentCount));
                    depReqsByIncomingReq.computeIfAbsent(incomingReqId2, _k -> new LinkedHashSet()).add(depReqId);
                    inputsByDepReq.put(depReqId, inputs);
                }
            }
        }
        this.executedDependencies.computeIfAbsent(dependantChain, _k -> new LinkedHashSet()).add(depName);
        CompletableFuture depResponse = this.krystalNodeExecutor.executeCommand(new ForwardBatch(depNodeId, (ImmutableSet<String>)((ImmutableSet)resolverDefinitions.stream().map(ResolverDefinition::resolvedInputNames).flatMap(Collection::stream).collect(ImmutableSet.toImmutableSet())), (ImmutableMap<RequestId, Inputs>)ImmutableMap.copyOf(inputsByDepReq), dependantChain.extend(this.nodeId, depName), (ImmutableMap<RequestId, String>)ImmutableMap.copyOf(skipReasonsByReq)));
        depResponse.whenComplete((batchResponse, throwable) -> {
            Set requestIds = resolverCommandsByReq.keySet().stream().flatMap(Collection::stream).collect(Collectors.toSet());
            ImmutableMap results = (ImmutableMap)requestIds.stream().collect(ImmutableMap.toImmutableMap((Function)Functions.identity(), requestId -> {
                if (throwable != null) {
                    return new Results(ImmutableMap.of((Object)Inputs.empty(), (Object)ValueOrError.withError((Throwable)throwable)));
                }
                Set depReqIds = depReqsByIncomingReq.getOrDefault(requestId, Set.of());
                return new Results((ImmutableMap)depReqIds.stream().collect(ImmutableMap.toImmutableMap(depReqId -> inputsByDepReq.getOrDefault(depReqId, Inputs.empty()), depReqId -> (ValueOrError)batchResponse.responses().getOrDefault(depReqId, (Object)ValueOrError.empty()))));
            }));
            NodeUtils.enqueueOrExecuteCommand(() -> new CallbackBatch(this.nodeId, depName, (ImmutableMap<RequestId, Results<Object>>)results, dependantChain), depNodeId, this.nodeDefinition, this.krystalNodeExecutor);
        });
        this.flushDependencyIfNeeded(depName, dependantChain);
    }

    private Optional<CompletableFuture<BatchResponse>> executeMainLogicIfPossible(DependantChain dependantChain) {
        ForwardBatch forwardCommand = this.getForwardCommand(dependantChain);
        ImmutableSet<String> inputNames = this.nodeDefinition.getMainLogicDefinition().inputNames();
        if (this.availableInputsByDepChain.getOrDefault(dependantChain, (Set<String>)ImmutableSet.of()).containsAll((Collection<?>)inputNames)) {
            if (forwardCommand.shouldSkip()) {
                return Optional.of(CompletableFuture.failedFuture((Throwable)new SkippedExecutionException(BatchNode.getSkipMessage(forwardCommand))));
            }
            return Optional.of(this.executeMainLogic((Set<RequestId>)forwardCommand.executableRequests().keySet(), dependantChain));
        }
        return Optional.empty();
    }

    private CompletableFuture<BatchResponse> executeMainLogic(Set<RequestId> requestIds, DependantChain dependantChain) {
        MainLogicDefinition<Object> mainLogicDefinition = this.nodeDefinition.getMainLogicDefinition();
        LinkedHashMap<RequestId, MainLogicInputs> mainLogicInputs = new LinkedHashMap<RequestId, MainLogicInputs>();
        for (RequestId requestId : requestIds) {
            mainLogicInputs.put(requestId, this.getInputsForMainLogic(dependantChain, requestId));
        }
        CompletableFuture<BatchResponse> resultForBatch = new CompletableFuture<BatchResponse>();
        Map<RequestId, CompletableFuture<ValueOrError<Object>>> results = this.executeDecoratedMainLogic(mainLogicDefinition, mainLogicInputs, dependantChain);
        CompletableFuture.allOf((CompletableFuture[])results.values().toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> resultForBatch.complete(new BatchResponse((ImmutableMap<RequestId, ValueOrError<Object>>)((ImmutableMap)mainLogicInputs.keySet().stream().collect(ImmutableMap.toImmutableMap((Function)Functions.identity(), requestId -> ((CompletableFuture)results.get(requestId)).getNow(ValueOrError.empty())))))));
        this.mainLogicExecuted.put(dependantChain, true);
        this.flushDecoratorsIfNeeded(dependantChain);
        return resultForBatch;
    }

    private Map<RequestId, CompletableFuture<ValueOrError<Object>>> executeDecoratedMainLogic(MainLogicDefinition<Object> mainLogicDefinition, Map<RequestId, MainLogicInputs> inputs, DependantChain dependantChain) {
        NavigableSet<MainLogicDecorator> sortedDecorators = this.getSortedDecorators(dependantChain);
        MainLogic logic = mainLogicDefinition::execute;
        for (MainLogicDecorator mainLogicDecorator : sortedDecorators) {
            logic = mainLogicDecorator.decorateLogic(logic, mainLogicDefinition);
        }
        MainLogic finalLogic = logic;
        LinkedHashMap<RequestId, CompletableFuture<ValueOrError<Object>>> resultsByRequest = new LinkedHashMap<RequestId, CompletableFuture<ValueOrError<Object>>>();
        inputs.forEach((requestId, mainLogicInputs) -> {
            CompletableFuture cachedResult = this.resultsCache.get(mainLogicInputs.providedInputs());
            if (cachedResult == null) {
                cachedResult = (CompletableFuture)finalLogic.execute((ImmutableList<Inputs>)ImmutableList.of((Object)mainLogicInputs.allInputsAndDependencies())).values().iterator().next();
                this.resultsCache.put(mainLogicInputs.providedInputs(), cachedResult);
            }
            resultsByRequest.put((RequestId)requestId, (CompletableFuture<ValueOrError<Object>>)cachedResult.handle(ValueOrError::valueOrError));
        });
        return resultsByRequest;
    }

    private void flushAllDependenciesIfNeeded(DependantChain dependantChain) {
        this.nodeDefinition.dependencyNodes().keySet().forEach(dependencyName -> this.flushDependencyIfNeeded((String)dependencyName, dependantChain));
    }

    private void flushDependencyIfNeeded(String dependencyName, DependantChain dependantChain) {
        if (!this.flushedDependantChain.contains(dependantChain)) {
            return;
        }
        if (this.executedDependencies.getOrDefault(dependantChain, Set.of()).contains(dependencyName)) {
            this.krystalNodeExecutor.executeCommand(new Flush((NodeId)this.nodeDefinition.dependencyNodes().get((Object)dependencyName), dependantChain.extend(this.nodeId, dependencyName)));
        }
    }

    private void flushDecoratorsIfNeeded(DependantChain dependantChain) {
        block6: {
            block5: {
                if (!this.flushedDependantChain.contains(dependantChain)) {
                    return;
                }
                if (this.mainLogicExecuted.getOrDefault(dependantChain, false).booleanValue()) break block5;
                if (!this.getForwardCommand(dependantChain).shouldSkip()) break block6;
            }
            Iterable reverseSortedDecorators = this.getSortedDecorators(dependantChain)::descendingIterator;
            for (MainLogicDecorator decorator : reverseSortedDecorators) {
                decorator.executeCommand(new FlushCommand(dependantChain));
            }
        }
    }

    private Inputs getInputsFor(DependantChain dependantChain, RequestId requestId, Set<String> boundFrom) {
        Inputs resolvableInputs = Optional.ofNullable(this.inputsValueCollector.get(dependantChain)).map(ForwardBatch::executableRequests).map(inputsByRequest -> (Inputs)inputsByRequest.get((Object)requestId)).orElse(Inputs.empty());
        Map depValues = this.dependencyValuesCollector.getOrDefault(dependantChain, Map.of());
        LinkedHashMap<String, InputValue> inputValues = new LinkedHashMap<String, InputValue>();
        for (String boundFromInput : boundFrom) {
            InputValue voe = (InputValue)resolvableInputs.values().get((Object)boundFromInput);
            if (voe == null) {
                CallbackBatch callbackBatch = (CallbackBatch)depValues.get(boundFromInput);
                if (callbackBatch == null) continue;
                inputValues.put(boundFromInput, (InputValue)callbackBatch.resultsByRequest().getOrDefault((Object)requestId, (Object)Results.empty()));
                continue;
            }
            inputValues.put(boundFromInput, voe);
        }
        return new Inputs(inputValues);
    }

    private MainLogicInputs getInputsForMainLogic(DependantChain dependantChain, RequestId requestId) {
        ForwardBatch forwardBatch = this.inputsValueCollector.get(dependantChain);
        ImmutableMap depValues = (ImmutableMap)this.dependencyValuesCollector.getOrDefault(dependantChain, (Map<String, CallbackBatch>)ImmutableMap.of()).entrySet().stream().collect(ImmutableMap.toImmutableMap(Map.Entry::getKey, e -> (Results)((CallbackBatch)e.getValue()).resultsByRequest().getOrDefault((Object)requestId, (Object)Results.empty())));
        Inputs inputValues = (Inputs)forwardBatch.executableRequests().getOrDefault((Object)requestId, (Object)Inputs.empty());
        Inputs allInputsAndDependencies = Inputs.union((Map)depValues, (Map)inputValues.values());
        return new MainLogicInputs(inputValues, allInputsAndDependencies);
    }

    private void collectInputValues(ForwardBatch forwardBatch) {
        if (this.requestsByDependantChain.putIfAbsent(forwardBatch.dependantChain(), forwardBatch.requestIds()) != null) {
            throw new DuplicateRequestException("Duplicate batch request received for dependant chain %s".formatted(forwardBatch.dependantChain()));
        }
        ImmutableSet<String> inputNames = forwardBatch.inputNames();
        if (this.inputsValueCollector.putIfAbsent(forwardBatch.dependantChain(), forwardBatch) != null) {
            throw new DuplicateRequestException("Duplicate data for inputs %s of node %s in dependant chain %s".formatted(inputNames, this.nodeId, forwardBatch.dependantChain()));
        }
        this.availableInputsByDepChain.computeIfAbsent(forwardBatch.dependantChain(), _k -> new LinkedHashSet()).addAll(inputNames);
    }

    private static String getSkipMessage(ForwardBatch forwardBatch) {
        return String.join((CharSequence)", ", (Iterable<? extends CharSequence>)forwardBatch.skippedRequests().values());
    }

    private void collectDependencyValues(CallbackBatch callbackBatch) {
        String dependencyName = callbackBatch.dependencyName();
        this.availableInputsByDepChain.computeIfAbsent(callbackBatch.dependantChain(), _k -> new LinkedHashSet()).add(dependencyName);
        if (this.dependencyValuesCollector.computeIfAbsent(callbackBatch.dependantChain(), k -> new LinkedHashMap()).putIfAbsent(dependencyName, callbackBatch) != null) {
            throw new DuplicateRequestException("Duplicate data for dependency %s of node %s in dependant chain %s".formatted(dependencyName, this.nodeId, callbackBatch.dependantChain()));
        }
    }
}

