;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns fides.jwt
  (:require [jsonista.core :as j]
            [tempus.core :as t]
            [fides.util.bytes :as bytes]
            [fides.jwe :as jwe])
  (:import [java.util Date]))

(declare prepare-claims validate-claims)

(defn encrypt
  ([claims key]
   (encrypt claims key nil))
  ([claims key opts]
   {:pre [(map? claims)]}
   (-> claims
       (prepare-claims opts)
       (j/write-value-as-string j/keyword-keys-object-mapper)
       (jwe/encrypt key opts))))

(defn decrypt
  ([message key]
   (decrypt message key nil))
  ([message key {:keys [skip-validation] :or {skip-validation false} :as opts}]
   (try
     (let [claims (-> message
                      (jwe/decrypt key opts)
                      (bytes/to-str)
                      (j/read-value j/keyword-keys-object-mapper))]
       (if skip-validation
         claims
         (validate-claims claims opts)))
     (catch com.fasterxml.jackson.core.JsonParseException e
       (throw (ex-info "Message seems corrupt or manipulated."
                       {:type :validation :cause :integrity}))))))

;;; Private

(defn- validate-claims
  "Checks the issuer in the `:iss` claim against one of the allowed
  issuers in the passed `:iss`. Passed `:iss` may be a string or a
  vector of strings.  If no `:iss` is passed, this check is not
  performed.
  ----
  Checks one or more audiences in the `:aud` claim against the single
  valid audience in the passed `:aud`.  If no `:aud` is passed, this
  check is not performed.
  ----
  Checks the subject in the `:sub` claim.  If no `:sub` is passed,
  this check is not performed.
  ----
  Checks the `:exp` claim is not less than the passed `:now`, with a
  leeway of the passed `:leeway`.  If no `:exp` claim exists, this
  check is not performed.
  ----
  Checks the `:nbf` claim is less than the passed `:now`, with a
  leeway of the passed `:leeway`.  If no `:nbf` claim exists, this
  check is not performed.
  ----
  Checks the passed `:now` is greater than the `:iat` claim plus the
  passed `:max-age`. If no `:iat` claim exists, this check is not
  performed.
  ----
  A check that fails raises an exception with `:type` of `:validation`
  and `:cause` indicating which check failed.
  `:now` is an integer POSIX time and defaults to the current time.
  `:leeway` is an integer number of seconds and defaults to zero."
  [claims {:keys [max-age iss aud sub now leeway]
           :or {now (t/into :long (t/now)) leeway 0}}]
  ;; Check the `:iss` claim.
  (when (and iss (let [iss-claim (:iss claims)]
                   (if (coll? iss)
                     (not-any? #{iss-claim} iss)
                     (not= iss-claim iss))))
    (throw (ex-info (str "Issuer does not match " iss)
                    {:type :validation :cause :iss})))

  ;; Check the `:aud` claim.
  (when (and aud (let [aud-claim (:aud claims)]
                   (if (coll? aud-claim)
                     (not-any? #{aud} aud-claim)
                     (not= aud aud-claim))))
    (throw (ex-info (str "Audience does not match " aud)
                    {:type :validation :cause :aud})))

  ;; Check the `:exp` claim.
  (when (and (:exp claims) (<= (:exp claims) (- now leeway)))
    (throw (ex-info (format "Token is expired (%s)" (:exp claims))
                    {:type :validation :cause :exp})))

  ;; Check the `:nbf` claim.
  (when (and (:nbf claims) (> (:nbf claims) (+ now leeway)))
    (throw (ex-info (format "Token is not yet valid (%s)" (:nbf claims))
                    {:type :validation :cause :nbf})))

  ;; Check the `:max-age` option.
  (when (and (:iat claims) (number? max-age) (> (- now (:iat claims)) max-age))
    (throw (ex-info (format "Token is older than max-age (%s)" max-age)
                    {:type :validation :cause :max-age})))

  ;; Check the `:sub` claim.
  (when (and sub (let [sub-claim (:sub claims)]
                   (if (coll? sub-claim)
                     (not-any? #{sub} sub-claim)
                     (not= sub sub-claim))))
    (throw (ex-info (str "The subject does not match " sub)
                    {:type :validation :cause :sub})))
  claims)

(defprotocol ITimestamp
  "Convert time objects to unix timestamp."
  (to-timestamp [obj] "Covert to timestamp"))

(extend-protocol ITimestamp
  java.util.Date
  (to-timestamp [obj]
    (quot (.getTime ^Date obj) 1000)))

(defn- normalize-date-claims
  "Normalize date related claims and return transformed object."
  [data]
  (into {} (map (fn [[key val]]
                  (if (satisfies? ITimestamp val)
                    [key (to-timestamp val)]
                    [key val]))
                data)))

(defn- normalize-nil-claims
  "Given a raw headers, try normalize it removing any key with null values."
  [data]
  (into {} (remove (comp nil? second) data)))

(defn- prepare-claims [claims opts]
  (-> claims
      (normalize-date-claims)
      (merge (-> opts
                 (select-keys [:exp :nbf :iat :iss :aud])
                 (normalize-nil-claims)
                 (normalize-date-claims)))))
