(ns lambdaisland.witchcraft.matrix
  "Vector/Matrix math

  A vector in this context can be anything that
  implements [[lambdaisland.witchcraft/with-xyz]]: a Clojure vector (`[x y z]`),
  a Clojure map (`{:x .. :y .. :z ..}`), or a Glowstone `Location` or `Vector`.
  You get the type back that you put in.

  A matrix is a vector of vectors (regular Clojure vectors) and can be
  3x3 (linear) or 4x4 (affine/homogenous).

  This code is not optimized for speed, it is fine for generating and
  manipulating minecraft structures, not for heavy number crunching.

  "
  (:require [lambdaisland.witchcraft :as wc]))

(defn v-
  "Vector subtraction

  Arguments can be Clojure maps (:x/:y/:z), vectors, or Glowstone Location or
  Vector instances. The return type is the type of `a`.
  "
  ([a]
   (wc/with-xyz a (mapv - (wc/xyz a))))
  ([a b]
   (wc/with-xyz a (mapv - (wc/xyz a) (wc/xyz b)))))

(defn v+
  "Vector addition

  Arguments can be Clojure maps (:x/:y/:z), vectors, or Glowstone Location or
  Vector instances. The return type is the type of `a`.
  "
  [a b]
  (wc/with-xyz a (mapv + (wc/xyz a) (wc/xyz b))))

(defn v*
  "Multiply a vector with a scalar

  `v` can be a
  Clojure map (`:x/:y/:z`), vector (`[x y z]`), or Glowstone Location or Vector
  instance. Returns the same type as `v`."
  [v s]
  (wc/with-xyz v (map (partial * s) (wc/xyz v))))

(defn vlength
  "Vector length"
  [v]
  (wc/distance [0 0 0] v))

(defn manhatten
  "Manhatten distance"
  [x y]
  (reduce + (map #(Math/abs (- %1 %2)) (wc/xyz x) (wc/xyz y))))

(defn chebyshev
  "Chebyshev (maximum metric) distance"
  [x y]
  (apply max (map #(Math/abs (- %1 %2)) (wc/xyz x) (wc/xyz y))))

(defn vnorm
  "Normalize a vector to be length=1"
  [v]
  (v* v (/ 1 (vlength v))))

(defn m*
  "Multiply a matrix with a scalar"
  [m s]
  (mapv (partial mapv (partial * s)) m))

(defn dot-product
  "Vector dot product

  Arguments can be Clojure maps (:x/:y/:z), vectors, or Glowstone Location or
  Vector instances. Returns a number.
  "
  [a b]
  (let [a (if (vector? a) a (wc/xyz a))
        b (if (vector? b) b (wc/xyz b))]
    (reduce + (map * a b))))

(defn cross-product [a b]
  (let [[ax ay az] (wc/xyz a)
        [bx by bz] (wc/xyz b)]
    [(- (* ay bz) (* az by))
     (- (* az bx) (* ax bz))
     (- (* ax by) (* ay bx))]))

(defn m*v
  "Multiply a matrix (vector of vectors) with a vector

  `m` is a Clojure vector of vectors, 3x3 (linear) or 4x4 (affine). `v` can be a
  Clojure map (`:x/:y/:z`), vector (`[x y z]`), or Glowstone Location or Vector
  instance. Returns the same type as `v`.
  "
  [m v] (wc/with-xyz v (mapv (partial dot-product (wc/xyz1 v)) m)))

(defn transpose
  "Transpose a matrix"
  [m]
  (apply mapv vector m))

(defn m*m
  "Multiply matrices"
  ([m1 m2 & rest]
   (apply m*m (m*m m1 m2) rest))
  ([m1 m2]
   (let [m2 (transpose m2)]
     (mapv (fn [row]
             (mapv (fn [bs]
                     (dot-product row bs)) m2))
           m1))))

(defn identity-matrix
  "Return a `degree x degree` matrix with all elements on the diagonal `1` and all
  others `0`"
  [degree]
  (mapv (fn [y]
          (mapv (fn [x]
                  (if (= x y) 1 0))
                (range degree)))
        (range degree)))

(defn translation-matrix
  "Returns an affine transformation matrix that moves a location by a fixed amount
  in each dimension."
  [v]
  (let [[x y z] (wc/xyz v)]
    [[1 0 0 x]
     [0 1 0 y]
     [0 0 1 z]
     [0 0 0 1]]))

(defn rotation-matrix
  "Matrix which rotates around the origin, takes the rotation in radians, and the
  dimensions that form the plane in which the rotation is performed,
  e.g. `(rotation-matrix Math/PI :x :z)`"
  [rad dim1 dim2]
  (let [row1 (mapv (fn [dim]
                     (cond (= dim dim1) (Math/cos rad)
                           (= dim dim2) (- (Math/sin rad))
                           :else 0))
                   [:x :y :z 0])
        row2 (mapv (fn [dim]
                     (cond (= dim dim1) (Math/sin rad)
                           (= dim dim2) (Math/cos rad)
                           :else 0))
                   [:x :y :z 0])]
    [(cond (= :x dim1) row1 (= :x dim2) row2 :else [1 0 0 0])
     (cond (= :y dim1) row1 (= :y dim2) row2 :else [0 1 0 0])
     (cond (= :z dim1) row1 (= :z dim2) row2 :else [0 0 1 0])
     [0 0 0 1]]))

(defn mirror-matrix
  "Matrix which mirrors points, mappings is a map of one or more of `:x/:y/:z` to
  `:x/:-x/:y/:-y/:z/:-z`. E.g. a mapping of `{:x :-x}` means that the x value
  gets flipped, in other words it's a mirroring around the `z=0` plane.
  `{:x :z, :z :x}` means that the `x` and `z` values get swapped, i.e. a
  mirroring around the `x=z` plane."
  [mappings]
  (mapv
   (fn [dim]
     (if (= 0 dim)
       [0 0 0 1]
       (case (get mappings dim dim)
         :x [1 0 0 0]
         :-x [-1 0 0 0]
         :y [0 1 0 0]
         :-y [0 -1 0 0]
         :z [0 0 1 0]
         :-z [0 0 -1 0])))
   [:x :y :z 0]))

(defn with-origin
  "Takes an affine transformation matrix, and an origin coordinate, and returns a
  matrix which performs the same trasnformation, but around the new origin. Use
  this to change the \"anchor\" around which e.g. a rotation happens, which by
  default is otherwise the `[0 0 0]` origin coordinate."
  [matrix origin]
  (m*m
   (translation-matrix (v* origin -1))
   matrix
   (translation-matrix origin)))

(defn transform
  "Transform a collection by applying a matrix to each element"
  ([coll m & rest]
   (transform coll (apply m*m m rest)))
  ([coll m]
   (into (empty coll)
         (map (partial m*v m))
         coll)))

(defn center
  "The center point of a collection of points, simply takes the average in each
  dimension."
  [coll]
  [(/ (transduce (map wc/x) + coll) (count coll))
   (/ (transduce (map wc/y) + coll) (count coll))
   (/ (transduce (map wc/z) + coll) (count coll))])

(defn rotate
  "Rotate a shape around its center (average of all block locations), given an
  angle in radians, and the two dimensions (as keywords, `:x`/`:y`/`:z`) that
  form the plane within which to rotate."
  [rad dim1 dim2 coll]
  (transform
   coll
   (with-origin
     (rotation-matrix rad dim1 dim2)
     (center coll))))

(defn extrude
  "Extrude a shape in a given direction, takes a collection of locations/blocks, a
  direction vector, and a number of times to apply the direction vector."
  [coll dir steps]
  (for [blk coll
        i (range steps)]
    (v+ blk (v* dir i))))
