(ns toxi.geom.mesh.trianglemesh3d
  (:require
    [toxi.geom.utils :as utils]
    [toxi.data.index :as index]
    [toxi.math.core :as math]
    [toxi.math.matrix4x4]
    [toxi.geom.core :as geom]
    [toxi.geom.mesh.core :as mesh]
    [toxi.geom.aabb :as aabb])
  (:import
    [toxi.geom.mesh.types TriangleMesh3D Face3D]))

(defn mesh-attribute-accessor
  [attr set-attrib-fn stride]
  (letfn[(src [mesh] (:items (get (:attribs mesh) attr)))]
	  {:source src
	   :face-attribs (fn[mesh f]
	                   (let[[a0 a1 a2] (get (:attr f) attr)]
	                     [(get src a0) (get src a1) (get src a2)]))
	   :compile-face
	     (let[s2 (* 2 stride)]
	       (fn[buf idx src mesh f]
	         (let[[a0 a1 a2] (get (:attr f) attr)]
		         (set-attrib-fn buf idx (get src a0))
		         (set-attrib-fn buf (+ stride idx) (get src a1))
		         (set-attrib-fn buf (+ s2 idx) (get src a2)))))
	   :stride stride}))

(def m3d-default-accessors {
  :vertices {
     :stride 3
     :source (fn[mesh] (:items (:vertices mesh)))
     :compile-face
       (fn[buf idx src mesh f]
         (utils/aset-vec3d buf idx (get src (:a f)))
         (utils/aset-vec3d buf (+ 3 idx) (get src (:b f)))
         (utils/aset-vec3d buf (+ 6 idx) (get src (:c f))))
  }
  :fnormals {
     :stride 3
     :source (fn[mesh] (:items (:normals mesh)))
     :compile-face
       (fn[buf idx src mesh f]
         (let[n (get src (:n f))]
           (utils/aset-vec3d buf idx n)
           (utils/aset-vec3d buf (+ 3 idx) n) 
           (utils/aset-vec3d buf (+ 6 idx) n)))
  }
  :vnormals {
     :stride 3
     :source (fn[mesh] (:items (:normals mesh)))
     :compile-face
       (fn[buf idx src mesh f]
         (let[vn (:vnormals mesh)]
           (utils/aset-vec3d buf idx (get src (get vn (get f :a))))
           (utils/aset-vec3d buf (+ 3 idx) (get src (get vn (get f :b))))
           (utils/aset-vec3d buf (+ 6 idx) (get src (get vn (get f :c))))))
  }
  :uv (mesh-attribute-accessor :uv utils/aset-vec2d 2)
})

(defn- m3d-vertex-at-point
  [mesh v]
  (let[id (get (:index (:vertices mesh)) (utils/swizzle v :xyz))]
    (when-not (nil? id) {:id id :v (get (:items (:vertices mesh)) id)})))
      
(defn- m3d-conjf
  ([mesh a b c] (m3d-conjf mesh a b c nil))
  ([mesh a b c attribs]
    (let [[vert-index ida] (index/index-item (:vertices mesh) (utils/swizzle a :xyz))
          [vert-index idb] (index/index-item vert-index (utils/swizzle b :xyz))
          [vert-index idc] (index/index-item vert-index (utils/swizzle c :xyz))
          ; index a single attribute
          attr-idx-fn (fn[attr-idx att-items]
		                     (reduce
		                       (fn[acc att]
		                         (let[[idx id] (index/index-item (get acc 0) att)]
		                           [idx (conj (get acc 1) id)]))
		                       [attr-idx []] att-items))
          ; process all vertex attributes
          [attr-index attr-face-ids] (reduce
                     (fn[acc e]
                       (let[id (key e)
                            att-map (first acc)
                            idx (get att-map id)
                            idx (if (nil? idx) (index/make-index) idx)
                            [idx idx-ids] (attr-idx-fn idx (val e))]
                         [(assoc att-map id idx) (assoc (second acc) id idx-ids)]))
                     [(:attribs mesh) {}] attribs)
          f (mesh/face3d ida idb idc attr-face-ids)]
      (assoc mesh
        :vertices vert-index
        :attribs attr-index
        :faces (conj (:faces mesh) f))))
  ([mesh faces]
    (reduce (fn[mesh [a b c attribs]]
              (m3d-conjf mesh a b c attribs)) mesh faces)))

(defn- m3d-face-verts
  [mesh f]
  (let[verts (:items (:vertices mesh))]
    [(get verts (:a f)) (get verts (:b f)) (get verts (:c f))]))

(defn- m3d-face-attribs
  [mesh f]
  (let[mesh-attribs (:attribs mesh)
       att-buf (reduce
         (fn[acc attr]
           (let[id (key attr)
                src (:items (get mesh-attribs id))
                [a0 a1 a2] (val attr)]
             (assoc acc id [(get src a0) (get src a1) (get src a2)])))
         {} (:attr f))]
    att-buf))
  
(defn- m3d-merge-meshes
  [mesh coll]
  (reduce
    (fn[accm m]
      (reduce
        (fn[acc f]
          (let[[a b c] (m3d-face-verts m f)
               attribs (m3d-face-attribs m f)]
            (m3d-conjf acc a b c attribs)))
        accm (:faces m)))
    mesh coll))

(defn- m3d-face-normals
  [mesh]
  (let[faces (:faces mesh)
       [norm-idx newfaces]
       (reduce
         (fn[acc f]
           (let [[a b c] (m3d-face-verts mesh f)
                 n (geom/normalize (geom/cross (geom/sub a b) (geom/sub a c)))
                 [newidx nid] (index/index-item (get acc 0) n)]
             [newidx (conj (get acc 1) (assoc f :n nid))]))
         [(index/make-index) []] faces)]
    (assoc mesh :normals norm-idx :faces newfaces)))

(defn- m3d-vertex-normals
  [mesh]
  (let[normals (:items (:normals mesh))
       numv (count (:items (:vertices mesh)))
       vnormals (reduce
         (fn[idx f]
           (let[a (get f :a) b (get f :b) c (get f :c)
                n (get normals (get f :n))
                na (get idx a) nb (get idx b) nc (get idx c)
                na (if (nil? na) n (geom/add na n))
                nb (if (nil? nb) n (geom/add nb n))
                nc (if (nil? nc) n (geom/add nc n))]
             (assoc idx a na b nb c nc)))
         {} (:faces mesh))
       [norm-idx vnormals]
         (loop [nidx (:normals mesh) vb [] i (int 0)]
           (if (= i numv)
             [nidx vb]
             (let[[idx id] (index/index-item nidx (geom/normalize (get vnormals i)))]
               (recur idx (conj vb id) (inc i)))))]
    (assoc mesh :normals norm-idx :vnormals vnormals)))
             
(defn- m3d-bounds
  [mesh]
  (let [vertices (:items (:vertices mesh))
        mi (reduce geom/minv (geom/vec3d (Double/MAX_VALUE)) vertices)
        mx (reduce geom/maxv (geom/vec3d (Double/MIN_VALUE)) vertices)]
    (geom/aabb-from-minmax mi mx)))

(defn- m3d-bsphere
  [mesh]
  (let [aabb (m3d-bounds mesh)
        radius (reduce
                 #(max %1 (geom/mag-squared (geom/sub %2 aabb)))
                 0 (:items (:vertices mesh)))]
    (geom/sphere aabb (Math/sqrt radius))))

(defn- m3d-centroid
  [mesh]
  (utils/swizzle (geom/bounds mesh) :xyz))

(defn- m3d-transform
  [mesh matrix]
  (let[newidx (reduce
         (fn[idx v] (first (index/index-item idx (math/transform-point matrix v))))
         (index/make-index) (:items (:vertices mesh)))]
    (assoc mesh :vertices newidx)))

(defn- m3d-compile
  ([mesh] (m3d-compile mesh [:vertices :fnormals] nil))
  ([mesh attribs] (m3d-compile mesh attribs nil))
  ([mesh attribs accessors]
    (let[accessors (merge accessors m3d-default-accessors)
         faces (:faces mesh)
         numf (count faces)
         numv (* 3 numf)
         buffers (reduce
                   (fn[bufmap attr]
                     (assoc bufmap attr (float-array (* numv (:stride (get accessors attr))))))
                   {} attribs)]
      (reduce
        (fn[buffers attr]
          (let[buf (get buffers attr)
               acc (get accessors attr)
               src ((:source acc) mesh)
               stride (* 3 (:stride acc))
               process-face-fn (:compile-face acc)
               last-face-id (dec numf)]
            (loop[fid 0 idx 0]
              (process-face-fn buf idx src mesh (get faces fid))
              (if (< fid last-face-id)
                (recur (inc fid) (+ idx stride))))
            buffers))
        buffers attribs))))

(defn- m3d-flip
  [mesh]
  (assoc mesh :faces
         (reduce (fn[acc {:keys[b c] :as f}] (conj acc (assoc f :b c :c b)))
                 [] (:faces mesh))))

(defn extend-mesh
  [type]
  (extend type
    geom/IShape {
      :bounds m3d-bounds
      :centroid m3d-centroid
    }
    geom/IShape3D {
      :bounding-sphere m3d-bsphere
      :->mesh identity
    }
    geom/ITransformable {
      :transform m3d-transform
    }
    geom/IFlippable {
      :flip m3d-flip
    }
    mesh/IMesh {
      :conjf m3d-conjf
      :merge-meshes m3d-merge-meshes
      :compile-mesh m3d-compile
      :face-attribs m3d-face-attribs
      :face-vertices m3d-face-verts
      :vertex-at-point m3d-vertex-at-point
    }
    mesh/IMesh3D {
      :compute-face-normals m3d-face-normals
      :compute-vertex-normals m3d-vertex-normals
    }
  ))

(extend-mesh TriangleMesh3D)