(ns scicloj.ml.data
  (:require [tablecloth.api :as tc])
  )

(defn load-iris []
  (tc/dataset
   "https://raw.githubusercontent.com/scicloj/metamorph.ml/main/test/data/iris.csv" {:key-fn keyword})
  )

(comment
  (def iris (load-iris))
  (require '[scicloj.ml.core :as ml]
           '[scicloj.ml.dataset :as ds]
           '[scicloj.ml.metamorph :as mm]
           '[tech.v3.dataset.math :as std-math]
           '[tech.v3.datatype.functional :as dtf]
           '[applied-science.waqi :as waqi]
           )





  (defn std-scale [col-seq]
    (fn [{:metamorph/keys [data id mode] :as ctx}]
      (def ctx ctx)
      (def data data)
      (def id id)
      (def col-seq col-seq)
      (case mode
        :fit
        (let [ds (tc/select-columns data col-seq)
              fit-std-xform (std-math/fit-std-scale ds)]
          (assoc ctx id
                 {:cols col-seq
                  :fit-std-xform fit-std-xform}
                 :metamorph/data (merge data (std-math/transform-std-scale ds fit-std-xform))


                 ))
        :transform
        (assoc ctx :metamorph/data
               (merge data
                      (std-math/transform-std-scale (tc/select-columns
                                                     data
                                                     (-> (get ctx id) :cols))
                                                    (-> (get ctx id) :fit-std-xform))
                      )
               )
        )
      )
    )
  
  (def pipe-fn
    (ml/pipeline
     (mm/select-columns [:species :sepal_length :sepal_width])
     (std-scale [:sepal_length :sepal_width])

     (mm/set-inference-target :species)
     (mm/categorical->number [:species])
     (mm/model {:model-type :smile.classification/random-forest})

     )
    )

  (def fitted-ctx
    (pipe-fn

     {:metamorph/data iris
      :metamorph/mode :fit}

     ))

  (def min-sepal-length
    (-> iris :sepal_length dtf/reduce-min))

  (def min-sepal-width
    (-> iris :sepal_width dtf/reduce-min))

  (def max-sepal-length
    (-> iris :sepal_length dtf/reduce-max))
  (def max-sepal-width
    (-> iris :sepal_width dtf/reduce-max))


  (defn stepped-range [start end n-steps]
    (let [diff (- end start)]
      (range start end (/ diff n-steps)))

    )

  (def grid
    (for [x1 (stepped-range min-sepal-length max-sepal-length 100)
          x2 (stepped-range min-sepal-width max-sepal-width 100)
          ]
      {:sepal_length x1
       :sepal_width x2
       :species nil
       }
      ))

  (def grid-ds
    (ds/dataset  grid))




  (def prediction-grid
    (->
     (pipe-fn
      (merge
       fitted-ctx
       {:metamorph/data grid-ds
        :metamorph/mode :transform}
       )
      )
     :metamorph/data

     (ds/column-values->categorical :species)
     seq
     ))

  (def grid-ds-prediction
    (ds/add-column grid-ds :species prediction-grid))


  (def prediction-iris
    (->
     (pipe-fn
      (merge
       fitted-ctx
       {:metamorph/data iris
        :metamorph/mode :transform}
       )
      )
     :metamorph/data

     (ds/column-values->categorical :species)
     seq
     ))

  (def ds-prediction
    (ds/add-column iris :species prediction-iris))
  )


(waqi/plot!
 {:data {:values (ds/rows grid-ds-prediction :as-maps)}
  :width 300
  :height 300
  :mark "rect",
  :encoding {:x {:field :sepal_length
                 :type "nominal"
                 :axis {:format "2.2"
                        :labelOverlap true}
                 }
             :y {:field :sepal_width :type "nominal"
                 :axis {:format "2.2"
                        :labelOverlap true}}
             :color {:field :species}
             }})

(waqi/plot!
 {:data {:values (ds/rows ds-prediction :as-maps)}
  :width 300
  :height 300
  :mark "circle",
  :encoding {:x {:field :sepal_length
                 :type "nominal"
                 :axis {:format "2.2"
                        :labelOverlap true}
                 }
             :y {:field :sepal_width :type "nominal"
                 :axis {:format "2.2"
                        :labelOverlap true}}
             :color {:field :species}
             }})


(waqi/plot!
 {:layer
  [

   {:data {:values (ds/rows grid-ds-prediction :as-maps)}
    :width 500
    :height 500
    :mark {:type "square" :opacity 0.1 :strokeOpacity 0.1 :stroke nil},
    :encoding {:x {:field :sepal_length
                   :type "quantitative"
                   :scale {:domain [4 8]}
                   :axis {:format "2.2"
                          :labelOverlap true}
                   }
               :y {:field :sepal_width :type "quantitative"
                   :axis {:format "2.2"
                          :labelOverlap true}
                   :scale {:domain [2 4.5]}
                   }
               :color {:field :species}
               }}
   
   {:data {:values (ds/rows ds-prediction :as-maps)}

    :width 500
    :height 500
    :mark {:type "circle" :opacity 1 :strokeOpacity 1},
    :encoding {:x {:field :sepal_length
                   :type "quantitative"
                   :axis {:format "2.2"
                          :labelOverlap true}
                   :scale {:domain [4 8]}
                   }
               :y {:field :sepal_width :type "quantitative"
                   :axis {:format "2.2"
                          :labelOverlap true}
                   :scale {:domain [2 4.5]}
                   }

               :color {:field :species}
               }}


   ]

  }

 )
