(ns optimusbuf.ast-transform
  (:require [com.rpl.specter :refer [collect collect-one if-path keypath nthpath select select-one selected? setval srange subselect transform transformed
                                     ALL FIRST MAP-VALS NONE STAY]]))

(defrecord syntax [version])
(defrecord package [namespace])

(defrecord field [label type name fnum options])
(defrecord map-field [ktype vtype name fnum options])
(defrecord oneof-field [type name fnum options])
(defrecord enum-field [name fnum options])
(defrecord rpc [name input output options])

(defrecord message [name main extensions reserved-names reserved-ranges options])
(defrecord enum [name main options])
(defrecord group [label name fnum main extensions reserved-names reserved-ranges options])
(defrecord service [name main options])

(defn- gather-in-btw
  "(gather-in-btw [1 2 3 4 5 6 7 8] :front-n 2 :back-n 4) ; => [1 2 [3 4] 5 6 7 8]"
  [form & {:keys [front-n back-n]}]
  (let [second-n (- (count form) back-n)]
    (transform (srange front-n second-n) vector form)))

; (gather-in-btw [1 2 3 4 5 6 7 8] :front-n 2 :back-n 4)

(defn- recordify-dispatch [form]
  (if (sequential? form)
    (case (first form)
      :opts       (second form)
      :exts       (rest form)
      :rranges    (rest form)
      :rnames     (rest form)
      :syntax     [:syntax     (apply ->syntax      (rest form))]
      :package    [:package    (apply ->package     (rest form))]
      :field      [:field      (apply ->field       (rest form))]
      :mapField   [:mapField   (apply ->map-field   (rest form))]
      :oneofField [:oneofField (apply ->oneof-field (rest form))]
      :enumField  [:enumField  (apply ->enum-field  (rest form))]
      :rpc        [:rpc        (apply ->rpc         (rest form))]
      ; forms below have variable number of items but items at front & back are fixed,
      ; so we group in between items into 1 which into `main` param of defrecord
      :message    [:message    (apply ->message     (rest (gather-in-btw form :front-n 2 :back-n 4)))]
      :enum       [:enum       (apply ->enum        (rest (gather-in-btw form :front-n 2 :back-n 1)))]
      :group      [:group      (apply ->group       (rest (gather-in-btw form :front-n 4 :back-n 4)))]
      :service    [:service    (apply ->service     (rest (gather-in-btw form :front-n 2 :back-n 1)))]
      form)
    form))

(defn recordify [ast]
  (let [ast2 (clojure.walk/postwalk recordify-dispatch ast)
        top-lvl-opts (last ast2)]
    (assoc-in ast2 [(dec (count ast2))] [:options top-lvl-opts])))

; (gather-in-btw [1 2 3 4 5 6 7 8] {:front-n 2 :back-n 4})

(defn- seq-type-of? [form & args]
  (and (seqable? form) (not-empty form) ((set args) (first form))))

(defn- msg-enm-grp-ext? [form] (seq-type-of? form :message :enum :group :extend))
(defn- grp? [form]             (seq-type-of? form :group))
(defn- ext? [form]             (seq-type-of? form :extend))
(defn- msg-enm-grp? [form]     (seq-type-of? form :message :enum :group))
(defn- unnestable? [form] (not (seq-type-of? form :message :enum :extend)))

(defn- group->field
  "Protobuf group is like an embedded message that contains label (repeated, 
   optional, required) and field number.
   Example input:
     [:group :repeated \"G1\" 1
       [:field :required \"Ymsg\" \"g1\" 2 nil] ...]
   Example output:
     [:field :repeated \"G1\" \"G1\" 1]" ; note field name is same as group name
  [form]
  (if (grp? form)
    (let [[_ label grp-name fnum & _] form]
      [:field label grp-name grp-name fnum]) ; note field name is smae as group name
    form))

(defn group->group-
  "Strips label (repeated, optional, required) and field number from group
   Example input:
     [:group :repeated \"G1\" 1 [:field ...] ...]
   Example output:
     [:group \"G1\" [:field ...] ...]"
  [form]
  (if (grp? form)
    (let [[_ _ grp-name _ & args] form]
      (into [:group grp-name] args))
    form))

(defn- get-name-idx [form]
  (if (msg-enm-grp? form) 1 nil)) ; grp is group-

(defn- join-path [form name-idx path]
  (if (nil? name-idx)
    path
    (if (empty? path)
      (nth form name-idx)
      (str path "." (nth form name-idx)))))

(defn- rename [form name-idx new-name]
  (if (nil? name-idx)
    form (assoc-in form [name-idx] new-name)))

(defn unnest
  ([form] (unnest form ""))
  ([form path]
   (if (seqable? form)
     (let [nested (filter msg-enm-grp-ext? form)
           unnested (->> form
                         (filter unnestable?) ; note group is unnestable into field, so it appears in both nested and unnested
                         (into []))
           name-idx (get-name-idx unnested)
           new-path (join-path unnested name-idx path)
           unnested (as-> unnested $
                      (rename $ name-idx new-path)
                      (map group->field $)
                      (into [] $)
                      (if (ext? $) (conj $ [:from new-path]) $) ; insert :from element to keep path info
                      (if (msg-enm-grp-ext? $) [$] $))]
       (if (not-empty nested)
         (->> nested
              (map group->group-)
              (map #(unnest % new-path))
              (reduce into unnested))
         unnested))
     form)))

;; ------------------------------------------------------------------------------------------------
;; Merge and remove extensions
;; ------------------------------------------------------------------------------------------------

(defn starts-with? [x] #(and (seqable? %) (= (first %) x)))

(defn extract-extends [reg-u]
  (select [ALL                   ; {filename ast}
           (collect-one [FIRST]) ; key: filename
           (nthpath 1)           ; val: ast
           (collect-one [ALL (starts-with? :package)])
           (collect     [ALL (starts-with? :import)])
           (subselect   [ALL (starts-with? :extend)])] reg-u))

(defn remove-extends [reg-u]
  (setval [MAP-VALS ALL (selected? [(starts-with? :extend)])] NONE reg-u))

(defn ref-by?
  "Test if `subject` (full path msg) is a valid reference for `ref-target` (msg name) defined in `ref-from` (package)
   return true if subject's RHS == ref-target && (subject - ref-target) == ref-from's LHS.
   Example: ref-from:      [a b c d e x y]
            ref-target:    [x y Msg1]
            subject:       [a b c x y Msg1]         => true
            subject:       [a b c y Msg1]           => false (subject's RHS != ref-target)
            subject:       [a b c d e x y x y Msg1] => true
            subject:       [a b t x y Msg1]         => false ((subject - ref-target) != ref-from's LHS)
            subject:       [a b c d e x y Msg1]     => true"
  [ref-from ref-target subject]
  (let [subject-lhs-len (- (count subject) (count ref-target))
        subject-lhs (take subject-lhs-len subject)
        subject-rhs (drop subject-lhs-len subject)
        ref-from-lhs (take (count subject-lhs) ref-from)]
    (and (= subject-rhs ref-target)
         (= subject-lhs ref-from-lhs))))

(defn- find-max [targets]
  (->> targets
       (map-indexed vector)
       (apply max-key second)
       first))

(defn- dot-split [text] (clojure.string/split text #"\."))

(defn merge-extend [reg-u     ; {filename ast}
                    extend    ; target message/group message to extend, i.e. 'extend p.q.Target { ... }' in pb
                    src-pkg   ; package where extend is defined
                    filenames ; keys of reg-u; corresponding ast in reg-u will be searched for target
                    ]
  (let [extend-from (->> extend last last dot-split  ; [:extend ... [:from "x.y.Msg1"]] => [x y Msg1]
                         (into (dot-split src-pkg))) ; => [a b c x y Msg1]
        extend-target (-> extend second dot-split)
        extend-target-name (last extend-target)
        candidates (select [ALL
                            (collect-one [FIRST]) ; filename
                            (if-path #(filenames (first %)) (nthpath 1)) ; {filename ast}
                            (collect-one [ALL (starts-with? :package) (nthpath 1) ; [:package "a.b.c"]
                                          (transformed [STAY] dot-split)])        ; "a.b.c" => [a b c]
                            ALL
                            #(and (#{:message :group} (first %)) ; [:message ...] or [:group ...]
                                  (= extend-target-name (-> % second str
                                                            dot-split
                                                            last))) ; matches name
                            (nthpath 1)
                            (transformed [STAY] dot-split)] reg-u) ; "x.y.Msg" => [x y Msg]
        finalists (filter #(ref-by? extend-from extend-target (into (nth % 1) (nth % 2))) candidates)
        winner (->> finalists
                    (map (fn [[_ pkg msg]] (+ (count pkg) (count msg))))
                    find-max
                    (nth finalists))
        [filename _ msg] winner
        extend-from (clojure.string/join "." extend-from)
        extend-fields (->> extend
                           (select [ALL (starts-with? :field)])
                           (transform [ALL (nthpath 3)] #(str extend-from "." %))) ; full path field name
        fn-extend-fields #(let [[front back] (split-at (- (count %) 4) %)]
                            (-> (into [] front)
                                (into extend-fields)
                                (into back)))
        extended (transform [(keypath filename)
                             ALL
                             #(and (#{:message :group} (first %))
                                   (= (clojure.string/join "." msg) (second %)))]
                            fn-extend-fields reg-u)]
    extended))

; (1) starting with ns where `extend XXX { ... }` is defined, find matching message XXX
; (2) apply content of `extend` to message found
; PS: step (1) terminates once best match is found even if it has no `extensions` defined
(defn merge-extends [reg-u extends-in-a-file]
  (let [file (first extends-in-a-file)
        package (select-one [(nthpath 1 1)] extends-in-a-file)
        imports (select     [(nthpath 2) ALL (nthpath 1)] extends-in-a-file)
        extends (select     [(nthpath 3) ALL] extends-in-a-file)
        filenames (set (conj imports file))
        reg-ue (reduce #(merge-extend %1 %2 package filenames) reg-u extends)]
    reg-ue))