;;   Copyright (c) 7theta. All rights reserved.
;;   The use and distribution terms for this software are covered by the
;;   MIT License (https://opensource.org/licenses/MIT) which can also be
;;   found in the LICENSE file at the root of this distribution.
;;
;;   By using this software in any fashion, you are agreeing to be bound by
;;   the terms of this license.
;;   You must not remove this notice, or any others, from this software.

(ns fides.tls-credential
  (:require [tempus.core :as t]
            [tempus.interval :as ti]
            [utilis.map :refer [compact]]
            [utilis.fn :refer [apply-kw]])
  (:import [java.security KeyPairGenerator PrivateKey PublicKey SecureRandom]
           [java.security.cert X509Certificate]
           [java.security.spec ECGenParameterSpec]
           [javax.security.auth.x500 X500Principal]
           [org.bouncycastle.cert.jcajce JcaX509CertificateConverter JcaX509v3CertificateBuilder]
           [org.bouncycastle.jce.provider BouncyCastleProvider]
           [org.bouncycastle.openssl PEMParser]
           [org.bouncycastle.openssl.jcajce JcaPEMWriter JcaPKCS8Generator JcaPEMKeyConverter]
           [org.bouncycastle.operator ContentSigner]
           [org.bouncycastle.operator.jcajce JcaContentSignerBuilder]
           [org.bouncycastle.asn1 ASN1Encodable DERSequence]
           [org.bouncycastle.asn1.x509 BasicConstraints Extension KeyPurposeId KeyUsage GeneralName GeneralNames]
           [clojure.lang IObj IMeta ILookup]
           [java.io BufferedReader ByteArrayInputStream InputStreamReader StringWriter]
           [java.util Date]))

(declare key-pair certificate-builder sign-certificate pr-credential)

(def ^:private ^String default-key-algorithm "EC")
(def ^:private ^String default-group-name "secp384r1")
(def ^:private ^String default-signature-algorithm "SHA256withECdSA")

(deftype TLSCredential [^X509Certificate cert ^PrivateKey key meta-map]
  Object
  (toString [^TLSCredential this]
    (pr-credential this))
  (hashCode [_]
    (hash [:fides/tls-credential cert key]))
  (equals [this other]
    (and (= (.cert this)
            (.cert ^TLSCredential other))
         (= (.key this)
            (.key ^TLSCredential other))))

  ILookup
  (valAt [this k default]
    (case k
      :cert cert
      :key key
      :serial-number
      (str (.getSerialNumber ^X509Certificate cert))
      :subject ^String
      (let [x500-principal (.getSubjectX500Principal ^X509Certificate cert)]
        (.getName ^X500Principal x500-principal))
      :cn
      (second (re-find #"CN=(.*)" (:subject this)))
      :validity
      (let [->tempus (fn [^Date d]
                       (->> d .toInstant .toEpochMilli (t/from :long)))]
        (ti/interval
         (->tempus (.getNotBefore cert))
         (->tempus (.getNotAfter cert))))
      :expiry (:end (:validity this))
      default))
  (valAt [this k]
    (let [v (get this k ::not-found)]
      (if (= ::not-found v)
        (throw (ex-info "invalid key" {:key k}))
        v)))

  IMeta
  (meta [_]
    meta-map)

  IObj
  (withMeta [_ meta-map]
    (TLSCredential. cert key meta-map)))

(defn tls-credential?
  [x & {:keys [key?] :or {key? true}}]
  (and (instance? TLSCredential x)
       (if key? (.key ^TLSCredential x) true)))

(defn self-signed
  ^TLSCredential
  [& {:keys [subject cn validity expiry extensions]}]
  (when-not (or subject cn)
    (throw (ex-info "missing subject or cn" {:subject subject :cn cn})))
  (when-not (or validity expiry)
    (throw (ex-info "missing validity or expiry" {:validity validity :expiry expiry})))
  (let [^String subject (or subject (str "CN=" cn))
        {^PublicKey public-key :public
         ^PrivateKey private-key :private} (key-pair default-key-algorithm default-group-name)]
    (TLSCredential.
     (-> (certificate-builder (X500Principal. subject) public-key (X500Principal. subject) validity expiry extensions)
         (sign-certificate private-key))
     private-key
     {})))

(defn self-signed-ca
  ^TLSCredential
  [& {:as args}]
  (apply-kw self-signed (update args :extensions assoc-in [:bc :ca] true)))

(defn signed
  [& {:keys [subject cn validity expiry extensions ^TLSCredential signer]}]
  (when-not (or subject cn)
    (throw (ex-info "missing subject or cn" {:subject subject :cn cn})))
  (when-not (or validity expiry)
    (throw (ex-info "missing validity or expiry" {:validity validity :expiry expiry})))
  (when-not (and signer (instance? TLSCredential signer))
    (throw (ex-info "missing signer" {:signer signer})))
  (let [^String subject (or subject (str "CN=" cn))
        {^PublicKey public-key :public
         ^PrivateKey private-key :private} (key-pair default-key-algorithm default-group-name)
        signing-cert ^X509Certificate (.cert signer)
        signing-key (.key signer)]
    (TLSCredential.
     (-> (certificate-builder (.getSubjectX500Principal signing-cert) public-key (X500Principal. subject) validity expiry extensions)
         (sign-certificate signing-key))
     private-key
     {})))

(defn signed-service
  [& {:as args}]
  (apply-kw signed (update args :extensions (fn [args]
                                              (-> (assoc-in args [:bc :ca] false)
                                                  (assoc :eku [:server-auth]))))))

(defn signed-client
  [& {:as args}]
  (apply-kw signed (update args :extensions (fn [args]
                                              (-> (assoc-in args [:bc :ca] false)
                                                  (assoc :eku [:client-auth]))))))

(defn ->pem
  [^TLSCredential c]
  (let [encode (fn [o]
                 (with-open [sw (StringWriter.)
                             pw (JcaPEMWriter. sw)]
                   (.writeObject pw o)
                   (.flush pw)
                   (.toString sw)))]
    (cond-> {:cert (encode (.cert c))}
      (.key c)
      (assoc :key (encode (JcaPKCS8Generator. (.key c) nil))))))

(defn pem->
  ^TLSCredential
  [{:keys [cert key]}]
  (let [decode (fn [^String pem-string]
                 (some-> pem-string
                         .getBytes
                         ByteArrayInputStream.
                         InputStreamReader.
                         BufferedReader.
                         PEMParser.
                         .readObject))]
    (TLSCredential.
     (.getCertificate (JcaX509CertificateConverter.)
                      (decode cert))
     (some->> (decode key) (.getPrivateKey (JcaPEMKeyConverter.)))
     {})))

(defn x509->
  ^TLSCredential
  [^X509Certificate cert]
  (TLSCredential. cert nil {}))

(defmethod print-method TLSCredential
  [^TLSCredential c w]
  (.write ^java.io.Writer w ^String (pr-credential c)))


;;; Private

(defn- sign-certificate
  [^JcaX509v3CertificateBuilder unsigned-certificate ^PrivateKey private-key]
  (let [signer ^ContentSigner (.build ^JcaContentSignerBuilder
                                      (JcaContentSignerBuilder. default-signature-algorithm)
                                      private-key)]
    (.getCertificate (.setProvider (JcaX509CertificateConverter.) (BouncyCastleProvider.))
                     (.build ^JcaX509v3CertificateBuilder unsigned-certificate signer))))

(defn- add-extensions
  "Takes in a map of certificate extensions (similar to, but not exactly like) keytool.
  These are parsed into a CertificateExtensions object passed to the Java Certificate
  classes.
  Example: {:bc {:ca true} :san {:dns \"google.com\"} :ku [:crl-sign] :eku [:server-auth :client-auth]}

  Supported extensions are BC, EKU, KU, and SAN.  Unupported (currently) are IAN, SIA, and AIA.
  Explicit criticality is not supported, but the default behaviour is leveraged from the *Extension classes

  More will be added as use cases require and sanity permits."
  [^JcaX509v3CertificateBuilder b extension-map]
  (loop [builder b
         extensions extension-map]
    (if (empty? extensions)
      builder
      (let [[k v] (first extensions)]
        (recur (case k
                 :bc (.addExtension
                      builder
                      Extension/basicConstraints true
                      (BasicConstraints. ^boolean (:ca v)))
                 :eku (.addExtension
                       builder
                       Extension/extendedKeyUsage true
                       (DERSequence.
                        ^"[Lorg.bouncycastle.asn1.ASN1Encodable;"
                        (into-array
                         ASN1Encodable
                         (->> (into [] v)
                              (mapv (fn [u]
                                      (case u
                                        :any-extended-key-usage KeyPurposeId/anyExtendedKeyUsage
                                        :server-auth KeyPurposeId/id_kp_serverAuth
                                        :client-auth KeyPurposeId/id_kp_clientAuth
                                        :code-signing KeyPurposeId/id_kp_codeSigning
                                        :email-protection KeyPurposeId/id_kp_emailProtection
                                        :ipsec-end-system KeyPurposeId/id_kp_ipsecEndSystem
                                        :ipsec-tunnel KeyPurposeId/id_kp_ipsecTunnel
                                        :ipsec-user KeyPurposeId/id_kp_ipsecUser)))))))
                 :ku (.addExtension
                      builder
                      Extension/keyUsage true
                      (->> (into [] v)
                           (mapv (fn [u]
                                   (case u
                                     :crl-sign KeyUsage/cRLSign
                                     :data-encipherment KeyUsage/dataEncipherment
                                     :decipher-only KeyUsage/decipherOnly
                                     :digital-signature KeyUsage/digitalSignature
                                     :encipher-only KeyUsage/encipherOnly
                                     :key-agreement KeyUsage/keyAgreement
                                     :key-certsign KeyUsage/keyCertSign
                                     :key-encipherment KeyUsage/keyEncipherment
                                     :non-repudiation KeyUsage/nonRepudiation)))
                           (reduce bit-or)
                           (KeyUsage.)))
                 :san (.addExtension
                       builder
                       Extension/subjectAlternativeName true
                       (GeneralNames.
                        ^"[Lorg.bouncycastle.asn1.x509.GeneralName;"
                        (into-array
                         GeneralName
                         (->> (into [] v)
                              (mapv (fn [[san-k san-v]]
                                      (let [san-vec (flatten [san-v])]
                                        (case san-k
                                          :dns (mapv (fn [dns]
                                                       (GeneralName. GeneralName/dNSName ^String dns))
                                                     san-vec)
                                          :email (mapv (fn [email]
                                                         (GeneralName. GeneralName/rfc822Name ^String email))
                                                       san-vec)
                                          :ip (mapv (fn [ip]
                                                      (GeneralName. GeneralName/iPAddress ^String ip))
                                                    san-vec)
                                          :uri (mapv (fn [uri]
                                                       (GeneralName. GeneralName/uniformResourceIdentifier ^String uri))
                                                     san-vec)))))
                              flatten)))))
               (rest extensions))))))

(defn- certificate-builder
  [^X500Principal issuer ^PublicKey public-key ^X500Principal subject validity expiry extensions]
  (let [[^Date start ^Date end] (->> (cond
                                       (instance? tempus.interval.Interval validity)
                                       [(:start validity) (:end validity)]

                                       (instance? tempus.duration.Duration expiry)
                                       (let [now (t/now)]
                                         [now (t/+ now expiry)])

                                       (instance? tempus.core.DateTime expiry)
                                       [(t/now) expiry]

                                       :else (throw (ex-info "validity invalid" {:validity validity :expiry expiry})))
                                     (map (fn [ts] (Date. ^long (t/into :long ts)))))
        cert-builder (JcaX509v3CertificateBuilder. issuer
                                                   (java.math.BigInteger. 64 (SecureRandom.))
                                                   start
                                                   end
                                                   subject
                                                   public-key)]
    (add-extensions cert-builder (compact extensions))))

(defn- key-pair
  [^String algorithm ^String group]
  (let [kpg (doto (KeyPairGenerator/getInstance algorithm)
              (.initialize (ECGenParameterSpec. group)))
        key-pair (.generateKeyPair kpg)]
    {:public (.getPublic key-pair)
     :private (.getPrivate key-pair)}))

(defn- pr-credential
  [^TLSCredential c]
  (format "#<fides/tls-credential@0x%x: %s>" (hash c)
          (str ":cert :subject \"" (:subject c) "\" "
               ":serial-number \"" (:serial-number c) "\" "
               (let [v (:validity c)]
                 (str ":valid [\"" (:start v) "\" \"" (:end v) "\"]"))
               (when (.key c)
                 " :key true"))))
