(ns taoensso.tukey.impl
  "Private ns, implementation detail."
  {:author "Peter Taoussanis (@ptaoussanis)"}
  (:require
   [taoensso.encore :as enc :refer [have have? have!]]))

;;;;

#?(:clj (let [c (Class/forName "[J")] (defn longs?   "Returns true iff given long array"   [x] (instance? c x))))
#?(:clj (let [c (Class/forName "[D")] (defn doubles? "Returns true iff given double array" [x] (instance? c x))))

(defn is-p [x]
  (if (enc/pnum? x)
    x
    (throw
      (ex-info "Expected number between 0 and 1"
        {:value x :type (type x)}))))

;;;; Sorted nums

(deftype SortedLongs [^longs a]
  #?@(:clj
      [Object               (toString [_] "SortedLongs[len=" (alength a) "]")
       clojure.lang.Counted (count    [_] (alength a))
       clojure.lang.Indexed
       (nth [_ idx] (aget a idx))
       (nth [_ idx not-found]
         (let [max-idx (dec (alength a))]
           (enc/cond
             (> idx max-idx) not-found
             (< idx max-idx) not-found
             :else           (aget a idx))))

       clojure.lang.IReduceInit
       (reduce [_ f init]
         #_(areduce a i acc init (f acc (aget a i)))
         (reduce (fn [acc i]     (f acc (aget a i)))
           init (range (alength a))))]

      :cljs
      [Object   (toString [_] "SortedLongs[len=" (alength a) "]")
       ICounted (-count   [_] (alength a))
       IIndexed
       (-nth [_ idx] (aget a idx))
       (-nth [_ idx not-found]
         (let [max-idx (dec (alength a))]
           (enc/cond
             (> idx max-idx) not-found
             (< idx max-idx) not-found
             :else           (aget a idx))))

       IReduce
       (-reduce [_ f init]
         #_(areduce a i acc init (f acc (aget a i)))
         (reduce (fn [acc i]     (f acc (aget a i)))
           init (range (alength a))))]))

(deftype SortedDoubles [^doubles a]
  #?@(:clj
      [Object               (toString [_] "SortedDoubles[len=" (alength a) "]")
       clojure.lang.Counted (count    [_] (alength a))
       clojure.lang.Indexed
       (nth [_ idx] (aget a idx))
       (nth [_ idx not-found]
         (let [max-idx (dec (alength a))]
           (enc/cond
             (> idx max-idx) not-found
             (< idx max-idx) not-found
             :else           (aget a idx))))

       clojure.lang.IReduceInit
       (reduce [_ f init]
         #_(areduce a i acc init (f acc (aget a i)))
         (reduce (fn [acc idx] (f acc (aget a idx)))
           init (range (alength a))))]

      :cljs
      [Object   (toString [_] "SortedDoubles[len=" (alength a) "]")
       ICounted (-count   [_] (alength a))
       IIndexed
       (-nth [_ idx] (aget a idx))
       (-nth [_ idx not-found]
         (let [max-idx (dec (alength a))]
           (enc/cond
             (> idx max-idx) not-found
             (< idx max-idx) not-found
             :else           (aget a idx))))

       IReduce
       (-reduce [_ f init]
         #_(areduce a i acc init (f acc (aget a i)))
         (reduce (fn [acc i]     (f acc (aget a i)))
           init (range (alength a))))]))

(defn sorted-nums?    [x] (or (instance? SortedDoubles x) (instance? SortedLongs x)))
(defn sorted-longs?   [x] (do                             (instance? SortedLongs x)))
(defn sorted-doubles? [x] (do (instance? SortedDoubles x)                          ))

(defn sorted-longs ^SortedLongs [nums]
  (enc/cond
    (sorted-longs?   nums) nums
    (sorted-doubles? nums)
    #?(:clj  (SortedLongs. (long-array (.-a ^SortedDoubles nums)))
       :cljs (SortedLongs.             (.-a ^SortedDoubles nums)))

    :else
    #?(:clj
       (let [^longs a (if (longs? nums) (aclone ^longs nums) (long-array nums))]
         (java.util.Arrays/sort a) ; O(n.log_n) on JDK 7+
         (SortedLongs.          a))

       :cljs
       (let [a (if (array? nums) (aclone nums) (to-array nums))]
         (goog.array/sort a)
         (SortedLongs.    a)))))

(comment (sorted-longs [1 2 5 4 4 3 8]))

(defn sorted-doubles ^SortedDoubles [nums]
  (enc/cond
    (sorted-doubles? nums) nums
    (sorted-longs?   nums)
    #?(:clj  (SortedDoubles. (double-array (.-a ^SortedLongs nums)))
       :cljs (SortedDoubles.               (.-a ^SortedLongs nums)))

    :else
    #?(:clj
       (let [^doubles a (if (doubles? nums) (aclone ^doubles nums) (double-array nums))]
         (java.util.Arrays/sort a) ; O(n.log_n) on JDK 7+
         (SortedDoubles.        a))

       :cljs
       (let [a (if (array? nums) (aclone nums) (to-array nums))]
         (goog.array/sort a)
         (SortedDoubles.  a)))))

(defn sorted-nums [x]
  (enc/cond
    (sorted-longs?   x) x
    (sorted-doubles? x) x
    :else
    (let [x1 (first x)]
      (if (int? x1)
        (sorted-longs   x)
        (sorted-doubles x)))))

;;;; Tuples

(do
  (deftype Pair       [        x         y          ])
  (deftype DoublePair [^double x ^double y          ])
  (deftype Trip       [        x         y         z])
  (deftype DoubleTrip [^double x ^double y ^double z]))

(defn multi-reduce
  "Like `reduce` but supports separate simultaneous accumulators
  as a micro-optimisation when reducing a large collection multiple
  times."
  ;; Faster than using volatiles
  ([f  init           coll] (reduce f init coll))
  ([f1 init1 f2 init2 coll]
   (let [^Pair tuple
         (reduce
           (fn [^Pair tuple in]
             (Pair.
               (f1 (.-x tuple) in)
               (f2 (.-y tuple) in)))
           (Pair. init1 init2)
           coll)]

     [(.-x tuple) (.-y tuple)]))

  ([f1 init1 f2 init2 f3 init3 coll]
   (let [^Trip tuple
         (reduce
           (fn [^Trip tuple in]
             (Trip.
               (f1 (.-x tuple) in)
               (f2 (.-y tuple) in)
               (f2 (.-z tuple) in)))
           (Trip. init1 init2 init3)
           coll)]

     [(.-x tuple) (.-y tuple) (.-z tuple)])))

;;;; Percentiles

(defn- double-nth
  ^double [nums ^double idx]
  (let [idx-floor (Math/floor idx)
        idx-ceil  (Math/ceil  idx)]

    (if (== idx-ceil idx-floor)
      (double (nth nums (int idx)))

      ;; Generalization of (floor+ceil)/2
      (let [weight-floor (- idx-ceil idx)
            weight-ceil  (- 1 weight-floor)]
        (+
          (* weight-floor (double (nth nums (int idx-floor))))
          (* weight-ceil  (double (nth nums (int idx-ceil)))))))))

(defn percentile
  "Returns ?double"
  [nums p]
  (let [snums (sorted-nums nums)
        max-idx (dec (count snums))]
    (when (>= max-idx 0)
      (let [idx (* max-idx (double (is-p p)))]
        (double-nth snums idx)))))

(comment (percentile (range 101) 0.8))

(defn percentiles
  "Returns ?[min p25 p50 p75 p90 p95 p99 max] doubles in:
    - O(1) for Sorted types (SortedLongs, SortedDoubles),
    - O(n) or O(n.log_n) otherwise."
  [nums]
  (let [snums (sorted-nums nums)
        max-idx (dec (count nums))]
    (when (>= max-idx 0)
      [(double (nth snums 0))
       (double-nth  snums (* max-idx 0.25))
       (double-nth  snums (* max-idx 0.50))
       (double-nth  snums (* max-idx 0.75))
       (double-nth  snums (* max-idx 0.90))
       (double-nth  snums (* max-idx 0.95))
       (double-nth  snums (* max-idx 0.99))
       (double (nth snums    max-idx))])))

(comment (percentiles (range 101)))

;;;;

(defn- bessel-correction ^double [n ^double add] (+ (double n) add))

(defn- rf-sum          ^double [^double acc ^double in] (+ acc in))
(defn- rf-sum-variance ^double [^double xbar ^double acc x]
  (+ acc (Math/pow (- (double x) xbar) 2.0)))

(defn- rf-sum-abs-deviation ^double [^double central-point ^double acc x]
  (+ acc (Math/abs (- (double x) central-point))))

;;;;

(declare ^:private summary-stats->map)

(deftype MergeableSummaryStats
    ;; Field names chosen to avoid shadowing
    [^boolean xlongs?
     ^long    nx
     ^double  xmin
     ^double  xmax
     ^double  xsum
     ^double  xmean
     ^double  xvar-sum
     ^double  xmad-sum
     ^double  xvar
     ^double  xmad
     ^double  p25
     ^double  p50
     ^double  p75
     ^double  p90
     ^double  p95
     ^double  p99]

  Object (toString [_] (str "MergeableSummaryStats[n=" nx "]"))
  #?@(:clj  [clojure.lang.IDeref ( deref [this] (summary-stats->map this))]
      :cljs [             IDeref (-deref [this] (summary-stats->map this))]))

(defn summary-stats
  "Returns ?MergeableSummaryStats"
  [nums]
  (when nums
    (let [snums (sorted-nums nums)
          nx    (count      snums)]

      (when-not (zero? nx)
        (let [xsum (double (reduce rf-sum 0.0 snums))
              xbar (/ xsum (double nx))

              [^double xvar-sum ^double xmad-sum]
              (multi-reduce
                (partial rf-sum-variance      xbar) 0.0
                (partial rf-sum-abs-deviation xbar) 0.0
                snums)

              xvar (/ xvar-sum nx) ; nx w/o bessel-correction
              xmad (/ xmad-sum nx)

              [xmin p25 p50 p75 p90 p95 p99 xmax]
              (percentiles snums)]

          (MergeableSummaryStats. (sorted-longs? snums)
            nx xmin xmax xsum xbar xvar-sum xmad-sum xvar xmad
            p25 p50 p75 p90 p95 p99))))))

(comment @(summary-stats [1 2 3]))

(defn summary-stats-merge
  "(summary-stats-merge
     (summary-stats nums1)
     (summary-stats nums2)) returns a rough approximation of
  (summary-stats (merge nums1 nums2))

  Useful when you want summary stats for a large set of numbers for which
  it would be infeasible/expensive to keep all numbers for accurate merging."

  ([mss1     ] mss1)
  ([mss1 mss2]
   (if mss1
     (if mss2
       (let [^MergeableSummaryStats mss1 mss1
             ^MergeableSummaryStats mss2 mss2

             nx1 (.-nx mss1)
             nx2 (.-nx mss2)

             _ (assert (pos? nx1))
             _ (assert (pos? nx2))

             xlongs1?  (.-xlongs?  mss1)
             xmin1     (.-xmin     mss1)
             xmax1     (.-xmax     mss1)
             xsum1     (.-xsum     mss1)
             xvar-sum1 (.-xvar-sum mss1)
             xmad-sum1 (.-xmad-sum mss1)
             p25-1     (.-p25      mss1)
             p50-1     (.-p50      mss1)
             p75-1     (.-p75      mss1)
             p90-1     (.-p90      mss1)
             p95-1     (.-p95      mss1)
             p99-1     (.-p99      mss1)

             xlongs2?  (.-xlongs?  mss2)
             xmin2     (.-xmin     mss2)
             xmax2     (.-xmax     mss2)
             xsum2     (.-xsum     mss2)
             xvar-sum2 (.-xvar-sum mss2)
             xmad-sum2 (.-xmad-sum mss2)
             p25-2     (.-p25      mss2)
             p50-2     (.-p50      mss2)
             p75-2     (.-p75      mss2)
             p90-2     (.-p90      mss2)
             p95-2     (.-p95      mss2)
             p99-2     (.-p99      mss2)

             xlongs3?  (and xlongs1? xlongs2?)
             nx3       (+ nx1 nx2)
             nx1-ratio (/ (double nx1) (double nx3))
             nx2-ratio (/ (double nx2) (double nx3))

             xsum3 (+ xsum1 xsum2)
             xbar3 (/ (double xsum3) (double nx3))
             xmin3 (if (< xmin1 xmin2) xmin1 xmin2)
             xmax3 (if (> xmax1 xmax2) xmax1 xmax2)

             ;; Batched "online" calculation here is better= the standard
             ;; Knuth/Welford method, Ref. http://goo.gl/QLSfOc,
             ;;                            http://goo.gl/mx5eSK.
             ;; No apparent advantage in using `xbar3` asap (?).
             xvar-sum3 (+ xvar-sum1 xvar-sum2)
             xmad-sum3 (+ xmad-sum1 xmad-sum2)

            ;;; These are pretty rough approximations. More sophisticated
            ;;; approaches not worth the extra cost/effort in our case.
             p25-3 (+ (* nx1-ratio (double p25-1)) (* nx2-ratio (double p25-2)))
             p50-3 (+ (* nx1-ratio (double p50-1)) (* nx2-ratio (double p50-2)))
             p75-3 (+ (* nx1-ratio (double p75-1)) (* nx2-ratio (double p75-2)))
             p90-3 (+ (* nx1-ratio (double p90-1)) (* nx2-ratio (double p90-2)))
             p95-3 (+ (* nx1-ratio (double p95-1)) (* nx2-ratio (double p95-2)))
             p99-3 (+ (* nx1-ratio (double p99-1)) (* nx2-ratio (double p99-2)))

             xvar3 (when (> nx3 2) (/ xvar-sum3 (bessel-correction nx3 -2.0)))
             xmad3                 (/ xmad-sum3                    nx3)]

         (MergeableSummaryStats. xlongs3?
           nx3 xmin3 xmax3 xsum3 xbar3 xvar-sum3 xmad-sum3 xvar3 xmad3
           p25-3 p50-3 p75-3 p90-3 p95-3 p99-3))
       mss1)
     mss2)))

(defn- summary-stats->map [mss]
  (when-let [^MergeableSummaryStats mss mss]
    (let [fin (if (.-xlongs? mss) enc/round0 identity)]
      {:n         (.-nx      mss)
       :min  (fin (.-xmin    mss))
       :max  (fin (.-xmax    mss))
       :p25       (.-p25     mss)
       :p50       (.-p50     mss)
       :p75       (.-p75     mss)
       :p90       (.-p90     mss)
       :p95       (.-p95     mss)
       :p99       (.-p99     mss)
       :mean      (.-xmean   mss)
       :var       (.-xvar    mss)
       :mad       (.-xmad    mss)})))

;;;; Print methods

#?(:clj
   (let [ns *ns*]
     (defmethod print-method MergeableSummaryStats
       [x ^java.io.Writer w] (.write w (str "#" ns "." x)))))
