{-# LANGUAGE ScopedTypeVariables #-}
module Data.PlanarGraph.EdgeOracle where
import           Control.Applicative (Alternative(..))
import           Control.Lens hiding ((.=))
import           Control.Monad.ST (ST)
import           Control.Monad.State.Strict
import           Data.Bitraversable
import           Data.Ext
import qualified Data.Foldable as F
import           Data.Maybe (catMaybes, isJust)
import           Data.PlanarGraph.Core
import           Data.PlanarGraph.Dart
import           Data.Traversable (fmapDefault,foldMapDefault)
import qualified Data.Vector as V
import qualified Data.Vector.Generic as GV
import qualified Data.Vector.Mutable as MV
import qualified Data.Vector.Unboxed as UV
import qualified Data.Vector.Unboxed.Mutable as UMV
newtype EdgeOracle s w a =
  EdgeOracle { _unEdgeOracle :: V.Vector (V.Vector (VertexId s w :+ a)) }
                         deriving (Show,Eq)
instance Functor (EdgeOracle s w) where
  fmap = fmapDefault
instance Foldable (EdgeOracle s w) where
  foldMap = foldMapDefault
instance Traversable (EdgeOracle s w) where
  traverse f (EdgeOracle v) = EdgeOracle <$> traverse g v
    where
      
      g = traverse (bitraverse pure f)
edgeOracle   :: PlanarGraph s w v e f -> EdgeOracle s w (Dart s)
edgeOracle g = buildEdgeOracle [ (v, mkAdjacency v <$> incidentEdges v g)
                               | v <- F.toList $ vertices' g
                               ]
  where
    mkAdjacency v d = otherVtx v d :+ d
    otherVtx v d = let u = tailOf d g in if u == v then headOf d g else u
buildEdgeOracle        :: forall f s w e. (Foldable f)
                       => [(VertexId s w, f (VertexId s w :+ e))] -> EdgeOracle s w e
buildEdgeOracle inAdj' = EdgeOracle $ V.create $ do
                          counts <- UV.thaw initCounts
                          marks  <- UMV.replicate (UMV.length counts) False
                          outV   <- MV.new (UMV.length counts)
                          build counts marks outV initQ
                          pure outV
    
    
    
    
  where
    
    inAdj = V.create $ do
              mv <- MV.new (length inAdj')
              forM_ inAdj' $ \(VertexId i,adjI) ->
                MV.write mv i (V.fromList . F.toList $ adjI)
              pure mv
    initCounts = V.convert . fmap GV.length $ inAdj
    
    initQ = GV.ifoldr (\i k q -> if k <= 6 then i : q else q) [] initCounts
    
    
    extractAdj         :: UMV.MVector s' Bool -> Int
                       -> ST s' (V.Vector (VertexId s w :+ e))
    extractAdj marks i = let p = fmap not . UMV.read marks . (^.core.unVertexId)
                         in GV.filterM  p $ inAdj V.! i
    
    
    decrease                          :: UMV.MVector s' Int -> (VertexId s w :+ e')
                                      -> ST s' (Maybe Int)
    decrease counts (VertexId j :+ _) = do k <- UMV.read counts j
                                           let k'  = k - 1
                                           UMV.write counts j k'
                                           pure $ if k' <= 6 then Just j else Nothing
    
    build :: UMV.MVector s' Int -> UMV.MVector s' Bool
          -> MV.MVector s' (V.Vector (VertexId s w :+ e)) -> [Int] -> ST s' ()
    build _      _     _    []    = pure ()
    build counts marks outV (i:q) = do
             b <- UMV.read marks i
             nq <- if b then pure []
                        else do
                          adjI <- extractAdj marks i
                          MV.write outV i adjI
                          UMV.write marks i True
                          V.toList <$> mapM (decrease counts) adjI
             build counts marks outV (catMaybes nq <> q)
hasEdge     :: VertexId s w -> VertexId s w -> EdgeOracle s w a -> Bool
hasEdge u v = isJust . findEdge u v
findEdge :: VertexId s w -> VertexId s w -> EdgeOracle s w a -> Maybe a
findEdge  (VertexId u) (VertexId v) (EdgeOracle os) = find' u v <|> find' v u
  where
    find' j i = fmap (^.extra) . F.find (\(VertexId k :+ _) -> j == k) $ os V.! i
findDart :: VertexId s w -> VertexId s w -> EdgeOracle s w (Dart s) -> Maybe (Dart s)
findDart (VertexId u) (VertexId v) (EdgeOracle os) = find' twin u v <|> find' id v u
  where
    
    find' f j i = fmap (f . (^.extra)) . F.find (\(VertexId k :+ _) -> j == k) $ os V.! i