{-# LANGUAGE CPP, MultiParamTypeClasses, FunctionalDependencies,
UndecidableInstances, FlexibleInstances #-}
{-# LANGUAGE DataKinds, TypeFamilies, TypeOperators #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE Safe #-}
module MonadLib (
Id, Lift, IdT, ReaderT, WriterT,
StateT,
ExceptionT,
ChoiceT, ContT,
MonadT(..), BaseM(..),
ReaderM(..), WriterM(..), StateM(..), ExceptionM(..), ContM(..), AbortM(..),
Label, labelCC, labelCC_, jump, labelC, callCC,
runId, runLift,
runIdT, runReaderT, runWriterT,
runStateT, runExceptionT, runContT,
runChoiceT, findOne, findAll,
RunM(..),
RunReaderM(..), RunWriterM(..), RunExceptionM(..),
asks, puts, sets, sets_, raises,
mapReader, mapWriter, mapException,
handle,
WithBase,
module Control.Monad
) where
import Control.Applicative
import Control.Monad
import Control.Monad.Fix
import Control.Monad.ST (ST)
import qualified Control.Exception as IO (throwIO,try,SomeException)
import System.Exit(ExitCode,exitWith)
import Data.Kind(Type)
import Prelude hiding (Ordering(..))
import qualified Control.Monad.Fail as MF
newtype Id a = I a
data Lift a = L a
newtype IdT m a = IT (m a)
newtype ReaderT i m a = R (i -> m a)
newtype WriterT i m a = W { unW :: m (P a i) }
data P a i = P a !i
newtype StateT i m a = S (i -> m (a,i))
newtype ExceptionT i m a = X (m (Either i a))
data ChoiceT m a = NoAnswer
| Answer a
| Choice (ChoiceT m a) (ChoiceT m a)
| ChoiceEff (m (ChoiceT m a))
newtype ContT i m a = C ((a -> m i) -> m i)
runId :: Id a -> a
runId (I a) = a
runLift :: Lift a -> a
runLift (L a) = a
runIdT :: IdT m a -> m a
runIdT (IT a) = a
runReaderT :: i -> ReaderT i m a -> m a
runReaderT i (R m) = m i
runWriterT :: (Monad m) => WriterT i m a -> m (a,i)
runWriterT (W m) = liftM to_pair m
where to_pair ~(P a w) = (a,w)
runStateT :: i -> StateT i m a -> m (a,i)
runStateT i (S m) = m i
runExceptionT :: ExceptionT i m a -> m (Either i a)
runExceptionT (X m) = m
runChoiceT :: (Monad m) => ChoiceT m a -> m (Maybe (a,ChoiceT m a))
runChoiceT (Answer a) = return (Just (a,NoAnswer))
runChoiceT NoAnswer = return Nothing
runChoiceT (Choice l r) = do x <- runChoiceT l
case x of
Nothing -> runChoiceT r
Just (a,l1) -> return (Just (a,Choice l1 r))
runChoiceT (ChoiceEff m) = runChoiceT =<< m
findOne :: (Monad m) => ChoiceT m a -> m (Maybe a)
findOne m = fmap fst `liftM` runChoiceT m
findAll :: (Monad m) => ChoiceT m a -> m [a]
findAll m = all_res =<< runChoiceT m
where all_res Nothing = return []
all_res (Just (a,as)) = (a:) `liftM` findAll as
runContT :: (a -> m i) -> ContT i m a -> m i
runContT i (C m) = m i
class Monad m => RunM m a r | m a -> r where
runM :: m a -> r
instance RunM Id a a where
runM = runId
instance RunM Lift a a where
runM = runLift
instance RunM IO a (IO a) where
runM = id
instance RunM m a r => RunM (IdT m) a r where
runM = runM . runIdT
instance RunM m a r => RunM (ReaderT i m) a (i -> r) where
runM m i = runM (runReaderT i m)
instance (Monoid i, RunM m (a,i) r) => RunM (WriterT i m) a r where
runM = runM . runWriterT
instance RunM m (a,i) r => RunM (StateT i m) a (i -> r) where
runM m i = runM (runStateT i m)
instance RunM m (Either i a) r => RunM (ExceptionT i m) a r where
runM = runM . runExceptionT
instance RunM m i r => RunM (ContT i m) a ((a -> m i) -> r) where
runM m k = runM (runContT k m)
instance RunM m (Maybe (a,ChoiceT m a)) r => RunM (ChoiceT m) a r where
runM = runM . runChoiceT
class MonadT t where
lift :: (Monad m) => m a -> t m a
instance MonadT IdT where lift m = IT m
instance MonadT (ReaderT i) where lift m = R (\_ -> m)
instance MonadT (StateT i) where lift m = S (\s -> liftM (\a -> (a,s)) m)
instance (Monoid i)
=> MonadT (WriterT i) where lift m = W (liftM (\a -> P a mempty) m)
instance MonadT (ExceptionT i) where lift m = X (liftM Right m)
instance MonadT ChoiceT where lift m = ChoiceEff (liftM Answer m)
instance MonadT (ContT i) where lift m = C (\k -> m >>= k)
t_inBase :: (MonadT t, BaseM m n) => n a -> t m a
t_inBase m = lift (inBase m)
t_return :: (MonadT t, Monad m) => a -> t m a
t_return x = lift (return x)
t_fail :: (MonadT t, MF.MonadFail m) => String -> t m a
t_fail x = lift (MF.fail x)
#if !MIN_VERSION_base(4,11,0)
t_oldfail :: (MonadT t, Monad m) => String -> t m a
t_oldfail x = lift (fail x)
#endif
t_mzero :: (MonadT t, MonadPlus m) => t m a
t_mzero = lift mzero
t_ask :: (MonadT t, ReaderM m i) => t m i
t_ask = lift ask
t_put :: (MonadT t, WriterM m i) => i -> t m ()
t_put x = lift (put x)
t_get :: (MonadT t, StateM m i) => t m i
t_get = lift get
t_set :: (MonadT t, StateM m i) => i -> t m ()
t_set i = lift (set i)
t_raise :: (MonadT t, ExceptionM m i) => i -> t m a
t_raise i = lift (raise i)
t_abort :: (MonadT t, AbortM m i) => i -> t m a
t_abort i = lift (abort i)
class (Monad m, Monad n) => BaseM m n | m -> n where
inBase :: n a -> m a
instance BaseM IO IO where inBase = id
instance BaseM Maybe Maybe where inBase = id
instance BaseM [] [] where inBase = id
instance BaseM Id Id where inBase = id
instance BaseM Lift Lift where inBase = id
instance BaseM (ST s) (ST s) where inBase = id
instance (BaseM m n) => BaseM (IdT m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ReaderT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (StateT i m) n where inBase = t_inBase
instance (BaseM m n,Monoid i)
=> BaseM (WriterT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ExceptionT i m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ChoiceT m) n where inBase = t_inBase
instance (BaseM m n) => BaseM (ContT i m) n where inBase = t_inBase
instance Monad Id where
m >>= k = k (runId m)
#if !MIN_VERSION_base(4,11,0)
fail = error
#endif
instance Monad Lift where
L x >>= k = k x
#if !MIN_VERSION_base(4,11,0)
fail = error
#endif
instance (Monad m) => Monad (IdT m) where
m >>= k = IT (runIdT m >>= (runIdT . k))
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance (Monad m) => Monad (ReaderT i m) where
m >>= k = R (\r -> runReaderT r m >>= \a -> runReaderT r (k a))
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance (Monad m) => Monad (StateT i m) where
m >>= k = S (\s -> runStateT s m >>= \ ~(a,s') -> runStateT s' (k a))
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance (Monad m,Monoid i) => Monad (WriterT i m) where
m >>= k = W $ unW m >>= \ ~(P a w1) ->
unW (k a) >>= \ ~(P b w2) ->
return (P b (mappend w1 w2))
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance (Monad m) => Monad (ExceptionT i m) where
m >>= k = X $ runExceptionT m >>= \e ->
case e of
Left x -> return (Left x)
Right a -> runExceptionT (k a)
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance (Monad m) => Monad (ChoiceT m) where
Answer a >>= k = k a
NoAnswer >>= _ = NoAnswer
Choice m1 m2 >>= k = Choice (m1 >>= k) (m2 >>= k)
ChoiceEff m >>= k = ChoiceEff (liftM (>>= k) m)
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance (Monad m) => Monad (ContT i m) where
m >>= k = C $ \c -> runContT (\a -> runContT c (k a)) m
#if !MIN_VERSION_base(4,11,0)
fail = t_oldfail
#endif
instance Functor Id where fmap = liftM
instance Functor Lift where fmap = liftM
instance (Monad m) => Functor (IdT m) where fmap = liftM
instance (Monad m) => Functor (ReaderT i m) where fmap = liftM
instance (Monad m) => Functor (StateT i m) where fmap = liftM
instance (Monad m,Monoid i) => Functor (WriterT i m) where fmap = liftM
instance (Monad m) => Functor (ExceptionT i m) where fmap = liftM
instance (Monad m) => Functor (ChoiceT m) where fmap = liftM
instance (Monad m) => Functor (ContT i m) where fmap = liftM
instance Applicative Id where (<*>) = ap; pure x = I x
instance Applicative Lift where (<*>) = ap; pure x = L x
instance (Monad m) => Applicative (IdT m) where (<*>) = ap; pure = t_return
instance (Monad m) => Applicative (ReaderT i m) where (<*>) = ap; pure = t_return
instance (Monad m) => Applicative (StateT i m) where (<*>) = ap; pure = t_return
instance (Monad m,Monoid i)
=> Applicative (WriterT i m) where (<*>) = ap; pure = t_return
instance (Monad m) => Applicative (ExceptionT i m)
where (<*>) = ap; pure = t_return
instance (Monad m) => Applicative (ChoiceT m) where (<*>) = ap; pure = Answer
instance (Monad m) => Applicative (ContT i m) where (<*>) = ap; pure = t_return
instance (MonadPlus m)
=> Alternative (IdT m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (ReaderT i m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (StateT i m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m,Monoid i)
=> Alternative (WriterT i m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (ExceptionT i m) where (<|>) = mplus; empty = mzero
instance (Monad m)
=> Alternative (ChoiceT m) where (<|>) = mplus; empty = mzero
instance (MonadPlus m)
=> Alternative (ContT i m) where (<|>) = mplus; empty = mzero
instance MonadFix Id where
mfix f = let m = f (runId m) in m
instance MonadFix Lift where
mfix f = let m = f (runLift m) in m
instance (MonadFix m) => MonadFix (IdT m) where
mfix f = IT (mfix (runIdT . f))
instance (MonadFix m) => MonadFix (ReaderT i m) where
mfix f = R $ \r -> mfix (runReaderT r . f)
instance (MonadFix m) => MonadFix (StateT i m) where
mfix f = S $ \s -> mfix (runStateT s . f . fst)
instance (MonadFix m,Monoid i) => MonadFix (WriterT i m) where
mfix f = W $ mfix (unW . f . val)
where val ~(P a _) = a
instance (MonadFix m) => MonadFix (ExceptionT i m) where
mfix f = X $ mfix (runExceptionT . f . fromRight)
where fromRight (Right a) = a
fromRight _ = error "ExceptionT: mfix looped."
instance (MonadPlus m) => MonadPlus (IdT m) where
mzero = t_mzero
mplus (IT m) (IT n) = IT (mplus m n)
instance (MonadPlus m) => MonadPlus (ReaderT i m) where
mzero = t_mzero
mplus (R m) (R n) = R (\r -> mplus (m r) (n r))
instance (MonadPlus m) => MonadPlus (StateT i m) where
mzero = t_mzero
mplus (S m) (S n) = S (\s -> mplus (m s) (n s))
instance (MonadPlus m,Monoid i) => MonadPlus (WriterT i m) where
mzero = t_mzero
mplus (W m) (W n) = W (mplus m n)
instance (MonadPlus m) => MonadPlus (ExceptionT i m) where
mzero = t_mzero
mplus (X m) (X n) = X (mplus m n)
instance (Monad m) => MonadPlus (ChoiceT m) where
mzero = NoAnswer
mplus m n = Choice m n
instance (MonadPlus m) => MonadPlus (ContT i m) where
mzero = t_mzero
mplus (C m) (C n) = C (\k -> m k `mplus` n k)
class (Monad m) => ReaderM m i | m -> i where
ask :: m i
instance (Monad m) => ReaderM (ReaderT i m) i where
ask = R return
instance (ReaderM m j) => ReaderM (IdT m) j where ask = t_ask
instance (ReaderM m j,Monoid i)
=> ReaderM (WriterT i m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (StateT i m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (ExceptionT i m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (ChoiceT m) j where ask = t_ask
instance (ReaderM m j) => ReaderM (ContT i m) j where ask = t_ask
class (Monad m) => WriterM m i | m -> i where
put :: i -> m ()
instance (Monad m,Monoid i) => WriterM (WriterT i m) i where
put x = W (return (P () x))
instance (WriterM m j) => WriterM (IdT m) j where put = t_put
instance (WriterM m j) => WriterM (ReaderT i m) j where put = t_put
instance (WriterM m j) => WriterM (StateT i m) j where put = t_put
instance (WriterM m j) => WriterM (ExceptionT i m) j where put = t_put
instance (WriterM m j) => WriterM (ChoiceT m) j where put = t_put
instance (WriterM m j) => WriterM (ContT i m) j where put = t_put
class (Monad m) => StateM m i | m -> i where
get :: m i
set :: i -> m ()
instance (Monad m) => StateM (StateT i m) i where
get = S (\s -> return (s,s))
set s = S (\_ -> return ((),s))
instance (StateM m j) => StateM (IdT m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ReaderT i m) j where
get = t_get; set = t_set
instance (StateM m j,Monoid i) => StateM (WriterT i m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ExceptionT i m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ChoiceT m) j where
get = t_get; set = t_set
instance (StateM m j) => StateM (ContT i m) j where
get = t_get; set = t_set
class (Monad m) => ExceptionM m i | m -> i where
raise :: i -> m a
#ifdef USE_BASE3
instance ExceptionM IO IO.Exception where
raise = IO.throwIO
#else
instance ExceptionM IO IO.SomeException where
raise = IO.throwIO
#endif
instance (Monad m) => ExceptionM (ExceptionT i m) i where
raise x = X (return (Left x))
instance (ExceptionM m j) => ExceptionM (IdT m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (ReaderT i m) j where
raise = t_raise
instance (ExceptionM m j,Monoid i) => ExceptionM (WriterT i m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (StateT i m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (ChoiceT m) j where
raise = t_raise
instance (ExceptionM m j) => ExceptionM (ContT i m) j where
raise = t_raise
class Monad m => ContM m where
callWithCC :: ((a -> Label m) -> m a) -> m a
liftJump :: (ContM m, MonadT t) =>
(a -> b) ->
((a -> Label (t m)) -> t m a) ->
((b -> Label m ) -> t m a)
liftJump ans f l = f $ \a -> Lab (lift $ jump $ l $ ans a)
instance (ContM m) => ContM (IdT m) where
callWithCC f = IT $ callWithCC $ \k -> runIdT $ liftJump id f k
instance (ContM m) => ContM (ReaderT i m) where
callWithCC f = R $ \r -> callWithCC $ \k -> runReaderT r $ liftJump id f k
instance (ContM m) => ContM (StateT i m) where
callWithCC f = S $ \s -> callWithCC $ \k -> runStateT s $ liftJump (ans s) f k
where ans s a = (a,s)
instance (ContM m,Monoid i) => ContM (WriterT i m) where
callWithCC f = W $ callWithCC $ \k -> unW $ liftJump (`P` mempty) f k
instance (ContM m) => ContM (ExceptionT i m) where
callWithCC f = X $ callWithCC $ \k -> runExceptionT $ liftJump Right f k
instance (ContM m) => ContM (ChoiceT m) where
callWithCC f = ChoiceEff $ callWithCC $ \k -> return $ liftJump Answer f k
instance (Monad m) => ContM (ContT i m) where
callWithCC f = C $ \k -> runContT k $ f $ \a -> Lab (C $ \_ -> k a)
class (ReaderM m i) => RunReaderM m i | m -> i where
local :: i -> m a -> m a
instance (Monad m) => RunReaderM (ReaderT i m) i where
local i m = lift (runReaderT i m)
instance (RunReaderM m j) => RunReaderM (IdT m) j where
local i (IT m) = IT (local i m)
instance (RunReaderM m j,Monoid i) => RunReaderM (WriterT i m) j where
local i (W m) = W (local i m)
instance (RunReaderM m j) => RunReaderM (StateT i m) j where
local i (S m) = S (local i . m)
instance (RunReaderM m j) => RunReaderM (ExceptionT i m) j where
local i (X m) = X (local i m)
instance (RunReaderM m j) => RunReaderM (ContT i m) j where
local i (C m) = C (local i . m)
class WriterM m i => RunWriterM m i | m -> i where
collect :: m a -> m (a,i)
instance (Monad m,Monoid i) => RunWriterM (WriterT i m) i where
collect m = lift (runWriterT m)
instance (RunWriterM m j) => RunWriterM (IdT m) j where
collect (IT m) = IT (collect m)
instance (RunWriterM m j) => RunWriterM (ReaderT i m) j where
collect (R m) = R (collect . m)
instance (RunWriterM m j) => RunWriterM (StateT i m) j where
collect (S m) = S (liftM swap . collect . m)
where swap (~(a,s),w) = ((a,w),s)
instance (RunWriterM m j) => RunWriterM (ExceptionT i m) j where
collect (X m) = X (liftM swap (collect m))
where swap (Right a,w) = Right (a,w)
swap (Left x,_) = Left x
instance (RunWriterM m j, MonadFix m) => RunWriterM (ContT i m) j where
collect (C m) = C $ \k -> fst `liftM`
mfix (\ ~(_,w) -> collect (m (\a -> k (a,w))))
class ExceptionM m i => RunExceptionM m i | m -> i where
try :: m a -> m (Either i a)
instance RunExceptionM IO IO.SomeException where
try = IO.try
instance (Monad m) => RunExceptionM (ExceptionT i m) i where
try m = lift (runExceptionT m)
instance (RunExceptionM m i) => RunExceptionM (IdT m) i where
try (IT m) = IT (try m)
instance (RunExceptionM m i) => RunExceptionM (ReaderT j m) i where
try (R m) = R (try . m)
instance (RunExceptionM m i,Monoid j) => RunExceptionM (WriterT j m) i where
try (W m) = W (liftM swap (try m))
where swap (Right (P a w)) = P (Right a) w
swap (Left e) = P (Left e) mempty
instance (RunExceptionM m i) => RunExceptionM (StateT j m) i where
try (S m) = S (\s -> liftM (swap s) (try (m s)))
where swap _ (Right ~(a,s)) = (Right a,s)
swap s (Left e) = (Left e, s)
class Monad m => AbortM m i where
abort :: i -> m a
instance Monad m => AbortM (ContT i m) i where
abort i = C (\_ -> return i)
instance AbortM IO ExitCode where
abort = exitWith
instance AbortM m i => AbortM (IdT m) i where abort = t_abort
instance AbortM m i => AbortM (ReaderT j m) i where abort = t_abort
instance (AbortM m i,Monoid j)
=> AbortM (WriterT j m) i where abort = t_abort
instance AbortM m i => AbortM (StateT j m) i where abort = t_abort
instance AbortM m i => AbortM (ExceptionT j m) i where abort = t_abort
instance AbortM m i => AbortM (ChoiceT m) i where abort = t_abort
newtype Label m = Lab (forall b. m b)
labelCC :: (ContM m) => a -> m (a, a -> Label m)
labelCC x = callWithCC (\l -> let label a = Lab (jump (l (a, label)))
in return (x, label))
labelCC_ :: forall m. (ContM m) => m (Label m)
labelCC_ = callWithCC $ \k -> let x :: m a
x = jump (k (Lab x))
in x
callCC :: ContM m => ((a -> m b) -> m a) -> m a
callCC f = callWithCC $ \l -> f $ \a -> jump $ l a
labelC :: (forall b. m b) -> Label m
labelC k = Lab k
jump :: Label m -> m a
jump (Lab k) = k
asks :: ReaderM m r => (r -> a) -> m a
asks f = do r <- ask
return (f r)
puts :: WriterM m w => (a,w) -> m a
puts ~(a,w) = put w >> return a
sets :: StateM m s => (s -> (a,s)) -> m a
sets f = do s <- get
let (a,s1) = f s
set s1
return a
sets_ :: StateM m s => (s -> s) -> m ()
sets_ f = do s <- get
set (f s)
raises :: ExceptionM m x => Either x a -> m a
raises (Right a) = return a
raises (Left x) = raise x
mapReader :: RunReaderM m r => (r -> r) -> m a -> m a
mapReader f m = do r <- ask
local (f r) m
mapWriter :: RunWriterM m w => (w -> w) -> m a -> m a
mapWriter f m = do ~(a,w) <- collect m
put (f w)
return a
mapException :: RunExceptionM m x => (x -> x) -> m a -> m a
mapException f m = do r <- try m
case r of
Right a -> return a
Left x -> raise (f x)
handle :: RunExceptionM m x => m a -> (x -> m a) -> m a
handle m f = do r <- try m
case r of
Right a -> return a
Left x -> f x
type family WithBase base layers :: Type -> Type where
WithBase b '[] = b
WithBase b (f ': fs) = f (WithBase b fs)
instance MF.MonadFail m => MF.MonadFail (IdT m) where fail = t_fail
instance MF.MonadFail m => MF.MonadFail (ReaderT i m) where fail = t_fail
instance (Monoid i, MF.MonadFail m)
=> MF.MonadFail (WriterT i m) where fail = t_fail
instance MF.MonadFail m => MF.MonadFail (StateT i m) where fail = t_fail
instance MF.MonadFail m => MF.MonadFail (ExceptionT i m) where fail = t_fail
instance MF.MonadFail m => MF.MonadFail (ChoiceT m) where fail = t_fail
instance MF.MonadFail m => MF.MonadFail (ContT i m) where fail = t_fail