{-# LANGUAGE StrictData #-}

-- |
-- Module     : Network.Wai.RateLimitMiddleware
--
-- WAI Rate Limiting Middleware.
--
-- Rate limiting is configured by providing a function that maps a 'Request' to
-- a 'key' and providing a 'Rate' for it. The 'Rate' configures the maximum
-- number of requests to allow after a period of inactivity (or at the
-- beginning), and an average rate at which requests should be allowed.
--
-- Note that rate limiting state is maintained in memory by this module and a
-- web server restart will reset all state. Thus, this module is more
-- appropriate for limits that apply over shorter periods of time.
module Network.Wai.RateLimitMiddleware
  ( -- * Types, constructors and setters
    Rate (..),
    mkRate,
    infRate,
    RateLimitSettings,
    newRateLimitSettings,
    setRateLimitExceededResponse,
    setResetInterval,

    -- * Middleware
    rateLimitMiddleware,

    -- * Other useful code
    stopResetThread,
    Cache,
    newCache,
    tryAllocate,
  )
where

import Control.Concurrent (ThreadId, forkIO, killThread, threadDelay)
import Control.Concurrent.MVar qualified as M
import Control.Concurrent.TokenBucket
import Control.Monad (forever)
import Data.HashMap.Strict qualified as H
import Data.Hashable qualified as H
import Data.Maybe (isNothing)
import Data.Word (Word64)
import Network.HTTP.Types (status429)
import Network.Wai (Middleware, Request, Response, responseLBS)

newtype CacheData key = CacheData
  { forall key. CacheData key -> HashMap key TokenBucket
contents :: H.HashMap key TokenBucket
  }

-- | A cache that maps request 'key's to a 'TokenBucket'.
newtype Cache key = Cache (M.MVar (CacheData key))

newCache :: IO (Cache key)
newCache :: forall key. IO (Cache key)
newCache = forall key. MVar (CacheData key) -> Cache key
Cache forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
M.newMVar (CacheData {contents :: HashMap key TokenBucket
contents = forall k v. HashMap k v
H.empty})

resetCache :: Cache key -> IO ()
resetCache :: forall key. Cache key -> IO ()
resetCache (Cache MVar (CacheData key)
mv) =
  forall a. MVar a -> (a -> IO a) -> IO ()
M.modifyMVar_ MVar (CacheData key)
mv forall a b. (a -> b) -> a -> b
$
    forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *) a. Monad m => a -> m a
return CacheData {contents :: HashMap key TokenBucket
contents = forall k v. HashMap k v
H.empty}

-- | @RateLimitSetings@ holds settings for the rate limiting middleware.
data RateLimitSettings key = RateLimitSettings
  { forall key. RateLimitSettings key -> Request -> IO (key, Rate)
getRequestKeyAndRate :: Request -> IO (key, Rate),
    forall key. RateLimitSettings key -> Response
limitExceededResponse :: Response,
    forall key. RateLimitSettings key -> Cache key
tbCache :: Cache key,
    forall key. RateLimitSettings key -> Maybe Int
resetInterval :: Maybe Int,
    forall key. RateLimitSettings key -> Maybe ThreadId
resetThreadId :: Maybe ThreadId
  }

-- | Create new rate limit settings by providing a function to map a request to
-- a key and a Rate. It sets up a default response for requests that exceed the
-- rate limit that returns a 429 status code with a simple error message.
newRateLimitSettings :: (Request -> IO (key, Rate)) -> IO (RateLimitSettings key)
newRateLimitSettings :: forall key.
(Request -> IO (key, Rate)) -> IO (RateLimitSettings key)
newRateLimitSettings Request -> IO (key, Rate)
keyRateFunc = do
  Cache key
c <- forall key. IO (Cache key)
newCache
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
    RateLimitSettings
      { getRequestKeyAndRate :: Request -> IO (key, Rate)
getRequestKeyAndRate = Request -> IO (key, Rate)
keyRateFunc,
        limitExceededResponse :: Response
limitExceededResponse = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status429 [] ByteString
"Rate limit exceeded",
        tbCache :: Cache key
tbCache = Cache key
c,
        resetInterval :: Maybe Int
resetInterval = forall a. Maybe a
Nothing,
        resetThreadId :: Maybe ThreadId
resetThreadId = forall a. Maybe a
Nothing
      }

-- | Set a custom error response.
setRateLimitExceededResponse :: RateLimitSettings key -> Response -> RateLimitSettings key
setRateLimitExceededResponse :: forall key.
RateLimitSettings key -> Response -> RateLimitSettings key
setRateLimitExceededResponse RateLimitSettings key
s Response
rsp = RateLimitSettings key
s {limitExceededResponse :: Response
limitExceededResponse = Response
rsp}

-- | @setResetInterval@ starts a thread to reset (clear) the token buckets map
-- after every given interval period of time (expressed in seconds). This is
-- useful if your webserver generates a lot of request 'key's that go idle after
-- some activity. In this situation the token bucket cache memory usage grows as
-- it contains an entry every key seen. When the cache is reset, the memory can
-- be garbage collected. Though this will cause all rate limit token buckets to
-- go "full" (i.e. allow the full burst of requests immediately), this solution
-- is acceptable as this is the case when the webserver is restarted as well. By
-- default, there is no reset thread launched (unless this function is called).
setResetInterval :: RateLimitSettings key -> Int -> IO (RateLimitSettings key)
setResetInterval :: forall key.
RateLimitSettings key -> Int -> IO (RateLimitSettings key)
setResetInterval RateLimitSettings key
s Int
seconds = do
  forall key. RateLimitSettings key -> IO ()
stopResetThread RateLimitSettings key
s
  ThreadId
tid <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
    Int -> IO ()
threadDelay forall a b. (a -> b) -> a -> b
$ Int
seconds forall a. Num a => a -> a -> a
* Int
1000000
    forall key. Cache key -> IO ()
resetCache forall a b. (a -> b) -> a -> b
$ forall key. RateLimitSettings key -> Cache key
tbCache RateLimitSettings key
s
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
    RateLimitSettings key
s
      { resetInterval :: Maybe Int
resetInterval = forall a. a -> Maybe a
Just Int
seconds,
        resetThreadId :: Maybe ThreadId
resetThreadId = forall a. a -> Maybe a
Just ThreadId
tid
      }

-- | @stopResetThread@ stops the thread launched by 'setResetInterval' and is
-- provided for completeness. If your application automatically restarts the
-- web-server using the rate limit middleware, call this function in your web
-- server's shutdown handler to ensure that the reset thread is killed (and does
-- not leak).
stopResetThread :: RateLimitSettings key -> IO ()
stopResetThread :: forall key. RateLimitSettings key -> IO ()
stopResetThread = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. Monad m => a -> m a
return ()) ThreadId -> IO ()
killThread forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall key. RateLimitSettings key -> Maybe ThreadId
resetThreadId

-- | @tryAllocate@ attempts to allocate the given @amount@ from the
-- 'TokenBucket' corresponding to the given @key@ and @Rate@ in the @Cache@. On
-- success it returns @Nothing@, otherwise it returns the minimum time (in
-- nanoseconds) to wait after which the allocation can succeed.
tryAllocate :: (H.Hashable key) => key -> Rate -> Word64 -> Cache key -> IO (Maybe Word64)
tryAllocate :: forall key.
Hashable key =>
key -> Rate -> Word64 -> Cache key -> IO (Maybe Word64)
tryAllocate key
k Rate
r Word64
amount (Cache MVar (CacheData key)
mv) = do
  TokenBucket
tb <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
M.modifyMVar MVar (CacheData key)
mv forall a b. (a -> b) -> a -> b
$ \CacheData {HashMap key TokenBucket
contents :: HashMap key TokenBucket
contents :: forall key. CacheData key -> HashMap key TokenBucket
..} ->
    case forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup key
k HashMap key TokenBucket
contents of
      Just TokenBucket
tb' -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall key. HashMap key TokenBucket -> CacheData key
CacheData HashMap key TokenBucket
contents, TokenBucket
tb')
      Maybe TokenBucket
Nothing -> do
        TokenBucket
tb' <- Rate -> IO TokenBucket
newTokenBucket Rate
r
        forall (m :: * -> *) a. Monad m => a -> m a
return (forall key. HashMap key TokenBucket -> CacheData key
CacheData forall a b. (a -> b) -> a -> b
$ forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
H.insert key
k TokenBucket
tb' HashMap key TokenBucket
contents, TokenBucket
tb')
  TokenBucket -> Word64 -> Rate -> IO (Maybe Word64)
tryAllocateTokens TokenBucket
tb Word64
amount Rate
r

-- | @rateLimitMiddleware@ performs rate limiting according to the given
-- settings.
rateLimitMiddleware :: (H.Hashable a) => RateLimitSettings a -> Middleware
rateLimitMiddleware :: forall a. Hashable a => RateLimitSettings a -> Middleware
rateLimitMiddleware RateLimitSettings {Maybe Int
Maybe ThreadId
Response
Cache a
Request -> IO (a, Rate)
resetThreadId :: Maybe ThreadId
resetInterval :: Maybe Int
tbCache :: Cache a
limitExceededResponse :: Response
getRequestKeyAndRate :: Request -> IO (a, Rate)
resetThreadId :: forall key. RateLimitSettings key -> Maybe ThreadId
resetInterval :: forall key. RateLimitSettings key -> Maybe Int
tbCache :: forall key. RateLimitSettings key -> Cache key
limitExceededResponse :: forall key. RateLimitSettings key -> Response
getRequestKeyAndRate :: forall key. RateLimitSettings key -> Request -> IO (key, Rate)
..} Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
  -- get request key and configured rate
  (a
key, Rate
rate) <- Request -> IO (a, Rate)
getRequestKeyAndRate Request
req

  -- check that the client has not exceeded
  Bool
allowRequest <- forall a. Maybe a -> Bool
isNothing forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall key.
Hashable key =>
key -> Rate -> Word64 -> Cache key -> IO (Maybe Word64)
tryAllocate a
key Rate
rate Word64
1 Cache a
tbCache

  if Bool
allowRequest
    then Application
app Request
req Response -> IO ResponseReceived
sendResponse
    else Response -> IO ResponseReceived
sendResponse Response
limitExceededResponse