(ns orcl.analyzer
  (:require [orcl.utils.cursor :as cursor]
            [orcl.utils :as utils]
            [clojure.set :as set]
            [orcl.analyzer.vars :as vars]
            [orcl.analyzer.patterns :as patterns]
    #?(:clj
            [orcl.analyzer.macro :as macro]))
  #?(:cljs (:require-macros [orcl.analyzer.macro :as macro])))

(defn primitive? [n] (#{:const :var} (:node n)))

(defn deflate-values
  [[c & cursors] node]
  (cond
    (nil? c) @node
    (primitive? @c) (recur cursors node)
    :else (macro/with-fresh fresh
            (let [orig @c]
              (cursor/reset! c {:node :var :var fresh})
              {:node    :pruning
               :pattern {:type :var :var fresh}
               :left    (deflate-values cursors node)
               :right   orig}))))

(defn bindings [ast]
  (case (:node ast)
    :var #{(:var ast)}
    :pruning (set/union (set/difference (bindings (:left ast)) (set/difference (:pattern ast))) (bindings (:right ast)))
    :sequential (set/union (bindings (:left ast)) (set/difference (bindings (:right ast)) (set/difference (:pattern ast))))))

(defn strict-pattern? [p]
  (case (:type p)
    (:wildcard :var) false
    true))

(defn site [n]
  {:node   :var
   :var    n
   :source {:type       :site
            :source     {:type :prelude}
            :definition n}})

(def site-pattern-extract (site "_PatternExtract"))
(def site-pattern-get (site "_PatternGet"))
(def site-wrap-some (site "_WrapSome"))
(def site-unwrap-some (site "_UnwrapSome"))
(def site-none (site "_None"))
(def site-is-none (site "_IsNone"))
(def site-ift (site "Ift"))
(def site-iff (site "Iff"))

(defn pattern-bindings [p]
  (case (:type p)
    :var [(:var p)]
    :wildcard []
    :const []
    :record (mapcat (comp pattern-bindings second) (:pairs p))
    (:list :tuple) (mapcat pattern-bindings (:patterns p))
    :as (cons (:alias p)
              (pattern-bindings (:pattern p)))
    :cons (concat (pattern-bindings (:head p)) (pattern-bindings (:tail p)))))

(defn bind-pattern [pattern bridge target]
  (reduce (fn [res binding]
            {:node    :pruning
             :left    res
             :pattern {:type :var
                       :var  binding}
             :right   {:node   :call
                       :target site-pattern-get
                       :args   [{:node :var
                                 :var  bridge}
                                {:node  :const
                                 :value binding}]}})
          target (pattern-bindings pattern)))

(defn make-match [source target else]
  (macro/with-freshs 3 [res maybe-res target-binding]
    {:node    :sequential
     :left    {:node  :otherwise
               :left  {:node    :sequential
                       :left    source
                       :pattern {:type :var :var res}
                       :right   {:node   :call
                                 :target site-wrap-some
                                 :args   [{:node :var :var res}]}}
               :right {:node   :call
                       :target site-none
                       :args   []}}
     :pattern {:type :var :var maybe-res}
     :right   {:node  :parallel
               :left  {:node    :sequential
                       :left    {:node   :call
                                 :target site-unwrap-some
                                 :args   [{:node :var :var maybe-res}]}
                       :pattern {:type :var :var target-binding}
                       :right   (target target-binding)}
               :right {:node    :sequential
                       :left    {:node   :call
                                 :target site-is-none
                                 :args   [{:node :var :var maybe-res}]}
                       :pattern {:type :wildcard}
                       :right   else}}}))

(defn rebind-vars [vars expr]
  (reduce (fn [res [var-param var-binding]]
            {:node    :pruning
             :left    res
             :pattern {:type :var
                       :var  (:var var-param)}
             :right   {:node :var
                       :var  var-binding}})
          expr vars))

(defn translate-clause* [vars strict else instance]
  (let [pattern {:type     :list
                 :patterns (mapv first strict)}]
    (macro/with-fresh stricted-values
      {:node    :sequential
       :left    {:node   :list
                 :values (mapv (fn [b] {:node :var :var b}) (map second strict))}
       :pattern {:type :var :var stricted-values}
       :right   (make-match {:node   :call
                             :target site-pattern-extract
                             :args   [{:node  :const
                                       :value pattern}
                                      {:node :var
                                       :var  stricted-values}]}
                            (fn [binding]
                              (rebind-vars vars (bind-pattern pattern binding (:body instance))))
                            else)})))

(defn pattern-type [p]
  (case (:type p)
    :var :var
    :wildcard :wildcard
    :strict))

(defn translate-clause [bindings else instance]
  (let [{:keys [var strict]} (group-by (fn [[pattern binding]] (pattern-type pattern))
                                       (map vector (:params instance) bindings))]
    (if (seq strict)
      (translate-clause* var strict else instance)
      (rebind-vars var (:body instance)))))

(defn translate-clauses [bindings instances]
  (reduce (partial translate-clause bindings) {:node :stop} (reverse instances)))

(defn translate-def [def]
  (let [arity (count (:params (first (:instances def))))]
    (macro/with-freshs arity bindings
      (assoc def :instances [{:params (map (fn [b] {:type :var :var b}) bindings)
                              :body   (translate-clauses bindings (:instances def))}]))))

;; TODO site imports & type datastructures
(defn translate-declarations [{:keys [decls expr]} k translate-clauses?]
  (letfn [(finalize [[state acc] decls]
            (let [expr' (if (seq decls)
                          {:node  :declarations
                           :decls decls
                           :expr  expr}
                          expr)]
              (case state
                :def {:node :defs-group
                      :defs (if translate-clauses?
                              (mapv translate-def (vals acc))
                              (vals acc))
                      :expr expr'}
                :refer {:node       :refer
                        :namespaces acc
                        :expr       expr'}
                :site {:node :sites
                       :definitions acc
                       :expr expr'})))]
    (loop [[state acc :as s] [:init] [d & tail :as decls] decls]
      (if d
        (case (:type d)
          :include (recur s (concat (:decls d) tail))
          :val (if (= :init state)
                 (k {:node    :pruning
                     :pattern (:pattern d)
                     :left    (if (seq tail)
                                {:node  :declarations
                                 :decls tail
                                 :expr  expr}
                                expr)
                     :right   (:expr d)})
                 (finalize s decls))
          :def (if (#{:def :init} state)
                 (let [inst {:params (:params d)
                             :body   (:body d)
                             :guard  (:guard d)}
                       def  (get acc (:name d)
                                 {:name      (:name d)
                                  :instances []})]
                   (recur [:def (assoc acc (:name d) (update def :instances conj inst))] tail))
                 (finalize s decls))
          :refer (if (#{:refer :init} state)
                   (recur [:refer (conj acc [(:namespace d) (:symbols d)])] tail)
                   (finalize s decls))
          :site (if (#{:site :init} state)
                  (recur [:site (assoc acc (:name d) (:definition d))] tail)
                  (finalize s decls))
          :def-sig (recur s tail))
        (finalize s decls)))))

(defn translate-pattern [pattern source target]
  (macro/with-fresh bridge
    (let [source' {:node    :sequential
                   :pattern {:type :var :var "to-extract"}
                   :left    source
                   :right   {:node   :call
                             :target site-pattern-extract
                             :args   [{:node  :const
                                       :value pattern}
                                      {:node :var
                                       :var  "to-extract"}]}}
          target' (bind-pattern pattern bridge target)]
      [source' {:type :var :var bridge} target'])))

(defn translate-conditional [{:keys [if then else]}]
  (macro/with-fresh t
    {:node    :pruning
     :pattern {:type :var :var t}
     :right   if
     :left    {:node  :parallel
               :left  {:node    :sequential
                       :left    {:node   :call
                                 :target site-ift
                                 :args   [{:node :var :var t}]}
                       :pattern {:type :wildcard}
                       :right   then}
               :right {:node    :sequential

                       :left    {:node   :call
                                 :target site-iff
                                 :args   [{:node :var :var t}]}
                       :pattern {:type :wildcard}
                       :right   else}}}))

(declare translate)
(defn translate* [{:keys [deflate? patterns? conditional? clauses?] :as options} ast]
  (case (:node ast)
    :declarations (translate-declarations ast (partial translate* options) clauses?)
    :lambda (let [body (:body ast)
                  n    (str "__def_" (utils/sha body))]
              {:node :defs-group
               :defs [{:name      n
                       :instances [{:guard  (:guard ast)
                                    :params (:params ast)
                                    :body   body}]}]
               :expr {:node :var
                      :var  n}})
    (:list :tuple) (if deflate?
                     (macro/as-cursor [c ast] (deflate-values (seq (:values c)) c))
                     ast)
    :record (if deflate?
              (macro/as-cursor [c ast] (deflate-values (map second (:pairs c)) c))
              ast)
    :call (if (= ":=" (get-in ast [:target :var]))
            (translate* options
                        {:node   :call
                         :args   [(second (:args ast))]
                         :target {:node   :field-access
                                  :target (first (:args ast))
                                  :field  "write"}})
            (if deflate?
              (macro/as-cursor [c ast] (deflate-values (concat [(:target c)] (:args c)) c))
              ast))
    :field-access (if deflate?
                    (macro/as-cursor [c ast] (deflate-values [(:target c)] c))
                    ast)
    :conditional (cond
                   conditional? (translate-conditional ast)
                   deflate? (macro/as-cursor [c ast] (deflate-values [(:if c)] c))
                   :else ast)
    
    :sequential (if (and patterns? (strict-pattern? (:pattern ast)))
                  (let [[source bridge target] (translate-pattern (:pattern ast) (:left ast) (:right ast))]
                    {:node    :sequential
                     :pattern bridge
                     :left    source
                     :right   target})
                  ast)

    :pruning (if (and patterns? (strict-pattern? (:pattern ast)))
               (let [[source bridge target] (translate-pattern (:pattern ast) (:right ast) (:left ast))]
                 {:node    :pruning
                  :pattern bridge
                  :left    target
                  :right   source})
               ast)

    :dereference (translate* options
                             {:node   :call
                              :args   []
                              :target {:node   :field-access
                                       :target (:target ast)
                                       :field  "read"}})

    :ns (translate* options (:body ast))

    ast))

(defn translate [ast options]
  (utils/ast-prewalk (partial translate* options) ast))

(defn with-sha [ast]
  (utils/ast-postwalk utils/with-sha ast))

(declare analyze-env)

;; Call in tail position if
;; - it is not in left branch of sequential,
;; - not in right branch of pruning,
;; - not in left branch of otherwise
(def ^:dynamic *tail-pos*)

(defn analyze-instance [id instance]
  (let [argument-envs (fn [i p] (map (fn [binding] [binding {:type         :argument
                                                             :position     i
                                                             :instance-sha (:sha (:body instance))
                                                             :id           id}])
                                     (patterns/pattern-bindings p)))]
    (macro/with-envs (into {} (mapcat argument-envs (range) (:params instance)))
      (binding [*tail-pos* {:id id}]
        (update instance :body analyze-env)))))

(defn analyze-def [defs {:keys [name usages sha instances] :as node}]
  (macro/with-envs (into {} (for [{:keys [name usages node sha]} defs]
                              [name {:type   :def
                                     :id     sha
                                     :usages usages}]))
    (assoc node
      :arity (count (:params (first instances)))
      :instances (mapv (partial analyze-instance sha) instances))))

(defn analyze-defs [defs]
  (let [defs' (map #(assoc %
                      :usages (atom 0)
                      :sha (utils/sha %))
                   defs)]
    (mapv (partial analyze-def defs') defs')))

;; TODO check target & arity
(defn check-call! [call])

;:site (macro/with-env (:name decl) {:type       :site
;                                    :source     {:type :import :pos (:pos decl)}
;                                    :definition (:definition decl)}
;        (analyze-env (:expr ast)))

(defn analyze-env [ast]
  (case (:node ast)
    :pruning (assoc ast :left (macro/with-pattern (:pattern ast) {:type     :pruning
                                                                  :node-sha (:sha (:right ast))}
                                (analyze-env (:left ast)))
                        :right (binding [*tail-pos* nil]
                                 (analyze-env (:right ast))))
    :sequential (assoc ast :right (macro/with-pattern (:pattern ast) {:type     :sequential
                                                                      :node-sha (:sha (:left ast))}
                                    (analyze-env (:right ast)))
                           :left (binding [*tail-pos* nil]
                                   (analyze-env (:left ast))))
    :otherwise (assoc ast :right (analyze-env (:right ast))
                          :left (binding [*tail-pos* nil]
                                  (analyze-env (:left ast))))
    :defs-group (let [defs (analyze-defs (:defs ast))]
                  (macro/with-envs (into {} (for [{:keys [name sha usages]} defs]
                                              [name {:type   :def
                                                     :id     sha
                                                     :usages usages}]))
                    (assoc ast
                      :defs defs
                      :expr (analyze-env (:expr ast)))))
    :refer (macro/with-envs (into {} (for [[ns symbols] (:namespaces ast)
                                           s symbols]
                                       [s {:type      :refer
                                           :namespace ns}]))
             (analyze-env (:expr ast)))
    :sites (macro/with-envs (into {} (for [[var definition] (:definitions ast)]
                                       [var {:type       :site
                                             :source     {:type :custom}
                                             :definition definition}]))
             (analyze-env (:expr ast)))
    :call (let [ast' (-> ast (update :target analyze-env) (assoc :args (mapv analyze-env (:args ast))))]
            (check-call! ast')
            (let [s (get-in ast' [:target :source])]
              (if (and *tail-pos* (= :def (:type s)) (= (:id s) (:id *tail-pos*)))
                (assoc ast'
                  :tail-pos *tail-pos*)
                ast')))
    :var (if-let [source (get vars/*env* (:var ast))]
           (do
             (when (= :def (:type source))
               (swap! (:usages source) inc))
             (assoc ast :source (dissoc source :usages)))
           ;; "Undefined variable" (:var ast)
           (throw (ex-info "Undefined variable" {:orcl/error-pos (:pos ast)
                                                 :orcl/error     "Undefined variable"
                                                 :variable       (:var ast)})))
    (utils/ast-walk analyze-env identity ast)))

(defn def-instance-locals [instance]
  (apply set/difference (:locals (:body instance))
         (map patterns/pattern-bindings (:params instance))))

(defn def-locals [def]
  (set (mapcat def-instance-locals (:instances def))))

(defn analyze-stage2 [ast options]
  (let [set-locals    (fn [ast]
                        (if (= :defs-group (:node ast))
                          (do
                            (let [defs' (mapv #(assoc % :locals (def-locals %)) (:defs ast))]
                              (assoc ast
                                :defs defs'
                                :locals (set (mapcat :locals defs')))))
                          (assoc ast :locals
                                     (case (:node ast)
                                       (:otherwise :parallel) (set/union (:locals (:left ast))
                                                                         (:locals (:right ast)))
                                       (:sequential :pruning) (set/difference (set/union (:locals (:left ast))
                                                                                         (:locals (:right ast)))
                                                                              (patterns/pattern-bindings (:pattern ast)))
                                       :conditional (set/union (:locals (:var ast))
                                                               (:locals (:then ast))
                                                               (:locals (:else ast)))
                                       (:field-access :call) (apply set/union
                                                                    (get-in ast [:target :locals])
                                                                    (map :locals (:args ast)))
                                       (:tuple :list) (apply set/union (map :locals (:values ast)))
                                       :record (apply set/union (map (comp :locals second) (:pairs ast)))
                                       :var (if (#{:def :site :refer} (get-in ast [:source :type]))
                                              #{}
                                              #{(:var ast)})
                                       (:const :stop) #{}))))
        remove-unused (fn [ast]
                        (case (:node ast)
                          :defs-group
                          (let [in-use (filterv #(pos? (:usages %)) (map #(update % :usages deref) (:defs ast)))]
                            (if (not-empty in-use)
                              (assoc ast :defs in-use)
                              (:expr ast)))
                          ast))]
    (utils/ast-postwalk (if (:remove-unused? options)
                          (comp set-locals remove-unused)
                          set-locals)
                        ast)))

(defn analyze-final
  [ast env options]
  (binding [vars/*env* env
            *tail-pos* nil]
    (let [ast' (analyze-env ast)]
      (analyze-stage2 ast' options))))

(defn analyze
  ([ast] (analyze ast {}))
  ([ast env] (analyze ast env {:deflate?       true
                               :conditional?   false
                               :clauses?       false
                               :patterns?      false
                               :remove-unused? true}))
  ([ast env options] (analyze-final (with-sha (translate ast options)) env options)))

(defn flat-namespace [ast]
  (case (:node ast)
    :defs-group (concat (:defs ast) (flat-namespace (:expr ast)))
    :refer (flat-namespace (:expr ast))
    nil))

(defn analyze-namespace
  [ns env options]
  (-> {:node  :declarations
       :decls (:body ns)
       :expr  {:node :stop}}
      (analyze env (assoc options :remove-unused? false))
      (flat-namespace)))