{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module Control.Carrier.State.Church
(
runState
, evalState
, execState
, StateC(StateC)
, module Control.Effect.State
) where
import Control.Algebra
import Control.Applicative (Alternative(..), liftA2)
import Control.Effect.State
import Control.Monad (MonadPlus)
import Control.Monad.Fail as Fail
import Control.Monad.Fix
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
runState :: forall s m a b . (s -> a -> m b) -> s -> StateC s m a -> m b
runState f s (StateC m) = m f s
{-# INLINE runState #-}
evalState :: forall s m a . Applicative m => s -> StateC s m a -> m a
evalState = runState (const pure)
{-# INLINE evalState #-}
execState :: forall s m a . Applicative m => s -> StateC s m a -> m s
execState = runState (const . pure)
{-# INLINE execState #-}
newtype StateC s m a = StateC (forall r . (s -> a -> m r) -> s -> m r)
deriving (Functor)
instance Applicative (StateC s m) where
pure a = StateC $ \ k s -> k s a
{-# INLINE pure #-}
StateC f <*> StateC a = StateC $ \ k -> f (\ s f' -> a (\ s' -> k s' . f') s)
{-# INLINE (<*>) #-}
liftA2 f (StateC a) (StateC b) = StateC $ \ k ->
a (\ s' a' -> b (\ s'' -> k s'' . f a') s')
{-# INLINE liftA2 #-}
StateC a *> StateC b = StateC $ \ k -> a (const . b k)
{-# INLINE (*>) #-}
StateC a <* StateC b = StateC $ \ k ->
a (\ s' a' -> b (\ s'' _ -> k s'' a') s')
{-# INLINE (<*) #-}
instance Alternative m => Alternative (StateC s m) where
empty = StateC $ \ _ _ -> empty
{-# INLINE empty #-}
StateC l <|> StateC r = StateC $ \ k s -> l k s <|> r k s
{-# INLINE (<|>) #-}
instance Monad (StateC s m) where
StateC a >>= f = StateC $ \ k -> a (\ s -> runState k s . f)
{-# INLINE (>>=) #-}
instance Fail.MonadFail m => Fail.MonadFail (StateC s m) where
fail = lift . Fail.fail
{-# INLINE fail #-}
instance MonadFix m => MonadFix (StateC s m) where
mfix f = StateC $ \ k s -> mfix (runState (curry pure) s . f . snd) >>= uncurry k
{-# INLINE mfix #-}
instance MonadIO m => MonadIO (StateC s m) where
liftIO = lift . liftIO
{-# INLINE liftIO #-}
instance (Alternative m, Monad m) => MonadPlus (StateC s m)
instance MonadTrans (StateC s) where
lift m = StateC $ \ k s -> m >>= k s
{-# INLINE lift #-}
instance Algebra sig m => Algebra (State s :+: sig) (StateC s m) where
alg hdl sig ctx = StateC $ \ k s -> case sig of
L Get -> k s (s <$ ctx)
L (Put s) -> k s ctx
R other -> thread (uncurry (runState (curry pure)) ~<~ hdl) other (s, ctx) >>= uncurry k
{-# INLINE alg #-}