module Control.Monad.StateStack
(
MonadStateStack(..)
, StateStackT(..), StateStack
, runStateStackT, evalStateStackT, execStateStackT
, runStateStack, evalStateStack, execStateStack
, liftState
) where
#if __GLASGOW_HASKELL__ < 710
import Data.Monoid
import Control.Applicative
#endif
import Control.Arrow (second)
import Control.Monad.Identity
import qualified Control.Monad.State as St
import Control.Arrow (first, (&&&))
import Control.Monad.Trans
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Except
import Control.Monad.Trans.Identity
import Control.Monad.Trans.List
import Control.Monad.Trans.Maybe
import Control.Monad.Trans.Reader (ReaderT)
import Control.Monad.Trans.State.Lazy as Lazy
import Control.Monad.Trans.State.Strict as Strict
import Control.Monad.Trans.Writer.Lazy as Lazy
import Control.Monad.Trans.Writer.Strict as Strict
import qualified Control.Monad.Cont.Class as CC
import qualified Control.Monad.State.Class as StC
import qualified Control.Monad.IO.Class as IC
newtype StateStackT s m a = StateStackT { unStateStackT :: St.StateT (s,[s]) m a }
deriving (Functor, Applicative, Monad, MonadTrans, IC.MonadIO)
class St.MonadState s m => MonadStateStack s m where
save :: m ()
restore :: m ()
instance Monad m => St.MonadState s (StateStackT s m) where
get = StateStackT $ St.gets fst
put s = StateStackT $ (St.modify . first) (const s)
instance Monad m => MonadStateStack s (StateStackT s m) where
save = StateStackT $ St.modify (fst &&& uncurry (:))
restore = StateStackT . St.modify $ \(cur,hist) ->
case hist of
[] -> (cur,hist)
(r:hist') -> (r,hist')
runStateStackT :: Monad m => StateStackT s m a -> s -> m (a, s)
runStateStackT m s = (liftM . second) fst . flip St.runStateT (s,[]) . unStateStackT $ m
evalStateStackT :: Monad m => StateStackT s m a -> s -> m a
evalStateStackT m s = liftM fst $ runStateStackT m s
execStateStackT :: Monad m => StateStackT s m a -> s -> m s
execStateStackT m s = liftM snd $ runStateStackT m s
type StateStack s a = StateStackT s Identity a
runStateStack :: StateStack s a -> s -> (a,s)
runStateStack m s = runIdentity $ runStateStackT m s
evalStateStack :: StateStack s a -> s -> a
evalStateStack m s = runIdentity $ evalStateStackT m s
execStateStack :: StateStack s a -> s -> s
execStateStack m s = runIdentity $ execStateStackT m s
liftState :: Monad m => St.StateT s m a -> StateStackT s m a
liftState st = StateStackT . St.StateT $ \(s,ss) -> (liftM . second) (flip (,) ss) (St.runStateT st s)
instance MonadStateStack s m => MonadStateStack s (ContT r m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (ExceptT e m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (IdentityT m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (ListT m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (MaybeT m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (ReaderT r m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (Lazy.StateT s m) where
save = lift save
restore = lift restore
instance MonadStateStack s m => MonadStateStack s (Strict.StateT s m) where
save = lift save
restore = lift restore
instance (Monoid w, MonadStateStack s m) => MonadStateStack s (Lazy.WriterT w m) where
save = lift save
restore = lift restore
instance (Monoid w, MonadStateStack s m) => MonadStateStack s (Strict.WriterT w m) where
save = lift save
restore = lift restore
instance CC.MonadCont m => CC.MonadCont (StateStackT s m) where
callCC c = StateStackT $ CC.callCC (unStateStackT . (\k -> c (StateStackT . k)))