module Control.Monad.ST.Logic.Internal
( LogicT
, runLogicT
, observeT
, observeAllT
, observeManyT
, liftST
, runLogicST
, observeST
, observeAllST
, observeManyST
, Ref
, newRef
, readRef
, writeRef
, modifyRef
, modifyRef'
) where
import Control.Applicative
import Control.Monad
import Control.Monad.IO.Class
import qualified Control.Monad.Logic as Logic
import Control.Monad.Logic.Class
#ifdef MODULE_Control_Monad_ST_Safe
import Control.Monad.ST.Safe
#else
import Control.Monad.ST
#endif
import Control.Monad.Trans.Class
import Control.Monad.Trans.State.Strict (StateT, evalStateT)
import qualified Control.Monad.Trans.State.Strict as State
import qualified Data.STRef as ST
class Monad m => MonadST m where
type World m
liftST :: ST (World m) a -> m a
instance MonadST (ST s) where
type World (ST s) = s
liftST = id
instance MonadST IO where
type World IO = RealWorld
liftST = stToIO
newtype LogicT s m a =
LogicT { unLogicT :: StateT (Switch m) (Logic.LogicT m) a
}
runLogicT :: MonadST m => (forall s . LogicT s m a) -> (a -> m r -> m r) -> m r -> m r
runLogicT m = unsafeRunLogicT m
runLogicST :: (forall s . LogicT s (ST s) a) -> (a -> r -> r) -> r -> r
runLogicST m next zero = runST $ unsafeRunLogicT m (liftM . next) (return zero)
unsafeRunLogicT :: MonadST m => LogicT s m a -> (a -> m r -> m r) -> m r -> m r
unsafeRunLogicT m next zero = do
s <- newSwitch
Logic.runLogicT (evalStateT (unLogicT m) s) next zero
observeT :: MonadST m => (forall s . LogicT s m a) -> m a
observeT m = unsafeObserveT m
observeST :: (forall s . LogicT s (ST s) a) -> a
observeST m = runST $ unsafeObserveT m
unsafeObserveT :: MonadST m => LogicT s m a -> m a
unsafeObserveT m = do
s <- newSwitch
Logic.observeT (evalStateT (unLogicT m) s)
observeAllT :: MonadST m => (forall s . LogicT s m a) -> m [a]
observeAllT m = unsafeObserveAllT m
observeAllST :: (forall s . LogicT s (ST s) a) -> [a]
observeAllST m = runST $ unsafeObserveAllT m
unsafeObserveAllT :: MonadST m => LogicT s m a -> m [a]
unsafeObserveAllT m = do
s <- newSwitch
Logic.observeAllT (evalStateT (unLogicT m) s)
observeManyT :: MonadST m => Int -> (forall s . LogicT s m a) -> m [a]
observeManyT n m = unsafeObserveManyT n m
observeManyST :: Int -> (forall s . LogicT s (ST s) a) -> [a]
observeManyST n m = runST $ unsafeObserveManyT n m
unsafeObserveManyT :: MonadST m => Int -> LogicT s m a -> m [a]
unsafeObserveManyT n m = do
s <- newSwitch
Logic.observeManyT n (evalStateT (unLogicT m) s)
instance Functor (LogicT s m) where
fmap f = LogicT . fmap f . unLogicT
instance Applicative (LogicT s m) where
pure = LogicT . pure
f <*> a = LogicT $ unLogicT f <*> unLogicT a
#ifndef CLASS_OldApplicative
a *> b = LogicT $ unLogicT a *> unLogicT b
a <* b = LogicT $ unLogicT a <* unLogicT b
#endif
instance MonadST m => Alternative (LogicT s m) where
empty = LogicT empty
(<|>) = plusLogic
instance Monad (LogicT s m) where
return = LogicT . return
m >>= k = LogicT $ unLogicT m >>= unLogicT . k
m >> n = LogicT $ unLogicT m >> unLogicT n
fail = LogicT . fail
instance MonadST m => MonadPlus (LogicT s m) where
mzero = LogicT mzero
mplus = plusLogic
plusLogic :: MonadST m => LogicT s m a -> LogicT s m a -> LogicT s m a
plusLogic m n = do
s <- newSwitch
LogicT $ unLogicT (put s *> m) <|> unLogicT (flipSwitch s *> n)
instance MonadST m => MonadLogic (LogicT s m) where
msplit = LogicT . fmap (fmap (fmap LogicT)) . msplit . unLogicT
liftLogic :: Monad m => m a -> LogicT s m a
liftLogic = LogicT . lift . lift
instance MonadIO m => MonadIO (LogicT s m) where
liftIO = liftLogic . liftIO
instance MonadST m => MonadST (LogicT s m) where
type World (LogicT s m) = World m
liftST = liftLogic . liftST
get :: Monad m => LogicT s m (Switch m)
get = LogicT State.get
put :: Monad m => Switch m -> LogicT s m ()
put s = s `seq` LogicT (State.put s)
type Switch m = ST.STRef (World m) Bool
newSwitch :: MonadST m => m (Switch m)
newSwitch = liftST $ ST.newSTRef False
flipSwitch :: MonadST m => Switch m -> m ()
flipSwitch = liftST . flip ST.writeSTRef True
ifFlipped :: Switch (ST s) -> ST s a -> ST s a -> ST s a
ifFlipped switch t f = do
p <- ST.readSTRef switch
if p then t else f
newtype Ref s m a = Ref (ST.STRef (World m) (Value m a))
data Value m a
= New !(Write m a)
| !(Write m a) :| !(Value m a)
data Write m a = Write !(Switch m) a
newRef :: MonadST m => a -> LogicT s m (Ref s m a)
newRef a = get >>= liftST . fmap Ref . newSTRef a
newSTRef :: a -> Switch m -> ST (World m) (ST.STRef (World m) (Value m a))
newSTRef a = ST.newSTRef .! New . flip Write a
infixr 9 .!
(.!) :: (b -> c) -> (a -> b) -> a -> c
f .! g = \ a -> a `seq` f (g a)
readRef :: MonadST m => Ref s m a -> LogicT s m a
readRef (Ref ref) = liftST $ readSTRef ref
readSTRef :: ST.STRef (World m) (Value m a) -> ST (World m) a
readSTRef ref = ST.readSTRef ref >>= \ value -> case value of
Write switch a :| xs -> ifFlipped switch (backtrack xs) $ return a
New (Write _ a) -> return a
where
backtrack xs@(Write switch a :| ys) =
ifFlipped switch (backtrack ys) $
ST.writeSTRef ref xs >> return a
backtrack xs@(New (Write _ a)) =
ST.writeSTRef ref xs >> return a
writeRef :: MonadST m => Ref s m a -> a -> LogicT s m ()
writeRef ref a = modifyRef'' ref $ \ switch _ -> Write switch a
modifyRef :: MonadST m => Ref s m a -> (a -> a) -> LogicT s m ()
modifyRef ref f = modifyRef'' ref $ \ switch a -> Write switch $ f a
modifyRef' :: MonadST m => Ref s m a -> (a -> a) -> LogicT s m ()
modifyRef' ref f = modifyRef'' ref $ \ switch a -> Write switch $! f a
modifyRef'' :: MonadST m => Ref s m a -> (Switch m -> a -> Write m a) -> LogicT s m ()
modifyRef'' (Ref ref) f = get >>= \ r -> liftST $ modifySTRef ref f r
modifySTRef :: ST.STRef (World m) (Value m a) ->
(Switch m -> a -> Write m a) ->
Switch m ->
ST (World m) ()
modifySTRef ref f = \ r -> ST.readSTRef ref >>= \ value -> backtrack value r
where
backtrack xs@(Write switch a :| ys) r =
ifFlipped switch
(backtrack ys r)
(ST.writeSTRef ref $! f r a :| if switch == r then ys else xs)
backtrack xs@(New (Write switch a)) r =
ST.writeSTRef ref $!
if switch == r then New (f r a) else f r a :| xs