(ns morri.lib.interval
  (:require [clojure.string :as str]
            [clojure.pprint :refer [pprint]]
            [clojure.java.io :as io]
            [clojure.set :as set])
  (:import [org.biojava3.genome.parsers.gff GFF3Reader]))

;; Some simple genomic interval calculations.  Note limitations below.

(defn tprn [x] (prn x) x)

;; Add functions for read/write to string

(defprotocol GenomicIntervalOrder
  (precedes? [this-interval that-interval])
  (interval-sort-order [this-interval])
  (can-combine? [this-interval that-interval]))

(defprotocol GenomicIntervalCombination
  (interval-overlap [this-interval that-interval])
  (combine-intervals [this-interval that-interval]))

(defprotocol GenomicIntervalOutput
  (interval->bed-line [this-interval]))

(defrecord UnstrandedInterval [chr start stop]
  GenomicIntervalOrder
  (precedes?
   [this-interval that-interval]
   (not (pos? (compare
               [chr stop]
               ((juxt :chr :start) that-interval)))))
  (interval-sort-order
   [this-interval]
   [chr start stop])
  (can-combine?
   [this-interval that-interval]
   (if-not (= chr (:chr that-interval))
     (throw (Exception.
             (format "chr does not match for %s and %s"
                     this-interval
                     that-interval)))
     true))
  GenomicIntervalCombination
  (interval-overlap
   [this-interval that-interval]
   (when (can-combine? this-interval that-interval)
     (->UnstrandedInterval
      chr
      (max start (:start that-interval))
      (min stop (:stop that-interval)))))
  (combine-intervals
   [this-interval that-interval]
   (when (can-combine? this-interval that-interval)
     (->UnstrandedInterval
      chr
      (min start (:start that-interval))
      (max stop (:stop that-interval)))))
  GenomicIntervalOutput
  (interval->bed-line
   [this-interval]
   (str/join \tab [chr start stop \newline])))

(defn interval->ucsc [{:keys [chr start stop]}]
  (format "%s:%d-%d" chr start stop))

(def stranded-interval-order
  {:precedes?
   (fn [this-interval that-interval]
     (not (pos? (compare
                 ((juxt :strand :chr :stop) this-interval)
                 ((juxt :strand :chr :start) that-interval)))))
   :interval-sort-order
   (fn [interval]
     ((juxt :strand :chr :start :stop) interval))
   :can-combine?
   (fn [this-interval that-interval]
     (if-not (= ((juxt :strand :chr) this-interval)
                ((juxt :strand :chr) that-interval))
       (throw (Exception.
               (format "chr or strand do not match for %s and %s"
                       this-interval
                       that-interval)))
       true))})

(defrecord StrandedInterval [chr start stop strand]
  GenomicIntervalOutput
  (interval->bed-line
   [this-interval]
   (str/join \tab [chr start stop "." 0 strand \newline])))

(def stranded-interval-combination
  {:interval-overlap
   (fn [this-interval that-interval]
     (when (can-combine? this-interval that-interval)
       (->StrandedInterval
        (:chr this-interval)
        (max (:start this-interval) (:start that-interval))
        (min (:stop this-interval) (:stop that-interval))
        (:strand this-interval))))
   :combine-intervals
   (fn [this-interval that-interval]
     (when (can-combine? this-interval that-interval)
       (->StrandedInterval
        (:chr this-interval)
        (min (:start this-interval) (:start that-interval))
        (max (:stop this-interval) (:stop that-interval))
        (:strand this-interval))))})

(extend StrandedInterval
  GenomicIntervalCombination
  stranded-interval-combination
  GenomicIntervalOrder
  stranded-interval-order)

(defn combine-names [name1 name2]
  (if (and (= name1 ".") (= name2 "."))
    "."
    (format "%s+%s" name1 name2)))

(defrecord Bed6 [chr start stop name score strand]
  GenomicIntervalOutput
  (interval->bed-line [this-interval]
    (str/join \tab [chr start stop name score strand \newline])))

(def bed-interval-combination
  {:interval-overlap
   (fn [this-interval that-interval]
     (when (can-combine? this-interval that-interval)
       (->Bed6
        (:chr this-interval)
        (max (:start this-interval) (:start that-interval))
        (min (:stop this-interval) (:stop that-interval))
        (combine-names (:name this-interval) (:name that-interval))
        0
        (:strand this-interval))))
   :combine-intervals
   (fn [this-interval that-interval]
     (when (can-combine? this-interval that-interval)
       (->Bed6
        (:chr this-interval)
        (min (:start this-interval) (:start that-interval))
        (max (:stop this-interval) (:stop that-interval))
        (combine-names (:name this-interval) (:name that-interval))
        0
        (:strand this-interval))))})

(extend Bed6
  GenomicIntervalOrder
  stranded-interval-order
  GenomicIntervalCombination
  bed-interval-combination)

(defn format-comb-field [f]
  (str (str/join "," f) \,))

(defrecord Bed12 [chr start stop name score strand
                  thick-start thick-end item-rgb
                  block-count block-sizes block-starts]
  GenomicIntervalOutput
  (interval->bed-line [this-interval]
    (str/join \tab [chr start stop name score strand
                    thick-start thick-end
                    (format-comb-field item-rgb)
                    block-count
                    (format-comb-field block-sizes)
                    (format-comb-field block-starts)
                    \newline])))

(extend Bed12
  GenomicIntervalOrder
  stranded-interval-order
  GenomicIntervalCombination
  bed-interval-combination)

(defrecord GFF [chr source feature
                start stop score strand
                frame attributes])

(extend GFF
  GenomicIntervalOrder
  stranded-interval-order
  GenomicIntervalCombination
  stranded-interval-combination)

(defn bed12->bed6 [{:keys [chr start stop name score strand]}]
  (->Bed6 chr start stop name score strand))

(defn interval->stranded [{:keys [chr start stop strand]}]
  (->StrandedInterval chr start stop strand))

(defn interval->unstranded [{:keys [chr start stop]}]
  (->UnstrandedInterval chr start stop))

(defn sort-intervals [intervals]
  (sort-by interval-sort-order intervals))

(defn consolidate-intervals [intervals]
  "Take a coll of intervals and merge them together if they overlap with
  themselves and where is the word wrap limit?"
  (loop [consolidated [] intervals (seq intervals)]
    (let [f (first intervals)
          s (second intervals)]
      (if s
        (if (precedes? f s)
          (recur (conj consolidated f) (rest intervals))
          (recur consolidated (conj (drop 2 intervals)
                                    (combine-intervals f s))))
        (conj consolidated f)))))

(defn check-update-cache [cache interval]
  "Remove any intervals from the cache that precede interval"
  (if (seq cache)
    (remove #(precedes? % interval) cache)
    ()))

(defn cache-hits [cache interval overlap-fn]
  "First remove any intervals from the cache that occur after interval.
  Then process the cache hits with overlap-fn"
  (map (partial overlap-fn interval)
       (remove #(precedes? interval %) cache)))

(defn find-overlaps [intervals-a intervals-b overlap-fn]
  (loop [overlaps [] a intervals-a b intervals-b a-cache () b-cache ()]
    (if (and (seq a) (seq b))
      (let [top-a (first a)
            top-b (first b)
            new-a-cache (check-update-cache a-cache top-b)
            new-b-cache (check-update-cache b-cache top-a)]
        (cond (precedes? top-a top-b)   ;Discarding top-a. Process
                                        ;top-a -> b-cache.
              (recur (into overlaps (cache-hits new-b-cache top-a overlap-fn))
                     (rest a) b new-a-cache new-b-cache)
              (precedes? top-b top-a)    ;Discarding top-b.  Process
                                        ;top-b -> a-cache.  Keep
                                        ;b-cache
              (recur (into overlaps (cache-hits new-a-cache top-b overlap-fn))
                     a (rest b) new-a-cache new-b-cache)
              :else                     ;We've got a match
              (let [new-overlaps (into
                                  (conj overlaps (overlap-fn top-a top-b))
                                  (concat
                                   (cache-hits new-b-cache top-a overlap-fn)
                                   (cache-hits new-a-cache top-b overlap-fn)))
                    sa (second a)
                    sb (second b)
                    final-a-cache (if (and sb (not (precedes? top-a sb)))
                                    (conj new-a-cache top-a)
                                    new-a-cache)
                    final-b-cache (if (and sa (not (precedes? top-b sa)))
                                    (conj new-b-cache top-b)
                                    new-b-cache)]
                (recur new-overlaps
                       (rest a)
                       (rest b)
                       final-a-cache
                       final-b-cache))))
      overlaps)))

(defn str->int [s] (Integer. s))

(defn parse-comb-field [f]
  "Take a field of the form \"a,b,c,\" and return an array, [a,b,c]
  converting a, b and c to ints"
  (map str->int (str/split f #",")))

(defn write-comb-field [f]
  (str (str/join "," f) \,))

(defn bed-line->interval [l]
  (let [bed-components (str/split l #"\s+")]
    (case (count bed-components)
      3 (let [[chr start stop] bed-components]
          (->UnstrandedInterval chr (str->int start) (str->int stop)))
      6 (let [[chr start stop name score strand] bed-components]
          (->Bed6 chr (str->int start) (str->int stop)
                  name (str->int score) strand))
      12 (let [[chr start stop name score strand
                thick-start thick-end item-rgb
                block-count block-sizes block-starts] bed-components]
           (->Bed12 chr (str->int start) (str->int stop)
                    name (str->int score) strand
                    (str->int thick-start) (str->int thick-end)
                    (parse-comb-field item-rgb)
                    (str->int block-count)
                    (parse-comb-field block-sizes)
                    (parse-comb-field block-starts))))))

(defn read-bed-file [bed-file]
  (with-open [rdr (io/reader bed-file)]
    (doall (map bed-line->interval (line-seq rdr)))))

(defn write-bed-file [bed-file interval-list]
  (with-open [wrtr (io/writer bed-file)]
    (doseq [interval interval-list]
      (.write wrtr (interval->bed-line interval)))))

(defn gff-feature->GFF [feat]
  (let [chr (.seqname feat)
          source (.source feat)
          feature (.type feat)
          loc (.location feat)
          start (.bioStart loc)
          stop (.bioEnd loc)
          score (.score feat)
          strand (str (.bioStrand loc))
          frame (.frame feat)
          attrs (into {} (.getAttributes feat))
          attrs-keyword (zipmap (map keyword (keys attrs))
                                (vals attrs))]
      (->GFF chr source feature
             start stop score strand
             frame attrs-keyword)))

(defn read-gff-file [gff-file & {:keys [type]}]
  (let [gff-reader (if type
                     (.selectByType (GFF3Reader/read gff-file) type)
                     (GFF3Reader/read gff-file))]
    (for [feat gff-reader]
      (gff-feature->GFF feat))))

