module Control.Monad.State.Concurrent.Strict (
module Control.Monad.State,
StateC,
runStateC, evalStateC, execStateC,
liftCallCCC, liftCallCCC', liftCatchC, liftListenC, liftPassC
) where
import Control.Applicative
import Control.Arrow (first)
import Control.Concurrent.STM
import Control.Monad
import Control.Monad.State
newtype StateC s m a = StateC { _runStateC :: TVar s -> m (a, TVar s) }
instance MonadTrans (StateC s) where
lift m = StateC $ \s -> do
a <- m
return (a, s)
instance (Functor m, MonadIO m) => Functor (StateC s m) where
fmap f m = StateC $ \s ->
fmap (first f) $ _runStateC m s
instance (Functor m, MonadIO m) => Applicative (StateC s m) where
pure = return
(<*>) = ap
instance (MonadIO m, Functor m, MonadPlus m) => Alternative (StateC s m) where
empty = mzero
(<|>) = mplus
instance (MonadPlus m, MonadIO m) => MonadPlus (StateC s m) where
mzero = StateC $ const mzero
m `mplus` n = StateC $ \s -> _runStateC m s `mplus` _runStateC n s
instance MonadIO m => Monad (StateC s m) where
return a = StateC $ \s -> return (a, s)
m >>= k = StateC $ \s -> do
(a, s') <- _runStateC m s
_runStateC (k a) s'
instance (Functor m, MonadIO m) => MonadState s (StateC s m) where
state f = StateC $ \tv -> do
newval <- liftIO . atomically $ do
old <- readTVar tv
let (a, s) = f old
swapTVar tv s
return a
return (newval, tv)
instance (MonadIO m, MonadFix m) => MonadFix (StateC s m) where
mfix f = StateC $ \s -> mfix $ \(a, _) -> _runStateC (f a) s
instance MonadIO m => MonadIO (StateC s m) where
liftIO i = StateC $ \s -> do
a <- liftIO i
return (a, s)
runStateC :: MonadIO m
=> StateC s m a
-> TVar s
-> m (a, s)
runStateC m s = do
(a, b) <- _runStateC m s
r <- liftIO $ readTVarIO b
return (a, r)
evalStateC :: MonadIO m
=> StateC s m a
-> TVar s
-> m a
evalStateC m s = liftM fst $ runStateC m s
execStateC :: MonadIO m
=> StateC s m a
-> TVar s
-> m s
execStateC m s = liftM snd $ runStateC m s
liftCallCCC :: ((((a, TVar s) -> m (b, TVar s)) -> m (a, TVar s)) -> m (a, TVar s)) ->
((a -> StateC s m b) -> StateC s m a) -> StateC s m a
liftCallCCC callCC f = StateC $ \tv ->
callCC $ \c ->
_runStateC (f (\a -> StateC $ \_ -> c (a, tv))) tv
liftCallCCC' :: ((((a, TVar s) -> m (b, TVar s)) -> m (a, TVar s))-> m (a, TVar s)) ->
((a -> StateC s m b) -> StateC s m a) -> StateC s m a
liftCallCCC' callCC f = StateC $ \tv ->
callCC $ \c ->
_runStateC (f (\a -> StateC $ \s' -> c (a, s'))) tv
liftCatchC :: (m (a, TVar s) -> (e -> m (a, TVar s)) -> m (a, TVar s)) ->
StateC s m a -> (e -> StateC s m a) -> StateC s m a
liftCatchC catchError m h =
StateC $ \s -> _runStateC m s `catchError` \e -> _runStateC (h e) s
liftListenC :: Monad m =>
(m (a, TVar s) -> m ((a, TVar s), w)) -> StateC s m a -> StateC s m (a,w)
liftListenC listen m = StateC $ \tv -> do
((a, s'), w) <- listen (_runStateC m tv)
return ((a, w), s')
liftPassC :: Monad m =>
(m ((a, TVar s), b) -> m (a, TVar s)) -> StateC s m (a, b) -> StateC s m a
liftPassC pass m = StateC $ \tv -> pass $ do
((a, f), s') <- _runStateC m tv
return ((a, s'), f)