-- | The multi-valued version of mtl's Reader / ReaderT
-- / MonadReader
module Control.Monad.MultiReader
  ( -- * MultiReaderT
  , MultiReaderTNull
  , MultiReader
  -- * MonadMultiReader class
  , MonadMultiReader(..)
  -- * functions
  , mAskRaw
  , withMultiReader
  , withMultiReaders
  , evalMultiReaderT
  , evalMultiReaderTWithInitial
  , mapMultiReaderT
) where

import Data.HList.HList

import Control.Monad.State.Strict ( StateT(..)
                                  , MonadState(..)
                                  , evalStateT
                                  , mapStateT )
import Control.Monad.Trans.Class  ( MonadTrans
                                  , lift )
import Control.Monad.Writer.Class ( MonadWriter
                                  , listen
                                  , tell
                                  , writer
                                  , pass )

import Data.Functor.Identity      ( Identity )

import Control.Applicative        ( Applicative(..) )
import Control.Monad              ( liftM
                                  , ap )

-- | A Reader transformer monad patameterized by:
-- * x - The list of types constituting the environment / input (to be read),
-- * m - The inner monad.
-- 'MultiReaderT' corresponds to mtl's 'ReaderT', but can contain
-- a heterogenous list of types.
-- This heterogenous list is represented using Types.Data.List, i.e:
--   * @'[]@ - The empty list,
--   * @a ': b@ - A list where @/a/@ is an arbitrary type
--     and @/b/@ is the rest list.
-- For example,
-- > MultiReaderT '[Int, Bool] :: (* -> *) -> (* -> *)
-- is a Reader transformer containing the types [Int, Bool].
newtype MultiReaderT x m a = MultiReaderT {
  runMultiReaderTRaw :: StateT (HList x) m a

-- | A MultiReader transformer carrying an empty state.
type MultiReaderTNull = MultiReaderT '[]

-- | A reader monad parameterized by the list of types x of the environment
-- / input to carry.
-- Similar to @Reader r = ReaderT r Identity@
type MultiReader x = MultiReaderT x Identity

class ContainsType a c where
  setHListElem :: a -> HList c -> HList c
  getHListElem :: HList c -> a

-- | All methods must be defined.
-- The idea is: Any monad stack is instance of @MonadMultiReader a@, iff
-- the stack contains a @MultiReaderT x@ with /a/ element of /x/.
class (Monad m) => MonadMultiReader a m where
  mAsk :: m a -- ^ Access to a specific type in the environment.

it might make seem straightforward to define the following class that
corresponds to other transformer classes. But while we can define the the
class and its instances, there is a problem we try to use it, assuming that we
do not want to annotate the full type signature of the config:
  the type of the config can not be inferred properly. we would need a feature
  like "infer, as return type for this function, the only type for
  which there exists a valid chain of instance definitions that is needed to
  by this function".
  In other words, it is impossible to use the mAskRaw function without
  binding a concrete type for c, because otherwise the inference runs into
  some overlapping instances.
For this reason, I removed this type class and created a non-class function
mAskRaw, for which the type inference works because it involves no
type classes.
  lennart spitzner

--class (Monad m) => MonadMultiReaderRaw c m where
--  mAskRaw :: m (HList c)

--instance (MonadTrans t, Monad (t m), MonadMultiReaderRaw c m)
--      => MonadMultiReaderRaw c (t m) where
--  mAskRaw = lift $ mAskRaw

--instance (Monad m) => MonadMultiReaderRaw a (MultiReaderT a m) where
--  mAskRaw = MultiReaderT $ get

instance ContainsType a (a ': xs) where
  setHListElem a (_ :+: xs) = a :+: xs
  getHListElem (x :+: _) = x

instance (ContainsType a xs) => ContainsType a (x ': xs) where
  setHListElem a (x :+: xs) = x :+: setHListElem a xs
  getHListElem (_ :+: xs) = getHListElem xs

instance (Functor f) => Functor (MultiReaderT x f) where
  fmap f = MultiReaderT . fmap f . runMultiReaderTRaw

instance (Applicative m, Monad m) => Applicative (MultiReaderT x m) where
  pure = MultiReaderT . pure
  (<*>) = ap

instance Monad m => Monad (MultiReaderT x m) where
  return = MultiReaderT . return
  k >>= f = MultiReaderT $ runMultiReaderTRaw k >>= (runMultiReaderTRaw.f)

instance MonadTrans (MultiReaderT x) where
  lift = MultiReaderT . lift

-- | Adds an element to the environment, thereby transforming a MultiReaderT
-- carrying an environment with types /(x:xs)/ to a a MultiReaderT with /xs/.
-- Think "Execute this computation with this additional value as environment".
withMultiReader :: Monad m
                => x
                -> MultiReaderT (x ': xs) m a
                -> MultiReaderT xs m a
withMultiReader x k = MultiReaderT $ do
  s <- get
  (a, _  :+: s') <- lift $ runStateT (runMultiReaderTRaw k) (x :+: s)
  put s'
  return a

-- | Adds a heterogenous list of elements to the environment, thereby
-- transforming a MultiReaderT carrying an environment with values
-- over types /xs++ys/ to a MultiReaderT over /ys/.
-- Similar to recursively adding single values with 'withMultiReader'.
-- Note that /ys/ can be Null; in that case the return value can be
-- evaluated further using 'evalMultiReaderT'.
withMultiReaders :: Monad m
                 => HList xs
                 -> MultiReaderT (Append xs ys) m a
                 -> MultiReaderT ys m a
withMultiReaders HNil = id
withMultiReaders (x :+: xs) = withMultiReaders xs . withMultiReader x

instance (Monad m, ContainsType a c)
      => MonadMultiReader a (MultiReaderT c m) where
  mAsk = MultiReaderT $ liftM getHListElem get

instance (MonadTrans t, Monad (t m), MonadMultiReader a m)
      => MonadMultiReader a (t m) where
  mAsk = lift $ mAsk

-- | A raw extractor of the contained HList (i.e. the complete environment).
-- For a possible usecase, see 'withMultiReaders'.
mAskRaw :: Monad m => MultiReaderT a m (HList a)
mAskRaw = MultiReaderT get

-- | Evaluate a computation over an empty environment.
-- Because the environment is empty, it does not need to be provided.
-- If you want to evaluate a computation over any non-Null environment, either
-- use
-- * 'evalMultiReaderTWithInitial'
-- * simplify the computation using 'withMultiReader' / 'withMultiReaders',
--   then use 'evalMultiReaderT' on the result.
evalMultiReaderT :: Monad m => MultiReaderT '[] m a -> m a
evalMultiReaderT k = evalStateT (runMultiReaderTRaw k) HNil

-- | Evaluate a reader computation with the given environment.
evalMultiReaderTWithInitial :: Monad m
                            => HList a            -- ^ The initial state
                            -> MultiReaderT a m b -- ^ The computation to evaluate
                            -> m b
evalMultiReaderTWithInitial c k = evalStateT (runMultiReaderTRaw k) c

-- | Map both the return value and the environment of a computation
-- using the given function.
-- Note that there is a difference to mtl's ReaderT,
-- where it is /not/ possible to modify the environment.
mapMultiReaderT :: (m (a, HList w)
                -> m' (a', HList w))
                -> MultiReaderT w m a
                -> MultiReaderT w m' a'
mapMultiReaderT f = MultiReaderT . mapStateT f . runMultiReaderTRaw

-- foreign lifting instances

instance (MonadState s m) => MonadState s (MultiReaderT c m) where
  put   = lift . put
  get   = lift $ get
  state = lift . state

instance (MonadWriter w m) => MonadWriter w (MultiReaderT c m) where
  writer = lift . writer
  tell   = lift . tell
  listen = MultiReaderT .
    mapStateT (liftM (\((a,w), w') -> ((a, w'), w)) . listen) .
  pass = MultiReaderT .
    mapStateT (pass . liftM (\((a, f), w) -> ((a, w), f))) .