{-# LANGUAGE TemplateHaskell #-}
module Data.Geometry.IntervalTree( NodeData(..)
                                 , splitPoint, intervalsLeft, intervalsRight
                                 , IntervalTree(..), unIntervalTree
                                 , IntervalLike(..)
                                 , createTree, fromIntervals
                                 , insert, delete
                                 , stab, search
                                 , toList
                                 ) where
import           Control.DeepSeq
import           Control.Lens
import           Data.BinaryTree
import           Data.Ext
import           Data.Geometry.Interval
import           Data.Geometry.Interval.Util
import           Data.Geometry.Properties
import qualified Data.List as List
import qualified Data.Map as M
import           GHC.Generics (Generic)
data NodeData i r = NodeData { _splitPoint     :: !r
                             , _intervalsLeft  :: !(M.Map (L r) [i])
                             , _intervalsRight :: !(M.Map (R r) [i])
                             } deriving (Show,Eq,Ord,Generic)
makeLenses ''NodeData
instance (NFData i, NFData r) => NFData (NodeData i r)
newtype IntervalTree i r =
  IntervalTree { _unIntervalTree :: BinaryTree (NodeData i r) }
  deriving (Show,Eq,Generic)
makeLenses ''IntervalTree
instance (NFData i, NFData r) => NFData (IntervalTree i r)
createTree     :: Ord r => [r] -> IntervalTree i r
createTree pts = IntervalTree . asBalancedBinTree
               . map (\m -> NodeData m mempty mempty) $ pts
fromIntervals    :: (Ord r, IntervalLike i, NumType i ~ r)
                 => [i] -> IntervalTree i r
fromIntervals is = foldr insert (createTree pts) is
  where
    endPoints (toRange -> Range' a b) = [a,b]
    pts = List.sort . concatMap endPoints $ is
toList :: IntervalTree i r -> [i]
toList = toList' . _unIntervalTree
  where
    toList' Nil              = []
    toList' (Internal l v r) =
      concat [concat $ v^..intervalsLeft.traverse, toList' l, toList' r]
search :: Ord r => r -> IntervalTree i r -> [i]
search = stab
stab                    :: Ord r => r -> IntervalTree i r -> [i]
stab x (IntervalTree t) = stab' t
  where
    stab' Nil = []
    stab' (Internal l (NodeData m ll rr) r)
      | x <= m    = let is = f (<= L (Closed x)) . M.toAscList $ ll
                    in is ++ stab' l
      | otherwise = let is = f (>= R (Closed x)) . M.toDescList $ rr
                    in is ++ stab' r
    f p = concatMap snd . List.takeWhile (p . fst)
insert                    :: (Ord r, IntervalLike i, NumType i ~ r)
                          => i -> IntervalTree i r -> IntervalTree i r
insert i (IntervalTree t) = IntervalTree $ insert' t
  where
    ri@(Range a b) = toRange i
    insert' Nil = Nil
    insert' (Internal l nd@(_splitPoint -> m) r)
      | m `inRange` ri = Internal l (insert'' nd) r
      | b <= Closed m  = Internal (insert' l) nd r
      | otherwise      = Internal l nd (insert' r)
    insert'' (NodeData m l r) = NodeData m (M.insertWith (++) (L a) [i] l)
                                           (M.insertWith (++) (R b) [i] r)
delete :: (Ord r, IntervalLike i, NumType i ~ r, Eq i)
          => i -> IntervalTree i r -> IntervalTree i r
delete i (IntervalTree t) = IntervalTree $ delete' t
  where
    ri@(Range a b) = toRange i
    delete' Nil = Nil
    delete' (Internal l nd@(_splitPoint -> m) r)
      | m `inRange` ri = Internal l (delete'' nd) r
      | b <= Closed m  = Internal (delete' l) nd r
      | otherwise      = Internal l nd (delete' r)
    delete'' (NodeData m l r) = NodeData m (M.update f (L a) l) (M.update f (R b) r)
    f is = let is' = List.delete i is in if null is' then Nothing else Just is'
class IntervalLike i where
  toRange :: i -> Range (NumType i)
instance IntervalLike (Range r) where
  toRange = id
instance IntervalLike (Interval p r) where
  toRange = fmap (^.core) . _unInterval