{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances #-} module Clean.Monad( module Clean.Applicative, -- * The basic Monad interface Monad(..),MonadFix(..),MonadTrans(..), (=<<),(>>),return, -- * Common monads -- ** The State Monad MonadState(..), StateT(..),State, evalStateT,execStateT,runState,execState,evalState, -- ** The Reader monad MonadReader(..),ReaderT(..),Reader, -- ** The Writer monad MonadWriter(..),WriterT(..),Writer,runWriter, -- ** The Continuation monad MonadCont(..),ContT(..),Cont,evalContT,evalCont ) where import Clean.Classes import Clean.Applicative import Clean.Core hiding (flip) import Clean.Traversable class Monad m => MonadFix m where mfix :: (a -> m a) -> m a instance MonadFix Id where mfix = cfix instance MonadFix ((->) b) where mfix = cfix instance MonadFix [] where mfix f = fix (f . head) fix f = f (fix f) cfix f = map fix (collect f) class MonadTrans t where lift :: Monad m => m a -> t m a internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> t m a -> t m b pure_ = lift . pure (>>) = (*>) (=<<) = flip (>>=) return = pure class Monad m => MonadState s m where get :: m s put :: s -> m () put = modify . const modify :: (s -> s) -> m () modify f = get >>= put . f get_ = lift get ; put_ = lift . put ; modify_ = lift . modify {-| A simple State Monad -} newtype StateT s m a = StateT { runStateT :: s -> m (s,a) } type State s a = StateT s Id a instance Unit m => Unit (StateT s m) where pure a = StateT (\s -> pure (s,a)) instance Monad m => Functor (StateT s m) instance Monad m => Applicative (StateT s m) instance Monad m => Monad (StateT s m) where StateT st >>= k = StateT (\s -> st s >>= \ ~(s',a) -> runStateT (k a) s') instance MonadTrans (StateT s) where lift m = StateT (\s -> map (s,) m) internal f (StateT st) = StateT (f . st) instance Monad m => MonadState s (StateT s m) where get = StateT (\s -> pure (s,s)) put x = StateT (\_ -> pure (x,())) modify f = StateT (\s -> pure (f s,())) instance MonadReader r m => MonadReader r (StateT s m) where ask = ask_ ; local = local_ instance MonadWriter w m => MonadWriter w (StateT s m) where tell = tell_ ; listen = listen_ ; censor = censor_ instance MonadCont m => MonadCont (StateT s m) where callCC f = StateT (\s -> callCC $ \k -> runStateT (f (\a -> lift (k (s,a)))) s) instance MonadFix m => MonadFix (StateT s m) where mfix f = StateT (\s -> mfix (\ ~(_,a) -> runStateT (f a) s)) deriving instance Semigroup (m (s,a)) => Semigroup (StateT s m a) deriving instance Monoid (m (s,a)) => Monoid (StateT s m a) deriving instance Ring (m (s,a)) => Ring (StateT s m a) evalStateT = map (map snd) . runStateT execStateT = map (map fst) . runStateT runState :: State s a -> s -> (s,a) runState = map getId . runStateT execState :: State s a -> s -> s execState = map fst . runState evalState :: State s a -> s -> a evalState = map snd . runState class Monad m => MonadReader r m where ask :: m r local :: (r -> r) -> m a -> m a ask_ = lift ask ; local_ f m = internal (local f) m {-| A simple Reader monad -} newtype ReaderT r m a = ReaderT { runReaderT :: r -> m a } deriving (Semigroup,Monoid,Ring) type Reader r a = ReaderT r Id a instance MonadTrans (ReaderT r) where lift m = ReaderT (const m) internal f (ReaderT r) = ReaderT (map snd . f . map ((),) . r) instance Monad m => Unit (ReaderT r m) where pure = pure_ instance Monad m => Functor (ReaderT r m) instance Monad m => Applicative (ReaderT r m) instance Monad m => Monad (ReaderT r m) where ReaderT rd >>= k = ReaderT (\r -> rd r >>= \a -> runReaderT (k a) r) instance Monad m => MonadReader r (ReaderT r m) where ask = ReaderT pure local f (ReaderT rd) = ReaderT (rd . f) instance MonadState s m => MonadState s (ReaderT r m) where get = get_ ; put = put_ ; modify = modify_ instance MonadWriter w m => MonadWriter w (ReaderT r m) where tell = tell_ ; listen = listen_ ; censor = censor_ instance MonadCont m => MonadCont (ReaderT r m) where callCC f = ReaderT (\r -> callCC (\k -> runReaderT (f (lift . k)) r)) instance MonadFix m => MonadFix (ReaderT r m) where mfix f = ReaderT (\r -> mfix (\a -> runReaderT (f a) r)) class (Monad m,Monoid w) => MonadWriter w m where tell :: w -> m () listen :: m a -> m (w,a) censor :: m (a,w -> w) -> m a tell_ = lift . tell listen_ = internal (\m -> listen m<&> \(w,(c,a)) -> (c,(w,a))) censor_ = internal (\m -> censor (m<&> \(c,(a,f)) -> ((c,a),f))) {-| A simple Writer monad -} newtype WriterT w m a = WriterT { runWriterT :: m (w,a) } type Writer w a = WriterT w Id a instance Monoid w => MonadTrans (WriterT w) where lift m = WriterT (map (zero,) m) internal f (WriterT m) = WriterT (f m) instance (Monoid w,Monad m) => Unit (WriterT w m) where pure = pure_ instance (Monoid w,Monad m) => Functor (WriterT w m) instance (Monoid w,Monad m) => Applicative (WriterT w m) instance (Monoid w,Monad m) => Monad (WriterT w m) where wr >>= k = WriterT $ do (w,a) <- runWriterT wr map (first (w+)) (runWriterT (k a)) instance (Monad m,Monoid w) => MonadWriter w (WriterT w m) where tell w = WriterT (pure (w,())) listen (WriterT m) = WriterT (m<&> \ ~(w,a) -> (w,(w,a))) censor (WriterT m) = WriterT (m<&> \ ~(w,~(a,f)) -> (f w,a)) instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where ask = ask_ ; local = local_ instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where get = get_ ; put = put_ ; modify = modify_ deriving instance Semigroup (m (w,a)) => Semigroup (WriterT w m a) deriving instance Monoid (m (w,a)) => Monoid (WriterT w m a) deriving instance Ring (m (w,a)) => Ring (WriterT w m a) instance (Monoid w,MonadFix m) => MonadFix (WriterT w m) where mfix f = WriterT (mfix (runWriterT . f . snd)) runWriter = getId . runWriterT {-| A simple continuation monad implementation -} class Monad m => MonadCont m where callCC :: ((a -> m b) -> m a) -> m a newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r } deriving (Semigroup,Monoid,Ring) type Cont r a = ContT r Id a instance Unit m => Unit (ContT r m) where pure a = ContT ($a) instance Monad m => Functor (ContT r m) instance Monad m => Applicative (ContT r m) instance Monad m => Monad (ContT r m) where ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc)) instance MonadTrans (ContT r) where lift m = ContT (m >>=) internal _ (ContT _) = undefined instance Monad m => MonadCont (ContT r m) where callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k) evalContT c = runContT c return evalCont = getId . evalContT instance MonadTrans Backwards where lift = Backwards internal f (Backwards m) = Backwards (snd<$>f (((),)<$>m)) instance MonadFix m => Monad (Backwards m) where Backwards ma >>= k = Backwards$fst<$>mfix (\r -> liftA2 (,) (forwards (k (snd r))) ma)