{-# LANGUAGE MultiParamTypeClasses, TupleSections, Rank2Types, UndecidableInstances #-}
module Clean.Monad(
  module Clean.Applicative,

  -- * The basic Monad interface
  Monad(..),MonadFix(..),MonadTrans(..),
  (=<<),(>>),return,
  
  -- * Common monads
  -- ** The State Monad
  MonadState(..),
  StateT(..),State,
  evalStateT,execStateT,runState,execState,evalState,

  -- ** The Reader monad
  MonadReader(..),ReaderT(..),Reader,

  -- ** The Writer monad
  MonadWriter(..),WriterT(..),Writer,runWriter,

  -- ** The Continuation monad
  MonadCont(..),ContT(..),Cont,evalContT,evalCont


  ) where

import Clean.Classes
import Clean.Applicative
import Clean.Core hiding (flip)
import Clean.Traversable

class Monad m => MonadFix m where
  mfix :: (a -> m a) -> m a
instance MonadFix Id where mfix = cfix
instance MonadFix ((->) b) where mfix = cfix
instance MonadFix [] where mfix f = fix (f . head)
fix f = f (fix f)
cfix f = map fix (collect f) 

class MonadTrans t where
  lift :: Monad m => m a -> t m a
  internal :: Monad m => (forall c. m (c,a) -> m (c,b)) -> t m a -> t m b
pure_ = lift . pure

(>>) = (*>)
(=<<) = flip (>>=)
return = pure

class Monad m => MonadState s m where
  get :: m s
  put :: s -> m ()
  put = modify . const
  modify :: (s -> s) -> m ()
  modify f = get >>= put . f
get_ = lift get ; put_ = lift . put ; modify_ = lift . modify  

{-| A simple State Monad  -}
newtype StateT s m a = StateT { runStateT :: s -> m (s,a) }
type State s a = StateT s Id a
instance Unit m => Unit (StateT s m) where pure a = StateT (\s -> pure (s,a))
instance Monad m => Functor (StateT s m) 
instance Monad m => Applicative (StateT s m)
instance Monad m => Monad (StateT s m) where
  StateT st >>= k = StateT (\s -> st s >>= \ ~(s',a) -> runStateT (k a) s')
instance MonadTrans (StateT s) where
  lift m = StateT (\s -> map (s,) m)
  internal f (StateT st) = StateT (f . st)
instance Monad m => MonadState s (StateT s m) where
  get = StateT (\s -> pure (s,s))
  put x = StateT (\_ -> pure (x,()))
  modify f = StateT (\s -> pure (f s,()))
instance MonadReader r m => MonadReader r (StateT s m) where
  ask = ask_ ; local = local_
instance MonadWriter w m => MonadWriter w (StateT s m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
instance MonadCont m => MonadCont (StateT s m) where
  callCC f = StateT (\s -> callCC $ \k -> runStateT (f (\a -> lift (k (s,a)))) s)
instance MonadFix m => MonadFix (StateT s m) where
  mfix f = StateT (\s -> mfix (\ ~(_,a) -> runStateT (f a) s))
deriving instance Semigroup (m (s,a)) => Semigroup (StateT s m a)
deriving instance Monoid (m (s,a)) => Monoid (StateT s m a)
deriving instance Ring (m (s,a)) => Ring (StateT s m a)

evalStateT = map (map snd) . runStateT 
execStateT = map (map fst) . runStateT
runState :: State s a -> s -> (s,a)
runState = map getId . runStateT
execState :: State s a -> s -> s
execState = map fst . runState
evalState :: State s a -> s -> a
evalState = map snd . runState

class Monad m => MonadReader r m where
  ask :: m r
  local :: (r -> r) -> m a -> m a
ask_ = lift ask ; local_ f m = internal (local f) m
{-| A simple Reader monad -}
newtype ReaderT r m a = ReaderT { runReaderT :: r -> m a }
                      deriving (Semigroup,Monoid,Ring)
type Reader r a = ReaderT r Id a
instance MonadTrans (ReaderT r) where
  lift m = ReaderT (const m)
  internal f (ReaderT r) = ReaderT (map snd . f . map ((),) . r)
instance Monad m => Unit (ReaderT r m) where pure = pure_
instance Monad m => Functor (ReaderT r m)
instance Monad m => Applicative (ReaderT r m)
instance Monad m => Monad (ReaderT r m) where
  ReaderT rd >>= k = ReaderT (\r -> rd r >>= \a -> runReaderT (k a) r)
instance Monad m => MonadReader r (ReaderT r m) where
  ask = ReaderT pure
  local f (ReaderT rd) = ReaderT (rd . f)
instance MonadState s m => MonadState s (ReaderT r m) where
  get = get_ ; put = put_ ; modify = modify_
instance MonadWriter w m => MonadWriter w (ReaderT r m) where
  tell = tell_ ; listen = listen_ ; censor = censor_
instance MonadCont m => MonadCont (ReaderT r m) where
  callCC f = ReaderT (\r -> callCC (\k -> runReaderT (f (lift . k)) r))
instance MonadFix m => MonadFix (ReaderT r m) where
  mfix f = ReaderT (\r -> mfix (\a -> runReaderT (f a) r))

class (Monad m,Monoid w) => MonadWriter w m where
  tell :: w -> m ()
  listen :: m a -> m (w,a)
  censor :: m (a,w -> w) -> m a
tell_ = lift . tell
listen_ = internal (\m -> listen m<&> \(w,(c,a)) -> (c,(w,a)))
censor_ = internal (\m -> censor (m<&> \(c,(a,f)) -> ((c,a),f)))

{-| A simple Writer monad -}
newtype WriterT w m a = WriterT { runWriterT :: m (w,a) }
type Writer w a = WriterT w Id a
instance Monoid w => MonadTrans (WriterT w) where
  lift m = WriterT (map (zero,) m)
  internal f (WriterT m) = WriterT (f m)
instance (Monoid w,Monad m) => Unit (WriterT w m) where pure = pure_
instance (Monoid w,Monad m) => Functor (WriterT w m)
instance (Monoid w,Monad m) => Applicative (WriterT w m)
instance (Monoid w,Monad m) => Monad (WriterT w m) where
  wr >>= k = WriterT $ do
    (w,a) <- runWriterT wr
    map (first (w+)) (runWriterT (k a))
instance (Monad m,Monoid w) => MonadWriter w (WriterT w m) where
  tell w = WriterT (pure (w,()))
  listen (WriterT m) = WriterT (m<&> \ ~(w,a) -> (w,(w,a)))
  censor (WriterT m) = WriterT (m<&> \ ~(w,~(a,f)) -> (f w,a))
instance (Monoid w,MonadReader r m) => MonadReader r (WriterT w m) where
  ask = ask_ ; local = local_
instance (Monoid w,MonadState r m) => MonadState r (WriterT w m) where
  get = get_ ; put = put_ ; modify = modify_
deriving instance Semigroup (m (w,a)) => Semigroup (WriterT w m a)
deriving instance Monoid (m (w,a)) => Monoid (WriterT w m a)
deriving instance Ring (m (w,a)) => Ring (WriterT w m a)
instance (Monoid w,MonadFix m) => MonadFix (WriterT w m) where
  mfix f = WriterT (mfix (runWriterT . f . snd))

runWriter = getId . runWriterT

{-| A simple continuation monad implementation  -}
class Monad m => MonadCont m where
  callCC :: ((a -> m b) -> m a) -> m a

newtype ContT r m a = ContT { runContT :: (a -> m r) -> m r }
                      deriving (Semigroup,Monoid,Ring)
type Cont r a = ContT r Id a
instance Unit m => Unit (ContT r m) where pure a = ContT ($a)
instance Monad m => Functor (ContT r m)
instance Monad m => Applicative (ContT r m)
instance Monad m => Monad (ContT r m) where
  ContT k >>= f = ContT (\cc -> k (\a -> runContT (f a) cc))
instance MonadTrans (ContT r) where
  lift m = ContT (m >>=)
  internal _ (ContT _) = undefined
instance Monad m => MonadCont (ContT r m) where
  callCC f = ContT (\k -> runContT (f (\a -> ContT (\_ -> k a))) k)

evalContT c = runContT c return
evalCont = getId . evalContT

instance MonadTrans Backwards where
  lift = Backwards
  internal f (Backwards m) = Backwards (snd<$>f (((),)<$>m))
instance MonadFix m => Monad (Backwards m) where
  Backwards ma >>= k = Backwards$fst<$>mfix (\r -> liftA2 (,) (forwards (k (snd r))) ma)