{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, FlexibleContexts #-} module Control.Monad.Distributive (Distributive(dist)) where import qualified Control.Monad.State.Strict as Strict import Control.Monad.State import Control.Monad.Writer import Control.Monad.Reader import Control.Monad.Identity import Data.Functor.Compose -- | Monads that distribute over one another. class (Monad m, Monad n) => Distributive m n where dist :: m (n t) -> n (m t) instance (Monad n) => Distributive Maybe n where dist (Just m) = liftM Just m dist Nothing = return Nothing instance (Monad n) => Distributive (Either t) n where dist (Left x) = return (Left x) dist (Right m) = liftM Right m instance (Monad n, Monoid x) => Distributive (WriterT x Identity) n where dist wr = let (m, w) = runWriter wr in m >>= \x -> return (tell w >> return x) instance (Monad n) => Distributive [] n where dist = sequence instance (Monad m) => Distributive m (Strict.StateT v Identity) where dist m = get >>= \x -> return (m >>= \st -> return $ Strict.evalState st x) instance (Monad m) => Distributive m (StateT v Identity) where dist m = get >>= \x -> return (m >>= \st -> return $ evalState st x) instance (Monad m) => Distributive m (ReaderT v Identity) where dist m = ask >>= \x -> return (m >>= \rd -> return $ runReader rd x) instance (Monad m) => Distributive m Identity where dist m = Identity (m >>= return . runIdentity) instance (Monad n) => Distributive Identity n where dist (Identity m) = m >>= return . Identity instance (Distributive n2 n, Distributive m n2, Distributive m n) => Distributive m (Compose n n2) where dist = Compose . liftM dist . dist . liftM getCompose instance (Distributive m n, Distributive m2 m, Distributive m2 n) => Distributive (Compose m m2) n where dist = liftM Compose . dist . liftM dist . getCompose join' m = Compose $ join $ liftM (liftM join . dist) m -- | Monads with a distributive law compose to give another monad. instance (Distributive n m) => Monad (Compose m n) where return = Compose . return . return Compose m >>= f = join' (liftM (liftM (getCompose . f)) m) instance (Monad m) => MonadTrans (Compose m) where lift = Compose . return