{-# LANGUAGE StrictData #-}
module Network.Wai.RateLimitMiddleware
(
Rate (..),
mkRate,
infRate,
RateLimitSettings,
newRateLimitSettings,
setRateLimitExceededResponse,
setResetInterval,
rateLimitMiddleware,
stopResetThread,
Cache,
newCache,
tryAllocate,
)
where
import Control.Concurrent (ThreadId, forkIO, killThread, threadDelay)
import Control.Concurrent.MVar qualified as M
import Control.Concurrent.TokenBucket
import Control.Monad (forever)
import Data.HashMap.Strict qualified as H
import Data.Hashable qualified as H
import Data.Maybe (isNothing)
import Data.Word (Word64)
import Network.HTTP.Types (status429)
import Network.Wai (Middleware, Request, Response, responseLBS)
newtype CacheData key = CacheData
{ forall key. CacheData key -> HashMap key TokenBucket
contents :: H.HashMap key TokenBucket
}
newtype Cache key = Cache (M.MVar (CacheData key))
newCache :: IO (Cache key)
newCache :: forall key. IO (Cache key)
newCache = forall key. MVar (CacheData key) -> Cache key
Cache forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
M.newMVar (CacheData {contents :: HashMap key TokenBucket
contents = forall k v. HashMap k v
H.empty})
resetCache :: Cache key -> IO ()
resetCache :: forall key. Cache key -> IO ()
resetCache (Cache MVar (CacheData key)
mv) =
forall a. MVar a -> (a -> IO a) -> IO ()
M.modifyMVar_ MVar (CacheData key)
mv forall a b. (a -> b) -> a -> b
$
forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a. Monad m => a -> m a
return CacheData {contents :: HashMap key TokenBucket
contents = forall k v. HashMap k v
H.empty}
data RateLimitSettings key = RateLimitSettings
{ forall key. RateLimitSettings key -> Request -> IO (key, Rate)
getRequestKeyAndRate :: Request -> IO (key, Rate),
forall key. RateLimitSettings key -> Response
limitExceededResponse :: Response,
forall key. RateLimitSettings key -> Cache key
tbCache :: Cache key,
forall key. RateLimitSettings key -> Maybe Int
resetInterval :: Maybe Int,
forall key. RateLimitSettings key -> Maybe ThreadId
resetThreadId :: Maybe ThreadId
}
newRateLimitSettings :: (Request -> IO (key, Rate)) -> IO (RateLimitSettings key)
newRateLimitSettings :: forall key.
(Request -> IO (key, Rate)) -> IO (RateLimitSettings key)
newRateLimitSettings Request -> IO (key, Rate)
keyRateFunc = do
Cache key
c <- forall key. IO (Cache key)
newCache
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
RateLimitSettings
{ getRequestKeyAndRate :: Request -> IO (key, Rate)
getRequestKeyAndRate = Request -> IO (key, Rate)
keyRateFunc,
limitExceededResponse :: Response
limitExceededResponse = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
status429 [] ByteString
"Rate limit exceeded",
tbCache :: Cache key
tbCache = Cache key
c,
resetInterval :: Maybe Int
resetInterval = forall a. Maybe a
Nothing,
resetThreadId :: Maybe ThreadId
resetThreadId = forall a. Maybe a
Nothing
}
setRateLimitExceededResponse :: RateLimitSettings key -> Response -> RateLimitSettings key
setRateLimitExceededResponse :: forall key.
RateLimitSettings key -> Response -> RateLimitSettings key
setRateLimitExceededResponse RateLimitSettings key
s Response
rsp = RateLimitSettings key
s {limitExceededResponse :: Response
limitExceededResponse = Response
rsp}
setResetInterval :: RateLimitSettings key -> Int -> IO (RateLimitSettings key)
setResetInterval :: forall key.
RateLimitSettings key -> Int -> IO (RateLimitSettings key)
setResetInterval RateLimitSettings key
s Int
seconds = do
forall key. RateLimitSettings key -> IO ()
stopResetThread RateLimitSettings key
s
ThreadId
tid <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
Int -> IO ()
threadDelay forall a b. (a -> b) -> a -> b
$ Int
seconds forall a. Num a => a -> a -> a
* Int
1000000
forall key. Cache key -> IO ()
resetCache forall a b. (a -> b) -> a -> b
$ forall key. RateLimitSettings key -> Cache key
tbCache RateLimitSettings key
s
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$
RateLimitSettings key
s
{ resetInterval :: Maybe Int
resetInterval = forall a. a -> Maybe a
Just Int
seconds,
resetThreadId :: Maybe ThreadId
resetThreadId = forall a. a -> Maybe a
Just ThreadId
tid
}
stopResetThread :: RateLimitSettings key -> IO ()
stopResetThread :: forall key. RateLimitSettings key -> IO ()
stopResetThread = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *) a. Monad m => a -> m a
return ()) ThreadId -> IO ()
killThread forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall key. RateLimitSettings key -> Maybe ThreadId
resetThreadId
tryAllocate :: (H.Hashable key) => key -> Rate -> Word64 -> Cache key -> IO (Maybe Word64)
tryAllocate :: forall key.
Hashable key =>
key -> Rate -> Word64 -> Cache key -> IO (Maybe Word64)
tryAllocate key
k Rate
r Word64
amount (Cache MVar (CacheData key)
mv) = do
TokenBucket
tb <- forall a b. MVar a -> (a -> IO (a, b)) -> IO b
M.modifyMVar MVar (CacheData key)
mv forall a b. (a -> b) -> a -> b
$ \CacheData {HashMap key TokenBucket
contents :: HashMap key TokenBucket
contents :: forall key. CacheData key -> HashMap key TokenBucket
..} ->
case forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
H.lookup key
k HashMap key TokenBucket
contents of
Just TokenBucket
tb' -> forall (m :: * -> *) a. Monad m => a -> m a
return (forall key. HashMap key TokenBucket -> CacheData key
CacheData HashMap key TokenBucket
contents, TokenBucket
tb')
Maybe TokenBucket
Nothing -> do
TokenBucket
tb' <- Rate -> IO TokenBucket
newTokenBucket Rate
r
forall (m :: * -> *) a. Monad m => a -> m a
return (forall key. HashMap key TokenBucket -> CacheData key
CacheData forall a b. (a -> b) -> a -> b
$ forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
H.insert key
k TokenBucket
tb' HashMap key TokenBucket
contents, TokenBucket
tb')
TokenBucket -> Word64 -> Rate -> IO (Maybe Word64)
tryAllocateTokens TokenBucket
tb Word64
amount Rate
r
rateLimitMiddleware :: (H.Hashable a) => RateLimitSettings a -> Middleware
rateLimitMiddleware :: forall a. Hashable a => RateLimitSettings a -> Middleware
rateLimitMiddleware RateLimitSettings {Maybe Int
Maybe ThreadId
Response
Cache a
Request -> IO (a, Rate)
resetThreadId :: Maybe ThreadId
resetInterval :: Maybe Int
tbCache :: Cache a
limitExceededResponse :: Response
getRequestKeyAndRate :: Request -> IO (a, Rate)
resetThreadId :: forall key. RateLimitSettings key -> Maybe ThreadId
resetInterval :: forall key. RateLimitSettings key -> Maybe Int
tbCache :: forall key. RateLimitSettings key -> Cache key
limitExceededResponse :: forall key. RateLimitSettings key -> Response
getRequestKeyAndRate :: forall key. RateLimitSettings key -> Request -> IO (key, Rate)
..} Application
app Request
req Response -> IO ResponseReceived
sendResponse = do
(a
key, Rate
rate) <- Request -> IO (a, Rate)
getRequestKeyAndRate Request
req
Bool
allowRequest <- forall a. Maybe a -> Bool
isNothing forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall key.
Hashable key =>
key -> Rate -> Word64 -> Cache key -> IO (Maybe Word64)
tryAllocate a
key Rate
rate Word64
1 Cache a
tbCache
if Bool
allowRequest
then Application
app Request
req Response -> IO ResponseReceived
sendResponse
else Response -> IO ResponseReceived
sendResponse Response
limitExceededResponse