(ns fides.certificates
  (:require [tempus.core :as t]
            [tempus.interval :as i]
            [tempus.duration :as d]
            [fides.util.bytes :as bytes])
  (:import [java.lang String]
           [java.util Date]
           [java.time ZoneOffset]
           [java.io BufferedReader ByteArrayInputStream InputStreamReader StringWriter]
           [java.security KeyPairGenerator PrivateKey PublicKey SecureRandom]
           [java.security.cert CertificateFactory X509Certificate]
           [java.security.spec ECGenParameterSpec]
           [javax.security.auth.x500 X500Principal]
           [org.bouncycastle.asn1 ASN1Encodable DERSequence]
           [org.bouncycastle.asn1.x509 BasicConstraints Extension KeyPurposeId KeyUsage GeneralName GeneralNames]
           [org.bouncycastle.cert.jcajce JcaX509CertificateConverter JcaX509v3CertificateBuilder]
           [org.bouncycastle.jce.provider BouncyCastleProvider]
           [org.bouncycastle.openssl PEMParser]
           [org.bouncycastle.openssl.jcajce JcaPEMKeyConverter JcaPEMWriter]
           [org.bouncycastle.operator ContentSigner]
           [org.bouncycastle.operator.jcajce JcaContentSignerBuilder]
           [org.bouncycastle.pkcs PKCS10CertificationRequest]
           [org.bouncycastle.pkcs.jcajce JcaPKCS10CertificationRequestBuilder]))

(declare sign-csr sign-certificate unbuilt-certificate certificate-extensions
         key-tuple ->key-pem ->pem pem-> pem->x509-cert ->tempus)

(defonce ^:private ^String default-key-algorithm "EC")
(defonce ^:private ^String default-signature-algorithm "SHA256withECdSA")
(defonce ^:private ^String default-group-name "secp384r1")

(defn self-signed-certificate
  ([^String cn validity] (self-signed-certificate cn validity {}))
  ([^String cn validity extensions]
   (let [[^PublicKey public-key ^PrivateKey private-key] (key-tuple default-key-algorithm default-group-name)
         subject (X500Principal. (str "CN=" cn))]

     {:certificate (-> (unbuilt-certificate subject public-key subject validity extensions)
                       (sign-certificate private-key)
                       ->pem)
      :key (->key-pem private-key)})))


(defn certificate-signing-request
  [cn]
  (let [[^PublicKey public-key ^PrivateKey private-key] (key-tuple default-key-algorithm default-group-name)]
    {:csr (-> (JcaPKCS10CertificationRequestBuilder. (X500Principal. (str "CN=" cn))
                                                     public-key)
              (sign-csr private-key)
              ->pem)
     :key (->key-pem private-key)}))


(defn signed-certificate
  [csr validity extensions signing-cert signing-key]
  (let [pkcs10 ^PKCS10CertificationRequest (pem-> csr)
        signing-x509 ^X509Certificate (pem->x509-cert signing-cert)
        signing-PK (.getPrivateKey (JcaPEMKeyConverter.) (pem-> signing-key))]
    {:certificate (-> (unbuilt-certificate (.getSubjectX500Principal signing-x509)
                                           (.getPublicKey signing-x509)
                                           (-> pkcs10 .getSubject .toString (X500Principal.))
                                           validity
                                           extensions)
                      (sign-certificate signing-PK)
                      ->pem)}))

(defn not-before
  [pem-cert]
  (let [x509 (pem->x509-cert pem-cert)]
    (->tempus (.getNotBefore ^X509Certificate x509))))

(defn not-after
  [pem-cert]
  (let [x509 (pem->x509-cert pem-cert)]
    (->tempus (.getNotAfter ^X509Certificate x509))))

(defn expires
  [pem-cert]
  (let [x509 (pem->x509-cert pem-cert)]
    (->> (i/interval (->tempus (.getNotBefore ^X509Certificate x509))
                     (->tempus (.getNotAfter ^X509Certificate x509)))
         (i/into :days)
         Math/floor
         int)))

(defn serial-number
  [pem-cert]
  (let [x509 (pem->x509-cert pem-cert)]
    (str (.getSerialNumber ^X509Certificate x509))))

(defn public-key
  [pem-cert]
  (let [x509 (pem->x509-cert pem-cert)]
    (->pem (.getPublicKey ^X509Certificate x509))))

(defn subject
  [pem-cert]
  (let [x509 (->> pem-cert pem->x509-cert)
        x500-principal (.getSubjectX500Principal ^X509Certificate x509)
        x509-subject (.getName ^X500Principal x500-principal)]
    (second (re-find #"CN=(.*)" ^String x509-subject))))


;;; Private

(defn- sign-csr
  [^JcaPKCS10CertificationRequestBuilder unsigned-csr ^PrivateKey private-key]
  (let [signer (.build ^JcaContentSignerBuilder
                       (JcaContentSignerBuilder. default-signature-algorithm)
                       private-key)]
    (.build unsigned-csr signer)))

(defn- sign-certificate
  [^JcaX509v3CertificateBuilder unsigned-certificate ^PrivateKey private-key]
  (let [signer ^ContentSigner (.build ^JcaContentSignerBuilder
                                      (JcaContentSignerBuilder. default-signature-algorithm)
                                      private-key)]
    (.getCertificate (.setProvider (JcaX509CertificateConverter.) (BouncyCastleProvider.))
                     (.build ^JcaX509v3CertificateBuilder unsigned-certificate signer))))

(defn- unbuilt-certificate
  [^X500Principal issuer ^PublicKey public-key ^X500Principal subject expires-in extensions]
  (let [cert-builder (JcaX509v3CertificateBuilder. issuer
                                                   (java.math.BigInteger. 64 (SecureRandom.))
                                                   (Date.)
                                                   (Date. (long (t/into :long (t/+ (t/now) (d/days expires-in)))))
                                                   subject
                                                   public-key)]
    (certificate-extensions cert-builder extensions)))

(defn- certificate-extensions
  "Takes in a map of certificate extensions (similar to, but not exactly like) keytool.
  These are parsed into a CertificateExtensions object passed to the Java Certificate
  classes.
  Example: {:bc {:ca true} :san {:dns \"google.com\"} :ku [:crl-sign] :eku [:server-auth :client-auth]}

  Supported extensions are BC, EKU, KU, and SAN.  Unupported (currently) are IAN, SIA, and AIA.
  Explicit criticality is not supported, but the default behaviour is leveraged from the *Extension classes

  More will be added as use cases require and sanity permits."
  [^JcaX509v3CertificateBuilder b m]
  (println "Adding extensions" m "to certificte-builder" b)
  (loop [builder b
         extensions m]
    (println "loop, extensions are" extensions)
    (if (empty? extensions)
      builder
      (let [[k v] (first extensions)]
        (println "Processing extension" k "with value" v ".")
        (recur (case k
                 :bc (.addExtension builder Extension/basicConstraints true (BasicConstraints. ^boolean (:ca v)))
                 :eku (.addExtension builder Extension/extendedKeyUsage true (->> (into [] v)
                                                                                  (mapv (fn [u]
                                                                                          (case u
                                                                                            :any-extended-key-usage KeyPurposeId/anyExtendedKeyUsage
                                                                                            :server-auth KeyPurposeId/id_kp_serverAuth
                                                                                            :client-auth KeyPurposeId/id_kp_clientAuth
                                                                                            :code-signing KeyPurposeId/id_kp_codeSigning
                                                                                            :email-protection KeyPurposeId/id_kp_emailProtection
                                                                                            :ipsec-end-system KeyPurposeId/id_kp_ipsecEndSystem
                                                                                            :ipsec-tunnel KeyPurposeId/id_kp_ipsecTunnel
                                                                                            :ipsec-user KeyPurposeId/id_kp_ipsecUser)))
                                                                                  (into-array ASN1Encodable)
                                                                                  (DERSequence.)
                                                                                  ))
                 :ku (.addExtension builder Extension/keyUsage true (->> (into [] v)
                                                                         (mapv (fn [u]
                                                                                 (case u
                                                                                   :crl-sign KeyUsage/cRLSign
                                                                                   :data-encipherment KeyUsage/dataEncipherment
                                                                                   :decipher-only KeyUsage/decipherOnly
                                                                                   :digital-signature KeyUsage/digitalSignature
                                                                                   :encipher-only KeyUsage/encipherOnly
                                                                                   :key-agreement KeyUsage/keyAgreement
                                                                                   :key-certsign KeyUsage/keyCertSign
                                                                                   :key-encipherment KeyUsage/keyEncipherment
                                                                                   :non-repudiation KeyUsage/nonRepudiation)))
                                                                         (into-array ASN1Encodable)
                                                                         (DERSequence.)))
                 :san (.addExtension builder Extension/subjectAlternativeName true (->> (into [] v)
                                                                                        (mapv (fn [[san-k san-v]]
                                                                                                (println "SAN" san-k " " san-v)
                                                                                                (case san-k
                                                                                                  :dns (GeneralName. GeneralName/dNSName san-v)
                                                                                                  :email (GeneralName. GeneralName/rfc822Name san-v)
                                                                                                  :ip (GeneralName. GeneralName/iPAddress san-v)
                                                                                                  :oid (GeneralName. GeneralName/otherName san-v)
                                                                                                  :uri (GeneralName. GeneralName/uniformResourceIdentifier san-v))))
                                                                                        (into-array GeneralName)
                                                                                        (GeneralNames.))))
               (rest extensions))))))

(defn- key-tuple
  [^String algorithm ^String group]
  (let [kpg (doto (KeyPairGenerator/getInstance algorithm)
              (.initialize (ECGenParameterSpec. group)))
        key-pair (.generateKeyPair kpg)]
    [(.getPublic key-pair) (.getPrivate key-pair)]))

(defn- ->key-pem
  [o]
  (clojure.string/join "\n"
                       (-> ["-----BEGIN PRIVATE KEY-----"]
                           (into (->> ^PrivateKey o .getEncoded bytes/to-b64-str (re-seq #".{1,64}")))
                           (into ["-----END PRIVATE KEY-----"]))))

(defn- ->pem
  [o]
  (let [sw (StringWriter.)
        pw (JcaPEMWriter. sw)]
    (.writeObject pw o)
    (.flush pw)
    (.toString sw)))

(defn- pem->
  [^String pem-string]
  (-> pem-string
      (.getBytes)
      (ByteArrayInputStream.)
      (InputStreamReader.)
      (BufferedReader.)
      (PEMParser.)
      .readObject))

(defn- pem->x509-cert
  [^String pem-cert]
  ^X509Certificate (.generateCertificate (CertificateFactory/getInstance "X.509") (-> pem-cert
                                                                                      .getBytes
                                                                                      (java.io.ByteArrayInputStream.))))

(defn- ->tempus
  "Currently assuming UTC timezone"
  [^Date d]
  (->> (-> d .toInstant (.atOffset ZoneOffset/UTC))
       (t/from :native)))
