{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UndecidableInstances #-}
{- |
Module: Capnp.TraversalLimit
Description: Support for managing message traversal limits.
This module is used to mitigate several pitfalls with the capnproto format,
which could potentially lead to denial of service vulnerabilities.
In particular, while they are illegal according to the spec, it is possible to
encode objects which have many pointers pointing the same place, or even
cycles. A naive traversal therefore could involve quite a lot of computation
for a message that is very small on the wire.
Accordingly, most implementations of the format keep track of how many bytes
of a message have been accessed, and start signaling errors after a certain
value (the "traversal limit") has been reached. The Haskell implementation is
no exception; this module implements that logic. We provide a monad
transformer and mtl-style type class to track the limit; reading from the
message happens inside of this monad.
-}
module Capnp.TraversalLimit
( MonadLimit(..)
, LimitT
, runLimitT
, evalLimitT
, execLimitT
, defaultLimit
) where
import Prelude hiding (fail)
import Control.Monad (when)
import Control.Monad.Catch (MonadThrow(throwM))
import Control.Monad.Fail (MonadFail(..))
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Primitive (PrimMonad(primitive), PrimState)
import Control.Monad.State.Strict
(MonadState, StateT, evalStateT, execStateT, get, put, runStateT)
import Control.Monad.Trans.Class (MonadTrans(lift))
-- Just to define 'MonadLimit' instances:
import Control.Monad.Reader (ReaderT)
import Control.Monad.RWS (RWST)
import Control.Monad.Writer (WriterT)
import qualified Control.Monad.State.Lazy as LazyState
import Capnp.Bits (WordCount)
import Capnp.Errors (Error(TraversalLimitError))
-- | mtl-style type class to track the traversal limit. This is used
-- by other parts of the library which actually do the reading.
class Monad m => MonadLimit m where
-- | @'invoice' n@ deducts @n@ from the traversal limit, signaling
-- an error if the limit is exhausted.
invoice :: WordCount -> m ()
-- | Monad transformer implementing 'MonadLimit'. The underlying monad
-- must implement 'MonadThrow'. 'invoice' calls @'throwM' 'TraversalLimitError'@
-- when the limit is exhausted.
newtype LimitT m a = LimitT (StateT WordCount m a)
deriving(Functor, Applicative, Monad)
-- | Run a 'LimitT', returning the value from the computation and the remaining
-- traversal limit.
runLimitT :: MonadThrow m => WordCount -> LimitT m a -> m (a, WordCount)
runLimitT limit (LimitT stateT) = runStateT stateT limit
-- | Run a 'LimitT', returning the value from the computation.
evalLimitT :: MonadThrow m => WordCount -> LimitT m a -> m a
evalLimitT limit (LimitT stateT) = evalStateT stateT limit
-- | Run a 'LimitT', returning the remaining traversal limit.
execLimitT :: MonadThrow m => WordCount -> LimitT m a -> m WordCount
execLimitT limit (LimitT stateT) = execStateT stateT limit
-- | A sensible default traversal limit. Currently 64 MiB.
defaultLimit :: WordCount
defaultLimit = (64 * 1024 * 1024) `div` 8
------ Instances of mtl type classes for 'LimitT'.
instance MonadThrow m => MonadThrow (LimitT m) where
throwM = lift . throwM
instance MonadThrow m => MonadLimit (LimitT m) where
invoice deduct = LimitT $ do
limit <- get
when (limit < deduct) $ throwM TraversalLimitError
put (limit - deduct)
instance MonadTrans LimitT where
lift = LimitT . lift
instance MonadState s m => MonadState s (LimitT m) where
get = lift get
put = lift . put
instance (PrimMonad m, s ~ PrimState m) => PrimMonad (LimitT m) where
type PrimState (LimitT m) = PrimState m
primitive = lift . primitive
instance MonadFail m => MonadFail (LimitT m) where
fail = lift . fail
instance MonadIO m => MonadIO (LimitT m) where
liftIO = lift . liftIO
------ Instances of 'MonadLimit' for standard monad transformers
instance MonadLimit m => MonadLimit (StateT s m) where
invoice = lift . invoice
instance MonadLimit m => MonadLimit (LazyState.StateT s m) where
invoice = lift . invoice
instance (Monoid w, MonadLimit m) => MonadLimit (WriterT w m) where
invoice = lift . invoice
instance (MonadLimit m) => MonadLimit (ReaderT r m) where
invoice = lift . invoice
instance (Monoid w, MonadLimit m) => MonadLimit (RWST r w s m) where
invoice = lift . invoice