{-# LANGUAGE Safe #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-|  This module contains a collection of monads that
   are defined in terms of the monad transformers from
   "MonadLib".   The definitions in this module are
   completely mechanical and so this module may become
   obsolete if support for automated derivations for instances
   becomes well supported across implementations.
 -}
module MonadLib.Monads (
  Reader, Writer, State, Exception, Cont,
  runReader, runWriter, runState, runException, runCont,
  module MonadLib
) where
import MonadLib
import MonadLib.Derive
import Control.Monad.Fix

newtype Reader    i a = R' { unR :: ReaderT    i Id a }
newtype Writer    i a = W' { unW :: WriterT    i Id a }
newtype State     i a = S' { unS :: StateT     i Id a }
newtype Exception i a = X' { unX :: ExceptionT i Id a }
newtype Cont      i a = C' { unC :: ContT      i Id a }

iso_R :: Iso (ReaderT i Id) (Reader i)
iso_W :: Iso (WriterT i Id) (Writer i)
iso_S :: Iso (StateT i Id) (State i)
iso_X :: Iso (ExceptionT i Id) (Exception i)
iso_C :: Iso (ContT i Id) (Cont i)

iso_R = Iso R' unR
iso_W = Iso W' unW
iso_S = Iso S' unS
iso_X = Iso X' unX
iso_C = Iso C' unC

instance               BaseM (Reader    i) (Reader    i) where inBase = id
instance (Monoid i) => BaseM (Writer    i) (Writer    i) where inBase = id
instance               BaseM (State     i) (State     i) where inBase = id
instance               BaseM (Exception i) (Exception i) where inBase = id
instance               BaseM (Cont      i) (Cont      i) where inBase = id

instance Monad (Reader i) where
  (>>=)   = derive_bind iso_R

#if !MIN_VERSION_base(4,11,0)
  fail = error
#endif

instance (Monoid i) => Monad (Writer i) where
  (>>=)   = derive_bind iso_W

#if !MIN_VERSION_base(4,11,0)
  fail = error
#endif

instance Monad (State i) where
  (>>=)   = derive_bind iso_S

#if !MIN_VERSION_base(4,11,0)
  fail = error
#endif


instance Monad (Exception i) where
  (>>=)   = derive_bind iso_X

#if !MIN_VERSION_base(4,11,0)
  fail = error
#endif

instance Monad (Cont i) where
  (>>=)   = derive_bind iso_C

#if !MIN_VERSION_base(4,11,0)
  fail = error
#endif

instance               Functor (Reader    i) where fmap = derive_fmap iso_R
instance (Monoid i) => Functor (Writer    i) where fmap = derive_fmap iso_W
instance               Functor (State     i) where fmap = derive_fmap iso_S
instance               Functor (Exception i) where fmap = derive_fmap iso_X
instance               Functor (Cont      i) where fmap = derive_fmap iso_C

instance Applicative (Reader i) where
  pure = derive_return iso_R
  (<*>) = ap

instance (Monoid i) => Applicative (Writer i) where
  pure = derive_return iso_W
  (<*>) = ap

instance Applicative (State i) where
  pure = derive_return iso_S
  (<*>) = ap

instance Applicative (Exception i) where
  pure = derive_return iso_X
  (<*>) = ap

instance Applicative (Cont i) where
  pure = derive_return iso_C
  (<*>) = ap

instance               MonadFix (Reader    i) where mfix = derive_mfix iso_R
instance (Monoid i) => MonadFix (Writer    i) where mfix = derive_mfix iso_W
instance               MonadFix (State     i) where mfix = derive_mfix iso_S
instance               MonadFix (Exception i) where mfix = derive_mfix iso_X

instance ReaderM (Reader i) i where ask = derive_ask iso_R
instance (Monoid i) => WriterM (Writer i) i where put = derive_put iso_W
instance StateM (State i) i where get = derive_get iso_S; set = derive_set iso_S
instance ExceptionM (Exception i) i where raise = derive_raise iso_X
instance ContM (Cont i) where callWithCC = derive_callWithCC iso_C

runReader     :: i -> Reader i a -> a
runWriter     :: Writer i a -> (a,i)
runState      :: i -> State i a -> (a,i)
runException  :: Exception i a -> Either i a
runCont       :: (a -> i) -> Cont i a -> i

runReader i  = runId . runReaderT i          . unR
runWriter    = runId . runWriterT            . unW
runState  i  = runId . runStateT i           . unS
runException = runId . runExceptionT         . unX
runCont   i  = runId . runContT (return . i) . unC

instance RunReaderM (Reader i) i where
  local = derive_local iso_R

instance (Monoid i) => RunWriterM (Writer i) i where
  collect = derive_collect iso_W

instance RunExceptionM (Exception i) i where
  try = derive_try iso_X