;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns fides.jwe
  (:require [clojure.string :as st]
            [jsonista.core :as j]
            [fides.util.bytes :as bytes]
            [fides.util.zip :as zip]
            [fides.cipher :as cipher]
            [fides.nonce :as nonce]))

(declare encode-header decode-header
         encode-payload decode-payload
         split-crypt split-message
         generate-iv)

(defn encrypt
  "Encrypt then sign arbitrary length string using json web encryption.
  Does not support key wrapping"
  [payload key & [{:keys [zip header] :or {zip false} :as opts}]]
  (let [iv (generate-iv)
        header (encode-header (-> header
                                  (cond-> zip (assoc :zip "DEF"))
                                  (assoc :alg :dir
                                         :enc :a128cbc-hs256)))
        [ciphertext auth-tag] (-> payload
                                  (encode-payload zip)
                                  (cipher/encrypt {:aad header :key key :iv iv})
                                  (split-crypt))]
    (st/join "." [(bytes/to-str header)
                  (bytes/to-b64-str (byte-array 0)) ;; placeholder for key wrapping
                  (bytes/to-b64-str iv)
                  (bytes/to-b64-str ciphertext)
                  (bytes/to-b64-str auth-tag)])))

(defn decrypt
  "Decrypt the jwe compliant message and return its payload."
  [input key opts]
  (let [[header crypt-key iv ciphertext auth-tag] (split-message input)]
    (when (or (nil? crypt-key) (nil? iv) (nil? ciphertext) (nil? auth-tag))
      (throw (ex-info "Message seems corrupt or manipulated."
                      {:type :validation :cause :signature})))
    (try
      (let [header (bytes/from-str header)
            {:keys [zip]} (decode-header header)]
        (-> (bytes/concat (bytes/from-b64-str ciphertext)
                          (bytes/from-b64-str auth-tag))
            (cipher/decrypt {:aad header :key key :iv (bytes/from-b64-str iv)})
            (decode-payload zip)))
      (catch java.lang.AssertionError e
        (throw (ex-info "Message seems corrupt or manipulated."
                        {:type :validation :cause :token})))
      (catch com.fasterxml.jackson.core.JsonParseException e
        (throw (ex-info "Message seems corrupt or manipulated."
                        {:type :validation :cause :signature}))))))

;;; Private

(defn- encode-header
  [header]
  (-> header
      (update :alg #(if (= % :dir) "dir" (st/upper-case (name %))))
      (update :enc #(st/upper-case (name %)))
      (j/write-value-as-string j/keyword-keys-object-mapper)
      (bytes/from-str)
      (bytes/to-url-b64s)))

(defn- decode-header
  [^String data]
  (try
    (let [{:keys [alg enc] :as header} (-> data
                                           (bytes/from-url-b64s)
                                           (bytes/to-str)
                                           (j/read-value j/keyword-keys-object-mapper))]
      (when-not (map? header)
        (throw (ex-info "Message seems corrupt or manipulated."
                        {:type :validation :cause :header})))
      (cond-> header
        alg (assoc :alg (keyword (st/lower-case alg)))
        enc (assoc :enc (keyword (st/lower-case enc)))))
    (catch com.fasterxml.jackson.core.JsonParseException e
      (throw (ex-info "Message seems corrupt or manipulated."
                      {:type :validation :cause :header})))))

(defn- encode-payload
  [input zip]
  (cond-> (bytes/from-str input)
    zip (zip/compress)))

(defn- decode-payload
  [payload zip]
  (cond-> payload
    zip (zip/expand)))

(defn- split-crypt
  [crypt]
  (let [crypt-length (count crypt)
        ciphertext (bytes/slice crypt 0 (- crypt-length cipher/tag-length))
        tag (bytes/slice crypt (- crypt-length cipher/tag-length) crypt-length)]
    [ciphertext tag]))

(defn- split-message
  [message]
  (st/split message #"\." 5))

(defn- generate-iv
  "initialization vector"
  []
  (nonce/random-bytes 16))
