{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances, FunctionalDependencies #-} module SimpleH.Monad( module SimpleH.Applicative, -- * The basic Monad interface Monad(..),MonadFix(..),MonadTrans(..), -- * Monad utilities Kleisli(..),_Kleisli, (=<<),(<=<),(>=>),(>>),(<*=),return, foldlM,foldrM,while,until, bind2,bind3,(>>>=),(>>>>=), -- * Common monads -- ** The RWS Monad RWST(..),RWS, -- *** The State Monad MonadState(..), IOLens,_ioref,_mvar, StateT,State, stateT,eval,exec,state, (=~),(=-),gets,saving, mapAccum,mapAccum_,mapAccumR,mapAccumR_,push,pop,withPrev,withNext, -- **** The State Arrow StateA(..),stateA, -- *** The Reader monad MonadReader(..), ReaderT,Reader, _readerT,_reader, -- *** The Writer monad MonadWriter(..), WriterT,Writer, _writerT,_writer, mute,intercept, -- ** The Continuation monad MonadCont(..), ContT(..),Cont, evalContT, evalCont, -- ** The List monad MonadList(..), ListT, _listT, -- ** The Error Monad MonadError(..),try,tryMay, EitherT, _eitherT ) where import SimpleH.Classes import SimpleH.Applicative import SimpleH.Core hiding (flip) import SimpleH.Traversable import SimpleH.Lens import qualified Control.Exception as Ex import qualified Control.Monad.Fix as Fix import Data.IORef import Control.Concurrent instance (Traversable g,Monad f,Monad g) => Monad (f:.:g) where join = Compose .map join.join.map sequence.getCompose.map getCompose -- |The class of all monads that have a fixpoint class Monad m => MonadFix m where mfix :: (a -> m a) -> m a instance MonadFix Id where mfix = cfix instance MonadFix ((->) b) where mfix = cfix instance MonadFix [] where mfix f = fix (f . head) instance MonadFix (Either e) where mfix f = fix (f . either undefined id) instance MonadFix IO where mfix = Fix.mfix instance MonadFix m => MonadFix (Backwards m) where mfix f = at _Backwards $ mfix (at' _Backwards.f) instance (MonadFix f,Traversable g,Monad g) => MonadFix (f:.:g) where mfix f = Compose $ mfix (map join . traverse (getCompose . f)) cfix :: Contravariant c => (a -> c a) -> c a cfix = map fix . collect mfixing :: MonadFix f => (b -> f (a, b)) -> f a mfixing f = fst<$>mfix (\ ~(_,b) -> f b ) class MonadTrans t where lift :: Monad m => m a -> t m a class MonadTrans t => MonadInternal t where internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> (t m a -> t m b) newtype Kleisli m a b = Kleisli { runKleisli :: a -> m b } instance Monad m => Category (Kleisli m) where id = Kleisli pure Kleisli f . Kleisli g = Kleisli (\a -> g a >>= f) instance Monad m => Choice (Kleisli m) where Kleisli f <|> Kleisli g = Kleisli (f <|> g) instance Monad m => Split (Kleisli m) where Kleisli f <#> Kleisli g = Kleisli (\(a,c) -> (,)<$>f a<*>g c) instance Isomorphic (a -> m b) (c -> m' d) (Kleisli m a b) (Kleisli m' c d) where _iso = iso Kleisli runKleisli _Kleisli :: Iso (Kleisli m a b) (Kleisli m' c d) (a -> m b) (c -> m' d) _Kleisli = _iso folding :: (Foldable t,Monoid w) => Iso' (a -> c) w -> (b -> a -> c) -> a -> t b -> c folding i f e t = at (from i) (foldMap (at i . f) t) e foldlM :: (Foldable t,Monad m) => (b -> a -> m a) -> a -> t b -> m a foldlM = folding (_Kleisli._Endo._Dual) foldrM :: (Foldable t,Monad m) => (b -> a -> m a) -> a -> t b -> m a foldrM = folding (_Kleisli._Endo) while :: Monad m => m (Maybe a) -> m () while e = fix (\w -> e >>= maybe unit (const w)) until :: Monad m => m (Maybe a) -> m a until e = fix (\w -> e >>= maybe w return) bind2 :: Monad m => (a -> b -> m c) -> m a -> m b -> m c bind2 f a b = join (f<$>a<*>b) (>>>=) :: Monad m => (m a,m b) -> (a -> b -> m c) -> m c (a,b) >>>= f = bind2 f a b bind3 :: Monad m => (a -> b -> c -> m d) -> m a -> m b -> m c -> m d bind3 f a b c = join (f<$>a<*>b<*>c) (>>>>=) :: Monad m => (m a,m b,m c) -> (a -> b -> c -> m d) -> m d (a,b,c) >>>>= f = bind3 f a b c infixr 2 =<< infixl 1 <*=,>> (>>) :: Applicative f => f a -> f b -> f b (>>) = (*>) (=<<) :: Monad m => (a -> m b) -> m a -> m b (=<<) = flip (>>=) (<=<) :: Monad m => (b -> m c) -> (a -> m b) -> (a -> m c) f <=< g = \a -> g a >>= f (>=>) :: Monad m => (a -> m b) -> (b -> m c) -> (a -> m c) (>=>) = flip (<=<) (<*=) :: Monad m => m a -> (a -> m b) -> m a a <*= f = a >>= (>>)<$>f<*>return return :: Unit f => a -> f a return = pure newtype RWST r w s m a = RWST { runRWST :: (r,s) -> m (a,s,w) } type RWS r w s a = RWST r w s Id a _RWST :: Iso (RWST r w s m a) (RWST r' w' s' m' a') ((r,s) -> m (a,s,w)) ((r',s') -> m' (a',s',w')) _RWST = iso RWST runRWST instance (Unit f,Monoid w) => Unit (RWST r w s f) where pure a = RWST (\ ~(_,s) -> pure (a,s,zero)) instance Functor f => Functor (RWST r w s f) where map f (RWST fa) = RWST (fa >>> map (\ ~(a,s,w) -> (f a,s,w))) instance (Monoid w,Monad m) => Applicative (RWST r w s m) instance (Monoid w,Monad m) => Monad (RWST r w s m) where join mm = RWST (\ ~(r,s) -> do ~(m,s',w) <- runRWST mm (r,s) ~(a,s'',w') <- runRWST m (r,s') return (a,s'',w+w')) instance (Monoid w,MonadFix m) => MonadFix (RWST r w s m) where mfix f = RWST (\x -> mfix (\ ~(a,_,_) -> runRWST (f a) x)) instance (Monoid w,MonadCont m) => MonadCont (RWST r w s m) where callCC f = RWST $ \(r,s) -> callCC $ \k -> runRWST (f (\a -> lift (k (a,s,zero)))) (r,s) deriving instance Semigroup (m (a,s,w)) => Semigroup (RWST r w s m a) deriving instance Monoid (m (a,s,w)) => Monoid (RWST r w s m a) deriving instance Ring (m (a,s,w)) => Ring (RWST r w s m a) instance (Monad m,Monoid w) => MonadState s (RWST r w s m) where get = RWST (\ ~(_,s) -> pure (s,s,zero) ) put s = RWST (\ _ -> pure ((),s,zero) ) modify f = RWST (\ ~(_,s) -> pure ((),f s,zero) ) instance (Monad m,Monoid w) => MonadReader r (RWST r w s m) where ask = RWST (\ ~(r,s) -> pure (r,s,zero) ) local f (RWST m) = RWST (\ ~(r,s) -> m (f r,s) ) instance (Monad m,Monoid w) => MonadWriter w (RWST r w s m) where tell w = RWST (\ ~(_,s) -> pure ((),s,w) ) listen (RWST m) = RWST (m >>> map (\ ~(a,s,w) -> ((w,a),s,w) ) ) censor (RWST m) = RWST (m >>> map (\ ~(~(a,f),s,w) -> (a,s,f w) ) ) instance Foldable m => Foldable (RWST Void w Void m) where fold (RWST m) = foldMap (\(w,_,_) -> w).m $ (zero,zero) instance Traversable m => Traversable (RWST Void w Void m) where sequence (RWST m) = map (RWST . const . map (\((s,w),a) -> (a,s,w))) . sequence . map (\(a,s,w) -> sequence ((s,w),a)) $ m (zero,zero) instance (Monoid w,MonadError e m) => MonadError e (RWST r w s m) where throw = lift.throw catch f (RWST m) = RWST (\x -> catch (flip runRWST x.f) (m x)) instance Monoid w => MonadTrans (RWST r w s) where lift m = RWST (\ ~(_,s) -> (,s,zero) <$> m) instance (Monoid w) => MonadInternal (RWST r w s) where internal f (RWST m) = RWST (\ x -> f (m x <&> \ ~(a,s,w) -> ((s,w),a) ) <&> \ ~((s,w),b) -> (b,s,w) ) {-| A simple State Monad -} class Monad m => MonadState s m | m -> s where get :: m s put :: s -> m () put = modify . const modify :: (s -> s) -> m () modify f = get >>= put . f instance MonadState (IO ()) IO where get = return unit put a = a modify f = put (f unit) type IOLens a = Lens' (IO ()) (IO a) _ioref :: IORef a -> IOLens a _ioref r = lens (const (readIORef r)) (\x a -> x >> a >>= writeIORef r) _mvar :: MVar a -> IOLens a _mvar r = lens (const (readMVar r)) (\x a -> x >> a >>= putMVar r) get_ :: (MonadTrans t, MonadState a m) => t m a get_ = lift get put_ :: (MonadTrans t, MonadState s m) => s -> t m () put_ = lift . put modify_ :: (MonadTrans t, MonadState s m) => (s -> s) -> t m () modify_ = lift . modify newtype StateT s m a = StateT (RWST Void Void s m a) deriving (Unit,Functor,Applicative,Monad,MonadFix, MonadTrans,MonadInternal, MonadCont,MonadState s) type State s a = StateT s Id a instance MonadReader r m => MonadReader r (StateT s m) where ask = ask_ ; local = local_ instance MonadWriter w m => MonadWriter w (StateT s m) where tell = tell_ ; listen = listen_ ; censor = censor_ deriving instance MonadError e m => MonadError e (StateT s m) deriving instance Semigroup (m (a,s,Void)) => Semigroup (StateT s m a) deriving instance Monoid (m (a,s,Void)) => Monoid (StateT s m a) deriving instance Ring (m (a,s,Void)) => Ring (StateT s m a) _StateT :: Iso (StateT s m a) (StateT t n b) (RWST Void Void s m a) (RWST Void Void t n b) _StateT = iso StateT (\ ~(StateT s) -> s) stateT :: (Functor m,Functor n) => Iso (StateT s m a) (StateT t n b) (s -> m (s,a)) (t -> n (t,b)) stateT = _mapping (_mapping $ iso (\ ~(s,a) -> (a,s,zero) ) (\(a,s,_) -> (s,a))) ._promapping _iso._RWST._StateT eval :: (Functor f, Functor f') => f (f' (a, b)) -> f (f' b) eval = map2 snd exec :: (Functor f, Functor f') => f (f' (a, b)) -> f (f' a) exec = map2 fst state :: Iso (State s a) (State t b) (s -> (s,a)) (t -> (t,b)) state = _mapping _Id.stateT (=-) :: MonadState s m => Lens' s s' -> s' -> m () infixl 0 =-,=~ l =- x = modify (set l x) (=~) :: MonadState s m => Lens' s s' -> (s' -> s') -> m () l =~ f = modify (warp l f) gets :: MonadState s m => Lens' s s' -> m s' gets l = at l<$>get saving :: MonadState s m => Lens' s s' -> m a -> m a saving l st = gets l >>= \s -> st <* (l =- s) -- * The State Arrow newtype StateA m s a = StateA (StateT s m a) stateA :: Iso (StateA m s a) (StateA m' s' a') (StateT s m a) (StateT s' m' a') stateA = iso StateA (\(StateA s) -> s) instance Monad m => Category (StateA m) where id = StateA get StateA sbc . StateA sab = StateA $ (^.stateT) $ \a -> (sab^..stateT) a >>= \(a',b) -> (a',).snd <$> (sbc^..stateT) b instance Monad m => Split (StateA m) where StateA sac <#> StateA sbd = StateA $ (^.stateT) $ map2 (\((a',c),(b',d)) -> ((a',b'),(c,d))) $ (Kleisli (sac^..stateT) <#> Kleisli (sbd^..stateT)) ^.. _Kleisli instance Monad m => Choice (StateA m) where StateA sac <|> StateA sbc = StateA $ (^.stateT) $ l Left (sac^..stateT)<|>l Right (sbc^..stateT) where l = map2 . first mapAccum :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b) mapAccum f t = traverse (at state<$>f) t^..state mapAccum_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b mapAccum_ = (map.map.map) snd mapAccum mapAccumR :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> (s, t b) mapAccumR f t = traverse (at (state._Backwards)<$>f) t^..state._Backwards mapAccumR_ :: Traversable t => (a -> s -> (s, b)) -> t a -> s -> t b mapAccumR_ = (map.map.map) snd mapAccumR push :: Traversable t => t a -> a -> t a push = mapAccum_ (,) pop :: Traversable t => t a -> a -> t a pop = mapAccumR_ (,) withPrev :: Traversable t => a -> t a -> t (a,a) withPrev = flip (mapAccum_ (\a p -> (a,(p,a)))) withNext :: Traversable t => t a -> a -> t (a,a) withNext = mapAccumR_ (\a p -> (a,(p,a))) class Monad m => MonadReader r m | m -> r where ask :: m r local :: (r -> r) -> m a -> m a instance MonadReader r ((->) r) where ask = id ; local = (>>>) ask_ :: (MonadTrans t, MonadReader a m) => t m a ask_ = lift ask local_ :: (MonadInternal t, MonadReader r m) => (r -> r) -> t m a -> t m a local_ f = internal (local f) {-| A simple Reader monad -} newtype ReaderT r m a = ReaderT (RWST r Void Void m a) deriving (Functor,Unit,Applicative,Monad,MonadFix, MonadTrans,MonadInternal, MonadReader r,MonadCont) type Reader r a = ReaderT r Id a _readerT :: (Functor m,Functor m') => Iso (ReaderT r m a) (ReaderT r' m' b) (r -> m a) (r' -> m' b) _readerT = iso readerT runReaderT where readerT f = ReaderT (RWST (\ ~(r,_) -> f r<&>(,zero,zero) )) runReaderT (ReaderT (RWST f)) r = f (r,zero) <&> \ ~(a,_,_) -> a _reader :: Iso (Reader r a) (Reader r' b) (r -> a) (r' -> b) _reader = _mapping _Id._readerT instance MonadState s m => MonadState s (ReaderT r m) where get = get_ ; put = put_ ; modify = modify_ instance MonadWriter w m => MonadWriter w (ReaderT r m) where tell = tell_ ; listen = listen_ ; censor = censor_ deriving instance Semigroup (m (a,Void,Void)) => Semigroup (ReaderT r m a) deriving instance Monoid (m (a,Void,Void)) => Monoid (ReaderT r m a) deriving instance Ring (m (a,Void,Void)) => Ring (ReaderT r m a) class (Monad m,Monoid w) => MonadWriter w m | m -> w where tell :: w -> m () listen :: m a -> m (w,a) censor :: m (a,w -> w) -> m a tell_ :: (MonadWriter w m, MonadTrans t) => w -> t m () tell_ = lift . tell listen_ :: (MonadInternal t, MonadWriter w m) => t m a -> t m (w, a) listen_ = internal (\m -> listen m <&> \(w,(c,a)) -> (c,(w,a)) ) censor_ :: (MonadInternal t, MonadWriter w m) => t m (a, w -> w) -> t m a censor_ = internal (\m -> censor (m <&> \(c,(a,f)) -> ((c,a),f))) instance Monoid w => MonadWriter w ((,) w) where tell w = (w,()) listen m@(w,_) = (w,m) censor ~(w,~(a,f)) = (f w,a) mute :: (MonadWriter w m,Monoid w) => m a -> m a mute m = censor (m<&>(,const zero)) intercept :: (MonadWriter w m,Monoid w) => m a -> m (w,a) intercept = listen >>> mute {-| A simple Writer monad -} newtype WriterT w m a = WriterT (RWST Void w Void m a) deriving (Unit,Functor,Applicative,Monad,MonadFix ,Foldable,Traversable ,MonadTrans,MonadInternal ,MonadWriter w,MonadCont) type Writer w a = WriterT w Id a instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where ask = ask_ ; local = local_ instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where get = get_ ; put = put_ ; modify = modify_ deriving instance Semigroup (m (a,Void,w)) => Semigroup (WriterT w m a) deriving instance Monoid (m (a,Void,w)) => Monoid (WriterT w m a) deriving instance Ring (m (a,Void,w)) => Ring (WriterT w m a) _writerT :: (Functor m,Functor m') => Iso (WriterT w m a) (WriterT w' m' b) (m (w,a)) (m' (w',b)) _writerT = iso writerT runWriterT where writerT mw = WriterT (RWST (pure (mw <&> \ ~(w,a) -> (a,zero,w) ))) runWriterT (WriterT (RWST m)) = m (zero,zero) <&> \ ~(a,_,w) -> (w,a) _writer :: Iso (Writer w a) (Writer w' b) (w,a) (w',b) _writer = _Id._writerT {-| A simple continuation monad implementation -} class Monad m => MonadCont m where callCC :: ((a -> m b) -> m a) -> m a newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r } deriving (Semigroup,Monoid,Ring) type Cont r a = ContT r Id a instance Unit m => Unit (ContT r m) where pure a = ContT ($a) instance Functor f => Functor (ContT r f) where map f (ContT c) = ContT (\kb -> c (kb . f)) instance Applicative m => Applicative (ContT r m) where ContT cf <*> ContT ca = ContT (\kb -> cf (\f -> ca (\a -> kb (f a)))) instance Monad m => Monad (ContT r m) where ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc)) instance MonadTrans (ContT r) where lift m = ContT (m >>=) instance Monad m => MonadCont (ContT r m) where callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k) evalContT :: Unit m => ContT r m r -> m r evalContT c = runContT c return evalCont :: Cont r r -> r evalCont = getId . evalContT instance MonadTrans Backwards where lift = Backwards instance MonadFix m => Monad (Backwards m) where join (Backwards ma) = Backwards$mfixing (\a -> liftA2 (,) (forwards a) ma) class Monad m => MonadList m where fork :: [a] -> m a instance MonadList [] where fork = id newtype ListT m a = ListT ((m:.:[]) a) deriving (Semigroup,Monoid, Functor,Applicative,Unit,Monad, Foldable,Traversable) _listT :: Iso (ListT m a) (ListT m' a') (m [a]) (m' [a']) _listT = iso (ListT . Compose) (\(ListT (Compose m)) -> m) instance Monad m => MonadList (ListT m) where fork = at _listT . return instance MonadFix m => MonadFix (ListT m) where mfix f = at _listT (mfix (at' _listT . f . head)) instance MonadTrans ListT where lift ma = (return<$>ma)^._listT instance MonadState s m => MonadState s (ListT m) where get = get_ ; modify = modify_ ; put = put_ instance MonadWriter w m => MonadWriter w (ListT m) where tell = lift.tell listen = _listT-.map sequence.listen.-_listT censor = _listT-.censor.map (\l -> (fst<$>l,compose (snd<$>l))).-_listT instance Monad m => MonadError Void (ListT m) where throw = const zero catch f mm = mm & _listT %%~ (\m -> m >>= \_l -> case _l of [] -> f zero^.._listT; l -> pure l) class Monad m => MonadError e m | m -> e where throw :: e -> m a catch :: (e -> m a) -> m a -> m a try :: MonadError Void m => m a -> m a -> m a try d = catch (\x -> const d (x::Void)) tryMay :: MonadError Ex.SomeException m => m a -> m (Maybe a) tryMay m = catch (\(Ex.SomeException _) -> return Nothing) (Just<$>m) instance MonadError e (Either e) where throw = Left catch f = f<|>Right instance MonadError Void [] where throw = const zero catch f [] = f zero catch _ l = l newtype EitherT e m a = EitherT ((m:.:Either e) a) deriving (Unit,Functor,Applicative,Monad,MonadFix ,Foldable,Traversable) instance MonadTrans (EitherT e) where lift m = (pure<$>m)^._eitherT _eitherT :: (Functor m) => Iso (EitherT e m a) (EitherT f m b) (m (e:+:a)) (m (f:+:b)) _eitherT = iso (EitherT . Compose) (\(EitherT (Compose e)) -> e) instance Applicative Maybe instance Monad Maybe where join = fold instance MonadError Void Maybe where throw = const Nothing catch f Nothing = f zero catch _ a = a instance Ex.Exception e => MonadError e IO where throw = Ex.throw catch = flip Ex.catch