{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, TupleSections, Rank2Types #-}
module Clean.Monad(
  module Clean.Applicative,

  Monad(..),MonadFix(..),MonadTrans(..),

  MonadState(..),
  MonadReader(..),MonadWriter(..),

  StateT(..),State,
  ReaderT(..),Reader,WriterT(..),Writer,
  ContT(..),Cont,
  
  (=<<),(>>),return
  ) where

import Clean.Classes
import Clean.Applicative
import Clean.Core hiding (flip)
import Clean.Traversable

class MonadFix m where
  mfix :: (a -> m a) -> m a
class Monad m => MonadState s m where
  get :: m s
  put :: s -> m ()
  put = modify . const
  modify :: (s -> s) -> m ()
  modify f = get >>= put . f
class Monad m => MonadReader r m where
  ask :: m r
  local :: (r -> r) -> m a -> m a
class (Monad m,Monoid w) => MonadWriter w m where
  tell :: w -> m ()
  listen :: m a -> m (w,a)
  censor :: m (a,w -> w) -> m a

class MonadTrans t where
  lift :: Monad m => m a -> t m a
  internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> t m a -> t m b
pure_ = lift . pure
get_ = lift get ; put_ = lift . put ; modify_ = lift . modify  
ask_ = lift ask ; local_ f m = internal (local f) m
tell_ = lift . tell
listen_ = internal (\m -> listen m<&> \(w,(c,a)) -> (c,(w,a)))
censor_ = internal (\m -> censor (m<&> \(c,(a,f)) -> ((c,a),f)))

fix f = f (fix f)
cfix f = map fix (collect f) 
instance MonadFix Id where mfix = cfix
instance MonadFix ((->) b) where mfix = cfix

{-| A simple State Monad  -}
newtype StateT s m a = StateT { runStateT :: s -> m (s,a) }
type State s a = StateT s Id a
instance Unit m => Unit (StateT s m) where pure a = StateT (\s -> pure (s,a))
instance Monad m => Functor (StateT s m) 
instance Monad m => Applicative (StateT s m)
instance Monad m => Monad (StateT s m) where
  StateT st >>= k = StateT (\s -> st s >>= \ ~(s',a) -> runStateT (k a) s')
instance MonadTrans (StateT s) where
  lift m = StateT (\s -> map (s,) m)
  internal f (StateT st) = StateT (f . st)
instance Monad m => MonadState s (StateT s m) where
  get = StateT (\s -> pure (s,s))
  put x = StateT (\_ -> pure (x,()))
  modify f = StateT (\s -> pure (f s,()))
instance MonadReader r m => MonadReader r (StateT s m) where
  ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
  
{-| A simple Reader monad -}
newtype ReaderT r m a = ReaderT { runReaderT :: r -> m a }
type Reader r a = ReaderT r Id a
instance MonadTrans (ReaderT r) where
  lift m = ReaderT (const m)
  internal f (ReaderT r) = ReaderT (map snd . f . map ((),) . r)
instance Monad m => Unit (ReaderT r m) where pure = pure_
instance Monad m => Functor (ReaderT r m)
instance Monad m => Applicative (ReaderT r m)
instance Monad m => Monad (ReaderT r m) where
  ReaderT rd >>= k = ReaderT (\r -> rd r >>= \a -> runReaderT (k a) r)
instance Monad m => MonadReader r (ReaderT r m) where
  ask = ReaderT pure
  local f (ReaderT rd) = ReaderT (rd . f)
instance MonadState s m => MonadState s (ReaderT r m) where
  get = get_ ; put = put_ ; modify = modify_
instance MonadWriter w m => MonadWriter w (ReaderT r m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
  
{-| A simple Writer monad -}
newtype WriterT w m a = WriterT { runWriterT :: m (w,a) }
type Writer w a = WriterT w Id a
instance Monoid w => MonadTrans (WriterT w) where
  lift m = WriterT (map (zero,) m)
  internal f (WriterT m) = WriterT (f m)
instance (Monoid w,Monad m) => Unit (WriterT w m) where pure = pure_
instance (Monoid w,Monad m) => Functor (WriterT w m)
instance (Monoid w,Monad m) => Applicative (WriterT w m)
instance (Monoid w,Monad m) => Monad (WriterT w m) where
  wr >>= k = WriterT $ do
    (w,a) <- runWriterT wr
    map (first (w+)) (runWriterT (k a))
instance (Monad m,Monoid w) => MonadWriter w (WriterT w m) where
  tell w = WriterT (pure (w,()))
  listen (WriterT m) = WriterT (m<&> \(w,a) -> (w,(w,a)))
  censor (WriterT m) = WriterT (m<&> \(w,(a,f)) -> (f w,a))
instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where
  ask = ask_ ; local = local_
instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where
  get = get_ ; put = put_ ; modify = modify_

{-| A simple continuation monad implementation  -}
newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
type Cont r a = ContT r Id a
instance Unit m => Unit (ContT r m) where pure a = ContT ($a)
instance Monad m => Functor (ContT r m)
instance Monad m => Applicative (ContT r m)
instance Monad m => Monad (ContT r m) where
  ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc))
instance MonadTrans (ContT r) where
  lift m = ContT (m >>=)
  internal _ (ContT _) = undefined

(>>) = (*>)
(=<<) = flip (>>=)
return = pure