(ns com.github.ivarref.paginate-vector.impl.bst-multi
  (:require [com.github.ivarref.paginate-vector.impl.bst :as bst]
            [clojure.edn :as edn])
  (:import (java.util UUID)))


(defn balanced-tree-multi
  [v opts]
  (let [srt-by (get opts :sort-by)
        v (vec (sort-by (apply juxt srt-by) v))]
    {:root         (bst/balanced-tree v)
     :input-vector v
     :id           (str (UUID/randomUUID))
     :opts         {:sort-by srt-by}}))


(defn maybe-decode-cursor [cursor]
  (when cursor
    (when (string? cursor)
      (edn/read-string cursor))))


(defn paginate-first [{:keys [root id opts]}
                      {:keys [max-items
                              f
                              keep?
                              context]
                       :or   {f       identity
                              context nil
                              keep?   (constantly true)}}
                      cursor]
  (let [sort-attrs (get opts :sort-by)
        max-items (inc max-items)]
    (let [org-cursor cursor
          decoded-cursor (maybe-decode-cursor cursor)
          cursor (atom (merge {:context context} decoded-cursor))
          all-nodes (transient [])
          first-node (transient [])
          save-node! (fn [node]
                       (when (= 0 (count first-node))
                         (conj! first-node node))
                       (swap! cursor assoc :cursor (select-keys node sort-attrs))
                       (conj! all-nodes {:node   (if (not= (count all-nodes) (dec max-items))
                                                   (f node)
                                                   node)
                                         :cursor (pr-str @cursor)}))
          [old-cursor _] (swap-vals! cursor assoc :id id)
          need-new-count? (not= (:id old-cursor) id)
          get-total-count (fn [old-count]
                            (or (and (false? need-new-count?)
                                     (when-let [oc old-count]
                                       oc))
                                (let [cnt (atom 0)]
                                  (bst/visit-all-depth-first
                                    root
                                    keep?
                                    (fn [_] (swap! cnt inc)))
                                  @cnt)))
          totalCount (:totalCount (swap! cursor (fn [cursor] (update cursor :totalCount get-total-count))))]
      (if-let [from-value (get @cursor :cursor)]
        (doseq [node (bst/after-value root keep? from-value sort-attrs max-items)]
          (save-node! node))
        (doseq [node (bst/from-beginning root keep? max-items)]
          (save-node! node)))
      (let [all-nodes (persistent! all-nodes)
            edges (vec (take (dec max-items) all-nodes))
            hasPrevPage (or (let [first-node (persistent! first-node)]
                              (when (not-empty first-node)
                                (not= (first first-node)
                                      (bst/get-leftmost-value root keep?))))
                            (and (empty? all-nodes)
                                 (some? org-cursor)))]
        {:edges    edges
         :pageInfo {:hasPrevPage (true? hasPrevPage)
                    :hasNextPage (= (count all-nodes) max-items)
                    :startCursor (or (get (first edges) :cursor)
                                     org-cursor)
                    :endCursor   (or (get (last edges) :cursor)
                                     org-cursor)
                    :totalCount  totalCount}}))))