(ns apache-commons-matrix.core
  (:require [clojure.core.matrix.protocols :as mp]
            [clojure.core.matrix.implementations :as imp])
  (:import [org.apache.commons.math3.linear
            Array2DRowRealMatrix RealMatrix
            ArrayRealVector RealVector
            RealMatrixChangingVisitor
            MatrixUtils
            SparseRealVector
            SparseRealMatrix
            OpenMapRealVector
            OpenMapRealMatrix
            CholeskyDecomposition
            EigenDecomposition
            LUDecomposition
            QRDecomposition
            SingularValueDecomposition]
           [org.apache.commons.math3.stat.regression
            OLSMultipleLinearRegression]))

(extend-protocol mp/PImplementation
  RealMatrix
    (implementation-key [m] :apache-commons)
    (new-vector [m length] (ArrayRealVector. length))
    (new-matrix [m rows columns] (Array2DRowRealMatrix. rows columns))
    (new-matrix-nd [m dims]
      (case (count dims)
            0 0.0
            1 (ArrayRealVector. (first dims))
            2 (Array2DRowRealMatrix. (first dims) (second dims))
            (throw (ex-info "Apache Commons Math matrices only supports up to 2 dimensions"
                            {:requested-shape dims}))))
    (construct-matrix [m data]
      (case (mp/dimensionality data)
            0 data
            1 (ArrayRealVector. ^doubles (mp/to-double-array data))
            2 (Array2DRowRealMatrix. (into-array (map mp/to-double-array (mp/get-major-slice-seq data))))))
    (supports-dimensionality? [m dims] (<= 1 dims 2)))

(extend-protocol mp/PImplementation
  RealVector
    (implementation-key [m] :apache-commons)
    (new-vector [m length] (ArrayRealVector. length))
    (new-matrix [m rows columns] (Array2DRowRealMatrix. rows columns))
    (new-matrix-nd [m dims]
      (case (count dims)
            0 0.0
            1 (ArrayRealVector. (first dims))
            2 (Array2DRowRealMatrix. (first dims) (second dims))
            (throw (ex-info "Apache Commons Math matrices only supports up to 2 dimensions"
                            {:requested-shape dims}))))
    (construct-matrix [m data]
      (case (mp/dimensionality data)
            0 data
            1 (ArrayRealVector. ^doubles (mp/to-double-array data))
            2 (Array2DRowRealMatrix. (into-array (map mp/to-double-array (mp/get-major-slice-seq data))))))
    (supports-dimensionality? [m dims] (<= 1 dims 2)))

(extend-protocol mp/PDimensionInfo
  RealMatrix
  (dimensionality [m] 2)
  (get-shape [m] (list (.getRowDimension m) (.getColumnDimension m)))
  (is-scalar? [m] false)
  (is-vector? [m] false)
  (dimension-count [m dimension-number]
    (case dimension-number
          0 (.getRowDimension m)
          1 (.getColumnDimension m)
          (throw (ex-info "RealMatrix only has 2 dimensions"
                          {:requested-dimension dimension-number}))))

  RealVector
  (dimensionality [v] 1)
  (get-shape [v] [(.getDimension v)])
  (is-scalar? [v] false)
  (is-vector? [v] true)
  (dimension-count [v dimension-number]
    (if (zero? dimension-number)
      (.getDimension v)
      (throw (ex-info "RealVector only has 1 dimension"
                      {:requested-dimension dimension-number})))))

(extend-protocol mp/PIndexedAccess
  RealMatrix
  (get-1d [m row] (.getRowVector m row))
  (get-2d [m row column] (.getEntry m row column))
  (get-nd [m indexes]
    (case (count indexes)
          1 (mp/get-1d m (first indexes))
          2 (mp/get-2d m (first indexes) (second indexes))
          (throw (ex-info "RealMatrix only has 2 dimensions"
                    {:requested-index indexes
                     :index-count (count indexes)}))))

  RealVector
  (get-1d [v index] (.getEntry v index))
  (get-2d [v row column]
    (throw (ex-info "RealVector only has 1 dimension"
                    {:index-count 2})))
  (get-nd [v indexes]
    (if (= (count indexes) 1)
      (mp/get-1d v (first indexes))
      (throw (ex-info "RealVector only has 1 dimension"
                      {:requested-index indexes
                       :index-count (count indexes)})))))

(extend-protocol mp/PIndexedSetting
  RealMatrix
  (set-1d [m row e] (mp/set-1d! (.copy m) row e))
  (set-2d [m row column e] (mp/set-2d! (.copy m) row column e))
  (set-nd [m indexes e] (mp/set-nd! (.copy m) indexes e))
  (is-mutable? [m] true)

  RealVector
  (set-1d [v index e] (mp/set-1d! (.copy v) index e))
  (set-2d [v row column e] (mp/set-2d! (.copy v) row column e))
  (set-nd [v indexes e] (mp/set-nd! (.copy v) indexes e))
  (is-mutable? [m] true))

(extend-protocol mp/PIndexedSettingMutable
  RealMatrix
  (set-1d! [m row e]
    (if (mp/is-vector? e)
      (doto m (.setRow row e))
      (throw (ex-info "Unable to set row" {}))))
  (set-2d! [m row column e] (doto m (.setEntry row column e)))
  (set-nd! [m indexes e])

  RealVector
  (set-1d! [v index e] (doto v (.setEntry index e)))
  (set-2d! [v row column e] (mp/set-nd! v [row column] e))
  (set-nd! [v indexes e]
    (if (= (count indexes) 1)
      (mp/set-1d! v (first indexes) e)
      (throw (ex-info "RealVector only has 1 dimension"
                      {:requested-index indexes
                       :index-count (count indexes)})))))

(extend-protocol mp/PMatrixCloning
  RealMatrix
  (clone [m] (.copy m))

  RealVector
  (clone [v] (.copy v)))

(extend-protocol mp/PTypeInfo
  RealMatrix
  (element-type [m] Double/TYPE)

  RealVector
  (element-type [v] Double/TYPE))

(extend-protocol mp/PMutableMatrixConstruction
  RealMatrix
  (mutable-matrix [m] (.copy m))

  RealVector
  (mutable-matrix [v] (.copy v)))

(extend-protocol mp/PMatrixScaling
  RealMatrix
  (scale [m a] (.scalarMultiply m a))
  (pre-scale [m a] (.scalarMultiply m a))

  RealVector
  (scale [v a] (.mapMultiply v a))
  (pre-scale [v a] (.mapMultiply v a)))

(extend-protocol mp/PMatrixMutableScaling
  RealMatrix
  (scale! [m a] (doto m
                  (.walkInOptimizedOrder (reify RealMatrixChangingVisitor
                                           (end [_] 0.0)
                                           (start [_ _ _ _ _ _ _])
                                           (visit [_ row column value] (* a value))))))
  (pre-scale! [m a] (mp/scale! m a)))

(extend-protocol mp/PNegation
  RealMatrix
  (negate [m] (mp/scale m -1))

  RealVector
  (negate [v] (mp/scale v -1)))

(extend-protocol mp/PTranspose
  RealMatrix
  (transpose [m] (.transpose m)))

(extend-protocol mp/PVectorOps
  RealVector
  (vector-dot [a b] (.dotProduct a (mp/coerce-param a b)))
  (length [a] (.getNorm a))
  (length-squared [a] (let [l (.getNorm a)] (* l l)))
  (normalise [a] (.unitVector a)))

(extend-protocol mp/PMutableVectorOps
  RealVector
  (normalise! [a] (doto a (.unitize))))

(extend-protocol mp/PVectorDistance
  RealVector
  (distance [a b] (.getDistance a (mp/coerce-param a b))))

(extend-protocol mp/PMatrixMultiply
  RealVector
  (matrix-multiply [a b]
    (let [b-dims (mp/dimensionality b)]
      (cond
       (== b-dims 0) (mp/scale a b)
       (== b-dims 1) (mp/vector-dot a b)
       (== b-dims 2) (let [[a-rows a-cols] (mp/get-shape a)
                          a-mat (mp/reshape a [1 a-rows])
                          prod (mp/matrix-multiply a-mat b)]
                      (.getRowVector prod 0)))))
  (element-multiply [a b]
    (if (number? b)
      (mp/scale a b)
      (let [[a b] (mp/broadcast-compatible a b)]
        (mp/element-map a clojure.core/* b))))

  RealMatrix
  (matrix-multiply [a b]
    (let [b-dims (mp/dimensionality b)]
      (cond
       (== b-dims 0) (mp/scale a b)
       (== b-dims 1) (let [[b-len] (mp/get-shape b)
                          b-mat (mp/reshape a [b-len 1])
                          prod (mp/matrix-multiply a b-mat)]
                       (.getColumnVector prod 0))
       (== b-dims 2) (.multiply a b))))
  (element-multiply [a b]
    (if (number? b)
      (mp/scale a b)
      (let [[a b] (mp/broadcast-compatible a b)]
        (mp/element-map a clojure.core/* b)))))

(extend-protocol mp/PMatrixOps
  RealMatrix
  (trace [m]
    (.getTrace m))
  (determinant [m]
    (.getDeterminant (LUDecomposition. m)))
  (inverse [m]
    (MatrixUtils/inverse m)))

(extend-protocol mp/PSparseArray
  RealMatrix
  (is-sparse? [m]
    (instance? SparseRealMatrix m)))

(defn- make-sparse-array [shape]
  (let [[rows cols] shape]
    (case (count shape)
      1 (OpenMapRealVector. rows)
      2 (OpenMapRealMatrix. rows cols)
      :else (throw (ex-info "only 1- or 2-dim arrays supported" {})))))

(extend-protocol mp/PNewSparseArray
  RealMatrix
  (new-sparse-array [m shape]
    (make-sparse-array shape))
  RealVector
  (new-sparse-array [m shape]
    (make-sparse-array shape)))

(defn- ->array2d [m]
  (into-array (map mp/to-double-array (mp/get-major-slice-seq m))))

(extend-protocol mp/PSolveLinear
  RealMatrix
  (solve [a b]
    (-> a LUDecomposition. .getSolver .solve b)))

(extend-protocol mp/PLeastSquares
  RealMatrix
  (least-squares [a b]
    (let [regr (OLSMultipleLinearRegression.)]
      (.newSampleData regr (mp/to-double-array b) (->array2d a))
      (.calculateHat regr))))

(extend-protocol mp/PSVDDecomposition
  RealMatrix
  (svd [m options]
    (let [solver (SingularValueDecomposition. m)]
      {:rank (.getRank solver)
       :S (.getSingularValues solver)
       ;; left singular vectors
       :U (.getU solver)
       ;; (transposed) right singular vectors
       :V* (.getVT solver)})))

(extend-protocol mp/PLUDecomposition
  RealMatrix
  (lu [m options]
    (let [solver (LUDecomposition. m)]
      {:l (.getL solver)
       :u (.getU solver)})))

(extend-protocol mp/PCholeskyDecomposition
  RealMatrix
  (cholesky [m options]
    (.getL (CholeskyDecomposition. m))))

(extend-protocol mp/PQRDecomposition
  RealMatrix
  (qr [m options]
    (let [solver (QRDecomposition. m)]
      {:q (.getQ solver)
       :r (.getR solver)})))

(extend-protocol mp/PEigenDecomposition
  RealMatrix
  (eigen [m options]
    (let [solver (EigenDecomposition. m)]
      ;; real part of eigen-values
      {:w (.getRealEigenvalues solver)
       ;; eigen-vectors
       :v (.getV solver)})))

(imp/register-implementation (Array2DRowRealMatrix. 1 1))
