(ns splendid.ml.rbm01
  (:use clojure.repl
        )
  (:import java.security.SecureRandom
           (org.jblas DoubleMatrix MatrixFunctions)))

(defn ^DoubleMatrix double-matrix [doubles]
  (DoubleMatrix. ^"[[D" (into-array (map double-array doubles))))

(defrecord RBM [nvis nhid ^double learn wts])

(defn create-rbm [nvis nhid & {:keys [learn] :or {learn 0.1}}]
  (let [r (SecureRandom.)
        wts (->> (repeatedly (* nvis nhid) #(* 0.1 (.nextGaussian r)))
                 (partition nhid)
                 (map #(cons 0.0 %))
                 (cons (repeatedly (inc nhid) (constantly 0.0)))
                 double-matrix)]
    (RBM. nvis nhid learn wts)))

(defn ^DoubleMatrix logistic [^DoubleMatrix x]
  (.rdiv (.add (MatrixFunctions/exp (.neg x))
               1.0)
         1.0))

(defn ^DoubleMatrix run-visible [^RBM rbm data]
  (let [num-examples (count data)
        nhidinc (inc (:nhid rbm))
        data (map #(cons 1.0 %) data) ;; data = np.insert(data, 0, 1, axis = 1)
        data (double-matrix data)
        ;;_ (println "data\n" data)
        ;;_ (println "wts\n" (.wts rbm))
        hidden-activations (.mmul data ^DoubleMatrix (.wts rbm))
        ;;_ (println "hidden-activations\n" hidden-activations)
        hidden-probs (logistic hidden-activations)
        ;;_ (println "hprobs\n" hidden-probs)
        hidden-states (.gt hidden-probs (DoubleMatrix/rand num-examples nhidinc))]
    ;;(println "hstates\n" hidden-states)
    ;;(println "collies\n" (.getColumns hidden-states (int-array (range 1 nhidinc))))
    (.getColumns hidden-states (int-array (range 1 nhidinc)))))

(defn train [rbm, data, & {:keys [max-epochs] :or {max-epochs 1000}}]
  (let [num-examples (double (count data))
        data (map #(cons 1.0 %) data) ;; data = np.insert(data, 0, 1, axis = 1)
        data (double-matrix data)
        ^DoubleMatrix wts (:wts rbm)]
    (dotimes [epoch max-epochs]
      (let [
            pos-hidden-activations (.mmul data wts)
            pos-hidden-probs (logistic pos-hidden-activations)
            pos-hidden-states (.gt pos-hidden-probs (DoubleMatrix/rand num-examples (inc (:nhid rbm))))
            pos-associations (.mmul (.transpose data) pos-hidden-probs)
            neg-visible-activations (.mmul pos-hidden-states (.transpose wts))
            neg-visible-probs (logistic neg-visible-activations)
            neg-visible-probs (.put neg-visible-probs
                                    (int-array (range (.getRows neg-visible-probs)))
                                    (int-array [0])
                                    1.0)
            neg-hidden-activations (.mmul neg-visible-probs wts)
            neg-hidden-probs (logistic neg-hidden-activations)
            neg-associations (.mmul (.transpose neg-visible-probs) neg-hidden-probs)
            x (.sub pos-associations neg-associations)
            y (.div x num-examples)
            z (.mul y (double (:learn rbm)))
            ]
        (.addi wts z)
        (when (zero? (mod (inc epoch) (/ max-epochs 10)))
          (println (format "Epoch %s: error is %s"
                           (inc epoch)
                           (-> data
                               (.sub neg-visible-probs)
                               (MatrixFunctions/pow 2.0)
                               .sum))))))))

(defn testus []
  (let [r (create-rbm 6 2 :learn 0.1)
        training-data (apply concat (repeatedly 500 #(vector [1,1,1,0,0,0] [0,0,0,1,1,1])))
        training-data [[1,1,1,0,0,0],[1,0,1,0,0,0],[1,1,1,0,0,0],
                       [0,0,1,1,1,0],[0,0,1,1,0,0],[0,0,1,1,1,0]]
        training-data [[1,1,1,0,0,0],[0,0,0,1,1,1],[0 0 0 0 0 0] [1 1 1 1 1 1]
                       [1,1,1,0,0,0],[0,0,0,1,1,1],[0 0 0 0 0 0] [1 1 1 1 1 1]
                       [1,1,1,0,0,0],[0,0,0,1,1,1],[0 0 0 0 0 0] [1 1 1 1 1 1]
                       [1,1,1,0,0,0],[0,0,0,1,1,1],[0 0 0 0 0 0] [1 1 1 1 1 1]
                       [1,1,1,0,0,0],[0,0,0,1,1,1],[0 0 0 0 0 0] [1 1 1 1 1 1]
                       ]
        training-data (apply concat (repeatedly 500 #(vector [0 0 0 0 0 0] [0,0,0,1,1,1]
                                                             [1,1,1,0,0,0] [1 1 1 1 1 1])))
        ]
    (train r training-data :max-epochs 9000)
    (println (->> (map (fn [_]
                         (let [r1 (run-visible r [[1,1,1,0,0,0]])
                               r2 (run-visible r [[0,0,0,1,1,1]])
                               r3 (run-visible r [[0,0,0,0,0,0]])
                               r4 (run-visible r [[1,1,1,1,1,1]])
                               ]
                           (or (.equals r1 r2)
                               (.equals r1 r3)
                               (.equals r1 r4)
                               #_(.equals r2 r3)
                               #_(.equals r2 r4)
                               #_(.equals r3 r4))))
                       (range 10000))
                  (remove true?)
                  count))
    (def ribbi r))
  #_(println (:wts ribbi))
  (println (run-visible ribbi [[0, 0, 0, 1, 1, 1]]))
  (println (run-visible ribbi [[1, 1, 1, 0, 0, 0]]))
  (println (run-visible ribbi [[0, 0, 0, 0, 0, 0]]))
  (println (run-visible ribbi [[1, 1, 1, 1, 1, 1]])))


(defn test-counting []
  (let [r (create-rbm 8 2 :learn 0.1)
        training-data [[0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0]
                       [0 0 0 0 0 0 0 1] [0 0 0 0 0 0 1 0] [0 0 0 0 0 1 0 0] [0 0 0 0 1 0 0 0]
                       [0 0 0 1 0 0 0 0] [0 0 1 0 0 0 0 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]
                       [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0]
                       [0 0 0 0 0 0 0 1] [0 0 0 0 0 0 1 0] [0 0 0 0 0 1 0 0] [0 0 0 0 1 0 0 0]
                       [0 0 0 1 0 0 0 0] [0 0 1 0 0 0 0 0] [0 1 0 0 0 0 0 0] [1 0 0 0 0 0 0 0]
                       [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0] [0 0 0 0 0 0 0 0]
                       ]
        ]
    (train r training-data :max-epochs 100000)
    (def ribbi r)
    (println (run-visible r [[0 0 0 0 0 0 0 0]]))
    (println (run-visible r [[0 0 0 0 0 0 0 1]]))
    (println (run-visible r [[0 0 0 0 0 0 1 0]]))
    (println (run-visible r [[0 0 0 0 0 1 0 0]]))))






;;;;;;;;;;;;;


(defn- get-format-str-for-column [^DoubleMatrix col]
  (let [max (.max col)
        min (abs (.min col))]
    (cond
     (and (< min 10)   (< max 10))   [12 8]
     (and (< min 100)  (< max 100))  [12 7]
     (and (< min 1000) (< max 1000)) [12 6]
     :else [13 5])))

(defmethod print-method DoubleMatrix [^DoubleMatrix m, ^java.io.Writer w]
  ;; 
  (let [row-count (.rows m)
        col-count (.columns m)
        cols (.columnsAsList m)
        f (->> (map get-format-str-for-column cols)
               (map (fn [[width decimals]]
                      (str "%," (dec width) "." decimals "f")))
               (clojure.string/join " "))]
    (prn "Eff:" f)
    (.write w (str "<DoubleMatrix (" row-count "x" col-count "):\n"))
    (doseq [row (.rowsAsList m)]
      (.write w "[")
      (.write w ^String (apply format f (.data ^DoubleMatrix row)))
      (.write w "]\n")))
  (.write w ">\n"))

(def x [-11.284468356038781, 5.139513366337572, 6.594349540028067
        3.864727283236045, 3.8354833998981928, -9.228396774974318
        3.865087030973879, 3.83760577330622, -9.232020901548555
        3.864735831498734, 3.8355336355787113, -9.228482577519175
        5.431318381318273, -9.403413918581478, 0.4969522739554263
        5.429140003994505, -9.40055292429378, 0.49849862735857026
        5.43022782373914, -9.401289787274182, 0.49597933329477895])

(def ^DoubleMatrix y
  (double-matrix [[-11.284468356038781, 5.139513366337572, 6.594349540028067]
                  [3.864727283236045, 3.8354833998981928, -9.228396774974318]
                  [3.865087030973879, 5.83760577330622, -9.232020901548555]
                  [3.864735831498734, 3.8355336355787113, -554.228482577519175]
                  [5.431318381318273, -9.403413918581478, 0.4969522739554263]
                  [5.429140003994505, -9.40055292429378, 0.49849862735857026]
                  [5.43022782373914, -9.401289787274182, 0.49597933329477895]]))