{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Capnp.TraversalLimit
( MonadLimit(..)
, LimitT
, runLimitT
, evalLimitT
, execLimitT
, defaultLimit
) where
import Prelude hiding (fail)
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow(throwM))
import Control.Monad.Fail (MonadFail (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Primitive (PrimMonad(primitive), PrimState)
import Control.Monad.State.Strict
(MonadState, StateT, evalStateT, execStateT, get, put, runStateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Control.Monad.RWS (RWST)
import Control.Monad.Reader (ReaderT)
import Control.Monad.Writer (WriterT)
import qualified Control.Monad.State.Lazy as LazyState
import Capnp.Bits (WordCount)
import Capnp.Errors (Error(TraversalLimitError))
class Monad m => MonadLimit m where
invoice :: WordCount -> m ()
newtype LimitT m a = LimitT (StateT WordCount m a)
deriving(a -> LimitT m b -> LimitT m a
(a -> b) -> LimitT m a -> LimitT m b
(forall a b. (a -> b) -> LimitT m a -> LimitT m b)
-> (forall a b. a -> LimitT m b -> LimitT m a)
-> Functor (LimitT m)
forall a b. a -> LimitT m b -> LimitT m a
forall a b. (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> LimitT m b -> LimitT m a
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> LimitT m b -> LimitT m a
fmap :: (a -> b) -> LimitT m a -> LimitT m b
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> LimitT m a -> LimitT m b
Functor, Functor (LimitT m)
a -> LimitT m a
Functor (LimitT m)
-> (forall a. a -> LimitT m a)
-> (forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b)
-> (forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c)
-> (forall a b. LimitT m a -> LimitT m b -> LimitT m b)
-> (forall a b. LimitT m a -> LimitT m b -> LimitT m a)
-> Applicative (LimitT m)
LimitT m a -> LimitT m b -> LimitT m b
LimitT m a -> LimitT m b -> LimitT m a
LimitT m (a -> b) -> LimitT m a -> LimitT m b
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall a b c.
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall (m :: * -> *). Monad m => Functor (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: LimitT m a -> LimitT m b -> LimitT m a
$c<* :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m a
*> :: LimitT m a -> LimitT m b -> LimitT m b
$c*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
liftA2 :: (a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> LimitT m a -> LimitT m b -> LimitT m c
<*> :: LimitT m (a -> b) -> LimitT m a -> LimitT m b
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m (a -> b) -> LimitT m a -> LimitT m b
pure :: a -> LimitT m a
$cpure :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
$cp1Applicative :: forall (m :: * -> *). Monad m => Functor (LimitT m)
Applicative, Applicative (LimitT m)
a -> LimitT m a
Applicative (LimitT m)
-> (forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b)
-> (forall a b. LimitT m a -> LimitT m b -> LimitT m b)
-> (forall a. a -> LimitT m a)
-> Monad (LimitT m)
LimitT m a -> (a -> LimitT m b) -> LimitT m b
LimitT m a -> LimitT m b -> LimitT m b
forall a. a -> LimitT m a
forall a b. LimitT m a -> LimitT m b -> LimitT m b
forall a b. LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *). Monad m => Applicative (LimitT m)
forall (m :: * -> *) a. Monad m => a -> LimitT m a
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> LimitT m a
$creturn :: forall (m :: * -> *) a. Monad m => a -> LimitT m a
>> :: LimitT m a -> LimitT m b -> LimitT m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> LimitT m b -> LimitT m b
>>= :: LimitT m a -> (a -> LimitT m b) -> LimitT m b
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
LimitT m a -> (a -> LimitT m b) -> LimitT m b
$cp1Monad :: forall (m :: * -> *). Monad m => Applicative (LimitT m)
Monad)
runLimitT :: MonadThrow m => WordCount -> LimitT m a -> m (a, WordCount)
runLimitT :: WordCount -> LimitT m a -> m (a, WordCount)
runLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = StateT WordCount m a -> WordCount -> m (a, WordCount)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT WordCount m a
stateT WordCount
limit
evalLimitT :: MonadThrow m => WordCount -> LimitT m a -> m a
evalLimitT :: WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = StateT WordCount m a -> WordCount -> m a
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT WordCount m a
stateT WordCount
limit
execLimitT :: MonadThrow m => WordCount -> LimitT m a -> m WordCount
execLimitT :: WordCount -> LimitT m a -> m WordCount
execLimitT WordCount
limit (LimitT StateT WordCount m a
stateT) = StateT WordCount m a -> WordCount -> m WordCount
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT StateT WordCount m a
stateT WordCount
limit
defaultLimit :: WordCount
defaultLimit :: WordCount
defaultLimit = (WordCount
64 WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
* WordCount
1024 WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
* WordCount
1024) WordCount -> WordCount -> WordCount
forall a. Integral a => a -> a -> a
`div` WordCount
8
instance MonadThrow m => MonadThrow (LimitT m) where
throwM :: e -> LimitT m a
throwM = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a) -> (e -> m a) -> e -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> m a
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM
instance MonadThrow m => MonadLimit (LimitT m) where
invoice :: WordCount -> LimitT m ()
invoice WordCount
deduct = StateT WordCount m () -> LimitT m ()
forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT (StateT WordCount m () -> LimitT m ())
-> StateT WordCount m () -> LimitT m ()
forall a b. (a -> b) -> a -> b
$ do
WordCount
limit <- StateT WordCount m WordCount
forall s (m :: * -> *). MonadState s m => m s
get
Bool -> StateT WordCount m () -> StateT WordCount m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (WordCount
limit WordCount -> WordCount -> Bool
forall a. Ord a => a -> a -> Bool
< WordCount
deduct) (StateT WordCount m () -> StateT WordCount m ())
-> StateT WordCount m () -> StateT WordCount m ()
forall a b. (a -> b) -> a -> b
$ Error -> StateT WordCount m ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM Error
TraversalLimitError
WordCount -> StateT WordCount m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (WordCount
limit WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
- WordCount
deduct)
instance MonadTrans LimitT where
lift :: m a -> LimitT m a
lift = StateT WordCount m a -> LimitT m a
forall (m :: * -> *) a. StateT WordCount m a -> LimitT m a
LimitT (StateT WordCount m a -> LimitT m a)
-> (m a -> StateT WordCount m a) -> m a -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT WordCount m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
instance MonadState s m => MonadState s (LimitT m) where
get :: LimitT m s
get = m s -> LimitT m s
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m s
forall s (m :: * -> *). MonadState s m => m s
get
put :: s -> LimitT m ()
put = m () -> LimitT m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> LimitT m ()) -> (s -> m ()) -> s -> LimitT m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
type PrimState (LimitT m) = PrimState m
primitive :: (State# (PrimState (LimitT m))
-> (# State# (PrimState (LimitT m)), a #))
-> LimitT m a
primitive = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a)
-> ((State# s -> (# State# s, a #)) -> m a)
-> (State# s -> (# State# s, a #))
-> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (State# s -> (# State# s, a #)) -> m a
forall (m :: * -> *) a.
PrimMonad m =>
(State# (PrimState m) -> (# State# (PrimState m), a #)) -> m a
primitive
instance MonadFail m => MonadFail (LimitT m) where
fail :: String -> LimitT m a
fail = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a) -> (String -> m a) -> String -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail
instance MonadIO m => MonadIO (LimitT m) where
liftIO :: IO a -> LimitT m a
liftIO = m a -> LimitT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m a -> LimitT m a) -> (IO a -> m a) -> IO a -> LimitT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO
instance MonadLimit m => MonadLimit (StateT s m) where
invoice :: WordCount -> StateT s m ()
invoice = m () -> StateT s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT s m ())
-> (WordCount -> m ()) -> WordCount -> StateT s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
invoice :: WordCount -> StateT s m ()
invoice = m () -> StateT s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> StateT s m ())
-> (WordCount -> m ()) -> WordCount -> StateT s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
invoice :: WordCount -> WriterT w m ()
invoice = m () -> WriterT w m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> WriterT w m ())
-> (WordCount -> m ()) -> WordCount -> WriterT w m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance (MonadLimit m) => MonadLimit (ReaderT r m) where
invoice :: WordCount -> ReaderT r m ()
invoice = m () -> ReaderT r m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> ReaderT r m ())
-> (WordCount -> m ()) -> WordCount -> ReaderT r m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice
instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
invoice :: WordCount -> RWST r w s m ()
invoice = m () -> RWST r w s m ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m () -> RWST r w s m ())
-> (WordCount -> m ()) -> WordCount -> RWST r w s m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordCount -> m ()
forall (m :: * -> *). MonadLimit m => WordCount -> m ()
invoice