{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
module Control.Monad.Search
(
Search
, runSearch
, runSearchBest
, SearchT
, runSearchT
, runSearchBestT
, MonadSearch
, cost
, cost'
, junction
, abandon
, seal
, collapse
, winner
) where
import Control.Applicative ( Alternative(..) )
import Control.Monad ( MonadPlus(..) )
import Control.Monad.Cont ( MonadCont )
import Control.Monad.Except ( ExceptT(..), MonadError
, runExceptT )
import Control.Monad.IO.Class ( MonadIO )
import qualified Control.Monad.RWS.Lazy as Lazy ( MonadRWS, RWST(..)
, runRWST )
import qualified Control.Monad.RWS.Strict as Strict ( RWST(..), runRWST )
import Control.Monad.Reader ( MonadReader, ReaderT(..)
, runReaderT )
import qualified Control.Monad.State.Lazy as Lazy ( MonadState
, StateT(..)
, runStateT )
import qualified Control.Monad.State.Strict as Strict ( StateT(..)
, runStateT )
import Control.Monad.Trans.Class ( MonadTrans, lift )
import Control.Monad.Trans.Free ( FreeF(Free, Pure), FreeT
, runFreeT, wrap )
import Control.Monad.Trans.Free.Church ( FT, fromFT )
import Control.Monad.Trans.State ( evalStateT, gets, modify )
import qualified Control.Monad.Writer.Lazy as Lazy ( MonadWriter
, WriterT(..)
, runWriterT )
import qualified Control.Monad.Writer.Strict as Strict ( WriterT(..)
, runWriterT )
import Data.Functor.Identity ( Identity, runIdentity )
import qualified Data.IntPSQ as PSQ
import qualified Data.IntMap.Strict as Map
import Data.Maybe ( catMaybes, listToMaybe )
import qualified Data.IntSet as Set
newtype Scopemap = Scopemap { unScopemap :: Map.IntMap Set.IntSet }
singleton :: Int -> Int -> Scopemap
singleton k = Scopemap . Map.singleton k . Set.singleton
insert :: Int -> Int -> Scopemap -> Scopemap
insert k v = Scopemap . Map.alter fn k . unScopemap
where
fn = Just . maybe (Set.singleton v) (Set.insert v)
delete :: Int -> Int -> Scopemap -> Scopemap
delete k v = Scopemap . Map.update fn k . unScopemap
where
fn = (\s -> if Set.null s then Nothing else Just s) . Set.delete v
list :: Int -> Scopemap -> [Int]
list k = maybe [] Set.toList . Map.lookup k . unScopemap
listAll :: Scopemap -> [Int]
listAll = Set.toList . foldr (Set.union . snd) Set.empty . Map.toList . unScopemap
type Search c = SearchT c Identity
runSearch :: (Ord c, Monoid c) => Search c a -> [(c, a)]
runSearch = runIdentity . runSearchT
runSearchBest :: (Ord c, Monoid c) => Search c a -> Maybe (c, a)
runSearchBest = runIdentity . runSearchBestT
data SearchF c a = Cost c c a
| Alt a a
| Enter a
| Exit a
| Collapse a
| Abandon
deriving Functor
newtype SearchT c m a = SearchT { unSearchT :: FT (SearchF c) m a }
deriving (Functor, Applicative, Monad, MonadTrans, MonadIO, MonadReader r, Lazy.MonadWriter w, Lazy.MonadState s, MonadError e, MonadCont)
instance (Ord c, Monoid c, Monad m) => Alternative (SearchT c m) where
empty = abandon
(<|>) = junction
instance (Ord c, Monoid c, Monad m) => MonadPlus (SearchT c m)
deriving instance Lazy.MonadRWS r w s m => Lazy.MonadRWS r w s (SearchT c m)
data Cand c m a = Cand { candCost :: !c
, candScope :: ![Int]
, candPath :: FreeT (SearchF c) m a
}
data St c m a = St { stNum :: !Int
, stScope :: !Int
, stActive :: !Scopemap
, stQueue :: !(PSQ.IntPSQ c (Cand c m a))
}
runSearchT :: (Ord c, Monoid c, Monad m) => SearchT c m a -> m [(c, a)]
runSearchT m = catMaybes <$> evalStateT go state
where
go = do
mmin <- gets (PSQ.minView . stQueue)
case mmin of
Nothing -> return []
Just (num, prio, cand, q) -> do
updateQueue $ const q
(:) <$> step num prio cand <*> go
step num prio cand@Cand{..} = do
path' <- lift $ runFreeT candPath
case path' of
Pure a -> return $ Just (candCost, a)
Free Abandon -> return Nothing
Free (Cost c e p) ->
let newCost = candCost `mappend` c
newPriority = max prio $ newCost `mappend` e
in do
reschedule <- gets (maybe False
(\(_, x, _) -> x <= newPriority) .
PSQ.findMin . stQueue)
let cand' = cand { candCost = newCost, candPath = p }
if reschedule
then do
updateQueue $ PSQ.insert num newPriority cand'
return Nothing
else step num newPriority cand'
Free (Alt lhs rhs) -> do
num' <- nextNum
addScopes candScope num'
updateQueue $ PSQ.insert num' prio cand { candPath = rhs }
step num prio cand { candPath = lhs }
Free (Enter p) -> do
scope <- nextScope
addScope scope num
step num
prio
cand { candScope = scope : candScope, candPath = p }
Free (Exit p) -> do
delScope (head candScope) num
step num prio cand { candScope = tail candScope, candPath = p }
Free (Collapse p) -> do
cs <- listScope $ listToMaybe candScope
updateQueue $ \q -> foldr PSQ.delete q cs
step num prio cand { candPath = p }
nextNum = do
modify $ \s -> s { stNum = stNum s + 1 }
gets stNum
nextScope = do
modify $ \s -> s { stScope = stScope s + 1 }
gets stScope
addScope scope c = updateActive $ insert scope c
addScopes scopes c = updateActive $ \sm -> foldr (`insert` c) sm scopes
delScope scope c = updateActive $ delete scope c
listScope scope = gets $ maybe listAll list scope . stActive
updateQueue f = modify $ \s -> s { stQueue = f (stQueue s) }
updateActive f = modify $ \s -> s { stActive = f (stActive s) }
state = St 0 0 (singleton 0 0) queue
queue = PSQ.singleton 0 mempty (Cand mempty [ 0 ] (fromFT $ unSearchT m))
runSearchBestT :: (Ord c, Monoid c, Monad m) => SearchT c m a -> m (Maybe (c, a))
runSearchBestT m = listToMaybe <$> runSearchT (m <* collapse)
class (Ord c, Monoid c, Monad m) => MonadSearch c m | m -> c where
cost :: c -> c -> m ()
junction :: m a -> m a -> m a
abandon :: m a
seal :: m a -> m a
collapse :: m ()
instance (Ord c, Monoid c, Monad m) => MonadSearch c (SearchT c m) where
cost c e = SearchT . wrap $ Cost c e (return ())
junction lhs rhs = SearchT . wrap $ Alt (unSearchT lhs) (unSearchT rhs)
abandon = SearchT . wrap $ Abandon
seal m = SearchT . wrap $ Enter (unSearchT m >>= wrap . Exit . return)
collapse = SearchT . wrap $ Collapse (return ())
instance MonadSearch c m => MonadSearch c (ReaderT r m) where
cost c e = lift $ cost c e
junction lhs rhs = ReaderT $
\r -> junction (runReaderT lhs r) (runReaderT rhs r)
abandon = lift abandon
seal m = ReaderT $ \r -> seal (runReaderT m r)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Lazy.WriterT w m) where
cost c e = lift $ cost c e
junction lhs rhs = Lazy.WriterT $
junction (Lazy.runWriterT lhs) (Lazy.runWriterT rhs)
abandon = lift abandon
seal m = Lazy.WriterT $ seal (Lazy.runWriterT m)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Strict.WriterT w m) where
cost c e = lift $ cost c e
junction lhs rhs = Strict.WriterT $
junction (Strict.runWriterT lhs) (Strict.runWriterT rhs)
abandon = lift abandon
seal m = Strict.WriterT $ seal (Strict.runWriterT m)
collapse = lift collapse
instance MonadSearch c m => MonadSearch c (Lazy.StateT s m) where
cost c e = lift $ cost c e
junction lhs rhs = Lazy.StateT $
\s -> junction (Lazy.runStateT lhs s) (Lazy.runStateT rhs s)
abandon = lift abandon
seal m = Lazy.StateT $ \s -> seal (Lazy.runStateT m s)
collapse = lift collapse
instance MonadSearch c m => MonadSearch c (Strict.StateT s m) where
cost c e = lift $ cost c e
junction lhs rhs = Strict.StateT $
\s -> junction (Strict.runStateT lhs s) (Strict.runStateT rhs s)
abandon = lift abandon
seal m = Strict.StateT $ \s -> seal (Strict.runStateT m s)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Lazy.RWST r w s m) where
cost c e = lift $ cost c e
junction lhs rhs = Lazy.RWST $
\r s -> junction (Lazy.runRWST lhs r s) (Lazy.runRWST rhs r s)
abandon = lift abandon
seal m = Lazy.RWST $ \r s -> seal (Lazy.runRWST m r s)
collapse = lift collapse
instance (Monoid w, MonadSearch c m) => MonadSearch c (Strict.RWST r w s m) where
cost c e = lift $ cost c e
junction lhs rhs = Strict.RWST $
\r s -> junction (Strict.runRWST lhs r s) (Strict.runRWST rhs r s)
abandon = lift abandon
seal m = Strict.RWST $ \r s -> seal (Strict.runRWST m r s)
collapse = lift collapse
instance MonadSearch c m => MonadSearch c (ExceptT e m) where
cost c e = lift $ cost c e
junction lhs rhs = ExceptT $ junction (runExceptT lhs) (runExceptT rhs)
abandon = lift abandon
seal m = ExceptT $ seal (runExceptT m)
collapse = lift collapse
cost' :: MonadSearch c m => c -> m ()
cost' c = cost c mempty
winner :: MonadSearch c m => m a -> m a
winner m = seal $ m <* collapse