(ns vectio.netty.websocket
  (:require [vectio.netty :as n]
            [clj-commons.byte-streams :as bs]
            [clojure.string :as st])
  (:import [io.netty.channel
            Channel
            ChannelHandlerContext
            ChannelDuplexHandler]
           [io.netty.util ReferenceCounted]
           [io.netty.buffer ByteBuf]
           [java.util.concurrent ExecutorService]
           [io.netty.handler.codec.http.websocketx
            WebSocketFrame
            PingWebSocketFrame
            PongWebSocketFrame
            TextWebSocketFrame
            ContinuationWebSocketFrame
            BinaryWebSocketFrame
            CloseWebSocketFrame]))

(defn websocket-message-coerce-fn
  [msg]
  (condp instance? msg
    WebSocketFrame
    msg

    CharSequence
    (TextWebSocketFrame. (bs/to-string msg))

    (BinaryWebSocketFrame. (n/to-byte-buf msg))))

(defn data->websocket-frames
  [^ChannelHandlerContext ctx ^long max-frame-size data]
  (condp = data
    :ping [(PingWebSocketFrame.)]
    :pong [(PongWebSocketFrame.)]
    (let [^ByteBuf byte-buf (n/to-byte-buf ctx data)
          frames (->> max-frame-size
                      (n/slice-byte-buf byte-buf)
                      (mapv n/acquire)
                      vec)
          string-frame? (string? data)]
      (n/release byte-buf)
      (->> frames
           (map-indexed
            (fn [index ^ByteBuf frame]
              (let [final-fragment? (= index (dec (count frames)))]
                (if (zero? index)
                  (if string-frame?
                    (TextWebSocketFrame. final-fragment? 0 frame)
                    (BinaryWebSocketFrame. final-fragment? 0 (n/to-byte-buf frame)))
                  (ContinuationWebSocketFrame. final-fragment? 0 frame)))))
           vec))))

(defn websocket-frame-bytes
  [^WebSocketFrame frame]
  (let [content (.content frame)
        bytes (byte-array (.readableBytes content))]
    (.readBytes content bytes)
    bytes))

(defn websocket-frame-coerced-content
  [^WebSocketFrame frame]
  (if (instance? TextWebSocketFrame frame)
    (.text ^TextWebSocketFrame frame)
    (websocket-frame-bytes frame)))

(defn send-websocket-frames
  [ctx flush-size frames]
  (let [write (if (instance? ChannelHandlerContext ctx)
                #(.write ^ChannelHandlerContext ctx %)
                #(.write ^Channel ctx %))
        flush (if (instance? ChannelHandlerContext ctx)
                #(.flush ^ChannelHandlerContext ctx)
                #(.flush ^Channel ctx))
        byte-counter (volatile! 0)]
    (dotimes [i (count frames)]
      (let [^WebSocketFrame frame (nth frames i)]
        (vswap! byte-counter + (.capacity (.content frame)))
        (write frame)
        (when (>= @byte-counter flush-size)
          (vreset! byte-counter 0)
          (flush))))
    (flush)))

(defn websocket-frame-collector
  [^WebSocketFrame initial-frame]
  (let [text? (instance? TextWebSocketFrame initial-frame)
        buffer (atom [(if text?
                        (.text ^TextWebSocketFrame initial-frame)
                        (websocket-frame-bytes initial-frame))])
        message-type (if text? :text :binary)]
    (fn [^WebSocketFrame frame]
      (if (instance? ContinuationWebSocketFrame frame)
        (let [^ContinuationWebSocketFrame frame frame
              content (condp = message-type
                        :text (.text frame)
                        :binary (websocket-frame-bytes frame))]
          (if (.isFinalFragment frame)
            (condp = message-type
              :text (str (st/join "" @buffer) ^String content)
              :binary (->> ^bytes content
                           (conj @buffer)
                           (mapcat seq)
                           byte-array))
            (do (swap! buffer conj content)
                nil)))
        (throw (ex-info "Can only collect WebSocketContinuationFrame frames."
                        {:frame frame}))))))

(defn inbound-collector
  [{:keys [on-text-message
           on-binary-message
           on-pong-message
           on-ping-message
           on-close]}]
  (let [frame-collector (atom nil)
        handle-final-content (fn [content]
                               (reset! frame-collector nil)
                               (if (string? content)
                                 #(on-text-message content)
                                 #(on-binary-message content)))
        handle-initial-frame (fn [^WebSocketFrame frame]
                               (if (.isFinalFragment frame)
                                 (->> frame
                                      websocket-frame-coerced-content
                                      handle-final-content)
                                 (do (->> frame
                                          (websocket-frame-collector)
                                          (reset! frame-collector))
                                     nil)))]
    (fn [^ChannelHandlerContext ctx frame]
      (when (instance? WebSocketFrame frame)
        (cond
          (instance? TextWebSocketFrame frame)
          (handle-initial-frame frame)

          (instance? BinaryWebSocketFrame frame)
          (handle-initial-frame frame)

          (instance? ContinuationWebSocketFrame frame)
          (let [^ContinuationWebSocketFrame frame frame]
            (when-let [content (@frame-collector frame)]
              (handle-final-content content)))

          (instance? PingWebSocketFrame frame)
          (do (.writeAndFlush ctx (PongWebSocketFrame.))
              (when on-ping-message
                (apply on-ping-message [frame]))
              nil)

          (instance? PongWebSocketFrame frame)
          (do (when on-pong-message
                (apply on-pong-message [frame]))
              nil)

          (instance? CloseWebSocketFrame frame)
          on-close

          :else
          (locking Object
            (println (ex-info "Unhandled frame" {:frame frame}))))))))

(defn inbound-handler
  ^ChannelDuplexHandler
  [^ExecutorService exec-service
   {:keys [on-text-message
           on-binary-message
           on-ping-message
           on-pong-message
           on-close ready]}]
  (let [frame-handler (inbound-collector
                       {:on-text-message on-text-message
                        :on-binary-message on-binary-message
                        :on-ping-message on-ping-message
                        :on-pong-message on-pong-message
                        :on-close on-close})]
    (proxy [ChannelDuplexHandler] []
      (channelRead [^ChannelHandlerContext ctx msg]
        (let [handler (frame-handler ctx msg)]
          (try (when handler
                 (.submit exec-service
                          (reify Runnable
                            (run [_]
                              (try
                                (when ready @ready)
                                (handler)
                                (catch Exception e
                                  (locking Object
                                    (println "Exception occurred handling inbound websocket frame"
                                             {:frame msg}
                                             e))
                                  (.fireExceptionCaught ctx e)))))))
               (finally
                 (when (instance? ReferenceCounted msg)
                   (.release ^ReferenceCounted msg)))))))))
