(ns active.clojure.new-monad)

(comment

  (declare-command do-something [arg])


  (defn do-something-impl [arg]
    *env*
    *state*
    (*recur* env state m) => [result state]
    (*recur* m) => result

    ;; =>
    [result state]
    ;; (with-state result state)

    ;; (set! *state* ...) ???
    )
  
  (monad-command-config
   {do-something do-something-impl}
   env
   initial-state)
  
  )


(def ^:dynamic *env* nil)
(def ^:dynamic *state* nil)

(def ^:dynamic *recur* nil)

(defrecord ^:private Command [name])
(defrecord ^:private Call [command args])

(defmacro declare-command [name params]
  ;; TODO: + docstring
  `(let [command# (Command. (... *ns* ~name))]
     (defn (with-meta ~name ::command command#)
       [~@params]
       (Call. command [~@params]))))

(defrecord ^:private CommandConfig [env initial-state command-impls])

(defn command-config [env initial-state command-impls]
  (assert (map? env) env)
  (assert (map? initial-state) initial-state)
  (assert (map? command-impls))
  (CommandConfig. env initial-state command-impls))

(defn- run-call [run-any env state impl args]
  (binding [*env* env
            *state* state]
    ;; TODO: maybe have some meta-data where the impl must say if they use recursion. (OPT)
    (binding [*recur* (fn ([m]
                           (let [[result state] (run-any *env* *state* m)]
                             (set! *state* state)
                             result))
                        ([env state m]
                         ;; returns [result state], state must be used by caller if they want to.
                         (run-any env state m)))])
    (let [res (apply impl args)]
      [res *state*])))

(defrecord ^:private Return [v])

(defrecord ^:private Bind [m f])

(defn return [v] (Return. v))

#_(defn- compm [g f]
  (fn [v]
    (let [x (f v)]
      (if (instance? Return x)
        (g (:v x))
        (bind x g)))))

(defrecord ^:private Continuations [f more]
  clojure.core.IFn
  (invoke [_this v]
    (recur [f f
            v v
            more more]
           (let [res (f v)]
             (cond
               (empty? more)
               res

               (if (instance? Return res)
                 (recur (first more)
                        (:v res)
                        (rest more))
                 (let [fs (rest more)]
                   (if (empty? fs)
                     (Bind. rest (first more))
                     (Bind. res (Continuations. (first more) fs))))))))))

(defn- compm [f g]
  (if (instance? Continuations g)
    (Continuations. (:f g) (conj (:more g) f))
    (Containuation. g [f])))

(defn bind [m f]
  (condp instance? m
    Return (f (:v m))

    ;; use a 'continuation queue' like in freer-monad?
    ;; (bind (bind x g) f) == (bind x (compm f g))  ?
    Bind
    (bind (:m m) (compm f (:f m)))

    ;; else
    (Bind. m f)))

(defrecord ^:private Shift [f])

(defrecord ^:private Reset [m])

(defn shift [f] (Shift. m))

(defn reset [m] (Reset. m))

(defn- make-reset-exn [id v state]
  (ex-info "Reset" {:type ::reset
                    :id id
                    :v v
                    :state state}))

(defn- is-reset-exn? [id exn]
  (when (and (= ::reset (:type (ex-data exn)))
             (= id (:id (ex-data exn))))
    [(:v (ex-data exn)) (:state (ex-data exn))]))

(defn run "Returns result of throws on unhandled exceptions." [command-config m]
  (let [command-impls (->> (:command-impls command-config)
                           (map (fn [[cv impl]]
                                  (if-let [cmd (::command (meta cv))]
                                    [cmd impl]
                                    (throw (ex-info (str "Not a command: " (pr-str cv) ". Use declare-command, and implement whem via the command vars.")
                                                    {:problem cv})))))
                           (into {}))
        run-any
        (fn run-any [reset env state m]
          (loop [m m
                 state state]
            ;; m = ...bind return throw etc
            (condp instance? m
              Return [(:v m) state]

              Bind
              (let [[v state] (run-any reset env state (:m m))]
                (recur ((:f m) v)
                       state))

              Call
              (if-let [impl (get command-impls (:command m))]
                (run-call (partial run-any reset) env state impl (:args m)) ;; returns [result state]

                (throw (ex-info (str "Command not implemented: " (:name (:command m)))
                                {:command (:command m)})))
              
              Reset
              (let [reset-id (unique-id)]
                (try (run-any (partial make-reset-exn reset-id)
                              env state (:m m))
                     (catch clojure.lang.ExceptionInfo exn
                       (if-let [[v state] (is-reset-exn? reset-id exn)]
                         [v state]
                         (throw exn)))))
              
              Shift
              (do
                (assert (some? reset) "shift outside reset")
                (recur ((:f m) (fn continuation [v]
                                 (throw (reset v state))))
                       state))

              ;; else: unknown monad command
              )))]
    (-> (run-any nil
                 (:env command-config)
                 (:initial-state command-config)
                 m)
        (first))))

;; *** State monad

(declare-command get-state [])
(declare-command put-state! [v])

(def state-command-config
  (command-config {} {}
                  {#'get-state (fn [] *state*)
                   #'put-state! (fn [v]
                                  (set! *state* v)
                                  nil)}))

;; *** Environment (Reader monad)

(declare-command get-env [])
(declare-command with-env [env m])

(def env-command-config
  (command-config {} {}
                  {#'get-env (fn [] *env*)
                   #'with-env (fn [env m]
                                (let [[res state] (*recur* env *state* m)]
                                  (reset! *state* state)
                                  res))}))

(defn get-env-component [k]
  (-> (get-env)
      (bind (fn [env]
              (return (get env k))))))

(defn with-env-component [k v m]
  (-> (get-env)
      (bind (fn [env]
              (with-env (assoc env k v) m)))))

;; *** Exceptions - requires env-command-config

(defn- shift-reset [f]
  (reset (shift f)))

(defn free-throw [exn]
  (-> (get-env-componenent ::catch)
      (bind (fn [catch*]
              (if catch*
                (catch* exn)
                ;; else real throw - unhandled exception?
                (throw exn))))))

(defn with-handler
  "If m throws, calls `(handler exception)`, or `(cont result)` otherwise."
  [m handler cont]
  (-> (shift-reset (fn [reset]
                     (with-env-component ::catch (fn [exn] (reset [false exn]))
                       (bind m (fn [v] (return [true v]))))))
      (bind (fn [[ok? res]]
              (if ok?
                (cont res)
                (handler res))))))

(defn try-catch [try-m handler]
  (with-handler try-m handler return))
