(ns de.cognesys.libs.bayes.naive
  (:import java.io.File))

(def common-words
  (sorted-set
   "aber" "alle" "allein" "allen" "alles" "als" "also" "am" "an" "anderen" "andern" "anders" "auch" "auf" "aus" "bald" "bei" "beiden" "bin" "bis" "da" "damit" "dann" "darauf" "das" "dass" "dem" "den" "denen" "denn" "der" "deren" 
"derselben" "des" "dessen" "deutschen" "dich" "die" "dies" "diese" "diesem" "diesen" "dieser" "dieses" "dir" "doch" "dort" "du" "durch" "eben" "ein" "eine" "einem" "einen" "einer" "eines" "einmal" "einzelnen" "er" "erst" "ersten" 
"es" "etwas" "euch" "frau" "fuer" "ganz" "ganze" "ganzen" "gar" "gegen" "gemacht" "gewesen" "große" "großen" "gut" "habe" "haben" "haette" "hand" "hat" "hatte" "hatten" "herr" "herren" "herrn" "heute" "hier" "ich" "ihm" "ihn" 
"ihnen" "ihr" "ihre" "ihrem" "ihren" "ihrer" "im" "immer" "in" "ist" "ja" "jahre" "jetzt" "kann" "kein" "keine" "koennen" "kommen" "konnte" "laesst" "lange" "lassen" "leben" "liebe" "machen" "macht" "man" "mann" "mehr" "mein" 
"meine" "meiner" "menschen" "mich" "mir" "mit" "muessen" "muss" "nach" "nicht" "nichts" "noch" "nun" "nur" "ob" "oder" "ohne" "paragraph" "recht" "sagen" "schon" "sehen" "sehr" "sei" "sein" "seine" "seinem" "seinen" "seiner" 
"seite" "selbst" "sich" "sie" "sind" "so" "solche" "soll" "sollte" "sondern" "ueber" "um" "und" "uns" "unter" "viel" "vielleicht" "vom" "von" "vor" "waehrend" "waere" "war" "waren" "was" "weil" "weise" "weit" "weiter" "weiß" 
"welche" "welcher" "welches" "welt" "wenn" "wer" "werde" "werden" "wie" "wieder" "will" "wir" "wird" "wo" "wohl" "wollen" "worden" "wuerde" "wurde" "wurden" "zeit" "zu" "zum" "zur" "zwar" "zwei" "zwischen"))


(defn update-map [map f]
  (persistent! (reduce-kv (fn [m k v]
                            (assoc! m k (f v)))
                          (transient {})
                          map)))

(defn add-frequencies [f1 f2]
  (persistent! (reduce-kv (fn [m k v]
                            (assoc! m k (+ v (m k 0))))
                          (transient (or f1 {}))
                          f2)))

(defrecord Bag [bag categories cat-counts pcs pwcs])

(defn bag
  [categories texts normalize-fn split-fn count-fn merge-fn freqs]
  (loop [[t & ts :as txts] texts
         bags []
         cat-counts {}
         pcs {}
         pwcs {}
         ]
    (if (seq txts)
      (let [cats (filter categories (:tags t))
            splitted (->> (:text t) normalize-fn split-fn)
            ;;word-count (count splitted)
            word-count (case freqs
                         :feature (count splitted)
                         :text 1)
            b (count-fn splitted)]
        (recur ts
               (conj bags b)
               (add-frequencies cat-counts (reduce #(assoc %1 %2 word-count) {} cats))
               (add-frequencies pcs (frequencies cats))
               (reduce (fn [m c]
                         (assoc m c (add-frequencies (m c) b)))
                       pwcs
                       cats)))
      (Bag. (apply merge-with merge-fn bags)
            (set (keys pwcs))
            cat-counts
            (reduce-kv #(assoc %1 %2 (/ %3 (count texts))) {} pcs)
            (reduce-kv (fn [m k v]
                         (assoc m k (update-map v #(/ % (cat-counts k)))))
                       {} pwcs)))))

(defrecord NaiveBayes [categories pwcs pcs category-counts split-fn])

(defn train-naive-bayes
  [categories texts
   & {:keys [normalize-fn split-fn count-fn merge-fn freqs]
      :or {normalize-fn   normalize-text
           split-fn       split-into-words
           count-fn       frequencies
           merge-fn       +
           freqs          :feature
           }}]
  (let [split (comp split-fn normalize-fn)
        b (bag (set categories) texts normalize-fn split-fn count-fn merge-fn freqs)
        cats (:categories b)]    
    (NaiveBayes. cats
                 (:pwcs b)
                 (:pcs b)
                 (:cat-counts b)
                 split)))

(defn category-freq
  "Returns `P(C)` for the given `nb`."
  [nb c]
  (get (:pcs nb) c))

(defn select-word-freq
  "Returns `P(W|C)` from the given `nb`."
  [nb c w]
  (-> (:pwcs nb)
      (get c)
      (get w)))

(defn calc-prob
  "Takes a NaiveBayes `nb`, a seq of `words` and a `category`.
  Words ==> W1 - Wn
  category ==> C
  Calculates P(C) * P(W1|C) * … * P(Wn|C)"
  [nb words category]
  (let [wc   (count words)
        sample-count (get (:category-counts nb) category)
        pc   (category-freq nb category) ; P(C)
        pwcs (map #(if-let [x (select-word-freq nb category %)]
                     #_x
                     (/ (inc (numerator x))
                        (+ wc (denominator x)))
                     #_1/4000000
                     (/ 1 (+ wc sample-count)))
                  words) ; P(W|C)’s
        ]
    #_(prn "calc-prob" words category wc pc pwcs)
    (apply * pc pwcs)))

(defn classify
  [nb text]
  (let [words ((:split-fn nb) text)
        categories (:categories nb)  ; cats into which this NB can classify
        category-products (reduce (fn [m c] ; map category C to P(C)*P(W|C)
                                    (assoc m c (calc-prob nb words c)))
                                  {} categories)
        ;; Numerator and denominator are equal, cause all probabilities
        ;; must add up to 1.0. So the sum is = P(W1, …, Wn):
        product-sum (apply + (vals category-products)) ;; P(W1, …, Wn)
        ]
    ;; Return a map {:cat1 P(C1|W), …, :catn P(Cn|W)}
    (reduce-kv (fn [m c prob]
                 (assoc m c (double (/ prob product-sum))))
               (sorted-map)
               category-products)))


;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;; Tests ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;


#_(def netti (train-naive-bayes #{:x :y}
            [(Text. :id1 "Ich kündige meinen Vertrag." #{:x :train})
             (Text. :id2 "Hiermit bestelle ich meinen Vertrag ab!\nIch brauch den nicht mehr!!\nGruß, André" #{:x :train})
             (Text. :id3 "Wo ist meine Rechnung?" #{:y :train})]))


(def format-texts
  (map #(let [id (.getName ^File %)]
          (Text. id (slurp % :encoding "UTF-8") #{:format}))
       (filter #(.endsWith (.getName ^File %) ".txt")
               (file-seq (File. "c:/temp/cewe_bayes/cewe_format/aufbereitet/")))))

(def bank-texts
  (map #(let [id (.getName ^File %)]
          (Text. id (slurp % :encoding "UTF-8") #{:non-format}))
       (filter #(.endsWith (.getName ^File %) ".txt")
               (file-seq (File. "c:/temp/cewe_bayes/cewe_bank/aufbereitet/")))))

(def kulanz-texts
  (map #(let [id (.getName ^File %)]
          (Text. id (slurp % :encoding "UTF-8") #{:non-format}))
       (filter #(.endsWith (.getName ^File %) ".txt")
               (file-seq (File. "c:/temp/cewe_bayes/cewe_kulanz/aufbereitet/")))))

(declare netti)
(declare net-kula)

#_
(def netti (train-naive-bayes #{:format :non-format}
            (concat (take 250 format-texts)
                    (take 350 bank-texts)
                    (take 500 kulanz-texts))
            :freqs :feature
            :split-fn #(->> % split-into-words (remove common-words) set)))

#_
(def net-kula (train-naive-bayes #{:format :non-format}
            (concat (map #(assoc % :categories #{:non-kulanz}) (take 250 format-texts))
                    (map #(assoc % :categories #{:non-kulanz}) (take 350 bank-texts))
                    (map #(assoc % :categories #{:kulanz}) (take 500 kulanz-texts)))
            :freqs :feature
            :split-fn #(->> % split-into-words (remove common-words) set)))

(defn test-format-texts [nb texts]
  (apply concat 
         (pmap (fn [p]
                 (map #(let [{:keys [format non-format]} (classify nb (:text %))]
                         (> format non-format))
                      p))
               (partition-all 1 texts))))

(defn test-format-texts2 [nb texts]
  (pmap #(let [{:keys [format non-format]} (classify nb (:text %))]
          (> format non-format))
       texts))

(defn testrun-format []
  (let [res (test-format-texts2 netti (drop 250 format-texts))
        resc (count res)
        correct (count (filter true? res))]
    (println "Texte:" resc)
    (println (format "Richtig: %d (%s)"
                     correct
                     (double (/ (* 100 correct) resc))))))

(defn testrun-nonformat []
  (let [res (test-format-texts2 netti (concat (drop 500 bank-texts) (drop 800 kulanz-texts)))
        resc (count res)
        correct (count (filter false? res))]
    (println "Texte:" resc)
    (println (format "Richtig: %d (%s)"
                     correct
                     (double (/ (* 100 correct) resc))))))

(defn test-kulanz-texts2 [nb texts]
  (pmap #(let [{:keys [kulanz non-kulanz]} (classify nb (:text %))]
          (> kulanz non-kulanz))
        texts))

(defn testrun-kulanz []
  (let [res (test-kulanz-texts2 net-kula #_(drop 700 kulanz-texts) (concat (drop 400 bank-texts) (drop 250 format-texts)))
        resc (count res)
        correct (count (filter false? res))]
    (println "Texte:" resc)
    (println (format "Richtig: %d (%s)"
                     correct
                     (double (/ (* 100 correct) resc))))))


;;;
;;; Travel Tests
;;;

(defonce strandnah-files
  (map #(let [id (.getName ^File %)]
          (Text. id (slurp % :encoding "UTF-8") #{:strandnah}))
       (filter #(.endsWith (.getName ^File %) ".txt")
               (file-seq (File. "e:/travel/beach/strandnaehe/")))))

(defonce strandfern-files
  (map #(let [id (.getName ^File %)]
          (Text. id (slurp % :encoding "UTF-8") #{:strandfern}))
       (filter #(.endsWith (.getName ^File %) ".txt")
               (file-seq (File. "e:/travel/beach/strandferne/")))))

(defn shuffle-strand []
  (def strandnah-texts (shuffle strandnah-files))
  (def strandfern-texts (shuffle strandfern-files))
  )

;;(count strandnah-texts)   ; 2024
;;(count strandfern-texts)  ; 9412

(defn strand-training []
  (shuffle-strand)
  (let [delete-words (conj common-words "b" "br" "00" "strong" "und" "z" )
        n-strandnah-training  1000
        n-strandnah-testing   1000
        n-strandfern-training 1000
        n-strandfern-testing  2000
        [nah-training-texts nah-test-texts] (split-at n-strandnah-training strandnah-texts)
        [fern-training-texts fern-test-texts] (split-at n-strandfern-training strandfern-texts)
        ]
    (def strandnah-training-texts nah-training-texts)
    (def strandnah-testing-texts  (take n-strandnah-testing nah-test-texts))
    (def strandfern-training-texts fern-training-texts)
    (def strandfern-testing-texts  (take n-strandfern-testing fern-test-texts))
    (def netti (train-naive-bayes #{:strandnah :strandfern}
                (concat strandnah-training-texts
                        strandfern-training-texts
                        )
                :freqs :text ;;:feature                
                :split-fn #(->> % split-into-words (remove delete-words) set)
                )))
  )

(defn strand-testing [& {:keys [nb] :or {nb netti}}]
  (time
   (let [test-strandnah (fn [t]
                          (let [{:keys [strandnah strandfern]} (classify nb (:text t))]
                            (> strandnah strandfern)))
         test-strandfern (fn [t]
                           (let [{:keys [strandnah strandfern]} (classify nb (:text t))]
                             (< strandnah strandfern)))
         ]
     (println "Testing strandnah")
     (time
      (let [nah-count (count strandnah-testing-texts)
            nah-res   (pmap test-strandnah strandnah-testing-texts)
            nah-correct (count (filter true? nah-res))]
        (println "Strandnah texts:" nah-count)
        (println (format "Correct: %d (%s%%)"
                         nah-correct
                         (double (/ (* 100 nah-correct) nah-count))))))
     (newline)
     (println "Testing strandfern")
     (time
      (let [fern-count (count strandfern-testing-texts)
            fern-res   (pmap test-strandfern strandfern-testing-texts)
            fern-correct (count (filter true? fern-res))]
        (println "Strandfern texts:" fern-count)
        (println (format "Correct: %d (%s%%)"
                         fern-correct
                         (double (/ (* 100 fern-correct) fern-count))))))
     )))

(defn testus [nb cats texts]
  (map (fn [cat]
         (println "Testing" cat)
         (let [])
         (map (fn [t]
                (classify nb t)
                (filter #(contains? (:tags t) cat))))
         (let [t :x]))
       cats))

