{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE UndecidableInstances  #-}

module Snap.Snaplet.Internal.RST where

import           Control.Applicative         (Alternative (..),
                                              Applicative (..))
import           Control.Monad
import           Control.Monad.Base          (MonadBase (..))
import qualified Control.Monad.Fail as Fail
import           Control.Monad.Reader        (MonadReader (..))
import           Control.Monad.State.Class   (MonadState (..))
import           Control.Monad.Trans         (MonadIO (..), MonadTrans (..))
import           Control.Monad.Trans.Control (ComposeSt, MonadBaseControl (..),
                                              MonadTransControl (..),
                                              defaultLiftBaseWith,
                                              defaultRestoreM)
import           Snap.Core                   (MonadSnap (..))


------------------------------------------------------------------------------
-- like RWST, but no writer to bog things down. Also assured strict, inlined
-- monad bind, etc
newtype RST r s m a = RST { forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST :: r -> s -> m (a, s) }


evalRST :: Monad m => RST r s m a -> r -> s -> m a
evalRST :: forall (m :: * -> *) r s a. Monad m => RST r s m a -> r -> s -> m a
evalRST RST r s m a
m r
r s
s = do
    (a
a,s
_) <- forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s
    forall (m :: * -> *) a. Monad m => a -> m a
return a
a
{-# INLINE evalRST #-}


execRST :: Monad m => RST r s m a -> r -> s -> m s
execRST :: forall (m :: * -> *) r s a. Monad m => RST r s m a -> r -> s -> m s
execRST RST r s m a
m r
r s
s = do
    (a
_,!s
s') <- forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s
    forall (m :: * -> *) a. Monad m => a -> m a
return s
s'
{-# INLINE execRST #-}


withRST :: Monad m => (r' -> r) -> RST r s m a -> RST r' s m a
withRST :: forall (m :: * -> *) r' r s a.
Monad m =>
(r' -> r) -> RST r s m a -> RST r' s m a
withRST r' -> r
f RST r s m a
m = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r'
r' s
s -> forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m (r' -> r
f r'
r') s
s
{-# INLINE withRST #-}


instance (Monad m) => MonadReader r (RST r s m) where
    ask :: RST r s m r
ask = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall (m :: * -> *) a. Monad m => a -> m a
return (r
r,s
s)
    local :: forall a. (r -> r) -> RST r s m a -> RST r s m a
local r -> r
f RST r s m a
m = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m (r -> r
f r
r) s
s


instance (Functor m) => Functor (RST r s m) where
    fmap :: forall a b. (a -> b) -> RST r s m a -> RST r s m b
fmap a -> b
f RST r s m a
m = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(a
a,s
s') -> (a -> b
f a
a, s
s')) forall a b. (a -> b) -> a -> b
$ forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s


instance (Functor m, Monad m) => Applicative (RST r s m) where
    pure :: forall a. a -> RST r s m a
pure = forall (m :: * -> *) a. Monad m => a -> m a
return
    <*> :: forall a b. RST r s m (a -> b) -> RST r s m a -> RST r s m b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap


instance (Functor m, MonadPlus m) => Alternative (RST r s m) where
    empty :: forall a. RST r s m a
empty = forall (m :: * -> *) a. MonadPlus m => m a
mzero
    <|> :: forall a. RST r s m a -> RST r s m a -> RST r s m a
(<|>) = forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
mplus


instance (Monad m) => MonadState s (RST r s m) where
    get :: RST r s m s
get   = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
s -> forall (m :: * -> *) a. Monad m => a -> m a
return (s
s,s
s)
    put :: s -> RST r s m ()
put s
x = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ((),s
x)


mapRST :: (m (a, s) -> n (b, s)) -> RST r s m a -> RST r s n b
mapRST :: forall (m :: * -> *) a s (n :: * -> *) b r.
(m (a, s) -> n (b, s)) -> RST r s m a -> RST r s n b
mapRST m (a, s) -> n (b, s)
f RST r s m a
m = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> m (a, s) -> n (b, s)
f (forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s)


instance (MonadSnap m) => MonadSnap (RST r s m) where
    liftSnap :: forall a. Snap a -> RST r s m a
liftSnap Snap a
s = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadSnap m => Snap a -> m a
liftSnap Snap a
s

rwsBind :: Monad m =>
           RST r s m a
        -> (a -> RST r s m b)
        -> RST r s m b
rwsBind :: forall (m :: * -> *) r s a b.
Monad m =>
RST r s m a -> (a -> RST r s m b) -> RST r s m b
rwsBind RST r s m a
m a -> RST r s m b
f = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST r -> s -> m (b, s)
go
  where
    go :: r -> s -> m (b, s)
go r
r !s
s = do
        (a
a, !s
s')  <- forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s
        forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST (a -> RST r s m b
f a
a) r
r s
s'
{-# INLINE rwsBind #-}

instance (Monad m) => Monad (RST r s m) where
    return :: forall a. a -> RST r s m a
return a
a = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
s -> forall (m :: * -> *) a. Monad m => a -> m a
return (a
a, s
s)
    >>= :: forall a b. RST r s m a -> (a -> RST r s m b) -> RST r s m b
(>>=)    = forall (m :: * -> *) r s a b.
Monad m =>
RST r s m a -> (a -> RST r s m b) -> RST r s m b
rwsBind
#if !MIN_VERSION_base(4,13,0)
    fail msg = RST $ \_ _ -> fail msg
#endif

instance Fail.MonadFail m => Fail.MonadFail (RST r s m) where
    fail :: forall a. String -> RST r s m a
fail String
msg = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> forall (m :: * -> *) a. MonadFail m => String -> m a
Fail.fail String
msg

instance (MonadPlus m) => MonadPlus (RST r s m) where
    mzero :: forall a. RST r s m a
mzero       = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> forall (m :: * -> *) a. MonadPlus m => m a
mzero
    RST r s m a
m mplus :: forall a. RST r s m a -> RST r s m a -> RST r s m a
`mplus` RST r s m a
n = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
m r
r s
s forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` forall r s (m :: * -> *) a. RST r s m a -> r -> s -> m (a, s)
runRST RST r s m a
n r
r s
s


instance (MonadIO m) => MonadIO (RST r s m) where
    liftIO :: forall a. IO a -> RST r s m a
liftIO = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO


instance MonadTrans (RST r s) where
    lift :: forall (m :: * -> *) a. Monad m => m a -> RST r s m a
lift m a
m = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
s -> do
        a
a <- m a
m
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ s
s seq :: forall a b. a -> b -> b
`seq` (a
a, s
s)


instance MonadBase b m => MonadBase b (RST r s m) where
    liftBase :: forall α. b α -> RST r s m α
liftBase = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (b :: * -> *) (m :: * -> *) α. MonadBase b m => b α -> m α
liftBase


instance MonadBaseControl b m => MonadBaseControl b (RST r s m) where
     type StM (RST r s m) a = ComposeSt (RST r s) m a
     liftBaseWith :: forall a. (RunInBase (RST r s m) b -> b a) -> RST r s m a
liftBaseWith = forall (t :: (* -> *) -> * -> *) (b :: * -> *) (m :: * -> *) a.
(MonadTransControl t, MonadBaseControl b m) =>
(RunInBaseDefault t m b -> b a) -> t m a
defaultLiftBaseWith
     restoreM :: forall a. StM (RST r s m) a -> RST r s m a
restoreM = forall (t :: (* -> *) -> * -> *) (b :: * -> *) (m :: * -> *) a.
(MonadTransControl t, MonadBaseControl b m) =>
ComposeSt t m a -> t m a
defaultRestoreM
     {-# INLINE liftBaseWith #-}
     {-# INLINE restoreM #-}


instance MonadTransControl (RST r s) where
    type StT (RST r s) a = (a, s)
    liftWith :: forall (m :: * -> *) a.
Monad m =>
(Run (RST r s) -> m a) -> RST r s m a
liftWith Run (RST r s) -> m a
f = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
r s
s -> do
        a
res <- Run (RST r s) -> m a
f forall a b. (a -> b) -> a -> b
$ \(RST r -> s -> n (b, s)
g) -> r -> s -> n (b, s)
g r
r s
s
        forall (m :: * -> *) a. Monad m => a -> m a
return (a
res, s
s)
    restoreT :: forall (m :: * -> *) a.
Monad m =>
m (StT (RST r s) a) -> RST r s m a
restoreT m (StT (RST r s) a)
k = forall r s (m :: * -> *) a. (r -> s -> m (a, s)) -> RST r s m a
RST forall a b. (a -> b) -> a -> b
$ \r
_ s
_ -> m (StT (RST r s) a)
k
    {-# INLINE liftWith #-}
    {-# INLINE restoreT #-}