(ns org.srasu.concurrent-queue
  "An implementation of a concurrent FIFO queue for Clojure that supports CAS
  semantics on the head element.

  The queue implements relevant Java interfaces, some Clojure interfaces (but
  not [[clojure.lang.IPersistentCollection]] as the queue is mutable), but is
  intended to be mostly used via interop with the Java interfaces.

  The key feature that this queue provides over those from
  `java.util.concurrent` is the ability to remove elements from the queue
  conditionally with [[pop-when!]] (or the lower-level [[cas-head!]]), rather
  than only allowing the removal of all elements which compare equal to a passed
  element."
  (:import
   (clojure.lang
    IMeta IReference Seqable SeqIterator Sequential)
   (java.io Writer)
   (java.util Collection NoSuchElementException Queue)
   (java.util.concurrent.atomic AtomicReference)))

(set! *warn-on-reflection* true)

(deftype QueueNode [v ^AtomicReference n])

(defn- node
  ^QueueNode [v]
  (QueueNode. v (AtomicReference.)))

(deftype ConcurrentQueue [^:volatile-mutable meta ^AtomicReference head ^AtomicReference tail]
  IMeta
  (meta [_] meta)

  IReference
  (alterMeta [_ f args]
    (set! meta (apply f meta args)))
  (resetMeta [_ v]
    (set! meta v))

  Seqable
  (seq [_]
    (let [iter (volatile! (.-n ^QueueNode (.get head)))]
      ((fn f []
         (when-some [^QueueNode node (.get ^AtomicReference @iter)]
           (lazy-seq
            (vreset! iter (.-n node))
            (cons (.-v node)
                  (f))))))))

  Sequential

  Queue
  (add [_ elt]
    (let [n (node elt)]
      (loop [^QueueNode old (.get tail)]
        (let [next (.get ^AtomicReference (.-n old))]
          (if (identical? (.get tail) old)
            (if (nil? next)
              (if-not (.compareAndSet ^AtomicReference (.-n old) nil n)
                (recur (.get tail))
                (.compareAndSet tail old n))
              (do (.compareAndSet tail old next)
                  (recur (.get tail))))
            (recur (.get tail)))))))
  (offer [this elt]
    (.add this elt))
  (remove [_]
    (loop [^QueueNode old-head (.get head)
           ^QueueNode old-tail (.get tail)]
      (let [^QueueNode next (.get ^AtomicReference (.-n old-head))]
        (if (identical? old-head (.get head))
          (if (identical? old-head old-tail)
            (when (some? next)
              (.compareAndSet tail old-tail (.get ^AtomicReference (.-n old-tail)))
              (recur (.get head) (.get tail)))
            (if (.compareAndSet head old-head next)
              (.-v next)
              (recur (.get head) (.get tail))))
          (recur (.get head) (.get tail))))))
  (poll [this]
    (.remove this))
  (peek [_]
    (some-> ^QueueNode (.get head)
            ^AtomicReference (.-n)
            ^QueueNode (.get)
            .-v))
  (element [this]
    (if-some [val (.peek this)]
      val
      (throw (NoSuchElementException.))))

  Iterable
  (iterator [this]
    (SeqIterator. (seq this)))

  Collection
  (addAll [this c]
    (doseq [elt (seq c)]
      (.add this elt))
    true)
  (isEmpty [_]
    (nil? (.get ^AtomicReference (.-n ^QueueNode (.get head)))))
  (size [this]
    (count (seq this)))
  (toArray [this]
    (object-array (seq this)))
  (^objects toArray [this ^objects kind]
    (let [^Class array-clazz (class kind)
          _ (assert (.isArray array-clazz) "array argument is an array")
          ^Class clazz (.getComponentType array-clazz)]
      (object-array (map (partial cast clazz) (seq this)))))
  (contains [this elt]
    (some? (some #(= elt %) (seq this))))
  (containsAll [this c]
    (every? #(.contains this %) (seq c))))

(defn queue
  "Constructs a new concurrent queue with the items from `coll`.

  If `meta` is provided the resulting type will have detadata that can be
  retrieved with [[meta]], and altered with [[alter-meta!]].

  The resulting queue implements [[Collection]], [[Iterable]], and [[Queue]]
  from Java; and [[IMeta]], [[IReference]], [[Seqable]], and [[Sequential]] from
  Clojure."
  (^ConcurrentQueue [] (queue [] nil))
  (^ConcurrentQueue [coll] (queue coll nil))
  (^ConcurrentQueue [coll meta]
   (let [free-node (node nil)]
     (doto (ConcurrentQueue. meta (AtomicReference. free-node) (AtomicReference. free-node))
       (.addAll ^Collection coll)))))

(defn cas-head!
  "Removes the first item from the `queue` if it's [[identical?]] to `old`.

  This allows conditionally removing the first element of the queue when used
  along with the [[java.util.Queue#peek]] method."
  [^ConcurrentQueue queue old]
  (let [^QueueNode next (.get ^AtomicReference (.-head queue))
        ^QueueNode nnext (.get ^AtomicReference (.-n next))]
    (if (identical? (.-v nnext) old)
      (.compareAndSet ^AtomicReference (.-n next) nnext (.get ^AtomicReference (.-n nnext)))
      false)))

(defmethod print-method ConcurrentQueue
  [^ConcurrentQueue v ^Writer w]
  (.write w "#org.srasu.concurrent-queue [ ")
  (doseq [item (seq v)]
    (print-method item w)
    (.write w " "))
  (.write w "]"))

(defmethod print-dup ConcurrentQueue
  [^ConcurrentQueue v ^Writer w]
  (.write w "#=(org.srasu.concurrent-queue/queue [ ")
  (doseq [item (seq v)]
    (print-dup item w)
    (.write w " "))
  (.write w "]")
  (when (meta v)
    (.write w " ")
    (print-dup (meta v) w))
  (.write w ")"))

(defn pop-when!
  "Pops the head of the `queue` if `pred` returns a truthy value when called on it.

  Returns true if the queue was changed, false otherwise."
  [^ConcurrentQueue queue pred]
  (let [head (.peek queue)]
    (when (pred head)
      (cas-head! queue head))))
