(ns com.github.ivarref.paginate-vector.impl.bst)

; based on https://gist.github.com/dmh43/83a7b3e452e83e80eb30

(defrecord Node [value left right])

(defn make-node
  ([value left right]
   (Node. value left right))
  ([value]
   (Node. value nil nil)))

(defn value [^Node node]
  (when node
    (.-value node)))

(defn left [^Node node]
  (when node
    (.-left node)))

(defn right [^Node node]
  (when node
    (.-right node)))


(defn depth-first-vals
  ([root]
   (let [l (transient [])]
     (depth-first-vals root l)
     (persistent! l)))
  ([root v]
   (if (nil? root)
     nil
     (do
       (depth-first-vals (left root) v)
       (conj! v (value root))
       (depth-first-vals (right root) v)))))


(defn visit-all-depth-first
  [root keep? f]
  (if (nil? root)
    nil
    (do
      (visit-all-depth-first (left root) keep? f)
      (when (keep? (value root))
        (f (value root)))
      (visit-all-depth-first (right root) keep? f))))


(defn balanced-tree
  ([sorted-vec]
   (balanced-tree sorted-vec 0 (count sorted-vec)))
  ([sorted-vec start end]
   (if (>= start end)
     nil
     (let [mid (int (Math/floor (/ (+ start end) 2)))]
       (make-node (nth sorted-vec mid)
                  (balanced-tree sorted-vec start mid)
                  (balanced-tree sorted-vec (inc mid) end))))))


(defn tree-vals [node]
  (if (nil? node)
    []
    (reduce into [(value node)]
            [(tree-vals (left node))
             (tree-vals (right node))])))


(defn take-all-last
  [root keep? res max-items]
  (if (nil? root)
    nil
    (do
      (when (not= max-items (count res))
        (take-all-last (right root) keep? res max-items))
      (when (and (not= max-items (count res))
                 (keep? (value root)))
        (conj! res (value root)))
      (when (not= max-items (count res))
        (take-all-last (left root) keep? res max-items)))))


(defn cmp-attrs [sort-attrs a b]
  (let [jxt (apply juxt sort-attrs)]
    (compare (jxt a) (jxt b))))


(defn take-all-first
  [root keep? res max-items]
  (if (nil? root)
    (not= max-items (count res))
    (do
      (when (not= max-items (count res))
        (take-all-first (left root) keep? res max-items))
      (when (and (not= max-items (count res))
                 (keep? (value root)))
        (conj! res (value root)))
      (when (not= max-items (count res))
        (recur (right root) keep? res max-items)))))


(defn- after-value-inner
  [root keep? find-value res sort-fn max-items]
  (if (nil? root)
    nil
    (let [curr-val (value root)
          cmp-int (compare (sort-fn find-value) (sort-fn curr-val))]
      (cond (= 0 cmp-int)
            (take-all-first (right root) keep? res max-items)

            (neg-int? cmp-int)
            (do
              (after-value-inner (left root) keep? find-value res sort-fn max-items)
              (when (not= max-items (count res))
                (when (keep? curr-val)
                  (conj! res curr-val))
                (when (not= max-items (count res))
                  (take-all-first (right root) keep? res max-items))))

            :else
            (after-value-inner (right root) keep? find-value res sort-fn max-items)))))


(defn after-value
  [root keep? from-value sort-attrs max-items]
  (let [res (transient [])]
    (after-value-inner root keep? from-value res (apply juxt sort-attrs) max-items)
    (persistent! res)))


(defn take-min [a b sort-fn]
  (cond (or (nil? a) (nil? b))
        (or a b)

        (pos-int? (compare (sort-fn a) (sort-fn b)))
        b

        :else a))

(comment
  (= (take-min {:inst 0} {:inst 1} (juxt :inst))
     (take-min {:inst 0} nil (juxt :inst))
     (take-min {:inst 1} {:inst 0} (juxt :inst))))


(defn- next-value-inner
  [root nearest-value sort-fn]
  (if (nil? root)
    nil
    (let [curr-val (value root)
          cmp-int (compare (sort-fn curr-val) (sort-fn nearest-value))]
      (cond (= 0 cmp-int)
            (next-value-inner (right root) nearest-value sort-fn)

            (pos-int? cmp-int)
            ; current node is bigger than the value we want to find
            ; thus we can recur to the left and use current node as a candidate
            (take-min
              curr-val
              (next-value-inner (left root) nearest-value sort-fn)
              sort-fn)

            :else
            ; current node is smaller than the value we want to find
            ; thus we disregard the current node and recur to the right
            (next-value-inner (right root) nearest-value sort-fn)))))


(defn next-value
  [root nearest-value sort-attrs]
  (next-value-inner root nearest-value (apply juxt sort-attrs)))


(defn- before-value-inner
  [root keep? find-value res sort-fn max-items]
  (if (nil? root)
    nil
    (let [curr-val (value root)
          cmp-int (compare (sort-fn find-value) (sort-fn curr-val))]
      #_(println (sort-fn find-value) "vs" (sort-fn curr-val) "=>" cmp-int)
      (cond (= 0 cmp-int)
            (take-all-last (left root) keep? res max-items)

            (pos-int? cmp-int) ; find-value > curr-val
            (do
              (before-value-inner (right root) keep? find-value res sort-fn max-items)
              (when (not= max-items (count res))
                (conj! res curr-val)
                (when (not= max-items (count res))
                  (take-all-last (left root) keep? res max-items))))

            :else ; find-value < curr-val
            (before-value-inner (left root) keep? find-value res sort-fn max-items)))))


(defn before-value
  [root keep? from-value sort-attrs max-items]
  (let [l (transient [])]
    (before-value-inner root keep? from-value l (apply juxt sort-attrs) max-items)
    (vec (reverse (persistent! l)))))


(defn tree-contains?
  [root v sort-attrs]
  (if (nil? root)
    false
    (let [cmp-int (cmp-attrs sort-attrs v (value root))]
      (cond (= 0 cmp-int)
            true

            (pos-int? cmp-int)
            (tree-contains? (right root) v sort-attrs)

            :else
            (tree-contains? (left root) v sort-attrs)))))


(defn get-leftmost-value
  [root keep?]
  (if (nil? root)
    nil
    (or
      (get-leftmost-value (left root) keep?)
      (when (keep? (value root)) (value root))
      (get-leftmost-value (right root) keep?))))


(defn get-rightmost-value
  [root keep?]
  (if (nil? root)
    nil
    (or
      (get-rightmost-value (right root) keep?)
      (when (keep? (value root)) (value root))
      (get-rightmost-value (left root) keep?))))


(defn from-beginning
  ([root keep? max-items]
   (let [res (transient [])]
     (from-beginning root keep? res max-items)
     (persistent! res)))
  ([root keep? res max-items]
   (cond (nil? root)
         nil

         (nil? (left root))
         (do
           (when (and (not= max-items (count res))
                      (keep? (value root)))
             (conj! res (value root)))
           (when (not= max-items (count res))
             (take-all-first (right root) keep? res max-items))
           (not= max-items (count res)))

         :else
         (when (from-beginning (left root) keep? res max-items)
           (when (and (not= max-items (count res))
                      (keep? (value root)))
             (conj! res (value root)))
           (when (not= max-items (count res))
             (take-all-first (right root) keep? res max-items))
           (not= max-items (count res))))))


(defn from-end
  ([root keep? max-items]
   (let [res (transient [])]
     (from-end root keep? res max-items)
     (vec (reverse (persistent! res)))))
  ([root keep? res max-items]
   (cond (nil? root)
         nil

         (nil? (right root))
         (do
           (when (and (not= max-items (count res))
                      (keep? (value root)))
             (conj! res (value root)))
           (when (not= max-items (count res))
             (take-all-last (left root) keep? res max-items)))

         :else
         (do
           (from-end (right root) keep? res max-items)
           (when (not= max-items (count res))
             (when (keep? (value root))
               (conj! res (value root))))
           (when (not= max-items (count res))
             (take-all-last (left root) keep? res max-items))))))