/*
 * Decompiled with CFR 0.152.
 */
package com.flipkart.krystal.visualization;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.flipkart.krystal.vajram.ComputeVajram;
import com.flipkart.krystal.vajram.IOVajram;
import com.flipkart.krystal.vajram.Vajram;
import com.flipkart.krystal.vajram.VajramID;
import com.flipkart.krystal.vajram.exec.VajramDefinition;
import com.flipkart.krystal.vajram.facets.DependencyDef;
import com.flipkart.krystal.vajram.facets.InputDef;
import com.flipkart.krystal.vajram.facets.VajramFacetDefinition;
import com.flipkart.krystal.vajramexecutor.krystex.VajramKryonGraph;
import com.flipkart.krystal.visualization.StaticCallGraphHtml;
import com.flipkart.krystal.visualization.models.Graph;
import com.flipkart.krystal.visualization.models.GraphGenerationResult;
import com.flipkart.krystal.visualization.models.Input;
import com.flipkart.krystal.visualization.models.Link;
import com.flipkart.krystal.visualization.models.Node;
import com.flipkart.krystal.visualization.models.VajramType;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableMap;
import java.lang.annotation.Annotation;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import lombok.Generated;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class StaticCallGraphGenerator {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(StaticCallGraphGenerator.class);

    public static GraphGenerationResult generateStaticCallGraphContent(VajramKryonGraph vajramKryonGraph, @Nullable String startVajram) throws ClassNotFoundException {
        Graph fullGraph;
        Graph graphToVisualize = fullGraph = StaticCallGraphGenerator.createGraphData(vajramKryonGraph);
        if (startVajram != null && !startVajram.isBlank()) {
            Node startNode = fullGraph.getNodes().stream().filter(node -> node.getName().equals(startVajram)).findFirst().orElse(null);
            if (startNode != null) {
                graphToVisualize = StaticCallGraphGenerator.filterGraph(fullGraph, startNode.getId());
            } else {
                throw new IllegalArgumentException("Start vajram: " + startVajram + " does not exist");
            }
        }
        String jsonGraph = StaticCallGraphGenerator.graphToJson(graphToVisualize);
        String htmlContent = StaticCallGraphHtml.generateStaticCallGraphHtml(jsonGraph);
        return GraphGenerationResult.builder().html(htmlContent).build();
    }

    private static Graph createGraphData(VajramKryonGraph vajramKryonGraph) throws ClassNotFoundException {
        VajramDefinition definition;
        VajramID vajramId;
        ArrayList<Node> nodes = new ArrayList<Node>();
        ArrayList<Link> links = new ArrayList<Link>();
        ImmutableMap vajramDefinitions = vajramKryonGraph.vajramDefinitions();
        for (Map.Entry entry : vajramDefinitions.entrySet()) {
            vajramId = (VajramID)entry.getKey();
            definition = (VajramDefinition)entry.getValue();
            ArrayList<Input> inputs = new ArrayList<Input>();
            for (VajramFacetDefinition facet : definition.vajram().getFacetDefinitions()) {
                if (!(facet instanceof InputDef)) continue;
                InputDef inputDef = (InputDef)facet;
                inputs.add(Input.builder().name(inputDef.name()).type(inputDef.type().javaReflectType().getTypeName()).isMandatory(inputDef.isMandatory()).documentation(inputDef.documentation()).build());
            }
            ImmutableCollection annotations = definition.vajramTags().annotations();
            List<String> annotationStringList = annotations.stream().map(Annotation::toString).toList();
            VajramType vajramType = StaticCallGraphGenerator.getVajramType(definition.vajram());
            if (vajramType == VajramType.UNKNOWN) {
                throw new IllegalArgumentException("Unknown vajram type for: " + definition.vajram());
            }
            Node node = Node.builder().id(vajramId.vajramId()).name(definition.vajramDefClass().getSimpleName()).vajramType(vajramType).inputs(inputs).annotationTags(annotationStringList).build();
            nodes.add(node);
        }
        for (Map.Entry entry : vajramDefinitions.entrySet()) {
            vajramId = (VajramID)entry.getKey();
            definition = (VajramDefinition)entry.getValue();
            for (VajramFacetDefinition facet : definition.vajram().getFacetDefinitions()) {
                if (!(facet instanceof DependencyDef)) continue;
                DependencyDef dependencyDef = (DependencyDef)facet;
                VajramID dependencyId = (VajramID)dependencyDef.dataAccessSpec();
                if (!vajramDefinitions.containsKey(vajramId) || !vajramDefinitions.containsKey(dependencyId)) continue;
                Link link = Link.builder().source(vajramId.vajramId()).target(dependencyId.vajramId()).name(dependencyDef.name()).isMandatory(dependencyDef.isMandatory()).canFanout(dependencyDef.canFanout()).documentation(dependencyDef.documentation()).build();
                links.add(link);
            }
        }
        return Graph.builder().nodes(nodes).links(links).build();
    }

    private static VajramType getVajramType(Vajram<?> vajram) {
        VajramType vajramType = vajram instanceof ComputeVajram ? VajramType.COMPUTE : (vajram instanceof IOVajram ? VajramType.IO : VajramType.UNKNOWN);
        return vajramType;
    }

    private static String graphToJson(Graph graph) {
        try {
            ObjectMapper objectMapper = new ObjectMapper();
            return objectMapper.writeValueAsString((Object)graph);
        }
        catch (Exception e) {
            throw new RuntimeException("Error converting graph data to JSON", e);
        }
    }

    private static Graph filterGraph(Graph fullGraph, String startNodeId) {
        HashMap adj = new HashMap();
        fullGraph.getLinks().forEach(link -> adj.computeIfAbsent(link.getSource(), k -> new ArrayList()).add(link.getTarget()));
        HashSet<String> reachable = new HashSet<String>();
        ArrayDeque<String> stack = new ArrayDeque<String>();
        stack.push(startNodeId);
        while (!stack.isEmpty()) {
            String current = (String)stack.pop();
            if (!reachable.add(current)) continue;
            List neighbors = adj.getOrDefault(current, List.of());
            neighbors.forEach(stack::push);
        }
        List<Node> filteredNodes = fullGraph.getNodes().stream().filter(node -> reachable.contains(node.getId())).toList();
        List<Link> filteredLinks = fullGraph.getLinks().stream().filter(link -> reachable.contains(link.getSource()) && reachable.contains(link.getTarget())).toList();
        return Graph.builder().nodes(filteredNodes).links(filteredLinks).build();
    }
}

