(ns orcl.parser
  (:require [blancas.kern.core :as k :refer [<|> >> << <:> <$> >>= <+>]]
            [blancas.kern.expr :as expr]
            [orcl.lexer :as lex]
            [orcl.parser.utils :as utils]
            [orcl.fs :as fs]
            [clojure.pprint :as pprint]
            [orcl.utils :refer [assoc-when]])
  #?(:cljs (:require-macros [blancas.kern.core :as k])))

(declare pattern)

(def value
  (<|> lex/bool-lit
       (>> (lex/word "signal") (k/return :signal))
       lex/string-lit
       lex/float-lit
       lex/dec-lit
       (>> lex/nil-lit (k/return nil))))

(def base-pattern
  (<|>
    (k/bind [x value]
      (k/return {:type  :const
                 :value x}))
    (k/bind [_ (lex/sym \_)]
      (k/return {:type :wildcard}))
    (<:> (k/bind [target lex/identifier
                  args (lex/parens (lex/comma-sep (k/fwd pattern)))]
           (k/return {:type   :call
                      :target target
                      :args   args})))
    (k/bind [var lex/identifier]
      (k/return {:type :var
                 :var  var}))
    (k/bind [patterns (utils/tuple-of (k/fwd pattern))]
      (k/return {:type     :tuple
                 :patterns patterns}))
    (k/bind [patterns (utils/list-of (k/fwd pattern))]
      (k/return {:type     :list
                 :patterns patterns}))
    (k/bind [pairs (utils/record-of (k/fwd pattern))]
      (k/return {:type  :record
                 :pairs pairs}))))

(def cons-pattern
  (expr/chainr1 base-pattern
                (>> (<:> (<< (lex/token ":") (k/not-followed-by (k/sym* \:))))
                    (k/return (fn [l r]
                                {:type :cons
                                 :head l
                                 :tail r})))))

(def as-pattern
  (utils/maybe-or-left
    cons-pattern
    (k/bind [_ (lex/word "as")
             var lex/identifier]
      (k/return (fn [pattern]
                  {:type    :as
                   :pattern pattern
                   :alias   var})))))

(declare expr)
(declare type-parser)

(def pattern
  (k/bind [pattern as-pattern
           type (k/optional (>> (lex/token "::") (k/fwd type-parser)))]
    (k/return (if type
                (assoc pattern :T type)
                pattern))))

(defn combinator-without-pattern [node c]
  (>> (lex/sym c) (k/return (fn [l r]
                              {:node node :left l :right r}))))

(defn combinator-with-pattern
  [node c]
  (k/bind [_ (lex/sym c) v (k/optional pattern) _ (lex/sym c)]
    (k/return
      (fn [l r]
        {:node    node
         :pattern (or v {:type :wildcard})
         :left    l
         :right   r}))))

(def special-token
  (apply <|> (map lex/token ["+" "-" "*" "**" "/" "%" "<:" "<=" ":>" "<=" ":>" ">=" "=" "/=" "~" "&&" "||" ":"])))


(def base
  (<|> (k/bind [v value]
         (k/return {:node  :const
                    :value v}))
       (k/bind [p k/get-position v lex/identifier]
         (k/return {:node :var
                    :var  v
                    :pos  p}))
       (k/bind [p k/get-position v (lex/sym "_")]
         (k/return {:node :placeholder}))
       (k/bind [v (lex/word "stop")]
         (k/return {:node :stop}))
       (k/bind [p k/get-position v (utils/list-of (k/fwd expr))]
         (k/return {:node   :list
                    :values v
                    :pos    p}))
       (k/bind [p k/get-position v (utils/record-of (k/fwd expr))]
         (k/return {:node  :record
                    :pairs v
                    :pos   p}))
       (k/bind [p k/get-position
                op (<:> (lex/parens special-token))]
         (k/return {:node :var
                    :var  op
                    :pos  p}))
       (k/bind [p k/get-position v (utils/tuple-of (k/fwd expr))]
         (k/return (if (= (count v) 1)
                     (first v)
                     {:node   :tuple
                      :values v
                      :pos    p})))
       ; Trim and Section are not supported
       ))

(def argument-op
  (<|> (k/bind [_ lex/dot-tok p k/get-position f lex/identifier]
         (k/return (fn [l] {:node   :field-access
                            :target l
                            :field  f
                            :pos    p})))
       (k/bind [_ (lex/sym \?) p k/get-position]
         (k/return (fn [l] {:node   :dereference
                            :target l
                            :pos    p})))
       (k/bind [p k/get-position
                type-args (k/optional (utils/list-of (k/fwd type-parser)))
                v (lex/parens (lex/comma-sep (k/fwd expr)))]
         (k/return (fn [l] (assoc-when {:node   :call
                                        :target l
                                        :args   v
                                        :pos    p}
                                       :type-args type-args))))))

(def call (expr/postfix1 base argument-op))

(def unary
  (<|>
    ;; lex/float-lit lex/dec-lit order is important here
    (k/bind [n (<:> (>> (lex/sym \-) (<|> lex/float-lit lex/dec-lit)))]
      (k/return {:node  :const
                 :value (- n)}))
    (k/bind [op (lex/one-of "-~") e call]
      (k/return {:node   :call
                 :target {:node :var
                          :var  (case op \- "UMinus" (str op))}
                 :args   [e]}))
    call))

(def expn-op (utils/right-assoc-infix unary [(lex/token "**")]))

(def mult-op (utils/left-assoc-infix expn-op [lex/mult-tok
                                              lex/div-tok
                                              (lex/token "%")]))

(def additional-op (utils/left-assoc-infix mult-op [(lex/token "-")
                                                    (lex/token "+")]))

(def cons-op (utils/right-assoc-infix additional-op [lex/cons-tok]))

(def relational-op (utils/non-assoc-infix cons-op [(lex/token "<:")
                                                   (lex/token ":>")
                                                   (lex/token "<=")
                                                   (lex/token ">=")
                                                   (lex/token "=")
                                                   (lex/token "/=")]))

(def logical-op (utils/left-assoc-infix relational-op [(lex/token "||")
                                                       (lex/token "&&")]))

(def assign-op (utils/non-assoc-infix logical-op [(lex/token ":=")]))

(def sequential
  (expr/chainr1 assign-op (combinator-with-pattern :sequential \>)))

(def parallel
  (expr/chainl1 sequential (combinator-without-pattern :parallel \|)))

(def pruning
  (expr/chainl1 parallel (combinator-with-pattern :pruning \<)))

(def otherwise
  (expr/chainl1 pruning (combinator-without-pattern :otherwise \;)))

(def guard
  (k/bind [_ (lex/word "if")
           e (lex/parens (k/fwd expr))]
    (k/return e)))

(declare declaration)
(def declarations (k/many (k/fwd declaration)))

(defn state-field [s kw]
  (get-in s [:user kw]))

(def include-declaration
  (k/bind [_ (lex/word "include")
           sl lex/string-lit]
    (fn [s]
      (let [file   (str (state-field s ::current-file) sl)
            body   (fs/read-file (state-field s ::file-system) file)
            parsed (k/parse declarations body file
                            (assoc (:user s) ::current-file file))]
        (if (:ok parsed)
          (assoc s :value {:type  :include
                           :src   file
                           :decls (:value parsed)})
          (assoc s :ok false
                   :error (:error parsed)))))))

(def site-declaration
  (k/bind [p k/get-position
           _ (lex/word "import")
           _ (lex/word "site")
           id lex/identifier
           _ (lex/sym \=)
           sl lex/string-lit]
    (k/return {:type       :site
               :name       id
               :definition sl
               :pos        p})))

(def return-type (>> (lex/token "::") (k/fwd type-parser)))

(def def-instance
  (k/bind [p k/get-position
           _ (lex/word "def")
           id lex/identifier
           type-params (k/optional (utils/list-of lex/identifier))
           p (lex/parens (lex/comma-sep pattern))
           return (k/optional return-type)
           g (k/optional guard)
           _ (lex/sym \=)
           b (k/fwd expr)]
    (k/return
      {:type   :def
       :name   id
       :body   b
       :params p
       :guard  g
       :T      {:type        :fun
                :type-params type-params
                :params      (map :T p)
                :return      return}})))

(def def-sig
  (k/bind [_ (lex/word "def")
           id lex/identifier
           type-params (k/optional (utils/list-of lex/identifier))
           params (lex/parens (lex/comma-sep (k/fwd type-parser)))
           return return-type]
    (k/return {:type :def-sig
               :name id
               :T    {:type        :fun
                      :type-params type-params
                      :params      params
                      :return      return}})))

(def def-declaration
  (<|> (<:> def-sig)
       def-instance))

(def val-declaration
  (k/bind [p k/get-position
           _ (lex/word "val")
           pattern pattern
           _ (lex/sym \=)
           ex (k/fwd expr)]
    (k/return {:type    :val
               :pattern pattern
               :expr    ex
               :pos     p})))

(defn constructor [type type-params]
  (k/bind [n lex/identifier
           slots (<|> (<:> (lex/parens (lex/comma-sep (lex/sym \_))))
                      (lex/parens (lex/comma-sep (k/fwd type-parser))))]
    (if (= \_ (first slots))
      (k/return {:name    n
                 :untyped true
                 :arity   (count slots)})
      (k/return {:name  n
                 :arity (count slots)
                 :T     {:type        :fun
                         :type-params type-params
                         :params      slots
                         :return      {:type :application
                                       :name type
                                       :args (for [v type-params] {:type :var :name v})}}}))))

(def type-declaration
  (k/bind [_ (lex/word "type")
           n lex/identifier
           params (k/optional (utils/list-of lex/identifier))
           _ (lex/sym \=)
           constructors (k/sep-by1 (lex/sym \|) (constructor n params))]
    (k/return {:type         :datatype
               :name         n
               :type-params  params
               :constructors constructors})))

(def type-alias-declaration
  (k/bind [_ (lex/word "type")
           n lex/identifier
           params (k/optional (utils/list-of lex/identifier))
           _ (lex/sym \=)
           source (k/fwd type-parser)]
    (k/return {:type :type-alias
               :name n
               :T    (if (seq params)
                       {:type        :polymorphic
                        :type-params params
                        :T           source}
                       source)})))

(def type-import-declaration
  (k/bind [p k/get-position
           _ (lex/word "import")
           _ (lex/word "type")
           id lex/identifier
           _ (lex/sym \=)
           sl lex/string-lit]
    (k/return {:type       :type-import
               :name       id
               :definition sl
               :pos        p})))

(def ^:dynamic *dependencies* (atom []))

(def refer-declaration
  (k/bind [p k/get-position
           _ (lex/word "refer")
           _ (lex/word "from")
           ns (lex/lexeme (<+> (k/many1 (<|> k/alpha-num (k/sym* \_) (k/sym* \.)))))
           symbols (utils/tuple-of lex/identifier)]
    (swap! *dependencies* conj ns)
    (k/return {:type      :refer
               :namespace ns
               :symbols   symbols})))

(def declaration
  (<< (<|> val-declaration
           def-declaration
           site-declaration
           include-declaration
           (<:> type-declaration)
           type-alias-declaration
           type-import-declaration
           refer-declaration)
      (k/optional (lex/sym \#))))

(def with-declaration
  (k/bind [decls (k/many1 declaration)
           expr (k/fwd expr)]
    (k/return {:node  :declarations
               :decls decls
               :expr  expr})))

(def conditional
  (k/bind [test (>> (lex/word "if") (k/fwd expr))
           then (>> (lex/word "then") (k/fwd expr))
           else (>> (lex/word "else") (k/fwd expr))]
    (k/return {:node :conditional
               :if   test :then then :else else})))

(def lambda
  (k/bind [p k/get-position
           _ (lex/word "lambda")
           type-params (k/optional (utils/list-of lex/identifier))
           params (lex/parens (lex/comma-sep pattern))
           return (k/optional return-type)
           _ (lex/sym \=)
           b (k/fwd expr)]
    (k/return (assoc-when {:node   :lambda
                           :params params
                           :body   b
                           :pos    p}
                          :T {:type        :fun
                              :type-params type-params
                              :params      (map :T params)
                              :return      return}))))

(def type-variable
  (k/bind [id lex/identifier
           args (k/optional (utils/list-of (k/fwd type-parser)))]
    (if (seq args)
      (k/return {:type :application
                 :name id
                 :args args})
      (k/return {:type :var
                 :name id}))))

(def type-tuple
  (k/bind [args (utils/tuple-of (k/fwd type-parser))]
    (k/return {:type :tuple
               :args args})))

(def type-record
  (k/bind [fields (utils/record-of (k/fwd type-parser))]
    (k/return {:type   :record
               :fields (into {} fields)})))

(def type-function
  (k/bind [_ (lex/word "lambda")
           type-params (k/optional (utils/list-of lex/identifier))
           params (lex/parens (lex/comma-sep (k/fwd type-parser)))
           return return-type]
    (k/return {:type        :fun
               :type-params type-params
               :params      params
               :return      return})))

(def type-parser
  (<|> type-variable
       type-tuple
       type-record
       type-function))

(def with-type
  (k/bind [expr otherwise
           type (k/optional (<|> (>> (lex/token "::") type-parser)
                                 (>>= (>> (lex/token ":!:") type-parser) #(k/return (assoc % :overwrite true)))))]
    (k/return (if type
                {:node :has-type
                 :expr expr
                 :T    type}
                expr))))

(def expr
  (>> lex/trim
      (<|> with-type
           with-declaration
           conditional
           lambda)))

(def program (<< expr k/eof))

(defn parse*
  [s parser file-system]
  (binding [*dependencies* (atom [])]
    (let [res (k/parse parser s nil {::file-system file-system})]
      (if (:ok res)
        {:node         :ns
         :dependencies @*dependencies*
         :body         (:value res)}
        (throw (ex-info "Parsing error" {:orcl/error-pos (:pos (:error res))
                                         :orcl/error     "Parsing error"}))))))

(defn parse
  ([s] (parse s (fs/in-memory-file-system {})))
  ([s file-system] (parse* s program file-system)))

(defn parse-namespace
  ([s] (parse-namespace s (fs/in-memory-file-system {})))
  ([s file-system]
   (parse* s declarations file-system)))

