{-# LANGUAGE MultiParamTypeClasses, FlexibleContexts, FunctionalDependencies, TypeOperators #-}

module Control.Monad.Distributive where

import qualified Control.Monad.State.Strict as Strict
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Reader
import Control.Monad.Maybe
import Control.Monad.Error
import Control.Monad.List
import Control.Monad.Identity
import Control.Monad.Morph

swap ~(x, y) = (y, x)

class (MonadTrans m) => Takeout m y | m -> y where
	-- | Pop out the underlying monad of a transformer, with a data structure to hold the state.
	takeout :: (Monad n) => m n t -> m Identity (n (y t))
	-- | Put the data structure back in.
	combine :: (Monad x) => m x (y t) -> m x t

instance Takeout (Strict.StateT s) ((,) s) where
	takeout m = Strict.get >>= return . liftM swap . Strict.runStateT m
	combine m = m >>= \(s, x) -> Strict.put s >> return x

instance Takeout (StateT s) ((,) s) where
	takeout m = get >>= return . liftM swap . runStateT m
	combine m = m >>= \(s, x) -> put s >> return x

instance Takeout (ReaderT r) Identity where
	takeout m = ask >>= return . liftM Identity . runReaderT m
	combine = liftM runIdentity

instance (Monoid w) => Takeout (WriterT w) ((,) w) where
	takeout = return . liftM swap . runWriterT
	combine m = m >>= \(w, x) -> tell w >> return x

-- | The opposite of takeout.
putin m = hoist lift (liftM (hoist (return . runIdentity)) m) >>= lift

putin1 m = hoist (return . runIdentity) m >>= lift

-- | Transformers that distribute over one another.
--
--   For reorganizing a monad stack.
class Leftdistr m where
	ldist :: (Monad (n x), Monad x) => m (n x) t -> n x (m Identity t)

class Rightdistr m where
	rdist :: (Monad (n Identity), Monad (n x), MonadTrans n, MFunctor n, Monad x) => n Identity (m x t) -> m (n x) (n x t)

instance Leftdistr MaybeT where
	ldist m = runMaybeT m >>= return . maybe mzero return

instance (Error t) => Leftdistr (ErrorT t) where
	ldist m = runErrorT m >>= return . either throwError return

instance (Monoid x) => Leftdistr (WriterT x) where
	ldist m = runWriterT m >>= \(x, w) -> return $ tell w >> return x

instance Leftdistr ListT where
	ldist m = runListT m >>= return . msum . map return

instance Rightdistr (Strict.StateT v) where
	rdist m = get >>= \s -> return $ putin1 $ liftM (`Strict.evalStateT` s) m

instance Rightdistr (StateT v) where
	rdist m = get >>= \s -> return $ putin1 $ liftM (`evalStateT` s) m

instance Rightdistr (ReaderT v) where
	rdist m = ask >>= \v -> return $ putin1 $ liftM (`runReaderT` v) m

-- | Left distributivity of a monad transformer.
ldist' m = putin $ ldist m

-- | Right distributivity.
rdist' m = liftM combine (rdist $ takeout m) >>= lift