(ns thi.ng.common.data.unionfind)

(defprotocol PUnionFind
  (add-single [_ p])
  (canonical [_ p])
  (disjoint-components [_])
  (component-for [_ p])
  (union [_ p q])
  (unified? [_ p q]))

(defprotocol PUnionFind
  (register [_ p])
  (canonical [_ p])
  (disjoint-components [_])
  (component [_ p])
  (union [_ [p q]] [_ p q])
  (unified? [_ p q]))

(deftype DisjointSet [index components]
  PUnionFind
  (canonical [_ p]
    (if (components p) p (index p)))
  (unified? [_ p q]
    (= (index p p) (index q q)))
  (component [_ p]
    (components (canonical _ p)))
  (disjoint-components [_]
    (vals components))
  (register
    [_ p]
    (if (canonical _ p) _
        (DisjointSet. (assoc index p p) (assoc components p [p]))))
  (union [_ p q]
    (let [canonp (index p p)
          canonq (index q q)]
      (if (= canonp canonq) _
          (let [compp (or (components canonp) [canonp])
                compq (or (components canonq) [canonq])
                [canonp canonq compp compq] (if (<= (count compp) (count compq))
                                              [canonp canonq compp compq]
                                              [canonq canonp compq compp])]
            (DisjointSet.
             (loop [idx (transient index), i compp]
               (if i
                 (recur (conj! idx [(first i) canonq]) (next i))
                 (persistent! idx)))
             (-> components
                 (dissoc canonp)
                 (assoc canonq (into compq compp))))))))
  Object
  (toString [_] (pr-str {:index index :components components})))

(defn disjoint-set
  ([] (DisjointSet. {} {}))
  ([xs] (reduce (partial apply union) (DisjointSet. {} {}) xs)))
