module Control.Monad.Representable.State
   ( State
   , runState
   , evalState
   , execState
   , mapState
   , StateT(..)
   , stateT
   , runStateT
   , evalStateT
   , execStateT
   , mapStateT
   , liftCallCC
   , liftCallCC'
   , MonadState(..)
   ) where
import Control.Applicative
import Data.Functor.Bind
import Data.Functor.Bind.Trans
import Control.Monad.State.Class
import Control.Monad.Cont.Class
import Control.Monad.Reader.Class
import Control.Monad.Writer.Class
import Control.Monad.Free.Class
import Control.Monad.Trans.Class
import Control.Monad.Identity
import Data.Functor.Rep
type State g = StateT g Identity
runState :: Representable g
         => State g a   
         -> Rep g       
         -> (a, Rep g)  
runState m = runIdentity . runStateT m
evalState :: Representable g
          => State g a  
          -> Rep g      
          -> a          
evalState m s = fst (runState m s)
execState :: Representable g
          => State g a  
          -> Rep g      
          -> Rep g      
execState m s = snd (runState m s)
mapState :: Functor g => ((a, Rep g) -> (b, Rep g)) -> State g a -> State g b
mapState f = mapStateT (Identity . f . runIdentity)
newtype StateT g m a = StateT { getStateT :: g (m (a, Rep g)) }
stateT :: Representable g => (Rep g -> m (a, Rep g)) -> StateT g m a
stateT = StateT . tabulate
runStateT :: Representable g => StateT g m a -> Rep g -> m (a, Rep g)
runStateT (StateT m) = index m
mapStateT :: Functor g => (m (a, Rep g) -> n (b, Rep g)) -> StateT g m a -> StateT g n b
mapStateT f (StateT m) = StateT (fmap f m)
evalStateT :: (Representable g, Monad m) => StateT g m a -> Rep g -> m a
evalStateT m s = do
    (a, _) <- runStateT m s
    return a
execStateT :: (Representable g, Monad m) => StateT g m a -> Rep g -> m (Rep g)
execStateT m s = do
    (_, s') <- runStateT m s
    return s'
instance (Functor g, Functor m) => Functor (StateT g m) where
  fmap f = StateT . fmap (fmap (\ ~(a, s) -> (f a, s))) . getStateT
instance (Representable g, Bind m) => Apply (StateT g m) where
  mf <.> ma = mf >>- \f -> fmap f ma
instance (Representable g, Functor m, Monad m) => Applicative (StateT g m) where
  pure = StateT . leftAdjunctRep return
  mf <*> ma = mf >>= \f -> fmap f ma
instance (Representable g, Bind m) => Bind (StateT g m) where
  StateT m >>- f = StateT $ fmap (>>- rightAdjunctRep (runStateT . f)) m
instance (Representable g, Monad m) => Monad (StateT g m) where
  return = StateT . leftAdjunctRep return
  StateT m >>= f = StateT $ fmap (>>= rightAdjunctRep (runStateT . f)) m
instance Representable f => BindTrans (StateT f) where
  liftB m = stateT $ \s -> fmap (\a -> (a, s)) m
instance Representable f => MonadTrans (StateT f) where
  lift m = stateT $ \s -> liftM (\a -> (a, s)) m
instance (Representable g, Monad m, Rep g ~ s) => MonadState s (StateT g m) where
  get = stateT $ \s -> return (s, s)
  put s = StateT $ pureRep $ return ((),s)
#if MIN_VERSION_transformers(0,3,0)
  state f = stateT (return . f)
#endif
instance (Representable g, MonadReader e m) => MonadReader e (StateT g m) where
  ask = lift ask
  local = mapStateT . local
instance (Representable g, MonadWriter w m) => MonadWriter w (StateT g m) where
  tell = lift . tell
  listen = mapStateT $ \ma -> do
     ((a,s'), w) <- listen ma
     return ((a,w), s')
  pass = mapStateT $ \ma -> pass $ do
    ((a, f), s') <- ma
    return ((a, s'), f)
instance (Representable g, MonadCont m) => MonadCont (StateT g m) where
    callCC = liftCallCC' callCC
instance (Functor f, Representable g, MonadFree f m) => MonadFree f (StateT g m) where
    wrap as = stateT $ \s -> wrap (fmap (`runStateT` s) as)
leftAdjunctRep :: Representable u => ((a, Rep u) -> b) -> a -> u b
leftAdjunctRep f a = tabulate (\s -> f (a,s))
rightAdjunctRep :: Representable u => (a -> u b) -> (a, Rep u) -> b
rightAdjunctRep f ~(a, k) = f a `index` k
liftCallCC :: Representable g => ((((a,Rep g) -> m (b,Rep g)) -> m (a,Rep g)) -> m (a,Rep g)) ->
    ((a -> StateT g m b) -> StateT g m a) -> StateT g m a
liftCallCC callCC' f = stateT $ \s ->
    callCC' $ \c ->
    runStateT (f (\a -> StateT $ pureRep $ c (a, s))) s
liftCallCC' :: Representable g => ((((a,Rep g) -> m (b,Rep g)) -> m (a,Rep g)) -> m (a,Rep g)) ->
    ((a -> StateT g m b) -> StateT g m a) -> StateT g m a
liftCallCC' callCC' f = stateT $ \s ->
    callCC' $ \c ->
    runStateT (f (\a -> stateT $ \s' -> c (a, s'))) s