(ns libapl-clj.sidecar
  (:require [libapl-clj.apl :as apl :reload true]
            [libapl-clj.impl.jna :as jna :reload true]
            [tech.v3.tensor :as tensor]
            [tech.v3.datatype :as dtype]
            tech.v3.jna
            tech.v3.datatype.jna
            [complex.core :refer [complex]]

            [clojure.java.io :as io]
            [libapl-clj.impl.api :as api]
            [libapl-clj.impl.pointer :as p]
            [libapl-clj.impl.helpers :as h])
  (:import [java.util UUID]))

(defn ^:private random-var-name []
  (clojure.string/replace (str "a" (UUID/randomUUID))  #"-" ""))

(defn shape
  "Crappy memory leaky until I can figure out proper JNA bindings for \"shape\""
  [apl-value]
  (let [random-name       (random-var-name)
        random-name-shape (random-var-name)]
    (jna/set_var_value random-name apl-value " ")
    (apl/run-simple-string! (str random-name-shape "← ⍴ " random-name))
    (let [p             apl-value
          shape-pointer (apl/value-pointer random-name-shape)
          shape-count   (jna/get_element_count shape-pointer)
          shape'        (mapv #(jna/get_int shape-pointer %) (range shape-count))]
      (future
        (jna/apl_command (str "ERASE " random-name))
        (jna/apl_command (str "ERASE " random-name-shape)))
      shape')))

(defonce _ (apl/initialize!))

(defn apl->tensor [apl-value]
  (let [rank   (jna/get_rank apl-value)
        n      (jna/get_element_count apl-value)
        shape' (shape apl-value)]
    (if (not-empty shape') ;; an empty vector denotes an enclosed vector, e.g. (⊂1 2)
      (let [random-name (random-var-name)
            res         (-> (dtype/make-reader :object n
                                               (case (jna/get_type apl-value idx)
                                                 0    nil
                                                 0x02 (jna/get_char apl-value idx)
                                                 0x04 (->> idx
                                                           (jna/get_value apl-value)
                                                           (apl->tensor))
                                                 0x10 (jna/get_int apl-value idx)
                                                 0x20 (jna/get_real apl-value idx)
                                                 0x40 (complex (jna/get_real apl-value idx)
                                                               (jna/get_imag apl-value idx))))
                            (tensor/construct-tensor (tech.v3.tensor.dimensions/dimensions shape')))]
        (jna/set_var_value random-name apl-value "")
        (with-meta
          res
          {:apl/name random-name}))
      apl-value)))

(comment
  (jna/set_var_value "m" mixed "")
  (drawing (jna/get_var_value "m" "")))

(deftype PTensor [t]
  Object
  (toString [_]
    (str "\n" (jna/print_value_to_string t))))

(tech.v3.datatype.pprint/implement-tostring-print PTensor)

(defn drawing [apl-value-or-tensor]
  (let [drawing-name    (random-var-name)
        tensor?         (tensor/tensor? apl-value-or-tensor)
        apl-name        (if tensor?
                          (-> apl-value-or-tensor meta :apl/name)
                          (let [apl-name (random-var-name)]
                            (jna/set_var_value apl-name apl-value-or-tensor "")
                            apl-name))
        apl-value       (if tensor?
                          (jna/get_var_value apl-name "")
                          apl-value-or-tensor)
        drawing-pointer (do  (apl/run-simple-string! (str drawing-name " ← 4 ⎕CR " apl-name))
                             (apl/value-pointer drawing-name))
        sb              (StringBuilder.)
        res             (jna/print_value_to_string drawing-pointer)]
    (future
      (when-not tensor?
        (jna/apl_command (str "ERASE " apl-name)))
      (jna/apl_command (str "ERASE " drawing-pointer)))
    (PTensor. drawing-pointer)))

(defn draw! [t]
  (with-open [w (io/writer *out*)]
    (let [draw-string (str (drawing t))]
      (doseq [line (rest (clojure.string/split-lines draw-string))]
        (.write w line)
        (.write w "\n")))))

(comment
  (type  (apl->tensor mixed))
  (deftype PTensor [t]
    Object
    (toString [_]
      (str "\n" (jna/print_value_to_string t))))
  (tech.v3.datatype.pprint/implement-tostring-print PTensor)

  (apl->tensor mixed)

  (drawing mixed)

  (-> mixed apl->tensor drawing)

  (PTensor. mixed)

  (reify Object
    (toString [_]
      (drawing mixed)))
  (println (draw! mixed)))

(defn ->apl [tensor]
  (let [shape'       (dtype/shape tensor)
        shape-string (clojure.string/join #" " shape')
        shape-name   (random-var-name)
        tensor-name  (random-var-name)
        n            (apply * shape')]
    (apl/run-simple-string! (str shape-name " ← " shape-string))
    (apl/run-simple-string! (format "%s ← %s ⍴ 0"
                                    tensor-name
                                    shape-name))
    (let [apl-value (apl/value-pointer tensor-name)]
      (doseq [[idx elt] (->>  tensor
                              tensor/tensor->buffer
                              (interleave (range))
                              (partition 2))]
        (try
          (cond
            (int? elt)    (jna/set_int elt apl-value idx)
            (double? elt) (jna/set_double elt apl-value idx)
            (string? elt) (let [t (-> elt
                                      seq
                                      dtype/ensure-reader
                                      tensor/ensure-tensor
                                      ->apl)]
                            (jna/set_value t apl-value idx))
            (char? elt)   (jna/set_char elt apl-value idx)
            ((some-fn tensor/tensor? vector? list?)
             elt)
            (let [t (-> elt
                        seq
                        dtype/ensure-reader
                        tensor/ensure-tensor
                        ->apl)]
              (jna/set_value t apl-value idx))
            :else
            (throw (ex-info  "unsupported type"
                             {:idx   idx
                              :value elt})))
          (catch Exception e
            (throw (ex-info "something went wrong"
                            {:idx   idx
                             :value elt
                             :error e})))))
      ^{:apl/name tensor-name}
      apl-value)))


;; pointer ops


(defn pointer-binop [op]
  (fn binop [a b]
    (let [res-name  (random-var-name)
          a-name    (random-var-name)
          b-name    (random-var-name)
          op-string (format "%s ← %s %s %s" res-name a-name (str op) b-name)]
      (jna/set_var_value a-name a "")
      (jna/set_var_value b-name b "")
      (apl/run-simple-string! op-string)
      (future
        (jna/apl_command (str "ERASE " a-name))
        (jna/apl_command (str "ERASE " b-name)))
      (jna/get_var_value res-name ""))))

(def add (pointer-binop '+))
(def sub (pointer-binop '-))
(def mul (pointer-binop '×))
(def div (pointer-binop '÷))
(def lt  (pointer-binop '<))
(def lte (pointer-binop '≤))
(def gt  (pointer-binop '>))
(def gte (pointer-binop '≥))
(def eq  (pointer-binop '=))
(def neq (pointer-binop '≠))
(def reshape (pointer-binop '⍴))

(apl/run-simple-string! "a ← ⍳10")
(def a' (jna/get_var_value "a" ""))
(apl/run-simple-string! "twenty ← 20")
(def twenty (jna/get_var_value "twenty" ""))

(drawing a')

(drawing (add a' a'))
(drawing (sub a' a'))
(drawing (mul a' a'))
(drawing (->> (mul a' a')
              (mul a')
              (reshape (->apl (tensor/->tensor [3 2 2])))))
(drawing (div a' a'))
(drawing (lt a' a'))
(drawing (gt a' a'))
(drawing (gte a' a'))
(drawing (eq a' a'))
(drawing (neq a' a'))

(drawing (->> a'
              (mul a')
              (sub a')
              (mul a')
              (add a')
              (div a')
              (reshape
               (-> [3 2 4]
                   tensor/->tensor
                   ->apl))))

(drawing a')
(apl->tensor a')
(jna/print_value_to_string a')
(drawing twenty)

(drawing (add a' twenty))
(drawing (mul a' twenty))
(drawing (div a' twenty))

(drawing a')

(-> (mul a' a')
    apl->tensor
    meta)

(hash [1 2 3])
(hash (dtype/->reader (-> (tensor/->tensor (range (* 2 3 5)))
                          (tensor/reshape [2 3 5]))))

(comment
  (tensor/->tensor (vec "hello"))
  (tensor/tensor->buffer (tensor/->tensor (partition 3 (range 10))))

  (-> (partition 3 (range 10))
      tensor/->tensor
      ->apl
      drawing)

  (drawing mixed)
  (drawing (->apl (apl->tensor mixed)))

  (-> "hello" seq dtype/ensure-reader tensor/ensure-tensor)

  (-> ["hello" [1 2 3]
       (into []
             (partition-all 3)
             (range 10))]
      tensor/ensure-tensor
      ->apl
      drawing)

  (drawing (->apl (tensor/->tensor (partition 3 (range 10)))))

  (-> (tensor/construct-tensor (dtype/->buffer (range (apply * [2 3 5 7])))
                               (tech.v3.tensor.dimensions/dimensions [2 3 5 7]))
      ->apl
      drawing)

  (apl->tensor enclosed)
  (draw! enclosed))

(comment
  (apl/initialize!)
  (def names (apl/value-pointer "names"))
  (apl/run-simple-string! "res ← 'a' 1 1.1 (⊂1 2) (1 2) 1j2")
  (def res (apl/value-pointer "res"))

  (jna/get_type res 2)
  (jna/get_real res 2)
  (jna/get_type res 3)
  (jna/get_type res 4)
  (jna/get_type res 5)
  (def x (apl->tensor res))
  (meta (var x))
  (with-meta)

  (jna/get_type res 4)
  (apl->tensor (jna/get_value res 4))
  0x20
  (apl->tensor res)

  ;; come back here
  (jna/apl_value 2
                 (dtype/make-container :native-heap [1 2])
                 "")

  (apl/run-simple-string! "res ← 3 3 ⍴ 'A'")
  (def res (apl/value-pointer "res"))
  (apl->tensor res)

  (jna/get_type res 0)

  (jna/get_char res 0)

  (jna/get_int res 1)

  (jna/get_real res 2)

  (jna/get_type res 3)

  (def three (jna/get_value res 3))

  (jna/get_rank three)
  (jna/get_element_count three)

  (apl/run-simple-string! "⎕IO ← 0")

  (do (apl/run-simple-string! "x ← ⍳3 4 5")
      (def x (apl/value-pointer "x")))
  (def x0 (jna/get_value x 0))
  (apl->tensor x)

  (dtype/make-reader :object
                     (jna/get_element_count x)
                     (->> idx (jna/get_value x) (apl->tensor)))
  (shape x0)
  (apl->tensor x0)
  (jna/get_type x 0)
  (shape x)
  (apl->tensor x)

  tech.v3.tensor.dimensions/dimensions
  (-> (dtype/make-reader :object
                         60

                         (tensor/->tensor [idx idx idx]))
      (tensor/construct-tensor (tech.v3.tensor.dimensions/dimensions [3 4 5]))))

(comment
  (apl/run-simple-string! "A ← 3 3 ⍴ 'A'")
  (def A (apl/value-pointer "A"))
  (apl->tensor A)
  (drawing A)
  (drawing (api/int-scalar 1))

  (apl/run-simple-string! "mixed ← (1 2) (3 4) (5 6)")
  (def mixed (apl/value-pointer "mixed"))
  (apl->tensor mixed)
  (drawing mixed)

  (drawing (->apl (tensor/->tensor ["hello" [[1 2] [[3 4 "nexted"]]] 6])))

  (drawing (->apl (tensor/->tensor [[[1 0 0]
                                     [3 0 0]]
                                    [[1 2 0]
                                     [3 4 0]]
                                    [[1 2 4]
                                     [3 4 6]]])))

  (apl/run-simple-string! "iterations ← ⍳2 3 3 ")
  (apl/run-simple-string! "⎕io ← 0")
  (def iterations (apl/value-pointer "iterations"))
  (drawing iterations)
  (apl->tensor iterations)

  (def cmds ["res ← 3 3 ⍴ 2"
             "res1 ← res[,1;]"
             "res111 ← res[1 1 1;]"
             "res11111 ← res[1 1 1 1 1 1;]"
             "shape_res111 ← ⍴res111"])

  (doseq [cmd cmds]
    (apl/run-simple-string! cmd))

  (defn draw-symbol [s]
    (-> s
        apl/value-pointer
        drawing))
  (draw-symbol "res111")
  (draw-symbol "res11111")
  (draw-symbol "res1")
  (draw-symbol "res")

  (use '[clojure.java.io :only [output-stream]])
  (apl/run-simple-string! "enclosed ← (⊂1 2) 1 2 (1 2)")
  (def enclosed (apl/value-pointer "enclosed"))
  (apl->tensor enclosed)
  (require '[clojure.java.io :as io])

  io/writer

  *out*
  (apl/run-simple-string! "enclosedp ← 4 ⎕CR enclosed")
  (def enclosedp (apl/value-pointer "enclosedp"))
  (with-open [w (io/writer *out*)]
    (doseq [line (clojure.string/split-lines (jna/print_value_to_string enclosedp))]
      (.write w line)
      (.write w "\n")))

  (java.io.BufferedReader. (jna/print_value_to_string enclosed))

  (draw! enclosedp))

(comment
  (drawing (jna/get_element_count  (jna/apl_value 2 (dtype/make-container :native-heap :int64 [2 3]) "")))
  (drawing (jna/char_vector "hello" ""))

  (jna/get_function_ucs "⍴"
                        nil
                        nil)

  (require 'tech.v3.datatype.jna)
  (def ⍴ (jna/get_function_ucs (dtype/make-container :native-heap :uint64 [\⍴])
                               (dtype/make-container :native-heap :uint64 [])
                               (dtype/make-container :native-heap :uint64 [])))

  (let [A (-> [3 3] tensor/->tensor ->apl)
        B (jna/char_vector "hello" "")]
    (jna/get_axis (jna/eval__A_fun_B A ⍴ B) 1)))

(comment
  (-> (partition 3 (range 9))
      libapl-clj.api/->apl 
      libapl-clj.api/->jvm)
  (h/drawing! (libapl-clj.api/arg+fp+arg 1 p/+ 1))
  (h/drawing! (libapl-clj.api/fp+arg p/⍳ 10))
  (h/drawing! (libapl-clj.api/arg+fp+arg [1 2 3] p/+ [1 2 3]))
  (h/drawing! (libapl-clj.api/arg+fp+arg [1 2 3] p/+ 10))
  (h/drawing! (libapl-clj.api/arg+fp+arg 10 p/+ [1 2 3])))










