{-# LANGUAGE CPP #-}
module Control.Monad.Trans.MultiReader.Strict
(
MultiReaderT(..)
, MultiReaderTNull
, MultiReader
, MonadMultiReader(..)
, MonadMultiGet(..)
, runMultiReaderT
, runMultiReaderT_
, runMultiReaderTNil
, runMultiReaderTNil_
, withMultiReader
, withMultiReader_
, withMultiReaders
, withMultiReaders_
, withoutMultiReader
, inflateReader
, mapMultiReaderT
, mGetRaw
, mPutRaw
) where
import Data.HList.HList
import Data.HList.ContainsType
import Control.Monad.Trans.MultiReader.Class
import Control.Monad.Trans.MultiState.Class
import Control.Monad.State.Strict ( StateT(..)
, MonadState(..)
, evalStateT
, mapStateT )
import Control.Monad.Reader ( ReaderT(..) )
import Control.Monad.Trans.Class ( MonadTrans
, lift )
import Control.Monad.Writer.Class ( MonadWriter
, listen
, tell
, writer
, pass )
import Data.Functor.Identity ( Identity )
import Control.Applicative ( Applicative(..)
, Alternative(..)
)
import Control.Monad ( MonadPlus(..)
, liftM
, ap
, void )
import Control.Monad.Base ( MonadBase(..)
, liftBaseDefault
)
import Control.Monad.Trans.Control ( MonadTransControl(..)
, MonadBaseControl(..)
, ComposeSt
, defaultLiftBaseWith
, defaultRestoreM
)
import Control.Monad.Fix ( MonadFix(..) )
import Control.Monad.IO.Class ( MonadIO(..) )
newtype MultiReaderT x m a = MultiReaderT {
runMultiReaderTRaw :: StateT (HList x) m a
}
type MultiReaderTNull = MultiReaderT '[]
type MultiReader x = MultiReaderT x Identity
instance (Functor f) => Functor (MultiReaderT x f) where
fmap f = MultiReaderT . fmap f . runMultiReaderTRaw
instance (Applicative m, Monad m) => Applicative (MultiReaderT x m) where
pure = MultiReaderT . pure
(<*>) = ap
instance Monad m => Monad (MultiReaderT x m) where
return = MultiReaderT . return
k >>= f = MultiReaderT $ runMultiReaderTRaw k >>= (runMultiReaderTRaw . f)
instance MonadTrans (MultiReaderT x) where
lift = MultiReaderT . lift
#if MIN_VERSION_base(4,8,0)
instance {-# OVERLAPPING #-} (Monad m, ContainsType a c)
#else
instance (Monad m, ContainsType a c)
#endif
=> MonadMultiReader a (MultiReaderT c m) where
mAsk = MultiReaderT $ liftM getHListElem get
#if MIN_VERSION_base(4,8,0)
instance {-# OVERLAPPING #-} (Monad m, ContainsType a c)
#else
instance (Monad m, ContainsType a c)
#endif
=> MonadMultiGet a (MultiReaderT c m) where
mGet = MultiReaderT $ liftM getHListElem get
instance MonadFix m => MonadFix (MultiReaderT r m) where
mfix f = MultiReaderT $ mfix (runMultiReaderTRaw . f)
mGetRaw :: Monad m => MultiReaderT a m (HList a)
mGetRaw = MultiReaderT get
mPutRaw :: Monad m => HList s -> MultiReaderT s m ()
mPutRaw = MultiReaderT . put
mapMultiReaderT :: (m (a, HList w) -> m' (a', HList w))
-> MultiReaderT w m a
-> MultiReaderT w m' a'
mapMultiReaderT f = MultiReaderT . mapStateT f . runMultiReaderTRaw
runMultiReaderT :: Monad m => HList r -> MultiReaderT r m a -> m a
runMultiReaderT_ :: Functor m => HList r -> MultiReaderT r m a -> m ()
runMultiReaderT s k = evalStateT (runMultiReaderTRaw k) s
runMultiReaderT_ s k = void $ runStateT (runMultiReaderTRaw k) s
runMultiReaderTNil :: Monad m => MultiReaderT '[] m a -> m a
runMultiReaderTNil_ :: Functor m => MultiReaderT '[] m a -> m ()
runMultiReaderTNil k = evalStateT (runMultiReaderTRaw k) HNil
runMultiReaderTNil_ k = void $ runStateT (runMultiReaderTRaw k) HNil
withMultiReader :: Monad m => r -> MultiReaderT (r ': rs) m a -> MultiReaderT rs m a
withMultiReader_ :: (Functor m, Monad m) => r -> MultiReaderT (r ': rs) m a -> MultiReaderT rs m ()
withMultiReader x k = MultiReaderT $
get >>= lift . evalStateT (runMultiReaderTRaw k) . (x :+:)
withMultiReader_ x k = void $ withMultiReader x k
withMultiReaders :: Monad m => HList r1 -> MultiReaderT (Append r1 r2) m a -> MultiReaderT r2 m a
withMultiReaders_ :: (Functor m, Monad m) => HList r1 -> MultiReaderT (Append r1 r2) m a -> MultiReaderT r2 m ()
withMultiReaders HNil = id
withMultiReaders (x :+: xs) = withMultiReaders xs . withMultiReader x
withMultiReaders_ HNil = liftM (const ())
withMultiReaders_ (x :+: xs) = withMultiReaders_ xs . withMultiReader_ x
withoutMultiReader :: Monad m => MultiReaderT rs m a -> MultiReaderT (r ': rs) m a
withoutMultiReader k = MultiReaderT $ get >>= \case
(_ :+: rr) -> lift $ runMultiReaderT rr k
inflateReader :: (Monad m, ContainsType r rs)
=> ReaderT r m a
-> MultiReaderT rs m a
inflateReader k = mAsk >>= lift . runReaderT k
instance (MonadState s m) => MonadState s (MultiReaderT c m) where
put = lift . put
get = lift $ get
state = lift . state
instance (MonadWriter w m) => MonadWriter w (MultiReaderT c m) where
writer = lift . writer
tell = lift . tell
listen = MultiReaderT .
mapStateT (liftM (\((a,w), w') -> ((a, w'), w)) . listen) .
runMultiReaderTRaw
pass = MultiReaderT .
mapStateT (pass . liftM (\((a, f), w) -> ((a, w), f))) .
runMultiReaderTRaw
instance MonadIO m => MonadIO (MultiReaderT c m) where
liftIO = lift . liftIO
instance (Functor m, Applicative m, MonadPlus m) => Alternative (MultiReaderT c m) where
empty = lift mzero
MultiReaderT m <|> MultiReaderT n = MultiReaderT $ m <|> n
instance MonadPlus m => MonadPlus (MultiReaderT c m) where
mzero = MultiReaderT $ mzero
MultiReaderT m `mplus` MultiReaderT n = MultiReaderT $ m `mplus` n
instance MonadBase b m => MonadBase b (MultiReaderT r m) where
liftBase = liftBaseDefault
instance MonadTransControl (MultiReaderT r) where
type StT (MultiReaderT r) a = (a, HList r)
liftWith f = MultiReaderT $ liftWith $ \s -> f $ \r -> s $ runMultiReaderTRaw r
restoreT = MultiReaderT . restoreT
instance MonadBaseControl b m => MonadBaseControl b (MultiReaderT r m) where
type StM (MultiReaderT r m) a = ComposeSt (MultiReaderT r) m a
liftBaseWith = defaultLiftBaseWith
restoreM = defaultRestoreM