/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.krystal.krystex.node;

import com.flipkart.krystal.data.Inputs;
import com.flipkart.krystal.data.ValueOrError;
import com.flipkart.krystal.krystex.KrystalExecutor;
import com.flipkart.krystal.krystex.MainLogicDefinition;
import com.flipkart.krystal.krystex.commands.Flush;
import com.flipkart.krystal.krystex.commands.ForwardBatch;
import com.flipkart.krystal.krystex.commands.ForwardGranule;
import com.flipkart.krystal.krystex.commands.NodeCommand;
import com.flipkart.krystal.krystex.decoration.InitiateActiveDepChains;
import com.flipkart.krystal.krystex.decoration.LogicExecutionContext;
import com.flipkart.krystal.krystex.decoration.MainLogicDecorator;
import com.flipkart.krystal.krystex.decoration.MainLogicDecoratorConfig;
import com.flipkart.krystal.krystex.node.BatchNode;
import com.flipkart.krystal.krystex.node.BatchResponse;
import com.flipkart.krystal.krystex.node.DependantChain;
import com.flipkart.krystal.krystex.node.DisabledDependantChainException;
import com.flipkart.krystal.krystex.node.GranularNode;
import com.flipkart.krystal.krystex.node.GranuleResponse;
import com.flipkart.krystal.krystex.node.KrystalNodeExecutorConfig;
import com.flipkart.krystal.krystex.node.KrystalNodeExecutorMetrics;
import com.flipkart.krystal.krystex.node.NodeDefinition;
import com.flipkart.krystal.krystex.node.NodeDefinitionRegistry;
import com.flipkart.krystal.krystex.node.NodeExecutionConfig;
import com.flipkart.krystal.krystex.node.NodeId;
import com.flipkart.krystal.krystex.node.NodeRegistry;
import com.flipkart.krystal.krystex.node.NodeResponse;
import com.flipkart.krystal.krystex.request.IntReqGenerator;
import com.flipkart.krystal.krystex.request.RequestId;
import com.flipkart.krystal.krystex.request.RequestIdGenerator;
import com.flipkart.krystal.krystex.request.StringReqGenerator;
import com.flipkart.krystal.utils.Futures;
import com.flipkart.krystal.utils.MultiLeasePool;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.RejectedExecutionException;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public final class KrystalNodeExecutor
implements KrystalExecutor {
    private static final Logger log = LoggerFactory.getLogger(KrystalNodeExecutor.class);
    private final NodeDefinitionRegistry nodeDefinitionRegistry;
    private final KrystalNodeExecutorConfig executorConfig;
    private final MultiLeasePool.Lease<? extends ExecutorService> commandQueueLease;
    private final String instanceId;
    private final ImmutableMap<String, List<MainLogicDecoratorConfig>> requestScopedLogicDecoratorConfigs;
    private final Map<String, Map<String, MainLogicDecorator>> requestScopedMainDecorators = new LinkedHashMap<String, Map<String, MainLogicDecorator>>();
    private final NodeRegistry<?> nodeRegistry = new NodeRegistry();
    private final KrystalNodeExecutorMetrics krystalNodeMetrics;
    private volatile boolean closed;
    private final Map<RequestId, NodeExecution> allExecutions = new LinkedHashMap<RequestId, NodeExecution>();
    private final Set<RequestId> unFlushedExecutions = new LinkedHashSet<RequestId>();
    private final Map<NodeId, Set<DependantChain>> dependantChainsPerNode = new LinkedHashMap<NodeId, Set<DependantChain>>();
    private final RequestIdGenerator preferredReqGenerator;
    private final Set<DependantChain> depChainsDisabledInAllExecutions = new LinkedHashSet<DependantChain>();

    public KrystalNodeExecutor(NodeDefinitionRegistry nodeDefinitionRegistry, MultiLeasePool<? extends ExecutorService> commandQueuePool, KrystalNodeExecutorConfig executorConfig, String instanceId) {
        this.nodeDefinitionRegistry = nodeDefinitionRegistry;
        this.executorConfig = executorConfig;
        this.commandQueueLease = commandQueuePool.lease();
        this.instanceId = instanceId;
        this.requestScopedLogicDecoratorConfigs = ImmutableMap.copyOf(executorConfig.requestScopedLogicDecoratorConfigs());
        this.krystalNodeMetrics = new KrystalNodeExecutorMetrics();
        this.preferredReqGenerator = executorConfig.debug() ? new StringReqGenerator() : new IntReqGenerator();
    }

    private ImmutableMap<String, MainLogicDecorator> getRequestScopedDecorators(LogicExecutionContext logicExecutionContext) {
        NodeId nodeId = logicExecutionContext.nodeId();
        NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get(nodeId);
        MainLogicDefinition mainLogicDefinition = nodeDefinition.getMainLogicDefinition();
        LinkedHashMap decorators = new LinkedHashMap();
        Stream.concat(mainLogicDefinition.getRequestScopedLogicDecoratorConfigs().entrySet().stream(), this.requestScopedLogicDecoratorConfigs.entrySet().stream()).forEach(entry -> {
            String decoratorType = (String)entry.getKey();
            ArrayList decoratorConfigList = new ArrayList((Collection)entry.getValue());
            decoratorConfigList.forEach(decoratorConfig -> {
                String instanceId = decoratorConfig.instanceIdGenerator().apply(logicExecutionContext);
                if (decoratorConfig.shouldDecorate().test(logicExecutionContext)) {
                    MainLogicDecorator mainLogicDecorator = this.requestScopedMainDecorators.computeIfAbsent(decoratorType, t -> new LinkedHashMap()).computeIfAbsent(instanceId, _i -> decoratorConfig.factory().apply(new MainLogicDecoratorConfig.DecoratorContext(instanceId, logicExecutionContext)));
                    mainLogicDecorator.executeCommand(new InitiateActiveDepChains(nodeId, (ImmutableSet<DependantChain>)ImmutableSet.copyOf((Collection)this.dependantChainsPerNode.get(nodeId))));
                    decorators.putIfAbsent(decoratorType, mainLogicDecorator);
                }
            });
        });
        return ImmutableMap.copyOf(decorators);
    }

    @Override
    public <T> CompletableFuture<T> executeNode(NodeId nodeId, Inputs inputs, NodeExecutionConfig executionConfig) {
        if (this.closed) {
            throw new RejectedExecutionException("KrystalNodeExecutor is already closed");
        }
        Preconditions.checkArgument((executionConfig != null ? 1 : 0) != 0, (Object)"executionConfig can not be null");
        String executionId = executionConfig.executionId();
        Preconditions.checkArgument((executionId != null ? 1 : 0) != 0, (Object)"executionConfig.executionId can not be null");
        RequestId requestId = this.preferredReqGenerator.newRequest("%s:%s".formatted(this.instanceId, executionId));
        return this.enqueueCommand(() -> {
            this.createDependencyNodes(nodeId, this.nodeDefinitionRegistry.getDependantChainsStart(), executionConfig);
            CompletableFuture<Object> future = new CompletableFuture<Object>();
            if (this.allExecutions.containsKey(requestId)) {
                future.completeExceptionally(new IllegalArgumentException("Received duplicate requests for same instanceId '%s' and execution Id '%s'".formatted(this.instanceId, executionId)));
            } else {
                this.allExecutions.put(requestId, new NodeExecution(nodeId, requestId, inputs, executionConfig, future));
                this.unFlushedExecutions.add(requestId);
            }
            return future;
        }).thenCompose(Function.identity());
    }

    private void createDependencyNodes(NodeId nodeId, DependantChain dependantChain, NodeExecutionConfig executionConfig) {
        NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get(nodeId);
        if (!Sets.union(this.executorConfig.disabledDependantChains(), executionConfig.disabledDependantChains()).contains((Object)dependantChain)) {
            this.createNodeIfAbsent(nodeId, nodeDefinition);
            ImmutableMap<String, NodeId> dependencyNodes = nodeDefinition.dependencyNodes();
            dependencyNodes.forEach((dependencyName, depNodeId) -> this.createDependencyNodes((NodeId)depNodeId, dependantChain.extend(nodeId, (String)dependencyName), executionConfig));
            this.dependantChainsPerNode.computeIfAbsent(nodeId, _n -> new LinkedHashSet()).add(dependantChain);
        }
    }

    private void createNodeIfAbsent(NodeId nodeId, NodeDefinition nodeDefinition) {
        if (this.isGranular()) {
            this.nodeRegistry.createIfAbsent(nodeId, _n -> new GranularNode(nodeDefinition, this, this::getRequestScopedDecorators, this.executorConfig.logicDecorationOrdering(), this.executorConfig.resolverExecStrategy()));
        } else {
            NodeRegistry<?> batchNodeRegistry = this.nodeRegistry;
            batchNodeRegistry.createIfAbsent(nodeId, _n -> new BatchNode(nodeDefinition, this, this::getRequestScopedDecorators, this.executorConfig.logicDecorationOrdering(), this.executorConfig.resolverExecStrategy(), this.preferredReqGenerator));
        }
    }

    private boolean isGranular() {
        return NodeExecStrategy.GRANULAR.equals((Object)this.executorConfig.nodeExecStrategy());
    }

    <R extends NodeResponse> CompletableFuture<R> enqueueNodeCommand(Supplier<? extends NodeCommand> nodeCommand) {
        return this.enqueueCommand(() -> this._executeCommand((NodeCommand)nodeCommand.get())).thenCompose(Function.identity());
    }

    <T extends NodeResponse> CompletableFuture<T> executeCommand(NodeCommand nodeCommand) {
        if (GraphTraversalStrategy.BREADTH.equals((Object)this.executorConfig.graphTraversalStrategy())) {
            return this.enqueueNodeCommand(() -> nodeCommand);
        }
        this.krystalNodeMetrics.commandQueueBypassed();
        return this._executeCommand(nodeCommand);
    }

    private <R extends NodeResponse> CompletableFuture<R> _executeCommand(NodeCommand nodeCommand) {
        try {
            this.validate(nodeCommand);
        }
        catch (Throwable e) {
            return CompletableFuture.failedFuture(e);
        }
        if (nodeCommand instanceof Flush) {
            Flush flush = (Flush)nodeCommand;
            this.nodeRegistry.get(flush.nodeId()).executeCommand(flush);
            return CompletableFuture.completedFuture(null);
        }
        Object node = this.nodeRegistry.get(nodeCommand.nodeId());
        return node.executeCommand((NodeCommand)nodeCommand);
    }

    private void validate(NodeCommand nodeCommand) {
        DependantChain dependantChain = nodeCommand.dependantChain();
        if (this.depChainsDisabledInAllExecutions.contains(dependantChain)) {
            throw new DisabledDependantChainException(dependantChain);
        }
    }

    @Override
    public void flush() {
        this.enqueueRunnable(() -> {
            this.computeDisabledDependantChains();
            if (this.isGranular()) {
                this.unFlushedExecutions.forEach(requestId -> {
                    NodeExecution nodeExecution = this.allExecutions.get(requestId);
                    NodeId nodeId = nodeExecution.nodeId();
                    if (nodeExecution.future().isDone()) {
                        return;
                    }
                    NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get(nodeId);
                    this.submitGranular((RequestId)requestId, nodeExecution, nodeId, nodeDefinition);
                });
            } else {
                this.submitBatch(this.unFlushedExecutions);
            }
            this.unFlushedExecutions.stream().map(requestId -> this.allExecutions.get(requestId).nodeId()).distinct().forEach(nodeId -> this.executeCommand(new Flush((NodeId)nodeId, this.nodeDefinitionRegistry.getDependantChainsStart())));
        });
    }

    private void computeDisabledDependantChains() {
        this.depChainsDisabledInAllExecutions.clear();
        List<ImmutableSet> disabledDependantChainsPerExecution = this.unFlushedExecutions.stream().map(this.allExecutions::get).filter(Objects::nonNull).map(NodeExecution::executionConfig).map(NodeExecutionConfig::disabledDependantChains).toList();
        disabledDependantChainsPerExecution.stream().filter(x -> !x.isEmpty()).findAny().ifPresent(this.depChainsDisabledInAllExecutions::addAll);
        for (Set set : disabledDependantChainsPerExecution) {
            if (this.depChainsDisabledInAllExecutions.isEmpty()) break;
            this.depChainsDisabledInAllExecutions.retainAll(set);
        }
        this.depChainsDisabledInAllExecutions.addAll((Collection<DependantChain>)this.executorConfig.disabledDependantChains());
    }

    private void submitGranular(RequestId requestId, NodeExecution nodeExecution, NodeId nodeId, NodeDefinition nodeDefinition) {
        CompletionStage submissionResult = ((CompletableFuture)this.executeCommand(new ForwardGranule(nodeId, (ImmutableSet<String>)((ImmutableSet)nodeDefinition.getMainLogicDefinition().inputNames().stream().filter(s -> !nodeDefinition.dependencyNodes().containsKey(s)).collect(ImmutableSet.toImmutableSet())), nodeExecution.inputs(), this.nodeDefinitionRegistry.getDependantChainsStart(), requestId)).thenApply(GranuleResponse::response)).thenApply(valueOrError -> {
            if (valueOrError.error().isPresent()) {
                throw new RuntimeException((Throwable)valueOrError.error().get());
            }
            return valueOrError.value().orElse(null);
        });
        Futures.linkFutures((CompletableFuture)submissionResult, nodeExecution.future());
    }

    private void submitBatch(Set<RequestId> unFlushedRequests) {
        unFlushedRequests.stream().map(this.allExecutions::get).collect(Collectors.groupingBy(NodeExecution::nodeId)).forEach((nodeId, nodeResults) -> {
            NodeDefinition nodeDefinition = this.nodeDefinitionRegistry.get((NodeId)nodeId);
            CompletableFuture batchResponseFuture = this.executeCommand(new ForwardBatch((NodeId)nodeId, (ImmutableSet<String>)((ImmutableSet)nodeDefinition.getMainLogicDefinition().inputNames().stream().filter(s -> !nodeDefinition.dependencyNodes().containsKey(s)).collect(ImmutableSet.toImmutableSet())), (ImmutableMap<RequestId, Inputs>)((ImmutableMap)nodeResults.stream().collect(ImmutableMap.toImmutableMap(NodeExecution::instanceExecutionId, NodeExecution::inputs))), this.nodeDefinitionRegistry.getDependantChainsStart(), (ImmutableMap<RequestId, String>)ImmutableMap.of()));
            ((CompletableFuture)batchResponseFuture.thenApply(BatchResponse::responses)).whenComplete((responses, throwable) -> {
                for (NodeExecution nodeExecution : nodeResults) {
                    if (throwable != null) {
                        nodeExecution.future().completeExceptionally((Throwable)throwable);
                        continue;
                    }
                    ValueOrError result = (ValueOrError)responses.getOrDefault((Object)nodeExecution.instanceExecutionId(), (Object)ValueOrError.empty());
                    nodeExecution.future().complete(result.value().orElse(null));
                }
            });
            Futures.propagateCancellation(CompletableFuture.allOf((CompletableFuture[])nodeResults.stream().map(NodeExecution::future).toArray(CompletableFuture[]::new)), batchResponseFuture);
        });
    }

    public KrystalNodeExecutorMetrics getKrystalNodeMetrics() {
        return this.krystalNodeMetrics;
    }

    @Override
    public void close() {
        if (this.closed) {
            return;
        }
        this.closed = true;
        this.flush();
        this.enqueueCommand(() -> CompletableFuture.allOf((CompletableFuture[])this.allExecutions.values().stream().map(NodeExecution::future).toArray(CompletableFuture[]::new)).whenComplete((unused, throwable) -> this.commandQueueLease.close()));
    }

    private CompletableFuture<Void> enqueueRunnable(Runnable command) {
        return this.enqueueCommand(() -> {
            command.run();
            return null;
        });
    }

    private <T> CompletableFuture<T> enqueueCommand(Supplier<T> command) {
        return CompletableFuture.supplyAsync(() -> {
            this.krystalNodeMetrics.commandQueued();
            return command.get();
        }, (Executor)this.commandQueueLease.get());
    }

    public static enum NodeExecStrategy {
        GRANULAR,
        BATCH;

    }

    public static enum GraphTraversalStrategy {
        DEPTH,
        BREADTH;

    }

    private record NodeExecution(NodeId nodeId, RequestId instanceExecutionId, Inputs inputs, NodeExecutionConfig executionConfig, CompletableFuture<Object> future) {
    }

    public static enum ResolverExecStrategy {
        SINGLE,
        MULTI;

    }
}

