module Control.Concurrent.MState
(
MState
, runMState
, evalMState
, execMState
, mapMState
, withMState
, modifyM
, Forkable (..)
, forkM
) where
import Control.Monad.State.Class
import Control.Monad.Cont
import Control.Monad.Error
import Control.Monad.Reader
import Control.Monad.Writer
import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.IO.Peel
import Control.Exception.Peel
newtype MState t m a = MState { runMState' :: (TVar t, TVar [TMVar ()]) -> m a }
class (MonadPeelIO m) => Forkable m where
fork :: m () -> m ThreadId
instance Forkable IO where
fork = forkIO
instance Forkable (ReaderT s IO) where
fork newT = ask >>= liftIO . forkIO . runReaderT newT
waitForTermination :: MonadIO m
=> TVar [TMVar ()]
-> m ()
waitForTermination = liftIO . atomically . (mapM_ takeTMVar <=< readTVar)
runMState :: Forkable m
=> MState t m a
-> t
-> m (a,t)
runMState m t = do
ref <- liftIO $ newTVarIO t
c <- liftIO $ newTVarIO []
mv <- liftIO newEmptyMVar
_ <- runMState' (forkM $ m >>= liftIO . putMVar mv) (ref, c)
waitForTermination c
a <- liftIO $ takeMVar mv
t' <- liftIO . atomically $ readTVar ref
return (a,t')
evalMState :: Forkable m
=> MState t m a
-> t
-> m a
evalMState m t = runMState m t >>= return . fst
execMState :: Forkable m
=> MState t m a
-> t
-> m t
execMState m t = runMState m t >>= return . snd
mapMState :: (MonadIO m, MonadIO n)
=> (m (a,t) -> n (b,t))
-> MState t m a
-> MState t n b
mapMState f m = MState $ \s@(r,_) -> do
~(b,v') <- f $ do
a <- runMState' m s
v <- liftIO . atomically $ readTVar r
return (a,v)
liftIO . atomically $ writeTVar r v'
return b
withMState :: (MonadIO m)
=> (t -> t)
-> MState t m a
-> MState t m a
withMState f m = MState $ \s@(r,_) -> do
liftIO . atomically $ do
v <- readTVar r
writeTVar r (f v)
runMState' m s
modifyM :: (MonadIO m) => (t -> t) -> MState t m ()
modifyM f = MState $ \(t,_) ->
liftIO . atomically $ do
v <- readTVar t
writeTVar t (f v)
forkM :: Forkable m
=> MState t m ()
-> MState t m ThreadId
forkM m = MState $ \s@(_,c) -> do
w <- liftIO newEmptyTMVarIO
liftIO . atomically $ do
r <- readTVar c
writeTVar c (w:r)
fork $
runMState' m s `finally` liftIO (atomically $ putTMVar w ())
instance (Monad m) => Monad (MState t m) where
return a = MState $ \_ -> return a
m >>= k = MState $ \t -> do
a <- runMState' m t
runMState' (k a) t
fail str = MState $ \_ -> fail str
instance (Monad m) => Functor (MState t m) where
fmap f m = MState $ \t -> do
a <- runMState' m t
return (f a)
instance (MonadPlus m) => MonadPlus (MState t m) where
mzero = MState $ \_ -> mzero
m `mplus` n = MState $ \t -> runMState' m t `mplus` runMState' n t
instance (MonadIO m) => MonadState t (MState t m) where
get = MState $ \(r,_) -> liftIO . atomically $ readTVar r
put val = MState $ \(r,_) -> liftIO . atomically $ writeTVar r val
instance (MonadFix m) => MonadFix (MState t m) where
mfix f = MState $ \s -> mfix $ \a -> runMState' (f a) s
instance MonadTrans (MState t) where
lift m = MState $ \_ -> m
instance (MonadIO m) => MonadIO (MState t m) where
liftIO = lift . liftIO
instance (MonadCont m) => MonadCont (MState t m) where
callCC f = MState $ \s ->
callCC $ \c ->
runMState' (f (\a -> MState $ \_ -> c a)) s
instance (MonadError e m) => MonadError e (MState t m) where
throwError = lift . throwError
m `catchError` h = MState $ \s ->
runMState' m s `catchError` \e -> runMState' (h e) s
instance (MonadReader r m) => MonadReader r (MState t m) where
ask = lift ask
local f m = MState $ \s -> local f (runMState' m s)
instance (MonadWriter w m) => MonadWriter w (MState t m) where
tell = lift . tell
listen m = MState $ listen . runMState' m
pass m = MState $ pass . runMState' m