module Control.Monad.Distributive where
import qualified Control.Monad.State.Strict as Strict
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.Maybe
import Control.Monad.Error
import Control.Monad.List
import Control.Monad.Identity
import Control.Monad.Morph
swap ~(x, y) = (y, x)
class (MonadTrans m) => Takeout m y | m -> y where
takeout :: (Monad n) => m n t -> m Identity (n (y t))
combine :: (Monad x) => m x (y t) -> m x t
instance Takeout (Strict.StateT s) ((,) s) where
takeout m = Strict.get >>= return . liftM swap . Strict.runStateT m
combine m = m >>= \(s, x) -> Strict.put s >> return x
instance Takeout (StateT s) ((,) s) where
takeout m = get >>= return . liftM swap . runStateT m
combine m = m >>= \(s, x) -> put s >> return x
instance Takeout (ReaderT r) Identity where
takeout m = ask >>= return . liftM Identity . runReaderT m
combine = liftM runIdentity
instance (Monoid w) => Takeout (WriterT w) ((,) w) where
takeout = return . liftM swap . runWriterT
combine m = m >>= \(w, x) -> tell w >> return x
putin m = hoist lift (liftM (hoist (return . runIdentity)) m) >>= lift
putin1 m = hoist (return . runIdentity) m >>= lift
class Leftdistr m where
ldist :: (Monad (n x), Monad x) => m (n x) t -> n x (m Identity t)
class Rightdistr m where
rdist :: (Monad (n Identity), Monad (n x), MonadTrans n, MFunctor n, Monad x) => n Identity (m x t) -> m (n x) (n x t)
instance Leftdistr MaybeT where
ldist m = runMaybeT m >>= return . maybe mzero return
instance (Error t) => Leftdistr (ErrorT t) where
ldist m = runErrorT m >>= return . either throwError return
instance (Monoid x) => Leftdistr (WriterT x) where
ldist m = runWriterT m >>= \(x, w) -> return $ tell w >> return x
instance Leftdistr ListT where
ldist m = runListT m >>= return . msum . map return
instance Rightdistr (Strict.StateT v) where
rdist m = get >>= \s -> return $ putin1 $ liftM (`Strict.evalStateT` s) m
instance Rightdistr (StateT v) where
rdist m = get >>= \s -> return $ putin1 $ liftM (`evalStateT` s) m
instance Rightdistr (ReaderT v) where
rdist m = ask >>= \v -> return $ putin1 $ liftM (`runReaderT` v) m
ldist' m = putin $ ldist m
rdist' m = liftM combine (rdist $ takeout m) >>= lift