;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns vectio.tcp
  (:require
   [fluxus.fiber :refer [fiber]]
   [fluxus.flow :as f]
   [fluxus.promise :as p]
   [spectator.log :as log]
   [utilis.map :refer [assoc-if]])
  (:import
   [java.net InetSocketAddress ServerSocket]
   [java.nio ByteBuffer ByteOrder]
   [java.nio.channels ClosedChannelException SocketChannel]))

(declare slice)

(defn client
  [{:keys [host port _tls byte-order max-frame tx-buffer rx-buffer label]
    :or {byte-order :big-endian
         max-frame (* 64 1024)
         label (str "tcp://" host ":" port)}}]
  (let [byte-order (case byte-order
                     :big-endian ByteOrder/BIG_ENDIAN
                     :little-endian ByteOrder/LITTLE_ENDIAN)
        client (p/promise)
        [client-flow internal] (f/entangled (assoc-if {} :buffer tx-buffer)
                                            (assoc-if {} :buffer rx-buffer))]
    (try
      (let [address (InetSocketAddress. ^String host ^int port)
            socket (SocketChannel/open address)]
        (f/on-close internal (fn [_] (.close socket)))
        (f/consume (fn [^ByteBuffer frame]
                     (let [bytes-written (.write socket frame)]
                       (log/trace [:vectio.tcp/socket label :> (.position frame) (.limit frame) bytes-written]))) internal)
        (fiber {:label (str "tcp://" host ":" port)}
          (loop []
            (let [read-buffer (doto (ByteBuffer/allocate max-frame)
                                (.order byte-order))
                  bytes-read (try
                               (.read socket read-buffer)
                               (catch ClosedChannelException _ nil)
                               (catch Throwable e
                                 (log/error [:vectio.tcp/socket label] e)
                                 nil))]
              (when bytes-read
                (log/trace [:vectio.tcp/socket label :< bytes-read (.position read-buffer)])
                @(f/put! internal (slice read-buffer 0 (.position read-buffer)))
                (recur)))))
        (p/resolve! client client-flow))
      (catch Throwable e
        (p/reject! client e)))
    client))

(defn framer-xform
  [{:keys [offset length label]}]
  (log/debug [:vectio.tcp/framer :offset offset :length length])
  (let [header-length (+ offset length)
        cache (volatile! nil)]
    (fn [rf]
      (fn
        ([] (rf))
        ([result] (rf result))
        ([result ^ByteBuffer input]
         (let [forward (partial rf result)]
           (try
             (loop [input input]
               (let [buffer (if @cache
                              (let [^ByteBuffer buffer @cache]
                                (vreset! cache nil)
                                (doto (ByteBuffer/allocate (+ (.limit buffer) (.limit input)))
                                  (.order (.order input))
                                  (.put buffer)
                                  (.put input)
                                  (.position 0)))
                              input)
                     ^long length (+ (case (int length)
                                       1 (.getChar  buffer offset)
                                       2 (.getShort buffer offset)
                                       4 (.getInt   buffer offset)
                                       8 (.getLong  buffer offset))
                                     header-length)
                     limit (.limit buffer)]
                 (log/trace [:vectio.tcp/framer label :length length :limit limit :cache @cache])
                 (cond
                   (= length limit)
                   (forward buffer)
                   (< length limit)
                   (let [remaining (- limit length)]
                     (forward (slice buffer 0 length))
                     (recur (slice buffer length remaining)))
                   (> length limit)
                   (do
                     (vreset! cache buffer)
                     result))))
             (catch Throwable e
               (log/error [:vectio.tcp/framer label :input input :cache @cache] e)
               (throw e)))))))))

(defn free-port
  []
  (with-open [s (ServerSocket. 0)]
    (and s (.getLocalPort s))))


;;; Private

(defn- slice
  [^ByteBuffer buffer ^long index ^long length]
  (doto (.slice buffer index length)
    (.order (.order buffer))))
