-- |
-- Copyright: (c) 2022 Aditya Manthramurthy
-- SPDX-License-Identifier: Apache-2.0
-- Maintainer: Aditya Manthramurthy <aditya.mmy@gmail.com>
--
-- A wai-rate-limit backend using PostgreSQL.
module Network.Wai.RateLimit.Postgres
  ( PGBackendError (..),
    postgresBackend,
  )
where

import Control.Concurrent (forkIO, threadDelay)
import Control.Exception (Handler (..), catches, throwIO, try)
import Data.Pool (Pool, withResource)
import qualified Data.Text as T
import qualified Database.PostgreSQL.Simple as PG
import Network.Wai.RateLimit.Backend (Backend (..))

-- | Represents reasons for why requests made to Postgres backend have failed.
data PGBackendError
  = PGBackendErrorInit PG.SqlError
  | PGBackendErrorBugFmt PG.FormatError
  | PGBackendErrorBugQry PG.QueryError
  | PGBackendErrorBugRes PG.ResultError
  | PGBackendErrorBugSql PG.SqlError
  | PGBackendErrorAtMostOneRow
  | PGBackendErrorExactlyOneRow
  | PGBackendErrorExactlyOneUpdate
  deriving stock (PGBackendError -> PGBackendError -> Bool
(PGBackendError -> PGBackendError -> Bool)
-> (PGBackendError -> PGBackendError -> Bool) -> Eq PGBackendError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PGBackendError -> PGBackendError -> Bool
$c/= :: PGBackendError -> PGBackendError -> Bool
== :: PGBackendError -> PGBackendError -> Bool
$c== :: PGBackendError -> PGBackendError -> Bool
Eq, Int -> PGBackendError -> ShowS
[PGBackendError] -> ShowS
PGBackendError -> String
(Int -> PGBackendError -> ShowS)
-> (PGBackendError -> String)
-> ([PGBackendError] -> ShowS)
-> Show PGBackendError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PGBackendError] -> ShowS
$cshowList :: [PGBackendError] -> ShowS
show :: PGBackendError -> String
$cshow :: PGBackendError -> String
showsPrec :: Int -> PGBackendError -> ShowS
$cshowsPrec :: Int -> PGBackendError -> ShowS
Show)

instance Exception PGBackendError

initPostgresBackend :: Pool PG.Connection -> Text -> IO ()
initPostgresBackend :: Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName = Pool Connection -> (Connection -> IO ()) -> IO ()
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
  Either SqlError Int64
res <- IO Int64 -> IO (Either SqlError Int64)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Int64 -> IO (Either SqlError Int64))
-> IO Int64 -> IO (Either SqlError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
createTableQuery
  (SqlError -> IO ())
-> (Int64 -> IO ()) -> Either SqlError Int64 -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
    (PGBackendError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO ())
-> (SqlError -> PGBackendError) -> SqlError -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorInit)
    (IO () -> Int64 -> IO ()
forall a b. a -> b -> a
const (IO () -> Int64 -> IO ()) -> IO () -> Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    Either SqlError Int64
res
  where
    createTableQuery :: Query
createTableQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"CREATE TABLE IF NOT EXISTS",
              Text
tableName,
              Text
"(key VARCHAR PRIMARY KEY,",
              Text
"usage INT8 NOT NULL,",
              Text
"expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP + '1 week'::INTERVAL)"
            ]

sqlHandlers :: [Handler a]
sqlHandlers :: [Handler a]
sqlHandlers =
  [ (FormatError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (FormatError -> PGBackendError) -> FormatError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FormatError -> PGBackendError
PGBackendErrorBugFmt),
    (QueryError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (QueryError -> PGBackendError) -> QueryError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryError -> PGBackendError
PGBackendErrorBugQry),
    (ResultError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (ResultError -> PGBackendError) -> ResultError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ResultError -> PGBackendError
PGBackendErrorBugRes),
    (SqlError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (SqlError -> PGBackendError) -> SqlError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorBugSql)
  ]

pgBackendGetUsage :: Pool PG.Connection -> Text -> ByteString -> IO (Either PGBackendError Integer)
pgBackendGetUsage :: Pool Connection
-> Text -> ByteString -> IO (Either PGBackendError Integer)
pgBackendGetUsage Pool Connection
p Text
tableName ByteString
key = Pool Connection
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Integer))
 -> IO (Either PGBackendError Integer))
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ \Connection
c ->
  do
    Either PGBackendError [Only Integer]
res <-
      IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall e a. Exception e => IO a -> IO (Either e a)
try (IO [Only Integer] -> IO (Either PGBackendError [Only Integer]))
-> IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall a b. (a -> b) -> a -> b
$
        Connection -> Query -> Only ByteString -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
getUsageQuery (ByteString -> Only ByteString
forall a. a -> Only a
PG.Only ByteString
key) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
    Either PGBackendError Integer -> IO (Either PGBackendError Integer)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either PGBackendError Integer
 -> IO (Either PGBackendError Integer))
-> Either PGBackendError Integer
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ do
      [Only Integer]
rows <- Either PGBackendError [Only Integer]
res
      case [Only Integer]
rows of
        [] -> Integer -> Either PGBackendError Integer
forall a b. b -> Either a b
Right Integer
0
        [PG.Only Integer
a] -> Integer -> Either PGBackendError Integer
forall a b. b -> Either a b
Right Integer
a
        [Only Integer]
_ -> PGBackendError -> Either PGBackendError Integer
forall a b. a -> Either a b
Left PGBackendError
PGBackendErrorAtMostOneRow
  where
    getUsageQuery :: Query
getUsageQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"SELECT usage FROM",
              Text
tableName,
              Text
"WHERE key = ?",
              Text
"AND expires_at > CURRENT_TIMESTAMP"
            ]

pgBackendIncAndGetUsage :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO (Either PGBackendError Integer)
pgBackendIncAndGetUsage :: Pool Connection
-> Text
-> ByteString
-> Integer
-> IO (Either PGBackendError Integer)
pgBackendIncAndGetUsage Pool Connection
p Text
tableName ByteString
key Integer
usage = Pool Connection
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Integer))
 -> IO (Either PGBackendError Integer))
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
  Either PGBackendError [Only Integer]
res <- IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall e a. Exception e => IO a -> IO (Either e a)
try (IO [Only Integer] -> IO (Either PGBackendError [Only Integer]))
-> IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> (ByteString, Integer) -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
incAndGetQuery (ByteString
key, Integer
usage) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
  Either PGBackendError [Only Integer] -> IO ()
forall a (m :: * -> *). (MonadIO m, Show a) => a -> m ()
print Either PGBackendError [Only Integer]
res
  Either PGBackendError Integer -> IO (Either PGBackendError Integer)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either PGBackendError Integer
 -> IO (Either PGBackendError Integer))
-> Either PGBackendError Integer
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ do
    [Only Integer]
rows <- Either PGBackendError [Only Integer]
res
    case [Only Integer]
rows of
      [PG.Only Integer
a] -> Integer -> Either PGBackendError Integer
forall a b. b -> Either a b
Right Integer
a
      [Only Integer]
_ -> PGBackendError -> Either PGBackendError Integer
forall a b. a -> Either a b
Left PGBackendError
PGBackendErrorExactlyOneRow
  where
    incAndGetQuery :: Query
incAndGetQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"INSERT INTO",
              Text
tableName,
              Text
"as rl",
              Text
"(key, usage) VALUES (?, ?)",
              Text
"ON CONFLICT (key) DO UPDATE SET",
              Text
"usage = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.usage + EXCLUDED.usage ELSE EXCLUDED.usage END,",
              Text
"expires_at = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.expires_at ELSE CURRENT_TIMESTAMP + '1 week'::INTERVAL END",
              Text
"RETURNING usage"
            ]

pgBackendExpireIn :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO (Either PGBackendError ())
pgBackendExpireIn :: Pool Connection
-> Text -> ByteString -> Integer -> IO (Either PGBackendError ())
pgBackendExpireIn Pool Connection
p Text
tableName ByteString
key Integer
seconds = Pool Connection
-> (Connection -> IO (Either PGBackendError ()))
-> IO (Either PGBackendError ())
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError ()))
 -> IO (Either PGBackendError ()))
-> (Connection -> IO (Either PGBackendError ()))
-> IO (Either PGBackendError ())
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
  Text -> IO ()
forall (m :: * -> *). MonadIO m => Text -> m ()
putText Text
"Called expire in!"
  Either PGBackendError Int64
res <- IO Int64 -> IO (Either PGBackendError Int64)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Int64 -> IO (Either PGBackendError Int64))
-> IO Int64 -> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> (Integer, ByteString) -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
PG.execute Connection
c Query
expireInQuery (Integer
seconds, ByteString
key) IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
  Either PGBackendError () -> IO (Either PGBackendError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either PGBackendError () -> IO (Either PGBackendError ()))
-> Either PGBackendError () -> IO (Either PGBackendError ())
forall a b. (a -> b) -> a -> b
$ do
    Int64
count <- Either PGBackendError Int64
res
    if Int64
count Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64
1
      then PGBackendError -> Either PGBackendError ()
forall a b. a -> Either a b
Left PGBackendError
PGBackendErrorExactlyOneUpdate
      else () -> Either PGBackendError ()
forall a b. b -> Either a b
Right ()
  where
    expireInQuery :: Query
expireInQuery =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"UPDATE",
              Text
tableName,
              Text
"SET expires_at = CURRENT_TIMESTAMP + '? second'::interval",
              Text
"WHERE key = ?"
            ]

pgBackendCleanup :: Pool PG.Connection -> Text -> IO ()
pgBackendCleanup :: Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName = IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$
  IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
    IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
      Either PGBackendError Int64
res <- Pool Connection
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Int64))
 -> IO (Either PGBackendError Int64))
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
        IO Int64 -> IO (Either PGBackendError Int64)
forall a. IO a -> IO (Either PGBackendError a)
tryDBErr (IO Int64 -> IO (Either PGBackendError Int64))
-> IO Int64 -> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
removeExpired IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
      case Either PGBackendError Int64
res of
        Left PGBackendError
_ -> Int -> IO ()
threadDelay Int
d10s
        Right Int64
n -> Int64 -> IO ()
forall a. (Num a, Ord a) => a -> IO ()
delay Int64
n
  where
    d10s :: Int
d10s = Int
10_000_000
    d1s :: Int
d1s = Int
1_000_000
    d100ms :: Int
d100ms = Int
100_000

    -- Try to ensure we cleanup as fast as garbage is created.
    delay :: a -> IO ()
delay a
n
      | a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
5000 = Int -> IO ()
threadDelay Int
d100ms
      | a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 = Int -> IO ()
threadDelay Int
d1s
      | Bool
otherwise = Int -> IO ()
threadDelay Int
d10s

    tryDBErr :: IO a -> IO (Either PGBackendError a)
    tryDBErr :: IO a -> IO (Either PGBackendError a)
tryDBErr IO a
a = IO a -> IO (Either PGBackendError a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
a

    removeExpired :: Query
removeExpired =
      String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
        Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
          Text -> [Text] -> Text
T.intercalate
            Text
" "
            [ Text
"DELETE FROM",
              Text
tableName,
              Text
"WHERE key IN (SELECT key FROM",
              Text
tableName,
              Text
"WHERE expires_at < CURRENT_TIMESTAMP LIMIT 5000)"
            ]

-- | Initialize a postgres backend for rate-limiting. Takes a connection pool
-- and table name to use for storage. The table will be created if it does not
-- exist. A thread is also launched to periodically clean up expired rows from
-- the table.
postgresBackend :: Pool PG.Connection -> Text -> IO (Backend ByteString PGBackendError)
postgresBackend :: Pool Connection -> Text -> IO (Backend ByteString PGBackendError)
postgresBackend Pool Connection
p Text
tableName = do
  Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName
  Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName
  Backend ByteString PGBackendError
-> IO (Backend ByteString PGBackendError)
forall (m :: * -> *) a. Monad m => a -> m a
return (Backend ByteString PGBackendError
 -> IO (Backend ByteString PGBackendError))
-> Backend ByteString PGBackendError
-> IO (Backend ByteString PGBackendError)
forall a b. (a -> b) -> a -> b
$
    MkBackend :: forall key err.
(key -> IO (Either err Integer))
-> (key -> Integer -> IO (Either err Integer))
-> (key -> Integer -> IO (Either err ()))
-> Backend key err
MkBackend
      { backendGetUsage :: ByteString -> IO (Either PGBackendError Integer)
backendGetUsage = Pool Connection
-> Text -> ByteString -> IO (Either PGBackendError Integer)
pgBackendGetUsage Pool Connection
p Text
tableName,
        backendIncAndGetUsage :: ByteString -> Integer -> IO (Either PGBackendError Integer)
backendIncAndGetUsage = Pool Connection
-> Text
-> ByteString
-> Integer
-> IO (Either PGBackendError Integer)
pgBackendIncAndGetUsage Pool Connection
p Text
tableName,
        backendExpireIn :: ByteString -> Integer -> IO (Either PGBackendError ())
backendExpireIn = Pool Connection
-> Text -> ByteString -> Integer -> IO (Either PGBackendError ())
pgBackendExpireIn Pool Connection
p Text
tableName
      }