{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Capnp.TraversalLimit
( MonadLimit(..)
, LimitT
, runLimitT
, evalLimitT
, execLimitT
, defaultLimit
) where
import Prelude hiding (fail)
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow(throwM))
import Control.Monad.Fail (MonadFail (..))
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Primitive (PrimMonad(primitive), PrimState)
import Control.Monad.State.Strict
(MonadState, StateT, evalStateT, execStateT, get, put, runStateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Control.Monad.RWS (RWST)
import Control.Monad.Reader (ReaderT)
import Control.Monad.Writer (WriterT)
import qualified Control.Monad.State.Lazy as LazyState
import Capnp.Bits (WordCount)
import Capnp.Errors (Error(TraversalLimitError))
class Monad m => MonadLimit m where
invoice :: WordCount -> m ()
newtype LimitT m a = LimitT (StateT WordCount m a)
deriving(Functor, Applicative, Monad)
runLimitT :: MonadThrow m => WordCount -> LimitT m a -> m (a, WordCount)
runLimitT limit (LimitT stateT) = runStateT stateT limit
evalLimitT :: MonadThrow m => WordCount -> LimitT m a -> m a
evalLimitT limit (LimitT stateT) = evalStateT stateT limit
execLimitT :: MonadThrow m => WordCount -> LimitT m a -> m WordCount
execLimitT limit (LimitT stateT) = execStateT stateT limit
defaultLimit :: WordCount
defaultLimit = (64 * 1024 * 1024) `div` 8
instance MonadThrow m => MonadThrow (LimitT m) where
throwM = lift . throwM
instance MonadThrow m => MonadLimit (LimitT m) where
invoice deduct = LimitT $ do
limit <- get
when (limit < deduct) $ throwM TraversalLimitError
put (limit - deduct)
instance MonadTrans LimitT where
lift = LimitT . lift
instance MonadState s m => MonadState s (LimitT m) where
get = lift get
put = lift . put
instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
type PrimState (LimitT m) = PrimState m
primitive = lift . primitive
instance MonadFail m => MonadFail (LimitT m) where
fail = lift . fail
instance MonadIO m => MonadIO (LimitT m) where
liftIO = lift . liftIO
instance MonadLimit m => MonadLimit (StateT s m) where
invoice = lift . invoice
instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
invoice = lift . invoice
instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
invoice = lift . invoice
instance (MonadLimit m) => MonadLimit (ReaderT r m) where
invoice = lift . invoice
instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
invoice = lift . invoice