module Data.Geometry.KDTree where
import           Control.Lens hiding (imap, element, Empty, (:<))
import           Data.BinaryTree
import           Data.Coerce
import           Data.Ext
import qualified Data.Foldable as F
import           Data.Geometry.Box
import           Data.Geometry.Point
import           Data.Geometry.Properties
import           Data.Geometry.Vector
import qualified Data.List.NonEmpty as NonEmpty
import           Data.Maybe (fromJust)
import           Data.Proxy
import           Data.Seq (LSeq(..), ViewL(..))
import qualified Data.Seq as Seq
import           Data.Util
import qualified Data.Vector.Fixed as FV
import           GHC.TypeLits
import           Prelude hiding (replicate)
newtype Coord (d :: Nat) = Coord { unCoord ::  Int}
instance KnownNat d => Eq (Coord d) where
  (Coord i) == (Coord j) = (i `mod` d) == (j `mod` d)
    where
      d = fromInteger . natVal $ (Proxy :: Proxy d)
instance KnownNat d => Show (Coord d) where
  show (Coord i) = show $ 1 + (i `mod` d)
    where
      d = fromInteger . natVal $ (Proxy :: Proxy d)
instance KnownNat d => Enum (Coord d) where
  toEnum i = Coord $ 1 + (i `mod` d)
    where
      d = fromInteger . natVal $ (Proxy :: Proxy d)
  fromEnum = subtract 1 . unCoord
data Split d r = Split !(Coord d) !r !(Box d () r)
deriving instance (Show r, Arity d, KnownNat d) => Show (Split d r)
deriving instance (Eq r, Arity d, KnownNat d)   => Eq (Split d r)
type Split' d r = SP (Coord d) r
newtype KDTree' d p r = KDT { unKDT :: BinLeafTree (Split d r) (Point d r :+ p) }
deriving instance (Show p, Show r, Arity d, KnownNat d) => Show (KDTree' d p r)
deriving instance (Eq p, Eq r, Arity d, KnownNat d)     => Eq   (KDTree' d p r)
data KDTree d p r = Empty
                  | Tree (KDTree' d p r)
deriving instance (Show p, Show r, Arity d, KnownNat d) => Show (KDTree d p r)
deriving instance (Eq p, Eq r, Arity d, KnownNat d)     => Eq   (KDTree d p r)
toMaybe          :: KDTree d p r -> Maybe (KDTree' d p r)
toMaybe Empty    = Nothing
toMaybe (Tree t) = Just t
buildKDTree :: (Arity d, KnownNat d, Index' 0 d, Ord r)
            => [Point d r :+ p] -> KDTree d p r
buildKDTree = maybe Empty (Tree . buildKDTree') . NonEmpty.nonEmpty
buildKDTree' :: (Arity d, KnownNat d, Index' 0 d, Ord r)
             => NonEmpty.NonEmpty (Point d r :+ p) -> KDTree' d p r
buildKDTree' = KDT . addBoxes . build (Coord 1) . toPointSet . Seq.fromNonEmpty
  where     
    addBoxes t = let bbt = foldUpData (\l _ r -> boundingBoxList' [l,r])
                                      (boundingBox . (^.core)) t
                 in zipExactWith (\(SP c m) b -> Split c m b) const t bbt
ordNub :: Ord a => NonEmpty.NonEmpty a -> NonEmpty.NonEmpty a
ordNub = fmap NonEmpty.head . NonEmpty.group1 . NonEmpty.sort
toPointSet :: (Arity d, Ord r)
           => LSeq n (Point d r :+ p) -> PointSet (LSeq n) d p r
toPointSet = FV.imap sort . FV.replicate
  where
    sort i = Seq.unstableSortBy (compareOn $ 1 + i)
compareOn       :: (Ord r, Arity d)
                => Int -> Point d r :+ e -> Point d r :+ e -> Ordering
compareOn i p q = let f = (^.core.unsafeCoord i)
                  in (f p, p^.core) `compare` (f q, q^.core)
build      :: (Index' 0 d, Arity d, KnownNat d, Ord r)
           => Coord d
           -> PointSet (LSeq 1) d p r
           -> BinLeafTree (Split' d r) (Point d r :+ p)
build i ps = case asSingleton ps of
    Left p    -> Leaf p
    Right ps' -> let (l,m,r) = splitOn i ps'
                     j       = succ i
                   
                 in Node (build j l) m (build j r)
reportSubTree :: KDTree' d p r -> NonEmpty.NonEmpty (Point d r :+ p)
reportSubTree = NonEmpty.fromList . F.toList . unKDT
searchKDTree    :: (Arity d, Ord r)
                => Box d q r -> KDTree d p r -> [Point d r :+ p]
searchKDTree qr = maybe [] (searchKDTree' qr) . toMaybe
searchKDTree'                  :: (Arity d, Ord r)
                              => Box d q r -> KDTree' d p r -> [Point d r :+ p]
searchKDTree' qr = search . unKDT
  where
    search (Leaf p)
      | (p^.core) `intersects` qr = [p]
      | otherwise                 = []
    search t@(Node l (Split _ _ b) r)
      | b `containedIn` qr        = F.toList t
      | otherwise                 = l' ++ r'
      where
        l' = if qr `intersects` boxOf l then search l else []
        r' = if qr `intersects` boxOf r then search r else []
boxOf :: (Arity d, Ord r) => BinLeafTree (Split d r) (Point d r :+ p) -> Box d () r
boxOf (Leaf p)                 = boundingBox (p^.core)
boxOf (Node _ (Split _ _ b) _) = b
containedIn :: (Arity d, Ord r) => Box d q r -> Box d p r -> Bool
(Box (CWMin p :+ _) (CWMax q :+ _)) `containedIn` b = all (`intersects` b) [p,q]
type PointSet seq d p r = Vector d (seq (Point d r :+ p))
splitOn                 :: (Arity d, KnownNat d, Ord r)
                        => Coord d
                        -> PointSet (LSeq 2) d p r
                        -> ( PointSet (LSeq 1) d p r
                           , Split' d r
                           , PointSet (LSeq 1) d p r)
splitOn c@(Coord i) pts = (l, SP c (m^.core.unsafeCoord i), r)
  where
    
    m = let xs = fromJust $ pts^?element' (i1)
        in xs `Seq.index` (F.length xs `div` 2)
    
    
    
    f = bimap Seq.promise Seq.promise
      . Seq.partition (\p -> compareOn i p m == LT)
    (l,r) = unzip' . fmap f $ pts
    
    unzip' = bimap vectorFromListUnsafe vectorFromListUnsafe . unzip . F.toList
asSingleton   :: (Index' 0 d, Arity d) => PointSet (LSeq 1) d p r
              -> Either (Point d r :+ p) (PointSet (LSeq 2) d p r)
asSingleton v = case Seq.viewl $ v^.element (C :: C 0) of
                  _ :< _ Seq.:<< _ -> Right $ coerce v
                  p :< _           -> Left p