{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
module Data.Capnp.TraversalLimit
( MonadLimit(..)
, LimitT
, runLimitT
, evalLimitT
, execLimitT
, defaultLimit
) where
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow(throwM))
import Control.Monad.Primitive (PrimMonad(primitive), PrimState)
import Control.Monad.State.Strict
(MonadState, StateT, evalStateT, execStateT, get, put, runStateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
import Control.Monad.Reader (ReaderT)
import Control.Monad.RWS (RWST)
import Control.Monad.Writer (WriterT)
import qualified Control.Monad.State.Lazy as LazyState
import Data.Capnp.Errors (Error(TraversalLimitError))
class Monad m => MonadLimit m where
invoice :: Int -> m ()
newtype LimitT m a = LimitT (StateT Int m a)
deriving(Functor, Applicative, Monad)
runLimitT :: MonadThrow m => Int -> LimitT m a -> m (a, Int)
runLimitT limit (LimitT stateT) = runStateT stateT limit
evalLimitT :: MonadThrow m => Int -> LimitT m a -> m a
evalLimitT limit (LimitT stateT) = evalStateT stateT limit
execLimitT :: MonadThrow m => Int -> LimitT m a -> m Int
execLimitT limit (LimitT stateT) = execStateT stateT limit
defaultLimit :: Int
defaultLimit = 64 * 1024 * 1024
instance MonadThrow m => MonadThrow (LimitT m) where
throwM = lift . throwM
instance MonadThrow m => MonadLimit (LimitT m) where
invoice deduct = LimitT $ do
limit <- get
when (limit < deduct) $ throwM TraversalLimitError
put (limit - deduct)
instance MonadTrans LimitT where
lift = LimitT . lift
instance MonadState s m => MonadState s (LimitT m) where
get = lift get
put = lift . put
instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
type PrimState (LimitT m) = PrimState m
primitive = lift . primitive
instance MonadLimit m => MonadLimit (StateT s m) where
invoice = lift . invoice
instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
invoice = lift . invoice
instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
invoice = lift . invoice
instance (MonadLimit m) => MonadLimit (ReaderT r m) where
invoice = lift . invoice
instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
invoice = lift . invoice