(ns com.timezynk.useful.channel
  (:require
   [clojure.core.reducers :as r]
   [clojure.tools.logging :as log :refer [debug info warn]]
   [com.timezynk.useful.map :as umap]
   [com.timezynk.useful.mongo :as um]
   [com.timezynk.useful.prometheus.core :as prometheus]
   [somnium.congomongo :as mongo])
  (:import [java.util.concurrent LinkedBlockingQueue
            PriorityBlockingQueue
            BlockingQueue
            TimeUnit]
           [java.util UUID]))

(defonce ^{:dynamic true} *debug* true)

(def ^{:dynamic true} *reply-channel* nil)

(def ^:const NUM_WORKERS 2)

(defonce subscribers (ref {}))

(defonce current-task-id (atom 0))

(defonce current-message-id (atom 0))

(def queue-size (prometheus/gauge :channel_queue_size "Number of actions waiting in the channel queue" :queue_id))
(def processed-messages (prometheus/counter :channel_processed_total "Number of actions processed by the channel queue" :queue_id))

(defrecord ChannelMessage [prio task]
  Comparable
  (compareTo [this o]
    (compare (:prio this) (:prio o))))

(defprotocol MessageTask
  (process [task channel] "Process the message task"))

(defrecord RequestResponseTask [subscriber task-id topic cname message reply-channel]
  MessageTask
  (process [task channel]
    (.put reply-channel [:started task-id])
    (debug topic cname "running request-response task" task-id)
    (try
      ; Bind dynamic reply channel so that recursive tasks are collected in the outermost
      ; wait-for
      (binding [*reply-channel* reply-channel]
        (let [result ((:f subscriber) topic cname message)]
          (.put reply-channel [:finished task-id result])))
      (catch Exception e
        (warn e topic cname "request-response failed to run")
        (.put reply-channel [:exception task-id e])))))

(defrecord BroadcastTask [subscriber topic cname message]
  MessageTask
  (process [task channel]
    (debug topic cname "running broadcast task")
    (try
      ((:f subscriber) topic cname message)
      (catch Exception e
        (warn e topic cname "broadcast failed to run")))))

(defprotocol Subscriber
  (publish [this topic cname message reply-channel] "Publish message to subscriber"))

(defrecord RequestResponseSubscriber [collection-name f]
  Subscriber
  (publish [this topic cname message reply-channel]
    (when (or (nil? collection-name)
              (= collection-name cname))
      (let [task-id (swap! current-task-id inc)]
        (debug topic cname "add request-response task" task-id)
        (.put reply-channel [:queued task-id])
        (ChannelMessage. 5 (RequestResponseTask. this task-id topic cname message reply-channel))))))

(defrecord BroadcastSubscriber [collection-name f]
  Subscriber
  (publish [this topic cname message reply-channel]
    (when (or (nil? collection-name)
              (= collection-name cname))
      (debug topic cname "add broadcast task")
      (ChannelMessage. 10 (BroadcastTask. this topic cname message)))))

(defn ^BlockingQueue publish! [^BlockingQueue channel topic cname messages]
  (when (seq messages)
    (let [reply-channel (LinkedBlockingQueue.)]
      (doseq [message messages
              s (get @subscribers topic)]
        (try
          (when-let [msg (publish s topic cname message (or *reply-channel* reply-channel))]
            (.put channel msg))
          (catch Exception e
            (warn e topic cname "failed to publish" message))))
      (.put reply-channel [:queued-message-tasks])
      reply-channel)))

(defn wait-for [timeout-ms reply-channel]
  (when reply-channel
    (loop [[event id payload] (.poll reply-channel timeout-ms TimeUnit/MILLISECONDS)
           tasks #{}]
      (debug "event" event id "waiting for" tasks)
      (case event
        :queued-message-tasks (if (empty? tasks)
                                (do
                                  (debug "completed. No tasks to wait for")
                                  true)
                                (recur (.poll reply-channel timeout-ms TimeUnit/MILLISECONDS)
                                       tasks))

        :queued               (recur (.poll reply-channel timeout-ms TimeUnit/MILLISECONDS)
                                     (conj tasks id))

        :started              (recur (.poll reply-channel timeout-ms TimeUnit/MILLISECONDS)
                                     tasks)

        :finished             (let [new-tasks (disj tasks id)]
                                (if (empty? new-tasks)
                                  (do
                                    (debug "completed. All tasks finished")
                                    true)
                                  (recur (.poll reply-channel timeout-ms TimeUnit/MILLISECONDS)
                                         new-tasks)))

        :exception            (throw payload)

        (if (seq tasks)
          (do
            (info "timeout. Still waiting for" tasks)
            false)
          (recur (.poll reply-channel timeout-ms TimeUnit/MILLISECONDS)
                 tasks))))))

(defn- subscribe [topic subscriber]
  (dosync
   (alter subscribers update-in [topic] conj subscriber)))

(defn subscribe-broadcast [topic collection-name f]
  (when (and topic f)
    (debug topic collection-name "new broadcast subscriber")
    (if (sequential? topic)
      (doseq [t topic]
        (subscribe t (BroadcastSubscriber. collection-name f)))
      (subscribe topic (BroadcastSubscriber. collection-name f)))))

(defn subscribe-request-response [topic collection-name f]
  (when (and topic f)
    (debug topic collection-name "new request-response subscriber")
    (if (sequential? topic)
      (doseq [t topic]
        (subscribe t (RequestResponseSubscriber. collection-name f)))
      (subscribe topic (RequestResponseSubscriber. collection-name f)))))

(defn route-message [message channel message-counter]
  (when-let [t (:task message)]
    (process t channel))
  (.inc message-counter))

(defn broker-loop [^BlockingQueue channel queue-id]
  (fn []
    (info "starting message broker")
    (let [size-gauge (prometheus/gauge-with-labels queue-size queue-id)
          message-counter (prometheus/counter-with-labels processed-messages queue-id)]
      (while true
        (try
          (route-message (.take channel) channel message-counter)
          (.set size-gauge (.size channel))
          (catch Exception e
            (warn e "Exception in channel broker")
            (Thread/sleep 100)))))))

(defn ^BlockingQueue create-broker! [^BlockingQueue channel queue-id]
  (dotimes [i NUM_WORKERS]
    (doto (Thread. (broker-loop channel queue-id) (str "mchan-" i))
      (.setDaemon true)
      (.start)))
  channel)

(defn ^BlockingQueue start-channel! []
  (create-broker! (PriorityBlockingQueue.) (str (UUID/randomUUID))))

