module Algorithms.Geometry.WellSeparatedPairDecomposition.WSPD where
import           Algorithms.Geometry.WellSeparatedPairDecomposition.Types
import           Control.Lens hiding (Level, levels)
import           Control.Monad.Reader
import           Control.Monad.ST (ST,runST)
import           Data.BinaryTree
import           Data.Ext
import qualified Data.Foldable as F
import           Data.Geometry.Box
import           Data.Geometry.Point
import           Data.Geometry.Properties
import           Data.Geometry.Transformation
import           Data.Geometry.Vector
import qualified Data.Geometry.Vector as GV
import qualified Data.IntMap.Strict as IntMap
import qualified Data.LSeq as LSeq
import           Data.LSeq (LSeq,toSeq,ViewL(..),ViewR(..),pattern (:<|))
import qualified Data.List as L
import qualified Data.List.NonEmpty as NonEmpty
import           Data.Maybe
import           Data.Ord (comparing)
import           Data.Range
import qualified Data.Range as Range
import qualified Data.Sequence as S
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import           GHC.TypeLits
import           Debug.Trace
fairSplitTree     :: (Fractional r, Ord r, Arity d, 1 <= d
                     , Show r, Show p
                     )
                  => NonEmpty.NonEmpty (Point d r :+ p) -> SplitTree d p r ()
fairSplitTree pts = foldUp node' Leaf $ fairSplitTree' n pts'
  where
    pts' = imap sortOn . pure . g $ pts
    n    = length $ pts'^.GV.element (C :: C 0)
    sortOn' i = NonEmpty.sortWith (^.core.unsafeCoord i)
    sortOn  i = LSeq.fromNonEmpty . sortOn' (i + 1)
    
    
    g = NonEmpty.zipWith (\i (p :+ e) -> p :+ (i :+ e)) (NonEmpty.fromList [0..])
      . sortOn' 1
    
    
    node' l j r = Node l (NodeData j (bbOf l <> bbOf r) ()) r
wellSeparatedPairs   :: (Floating r, Ord r, Arity d, Arity (d + 1))
                     => r -> SplitTree d p r a -> [WSP d p r a]
wellSeparatedPairs s = f
  where
    f (Leaf _)     = []
    f (Node l _ r) = findPairs s l r ++ f l ++ f r
fairSplitTree'       :: (Fractional r, Ord r, Arity d, 1 <= d
                        , Show r, Show p
                        )
                     => Int -> GV.Vector d (PointSeq d (Idx :+ p) r)
                     -> BinLeafTree Int (Point d r :+ p)
fairSplitTree' n pts
    | n <= 1    = let p = LSeq.head $ pts^.GV.element (C :: C 0) in Leaf (dropIdx p)
    | otherwise = foldr node' (V.last path) $ V.zip nodeLevels (V.init path)
  where
    
    (levels, nodeLevels'@(maxLvl NonEmpty.:| _)) = runST $ do
        lvls  <- MV.replicate n Nothing
        ls    <- runReaderT (assignLevels (n `div` 2) 0 pts (Level 0 Nothing) []) lvls
        lvls' <- V.unsafeFreeze lvls
        pure (lvls',ls)
    
    
    nodeLevels = V.fromList . L.reverse . NonEmpty.toList $ nodeLevels'
    
    
    distrPts = distributePoints (1 + maxLvl^.unLevel) levels pts
    path = recurse <$> distrPts 
    
    node' (lvl,lc) rc = case lvl^?widestDim._Just of
                          Nothing -> error "Unknown widest dimension"
                          Just j  -> Node lc j rc
    recurse pts' = fairSplitTree' (length $ pts'^.GV.element (C :: C 0))
                                  (reIndexPoints pts')
distributePoints          :: (Arity d , Show r, Show p)
                          => Int -> V.Vector (Maybe Level)
                          -> GV.Vector d (PointSeq d (Idx :+ p) r)
                          -> V.Vector (GV.Vector d (PointSeq d (Idx :+ p) r))
distributePoints k levels = transpose . fmap (distributePoints' k levels)
transpose :: Arity d => GV.Vector d (V.Vector a) -> V.Vector (GV.Vector d a)
transpose = V.fromList . map GV.vectorFromListUnsafe . L.transpose
          . map V.toList . F.toList
distributePoints'              :: Int                      
                               -> V.Vector (Maybe Level)   
                               -> PointSeq d (Idx :+ p) r  
                               -> V.Vector (PointSeq d (Idx :+ p) r)
distributePoints' k levels pts
  | otherwise
  = fmap fromSeqUnsafe $ V.create $ do
    v <- MV.replicate k mempty
    forM_ pts $ \p ->
      append v (level p) p
    pure v
  where
    level p = maybe (k-1) _unLevel $ levels V.! (p^.extra.core)
    append v i p = MV.read v i >>= MV.write v i . (S.|> p)
fromSeqUnsafe = LSeq.promise . LSeq.fromSeq
reIndexPoints      :: (Arity d, 1 <= d)
                   => GV.Vector d (PointSeq d (Idx :+ p) r)
                   -> GV.Vector d (PointSeq d (Idx :+ p) r)
reIndexPoints ptsV = fmap reIndex ptsV
  where
    pts = ptsV^.GV.element (C :: C 0)
    reIndex = fmap (\p -> p&extra.core %~ fromJust . flip IntMap.lookup mapping')
    mapping' = IntMap.fromAscList $ zip (map (^.extra.core) . F.toList $ pts) [0..]
type RST s = ReaderT (MV.MVector s (Maybe Level)) (ST s)
assignLevels                  :: (Fractional r, Ord r, Arity d
                                 , Show r, Show p
                                 )
                              => Int 
                              -> Int 
                              -> GV.Vector d (PointSeq d (Idx :+ p) r)
                              -> Level 
                              -> [Level] 
                              -> RST s (NonEmpty.NonEmpty Level)
assignLevels h m pts l prevLvls
  | m >= h    = pure (l NonEmpty.:| prevLvls)
  | otherwise = do
    pts' <- compactEnds pts
    
    let j    = widestDimension pts'
        i    = j - 1 
        extJ = (extends pts')^.ix' i
        mid  = midPoint extJ
    
    
    
    (lvlJPts,deletePts) <- findAndCompact j (pts'^.ix' i) mid
    let pts''     = pts'&ix' i .~ lvlJPts
        l'        = l&widestDim .~ Just j
    forM_ deletePts $ \p ->
      assignLevel p l'
    assignLevels h (m + length deletePts) pts'' (nextLevel l) (l' : prevLvls)
compactEnds        :: Arity d
                   => GV.Vector d (PointSeq d (Idx :+ p) r)
                   -> RST s (GV.Vector d (PointSeq d (Idx :+ p) r))
compactEnds = traverse compactEnds'
assignLevel     :: (c :+ (Idx :+ p)) -> Level -> RST s ()
assignLevel p l = ask >>= \levels -> lift $ MV.write levels (p^.extra.core) (Just l)
levelOf   :: (c :+ (Idx :+ p)) -> RST s (Maybe Level)
levelOf p = ask >>= \levels -> lift $ MV.read levels (p^.extra.core)
hasLevel :: c :+ (Idx :+ p) -> RST s Bool
hasLevel = fmap isJust . levelOf
compactEnds'              :: PointSeq d (Idx :+ p) r
                          -> RST s (PointSeq d (Idx :+ p) r)
compactEnds' (l0 :<| s0) = fmap fromSeqUnsafe . goL $ l0 S.<| toSeq s0
  where
    goL s@(S.viewl -> l S.:< s') = hasLevel l >>= \case
                                     False -> goR s
                                     True  -> goL s'
    goR s@(S.viewr -> s' S.:> r) = hasLevel r >>= \case
                                     False -> pure s
                                     True  -> goR s'
findAndCompact                   :: (Ord r, Arity d
                                    , Show r, Show p
                                    )
                                 => Int
                                    
                                    
                                 -> PointSeq d (Idx :+ p) r
                                 -> r 
                                 -> RST s ( PointSeq d (Idx :+ p) r
                                          , PointSeq d (Idx :+ p) r
                                          )
findAndCompact j (l0 :<| s0) m = fmap select . stepL $ l0 S.<| toSeq s0
  where
    
    
    
    
    
    
    
    
    
    
    stepL s = case S.viewl s of
      S.EmptyL  -> pure $ FAC mempty mempty L
      l S.:< s' -> hasLevel l >>= \case
                     False -> if l^.core.unsafeCoord j <= m
                                 then addL l <$> stepR s'
                                 else pure $ FAC mempty s L
                     True  -> stepL s' 
    
    stepR s = case S.viewr s of
      S.EmptyR  -> pure $ FAC mempty mempty R
      s' S.:> r -> hasLevel r >>= \case
                     False -> if r^.core.unsafeCoord j >= m
                                 then addR r <$> stepL s'
                                 else pure $ FAC s mempty R
                     True  -> stepR s'
    addL l x = x&leftPart  %~ (l S.<|)
    addR r x = x&rightPart %~ (S.|> r)
    select = over both fromSeqUnsafe . select'
    
    select' (FAC l r L) = (r, l)
    select' (FAC l r R) = (l, r)
widestDimension :: (Num r, Ord r, Arity d) => GV.Vector d (PointSeq d p r) -> Int
widestDimension = fst . L.maximumBy (comparing snd) . zip [1..] . F.toList . widths
widths :: (Num r, Arity d) => GV.Vector d (PointSeq d p r) -> GV.Vector d r
widths = fmap Range.width . extends
extends :: Arity d => GV.Vector d (PointSeq d p r) -> GV.Vector d (Range r)
extends = imap (\i pts ->
                     ClosedRange ((LSeq.head pts)^.core.unsafeCoord (i + 1))
                                 ((LSeq.last pts)^.core.unsafeCoord (i + 1)))
findPairs                     :: (Floating r, Ord r, Arity d, Arity (d + 1))
                              => r -> SplitTree d p r a -> SplitTree d p r a
                              -> [WSP d p r a]
findPairs s l r
  | areWellSeparated' s l r   = [(l,r)]
  | maxWidth l <=  maxWidth r = concatMap (findPairs s l) $ children' r
  | otherwise                 = concatMap (findPairs s r) $ children' l
areWellSeparated                     :: (Arity d, Arity (d + 1), Fractional r, Ord r)
                                     => r 
                                     -> SplitTree d p r a
                                     -> SplitTree d p r a -> Bool
areWellSeparated _ (Leaf _) (Leaf _) = True
areWellSeparated s l        r        = boxBox s (bbOf l)   (bbOf r)
boxBox         :: (Fractional r, Ord r, Arity d, Arity (d + 1))
               => r -> Box d p r -> Box d p r -> Bool
boxBox s lb rb = boxBox' lb rb && boxBox' rb lb
  where
    boxBox' b' b = not $ b' `intersects` bOut
      where
        v    = (centerPoint b)^.vector
        bOut = translateBy v . scaleUniformlyBy s . translateBy ((-1) *^ v) $ b
areWellSeparated'                     :: (Floating r, Ord r, Arity d)
                                      => r
                                      -> SplitTree d p r a
                                      -> SplitTree d p r a
                                      -> Bool
areWellSeparated' _ (Leaf _) (Leaf _) = True
areWellSeparated' s l        r        = boxBox1 s (bbOf l) (bbOf r)
boxBox1         :: (Floating r, Ord r, Arity d) => r -> Box d p r -> Box d p r -> Bool
boxBox1 s lb rb = euclideanDist (centerPoint lb) (centerPoint rb) >= (s+1)*d
  where
    diam b = euclideanDist (b^.minP.core.cwMin) (b^.maxP.core.cwMax)
    d      = max (diam lb) (diam rb)
maxWidth                             :: (Arity d, Num r)
                                     => SplitTree d p r a -> r
maxWidth (Leaf _)                    = 0
maxWidth (Node _ (NodeData i b _) _) = fromJust $ widthIn' i b
bbOf                             :: Ord r => SplitTree d p r a -> Box d () r
bbOf (Leaf p)                    = boundingBox $ p^.core
bbOf (Node _ (NodeData _ b _) _) = b
children'              :: BinLeafTree v a -> [BinLeafTree v a]
children' (Leaf _)     = []
children' (Node l _ r) = [l,r]
ix'   :: (Arity d, KnownNat d) => Int -> Lens' (GV.Vector d a) a
ix' i = singular (GV.element' i)
dropIdx                 :: core :+ (t :+ extra) -> core :+ extra
dropIdx (p :+ (_ :+ e)) = p :+ e