(ns orcl.typecheck
  (:require [orcl.utils :as utils]
            [clojure.set :as set]))

(defn all? [[x & xs]]
  (if x
    (if (seq xs)
      (all? xs)
      true)
    false))

(defn subst
  ([T formals types]
   (subst T (into {} (map vector formals types))))
  ([T m]
   (case (:type T)
     :fun (do (assert (every? #(not (contains? m %)) (:formals T)))
              (-> T
                  (assoc :return (subst (:return T) m))
                  (assoc :params (mapv #(subst % m) (:params T)))))
     :tuple (assoc T :args (mapv #(subst % m) (:args T)))
     :record (assoc T :fields (utils/map-vals #(subst % m) (:fields T)))
     :instance (assoc T
                 :args (mapv #(subst % m) (:args T))
                 :fields (utils/map-vals #(subst % m) (:fields T)))
     :formal (get m T T)
     T)))

(defn apply-variance [v1 v2]
  (case v1
    :covariant :covariant
    :contravariant (case v2
                     :covariant :contravariant
                     :contravariant :covariant
                     v2)
    :invariant :invariant
    :constant v2))

(defn combine-variance [v1 v2]
  (case v1
    :covariant (case v2
                 (:constant :covariant) :covariant
                 :invariant)
    :contravariant (case v2
                     (:constant :contravariant) :contravariant
                     :invariant)
    :invariant :invariant
    :constant v2))

(defn variance-of [T v]
  (case (:type T)
    :formal (if (= v T)
              :covariant
              :constant)
    :tuple (if (seq (:args T))
             (reduce combine-variance (map #(variance-of % v) (:args T)))
             :covariant)
    :record (if (seq (:fields T))
              (reduce combine-variance (map (fn [[_ field-T]] (variance-of field-T v)) (:fields T)))
              :covariant)
    :fun (do (assert (not (contains? (set (:formals T)) v)))
             (reduce combine-variance
                     (cons (variance-of (:return T) v)
                           (map #(apply-variance :contravariant (variance-of % v)) (:params T)))))
    :instance (reduce combine-variance (map #(apply-variance %1 (variance-of %2 v))
                                            (:variances (:dt T)) (:args T)))
    :constant))

(declare lift)

(def formals-counter (atom 0))

(defn with-formals [ctx T]
  (let [formals (or (:formals T)
                    (doall (for [p (:type-params T)]
                             {:type :formal :name p :_c (swap! formals-counter inc)})))]
    [(assoc T :formals formals)
     (update ctx :types into (map vector (:type-params T) formals))]))

(defn datatype-with-variances [ctx ast name T]
  (let [[T' ctx'] (with-formals ctx T)
        variant-types (fn [ctx c]
                        (map #(lift ctx ast %) (:params (:T c))))
        find-variance (fn [formal guess]
                        (let [guess-variances (for [f (:formals T')] (if (= f formal) guess :invariant))
                              dt              {:type      :datatype
                                               :variances guess-variances}
                              ctx''           (assoc-in ctx' [:types name] dt)
                              v-types         (doall (mapcat #(variant-types ctx'' %) (:constructors T)))
                              v-variances     (mapv #(variance-of % formal) v-types)
                              variance        (reduce combine-variance :constant v-variances)]
                          (when (= guess variance)
                            guess)))]
    (assoc T' :variances (or (:variances T)
                             (doall (for [f (:formals T')]
                                      (or (some #(find-variance f %) [:constant :covariant :contravariant])
                                          :invariant)))))))

(defn apply-T [ctx ast T types]
  (case (:type T)
    :polymorphic (let [[T ctx'] (with-formals ctx T)]
                   (subst (lift ctx' ast (:T T)) (:formals T) types))
    :datatype (let [[_ ctx'] (with-formals ctx T)]
                {:type   :instance
                 :dt     T
                 :args   types
                 :fields (utils/map-vals #(subst (lift ctx' ast %) (:formals T) types) (:fields T))})
    (utils/error "Unappliable type" ast)))

(defn lift [ctx ast T]
  (case (:type T)
    :application (let [op (or (get-in ctx [:types (:name T)])
                              (utils/error "Can't find type operator" ast
                                           :name (:name T)))]
                   (apply-T ctx ast (lift ctx ast op) (mapv #(lift ctx ast %) (:args T))))
    :tuple (assoc T :args (mapv #(lift ctx ast %) (:args T)))
    :record (assoc T :fields (utils/map-vals #(lift ctx ast %) (:fields T)))
    :var (if-let [T' (get-in ctx [:types (:name T)])]
           (lift ctx ast T')
           (utils/error "Can't find type" ast :name (:name T)))
    :fun (let [[T' ctx'] (with-formals ctx T)]
           (-> T'
               (assoc :return (lift ctx' ast (:return T)))
               (assoc :params (mapv #(lift ctx' ast %) (:params T)))))
    T))

(declare subtype?)

(defn all-subtypes? [ss ts]
  (every? (fn [[s t]] (subtype? s t)) (map vector ss ts)))

(defn subtype-tuple? [s t]
  (all-subtypes? (:args s) (:args t)))

(defn subtype-record? [s t]
  (let [s-fields (map first (:fields s))
        t-fields (map first (:fields t))]
    (and (set/subset? (set t-fields) (set s-fields))
         (let [t-types (map second (:fields t))
               s-types (map (:fields s) t-fields)]
           (all-subtypes? s-types t-types)))))

(defn same-shape? [s t]
  (and (= (count (:type-params s)) (count (:type-params t)))
       (= (count (:params s)) (count (:params t)))))

(defn subtype-fun? [s t]
  (and (same-shape? s t)
       (all-subtypes? (:params t) (:params s))
       (subtype? (:return s) (:return t))))

(defn subtype-datatype? [s t]
  (let [dt (:dt s)]
    (every? (fn [[variance s' t']]
              (case variance
                :covariant (subtype? s' t')
                :contravariant (subtype? t' s')
                :invariant (= s' t')
                :constant true))
            (map vector (:variances dt) (:args s) (:args t)))))

(defn subtype? [s t]
  (cond
    (= s t) true
    (= :bot (:type s)) true
    (= :top (:type t)) true
    (= :tuple (:type s) (:type t)) (subtype-tuple? s t)
    (= :record (:type s) (:type t)) (subtype-record? s t)
    (= :fun (:type s) (:type t)) (subtype-fun? s t)
    (and (= :integer (:type s)) (= :number (:type t))) true
    (= :overload (:type s)) (boolean (some #(subtype? % t) (:alternatives s)))
    (and (= :instance (:type s) (:type t))
         (= (:dt s) (:dt t))) (subtype-datatype? s t)
    (and (= :integer-constant (:type s))) (subtype? {:type :integer} t)
    (and (= :integer-constant (:type t))) (subtype? s {:type :integer})
    (and (= :string-constant (:type s))) (subtype? {:type :string} t)
    (and (= :string-constant (:type t))) (subtype? s {:type :string})
    :else false))

(declare meet)

(defn join [t1 t2]
  (cond
    (= :tuple (:type t1) (:type t2))
    {:type :tuple :args (mapv #(join %1 %2) (:args t1) (:args t2))}

    (= :record (:type t1) (:type t2))
    (let [keys (set/intersection (set (keys (:fields t1)))
                                 (set (keys (:fields t2))))]
      {:type   :record
       :fields (into {} (for [k keys]
                          [k (join (get (:fields t1) k)
                                   (get (:fields t2) k))]))})

    (and (= :fun (:type t1) (:type t2)) (same-shape? t1 t2))
    (assoc t1
      :params (mapv #(meet %1 %2) (:params t1) (:params t2))
      :return (join (:return t1) (:return t2)))

    (subtype? t1 t2)
    t2

    (subtype? t2 t1)
    t1

    (and (= :instance (:type t1) (:type t2)) (= (:dt t1) (:dt t2)))
    {:type :instance
     :dt   (:dt t1)
     :args (mapv (fn [v s t]
                   (case v
                     :covariant (join s t)
                     :contravariant (meet s t)
                     :invariant (if (= s t) s {:type :top})
                     :constant {:type :bot}))
                 (:variances (:dt t1)) (:args t1) (:args t2))}

    :else {:type :top}))

(defn meet [t1 t2]
  (cond
    (= :tuple (:type t1) (:type t2))
    {:type :tuple :args (mapv #(join %1 %2) (:args t1) (:args t2))}

    (= :record (:type t1) (:type t2))
    (let [keys (set/union (set (keys (:fields t1)))
                          (set (keys (:fields t2))))]
      {:type   :record
       :fields (into {} (for [k keys]
                          [k (meet (get (:fields t1) k {:type :top})
                                   (get (:fields t2) k {:type :top}))]))})

    (and (= :fun (:type t1) (:type t2)) (same-shape? t1 t2))
    (assoc t1
      :params (mapv #(join %1 %2) (:params t1) (:params t2))
      :return (meet (:return t1) (:return t2)))

    (subtype? t1 t2)
    t1

    (subtype? t2 t1)
    t2

    (and (= :instance (:type t1) (:type t2)) (= (:dt t1) (:dt t2)))
    {:type :instance
     :dt   (:dt t1)
     :args (mapv (fn [v s t]
                   (case v
                     :covariant (meet s t)
                     :contravariant (join s t)
                     :invariant (if (= s t) s {:type :bot})
                     :constant {:type :top}))
                 (:variances (:dt t1)) (:args t1) (:args t2))}

    :else {:type :bot}))

(declare check)
(declare infer)

(defn unzip [coll]
  (let [n (count (first coll))]
    (for [i (range n)]
      (map #(nth % i) coll))))

(defn assert-subtype [ast s t]
  (when-not (subtype? s t)
    (utils/error "Wrong type" ast
                 :expected t
                 :actual s)))

(defn elim [ast T V variance]
  (case (:type T)
    :formal (if (contains? V T)
              (case variance
                :covariant {:type :top}
                :contravariant {:type :bot}
                :invariant (utils/error "Can't infer type parameter" ast))
              T)
    :tuple (assoc T :args (mapv #(elim ast % V variance) (:args T)))
    :record (assoc T :fields (utils/map-vals #(elim ast % V variance) (:fields T)))
    :fun (let [scope' (set/union V (:formals T))]
           (let [params-variance (apply-variance :contravariant variance)]
             (-> T
                 (update :return #(elim ast % scope' variance))
                 (assoc :params (mapv #(elim ast % scope' params-variance) (:params T))))))
    :instance (let [args' (mapv (fn [arg-variance arg-T]
                                  (elim ast arg-T V (apply-variance arg-variance variance)))
                                (:variances (:dt T)) (:args T))]
                (assoc T :args args'))
    T))

(defn promote [ast T V] (elim ast T V :covariant))

(defn demote [ast T V] (elim ast T V :contravariant))

(defn tc-meet [ast C D]
  (let [vars        (set/union (set (keys C)) (set (keys D)))
        constraints (for [v vars]
                      (let [lower (join (get-in C [v 0] {:type :bot})
                                        (get-in D [v 0] {:type :bot}))
                            upper (meet (get-in C [v 1] {:type :top})
                                        (get-in D [v 1] {:type :top}))]
                        (if (subtype? lower upper)
                          [v [lower upper]]
                          (utils/error "Type variable is overconstrained" ast :var v))))]
    (into {} constraints)))

;; TODO
(defn share-type-params [f1 f2]
  ;; FIXME nameclashes are possible here?
  (let [shared (map (fn [p] {:type :formal :name p}) (:type-params f1))]
    [f1 f2]))

(defn type-constraints [ast below above X V]
  (cond
    (= :top (:type above))
    {}

    (= :bot (:type below))
    {}

    (and (= :formal (:type above) (:type below)) (= (:name below) (:name above)))
    {}

    (and (= :formal (:type below)) (contains? X below))
    (let [above' (demote ast above V)]
      {below [{:type :bot} above']})

    (and (= :formal (:type above)) (contains? X above))
    (let [below' (promote ast below V)]
      {above [below' {:type :top}]})

    (and (= :fun (:type below) (:type above)) (same-shape? below above))
    (let [[{:keys [formals] lower-params :params lower-return :return}
           {upper-params :params upper-return :return}] (share-type-params below above)
          V'      (into V formals)
          args-tc (reduce #(tc-meet ast %1 %2) {} (map #(type-constraints ast %1 %2 X V') upper-params lower-params))]
      (tc-meet ast args-tc (type-constraints ast lower-return upper-return X V')))

    (and (= :tuple (:type below) (:type above)) (= (count (:args below)) (count (:args above))))
    (reduce #(tc-meet ast %1 %2)
            {}
            (map #(type-constraints ast %1 %2 X V) (:args below) (:args above)))

    (and (= :record (:type below) (:type above)) (set/subset? (set (keys (:fields above))) (set (keys (:fields below)))))
    (reduce #(tc-meet ast %1 %2)
            {}
            (map (fn [[f field-above]] (type-constraints ast (get (:fields below) f) field-above X V)) (:fields above)))

    (and (= :instance (:type below) (:type above)) (= (:dt below) (:dt above)))
    (reduce #(tc-meet ast %1 %2)
            {}
            (map (fn [v s t]
                   (case v
                     :constant {}
                     :covariant (type-constraints ast s t X V)
                     :contravariant (type-constraints ast t s X V)
                     :invariant (tc-meet ast
                                         (type-constraints ast s t X V)
                                         (type-constraints ast t s X V))))
                 (:variances (:dt below)) (:args below) (:args above)))

    :else
    (do (assert-subtype ast below above)
        {})))

(defn call-fun [{:keys [params formals return]} type-args args-T]
  (cond
    (not= (count params) (count args-T)) nil
    (not= (count formals) (count type-args)) nil
    (not (all? (map (fn [s t] (subtype? s (subst t formals type-args))) args-T params))) nil
    :else (subst return formals type-args)))

(defn assert-number-of-arguments [ast expected actual]
  (when-not (= expected actual)
    (utils/error "Wrong number of arguments" ast
                 :expected expected
                 :actual actual)))

(defn call-tuple [ast T args-T]
  (assert-number-of-arguments ast 1 (count args-T))
  (let [arg-T (first args-T)]
    (case (:type arg-T)
      :integer-constant (nth (:args T) (:value arg-T))
      :integer (reduce #(join %1 %2) (:args T))
      (utils/error "Wrong type" ast
                   :type arg-T))))

;; is used internally in pattern matching
(defn call-tuple-arity-checker [ast T [t n :as args-T]]
  (assert-number-of-arguments ast 2 (count args-T))
  (if-not (= (count (:args t)) (:value n))
    (utils/error "Wrong tuple's arity" ast
                 :expected (:value n)
                 :actual (count (:args t)))
    t))

(defn call-unapply [ast T args-T]
  (assert-number-of-arguments ast 2 (count args-T))
  (let [[instance name] args-T
        constructor (some #(when (= (:value name) (:name %)) %) (:constructors (:dt instance)))
        args        (:params (:T constructor))]
    {:type :tuple
     :args (:params (:T constructor))}))

(defn call-type [ast T type-args args-T]
  (case (:type T)
    ;; overload types are possible only as sites, programmer can't define them in Orc
    :overload (some #(call-type ast % type-args args-T) (:alternatives T))
    :unapply (call-unapply ast T args-T)
    :fun (call-fun T type-args args-T)
    :tuple-constructor {:type :tuple
                        :args args-T}
    :record-constructor {:type   :record
                         :fields (for [[f v] (partition 2 args-T)]
                                   [(:value f) v])}
    :tuple-arity-checker (call-tuple-arity-checker ast T args-T)
    :tuple (call-tuple ast T args-T)
    :let (case (count args-T)
           0 {:type :signal}
           1 (first args-T)
           {:type :tuple :args args-T})
    nil))



(defn any-subst [C]
  (utils/map-vals first C))

(defn minimal-subst [ast T C]
  (into {} (for [[v [lower upper]] C]
             [v
              (case (variance-of T v)
                (:constant :covariant) lower
                :contravariant upper
                :invariant (if (= lower upper)
                             lower
                             (utils/error "No minimal type could be inferred" ast :var v)))])))

(defn infer-call [ctx {:keys [target args type-args] :as ast} check-return]
  (let [[target' target-T] (infer ctx target)
        ;; limited arg types inference
        [args' args-T]
        (if (and (= :fun (:type target-T)) (empty? (:type-params target-T)))
          (let [args-T (:params target-T)]
            (assert-number-of-arguments ast (count args-T) (count args))
            [(mapv (partial check ctx) args-T args) args-T])
          ;; XXX why we can't check args against fun params in presence of type params
          (unzip (mapv (partial infer ctx) args)))
        ast' (-> ast
                 (assoc :target target')
                 (assoc :args args'))]

    (case (:type target-T)
      :fun (cond
             (seq type-args)
             (let [{:keys [formals params return]} target-T
                   type-args (map #(lift ctx ast %) type-args)]
               (when-not (= (count (:formals target-T)) (count type-args))
                 (utils/error "Wrong number of type parameters" ast
                              :expected (count formals)
                              :actual (count type-args)))
               (let [return' (subst (:return target-T) formals type-args)]
                 (mapv #(assert-subtype ast %1 (subst %2 formals type-args)) args-T params)
                 [ast' return']))

             (empty? (:type-params target-T))
             (do (mapv (partial assert-subtype ast) args-T (:params target-T))
                 [ast' (:return target-T)])
             :else
             (let [X          (set (:formals target-T))
                   C          (into {} (for [x X] [x [{:type :bot} {:type :top}]]))
                   C'         (reduce #(tc-meet ast %1 %2) C (map #(type-constraints ast %1 %2 X {}) args-T (:params target-T)))
                   subst-m    (if check-return
                                (any-subst (tc-meet ast C' (type-constraints ast (:return target-T) check-return X {})))
                                (minimal-subst ast (:return target-T) C'))
                   type-args' (map subst-m X)
                   return     (subst (:return target-T) subst-m)]
               [(assoc ast' :type-args type-args')
                return]))
      (if-let [return-T (call-type ast target-T type-args args-T)]
        [ast' return-T]
        (utils/error "Type can't be inferred" ast)))))

(defn infer-field-access [ctx {:keys [target field] :as ast}]
  (let [[target' target-T] (infer ctx target)]
    (case (:type target-T)
      (:instance :record) (let [T (get-in target-T [:fields field])]
                            (when-not T
                              (utils/error "Unknown field" ast :field field))
                            [(assoc ast :target target') T])
      (utils/error "Wrong type for field access" ast))))

(defn value-type [v]
  (cond
    (integer? v) {:type :integer-constant :value v}
    (number? v) {:type :number}
    (= :signal v) {:type :signal}
    (or (true? v) (false? v)) {:type :boolean}
    (string? v) {:type :string-constant :value v}
    (nil? v) {:type :null}
    (list? v) {:type :application :name "List" :args [{:type :top}]}
    :else (assert false)))

;; enriches ctx with type parameters and argument types
(defn def-ctx [ctx ast {:keys [params] :as T} i]
  (let [bindings (map :var (:params i))
        [T' ctx'] (with-formals ctx T)
        params'  (mapv #(lift ctx' ast %) params)]
    [(assoc T' :params params')
     (update ctx' :values into (mapv #(vector %1 (lift ctx' ast %2)) bindings params'))]))

(defn infer-defs [ctx defs]
  (loop [res [] [mr-group & defs] defs ctx ctx bindings {}]
    (cond
      (nil? mr-group)
      [res bindings]
      ;; we can infer return type only for non recursive functions
      (and (= 1 (count mr-group)) (nil? (:return (:T (first mr-group)))))
      (let [d (first mr-group)
            i (first (:instances d))]
        (when-not (= (count (keep identity (:params (:T d)))) (count (:params i)))
          (utils/error "Arguments type annotations are required" d))
        (if (contains? (:free-vars d) (:name d))
          (utils/error "Can't infer return type for recursive function" d)
          (let [[d-T' body-ctx] (def-ctx ctx d (:T d) i)
                [body' T] (infer body-ctx (:body i))
                d' (-> d
                       (assoc-in [:instances 0 :body] body')
                       (assoc :T d-T')
                       (assoc-in [:T :return] T))]
            (recur (conj res [d']) defs
                   (assoc-in ctx [:values (:name d)] (:T d'))
                   (assoc bindings (:name d) (:T d'))))))

      :else (let [bindings' (into bindings (for [d mr-group]
                                             (do
                                               (when-not (:return (:T d))
                                                 (utils/error "Can't infer return type for recursive function" d))
                                               [(:name d) (:T d)])))
                  ctx'      (update ctx :values merge bindings')
                  mr-group' (mapv (fn [d]
                                    (let [i (first (:instances d))
                                          [d-T ctx''] (def-ctx ctx' d (:T d) i)]
                                      (when-not (= (count (keep identity (:params (:T d)))) (count (:params i)))
                                        (utils/error "Arguments type annotations are required" d))
                                      (-> d
                                          (assoc :T d-T)
                                          (update-in [:instances 0 :body] #(check ctx'' (lift ctx'' d (:return d-T)) %)))))
                                  mr-group)]
              (recur (conj res mr-group') defs
                     ctx'
                     bindings')))))

(def base-types
  {"List"    {:type        :datatype
              :prelude     true
              :type-params ["T"]
              :variances   [:covariant]}
   "Integer" {:type :integer}
   "String"  {:type :string}
   "Number"  {:type :number}
   "Boolean" {:type :boolean}
   "Signal"  {:type :signal}
   "Option"  {:type        :datatype
              :prelude     true
              :type-params ["T"]
              :variances   [:covariant]}
   "Ref"     {:type        :datatype
              :type-params ["T"]
              :variances   [:invariant]
              :fields      {"read"  {:type   :fun
                                     :params []
                                     :return {:type :var :name "T"}}
                            "readD" {:type   :fun
                                     :params []
                                     :return {:type :var :name "T"}}
                            "write" {:type   :fun
                                     :params [{:type :var :name "T"}]
                                     :return {:type :signal}}}}

   "Cell"    {:type        :datatype
              :type-params ["T"]
              :variances   [:invariant]
              :fields      {"read"  {:type   :fun
                                     :params []
                                     :return {:type :var :name "T"}}
                            "readD" {:type   :fun
                                     :params []
                                     :return {:type :var :name "T"}}
                            "write" {:type   :fun
                                     :params [{:type :var :name "T"}]
                                     :return {:type :signal}}}}

   "Channel" {:type        :datatype
              :type-params ["T"]
              :variances   [:invariant]
              :fields      {"get"      {:type   :fun
                                        :params []
                                        :return {:type :var :name "T"}}
                            "getD"     {:type   :fun
                                        :params []
                                        :return {:type :var :name "T"}}
                            "put"      {:type   :fun
                                        :params [{:type :var :name "T"}]
                                        :return {:type :signal}}
                            "close"    {:type   :fun
                                        :params []
                                        :return {:type :signal}}
                            "closeD"   {:type   :fun
                                        :params []
                                        :return {:type :signal}}
                            "isClosed" {:type   :fun
                                        :params []
                                        :return {:type :boolean}}
                            "getAll"   {:type   :fun
                                        :params []
                                        :return {:type :application
                                                 :name "List"
                                                 :args [{:type :var :name "T"}]}}}}
   "Top"     {:type :top}
   "Bot"     {:type :bot}})

(defn infer
  [ctx ast]
  (case (:node ast)
    :stop [ast {:type :bot}]
    :const (let [t (value-type (:value ast))]
             [(assoc ast :T t) t])
    :var (if-let [t (get-in ctx [:values (:var ast)])]
           [(assoc ast :T t) (lift ctx ast t)]
           (utils/error "Undefined variable" ast
                        :variable (:var ast)))
    :call (infer-call ctx ast nil)
    :field-access (infer-field-access ctx ast)
    (:parallel :otherwise) (let [[l' lT] (infer ctx (:left ast))
                                 [r' rT] (infer ctx (:right ast))]
                             [(assoc ast :left l' :right r')
                              (join lT rT)])
    :sequential (let [[l' lT] (infer ctx (:left ast))
                      [r' rT] (infer (assoc-in ctx [:values (:var (:pattern ast))] lT) (:right ast))]
                  [(assoc ast :left l' :right r') rT])
    :pruning (let [[r' rT] (infer ctx (:right ast))
                   [l' lT] (infer (assoc-in ctx [:values (:var (:pattern ast))] rT) (:left ast))]
               [(assoc ast :left l' :right r') lT])
    :defs-group (let [[defs' bindings] (infer-defs ctx (:defs ast))
                      [expr' T] (infer (update ctx :values merge bindings) (:expr ast))]
                  [(assoc ast :defs defs' :expr expr') T])
    :declare-types (let [types' (for [[n t] (:types ast)]
                                  [n (case (:type t)
                                       :datatype (datatype-with-variances ctx ast n t)
                                       t)])]
                     (infer (update ctx :types into types') (:expr ast)))
    :refer (let [ctx' (update ctx :values into (for [[ns symbols] (:namespaces ast)
                                                     s symbols]
                                                 [s (or (get-in ctx [:dependencies ns s :T])
                                                        (utils/error "Unknown type" ast :ns ns :symol s))]))]
             (infer ctx' (:expr ast)))
    :has-type (let [T'    (lift ctx ast (:T ast))
                    expr' (check ctx T' (:expr ast))]
                [(assoc ast :expr expr') T'])))

(defn process [ast values-ctx dependencies]
  (reset! formals-counter 0)
  (let [base-types' (into {}
                          (for [[n t] base-types]
                            [n (case (:type t)
                                 :datatype (datatype-with-variances {} {} n t)
                                 t)]))
        [ast' _] (infer {:values values-ctx :types base-types' :dependencies dependencies} ast)]
    ast'))

;(defn lambda-to-infer? [ast]
;  (let [[[d]] (:defs ast)]
;    (when (and (:lambda ast) (empty? (:params (:T d))))
;      d)))
;(if-let [d (lambda-to-infer? ast)]
;  (case (:type T)
;    :fun (let [[i] (:instances d)
;               [T' ctx'] (def-ctx ctx d T i)
;               body' (check ctx' (:return T') (:body d))]
;           (-> ast
;               (assoc-in [:defs 0 0 :body] body')
;               (assoc-in [:defs 0 0 :T] T')))
;    (utils/error "Wrong type" ast :expected T))
;  ...)

(defn check [ctx T ast]
  (case (:node ast)
    :call (let [[e] (infer-call ctx ast T)]
            e)
    (:parallel :otherwise) (-> ast
                               (update :left #(check ctx T %))
                               (update :right #(check ctx T %)))
    :sequential (let [[l lT] (infer ctx (:left ast))]
                  (assoc ast :left l :right (check (assoc-in ctx [:values (:var (:pattern ast))] lT) T (:right ast))))
    :pruning (let [[r rT] (infer ctx (:right ast))]
               (assoc ast :right r :left (check (assoc-in ctx [:values (:var (:pattern ast))] rT) T (:left ast))))
    :defs-group (let [[defs' bindings] (infer-defs ctx (:defs ast))
                      expr' (check (update ctx :values merge bindings) T (:expr ast))]
                  (assoc ast :defs defs' :expr expr'))

    (let [[ast' T'] (infer ctx ast)]
      (assert-subtype ast T' T)
      ast')))
