module Network.Wai.RateLimit.Postgres
( PGBackendError (..),
postgresBackend,
)
where
import Control.Concurrent (forkIO, threadDelay)
import Control.Exception (Handler (..), catches, throwIO, try)
import Data.Pool (Pool, withResource)
import qualified Data.Text as T
import qualified Database.PostgreSQL.Simple as PG
import Network.Wai.RateLimit.Backend (Backend (..))
data PGBackendError
= PGBackendErrorInit PG.SqlError
| PGBackendErrorBugFmt PG.FormatError
| PGBackendErrorBugQry PG.QueryError
| PGBackendErrorBugRes PG.ResultError
| PGBackendErrorBugSql PG.SqlError
| PGBackendErrorAtMostOneRow
| PGBackendErrorExactlyOneRow
| PGBackendErrorExactlyOneUpdate
deriving stock (PGBackendError -> PGBackendError -> Bool
(PGBackendError -> PGBackendError -> Bool)
-> (PGBackendError -> PGBackendError -> Bool) -> Eq PGBackendError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PGBackendError -> PGBackendError -> Bool
$c/= :: PGBackendError -> PGBackendError -> Bool
== :: PGBackendError -> PGBackendError -> Bool
$c== :: PGBackendError -> PGBackendError -> Bool
Eq, Int -> PGBackendError -> ShowS
[PGBackendError] -> ShowS
PGBackendError -> String
(Int -> PGBackendError -> ShowS)
-> (PGBackendError -> String)
-> ([PGBackendError] -> ShowS)
-> Show PGBackendError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PGBackendError] -> ShowS
$cshowList :: [PGBackendError] -> ShowS
show :: PGBackendError -> String
$cshow :: PGBackendError -> String
showsPrec :: Int -> PGBackendError -> ShowS
$cshowsPrec :: Int -> PGBackendError -> ShowS
Show)
instance Exception PGBackendError
initPostgresBackend :: Pool PG.Connection -> Text -> IO ()
initPostgresBackend :: Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName = Pool Connection -> (Connection -> IO ()) -> IO ()
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
Either SqlError Int64
res <- IO Int64 -> IO (Either SqlError Int64)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Int64 -> IO (Either SqlError Int64))
-> IO Int64 -> IO (Either SqlError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
createTableQuery
(SqlError -> IO ())
-> (Int64 -> IO ()) -> Either SqlError Int64 -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either
(PGBackendError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO ())
-> (SqlError -> PGBackendError) -> SqlError -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorInit)
(IO () -> Int64 -> IO ()
forall a b. a -> b -> a
const (IO () -> Int64 -> IO ()) -> IO () -> Int64 -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
Either SqlError Int64
res
where
createTableQuery :: Query
createTableQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"CREATE TABLE IF NOT EXISTS",
Text
tableName,
Text
"(key VARCHAR PRIMARY KEY,",
Text
"usage INT8 NOT NULL,",
Text
"expires_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP + '1 week'::INTERVAL)"
]
sqlHandlers :: [Handler a]
sqlHandlers :: [Handler a]
sqlHandlers =
[ (FormatError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (FormatError -> PGBackendError) -> FormatError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FormatError -> PGBackendError
PGBackendErrorBugFmt),
(QueryError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (QueryError -> PGBackendError) -> QueryError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. QueryError -> PGBackendError
PGBackendErrorBugQry),
(ResultError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (ResultError -> PGBackendError) -> ResultError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ResultError -> PGBackendError
PGBackendErrorBugRes),
(SqlError -> IO a) -> Handler a
forall a e. Exception e => (e -> IO a) -> Handler a
Handler (PGBackendError -> IO a
forall e a. Exception e => e -> IO a
throwIO (PGBackendError -> IO a)
-> (SqlError -> PGBackendError) -> SqlError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlError -> PGBackendError
PGBackendErrorBugSql)
]
pgBackendGetUsage :: Pool PG.Connection -> Text -> ByteString -> IO (Either PGBackendError Integer)
pgBackendGetUsage :: Pool Connection
-> Text -> ByteString -> IO (Either PGBackendError Integer)
pgBackendGetUsage Pool Connection
p Text
tableName ByteString
key = Pool Connection
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer))
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ \Connection
c ->
do
Either PGBackendError [Only Integer]
res <-
IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall e a. Exception e => IO a -> IO (Either e a)
try (IO [Only Integer] -> IO (Either PGBackendError [Only Integer]))
-> IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall a b. (a -> b) -> a -> b
$
Connection -> Query -> Only ByteString -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
getUsageQuery (ByteString -> Only ByteString
forall a. a -> Only a
PG.Only ByteString
key) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
Either PGBackendError Integer -> IO (Either PGBackendError Integer)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either PGBackendError Integer
-> IO (Either PGBackendError Integer))
-> Either PGBackendError Integer
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ do
[Only Integer]
rows <- Either PGBackendError [Only Integer]
res
case [Only Integer]
rows of
[] -> Integer -> Either PGBackendError Integer
forall a b. b -> Either a b
Right Integer
0
[PG.Only Integer
a] -> Integer -> Either PGBackendError Integer
forall a b. b -> Either a b
Right Integer
a
[Only Integer]
_ -> PGBackendError -> Either PGBackendError Integer
forall a b. a -> Either a b
Left PGBackendError
PGBackendErrorAtMostOneRow
where
getUsageQuery :: Query
getUsageQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"SELECT usage FROM",
Text
tableName,
Text
"WHERE key = ?",
Text
"AND expires_at > CURRENT_TIMESTAMP"
]
pgBackendIncAndGetUsage :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO (Either PGBackendError Integer)
pgBackendIncAndGetUsage :: Pool Connection
-> Text
-> ByteString
-> Integer
-> IO (Either PGBackendError Integer)
pgBackendIncAndGetUsage Pool Connection
p Text
tableName ByteString
key Integer
usage = Pool Connection
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer))
-> (Connection -> IO (Either PGBackendError Integer))
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
Either PGBackendError [Only Integer]
res <- IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall e a. Exception e => IO a -> IO (Either e a)
try (IO [Only Integer] -> IO (Either PGBackendError [Only Integer]))
-> IO [Only Integer] -> IO (Either PGBackendError [Only Integer])
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> (ByteString, Integer) -> IO [Only Integer]
forall q r.
(ToRow q, FromRow r) =>
Connection -> Query -> q -> IO [r]
PG.query Connection
c Query
incAndGetQuery (ByteString
key, Integer
usage) IO [Only Integer] -> [Handler [Only Integer]] -> IO [Only Integer]
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler [Only Integer]]
forall a. [Handler a]
sqlHandlers
Either PGBackendError [Only Integer] -> IO ()
forall a (m :: * -> *). (MonadIO m, Show a) => a -> m ()
print Either PGBackendError [Only Integer]
res
Either PGBackendError Integer -> IO (Either PGBackendError Integer)
forall (m :: * -> *) a. Monad m => a -> m a
return (Either PGBackendError Integer
-> IO (Either PGBackendError Integer))
-> Either PGBackendError Integer
-> IO (Either PGBackendError Integer)
forall a b. (a -> b) -> a -> b
$ do
[Only Integer]
rows <- Either PGBackendError [Only Integer]
res
case [Only Integer]
rows of
[PG.Only Integer
a] -> Integer -> Either PGBackendError Integer
forall a b. b -> Either a b
Right Integer
a
[Only Integer]
_ -> PGBackendError -> Either PGBackendError Integer
forall a b. a -> Either a b
Left PGBackendError
PGBackendErrorExactlyOneRow
where
incAndGetQuery :: Query
incAndGetQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"INSERT INTO",
Text
tableName,
Text
"as rl",
Text
"(key, usage) VALUES (?, ?)",
Text
"ON CONFLICT (key) DO UPDATE SET",
Text
"usage = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.usage + EXCLUDED.usage ELSE EXCLUDED.usage END,",
Text
"expires_at = CASE WHEN rl.expires_at > CURRENT_TIMESTAMP THEN rl.expires_at ELSE CURRENT_TIMESTAMP + '1 week'::INTERVAL END",
Text
"RETURNING usage"
]
pgBackendExpireIn :: Pool PG.Connection -> Text -> ByteString -> Integer -> IO (Either PGBackendError ())
pgBackendExpireIn :: Pool Connection
-> Text -> ByteString -> Integer -> IO (Either PGBackendError ())
pgBackendExpireIn Pool Connection
p Text
tableName ByteString
key Integer
seconds = Pool Connection
-> (Connection -> IO (Either PGBackendError ()))
-> IO (Either PGBackendError ())
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError ()))
-> IO (Either PGBackendError ()))
-> (Connection -> IO (Either PGBackendError ()))
-> IO (Either PGBackendError ())
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
Text -> IO ()
forall (m :: * -> *). MonadIO m => Text -> m ()
putText Text
"Called expire in!"
Either PGBackendError Int64
res <- IO Int64 -> IO (Either PGBackendError Int64)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Int64 -> IO (Either PGBackendError Int64))
-> IO Int64 -> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> (Integer, ByteString) -> IO Int64
forall q. ToRow q => Connection -> Query -> q -> IO Int64
PG.execute Connection
c Query
expireInQuery (Integer
seconds, ByteString
key) IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
Either PGBackendError () -> IO (Either PGBackendError ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Either PGBackendError () -> IO (Either PGBackendError ()))
-> Either PGBackendError () -> IO (Either PGBackendError ())
forall a b. (a -> b) -> a -> b
$ do
Int64
count <- Either PGBackendError Int64
res
if Int64
count Int64 -> Int64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Int64
1
then PGBackendError -> Either PGBackendError ()
forall a b. a -> Either a b
Left PGBackendError
PGBackendErrorExactlyOneUpdate
else () -> Either PGBackendError ()
forall a b. b -> Either a b
Right ()
where
expireInQuery :: Query
expireInQuery =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"UPDATE",
Text
tableName,
Text
"SET expires_at = CURRENT_TIMESTAMP + '? second'::interval",
Text
"WHERE key = ?"
]
pgBackendCleanup :: Pool PG.Connection -> Text -> IO ()
pgBackendCleanup :: Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName = IO ThreadId -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$
IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Either PGBackendError Int64
res <- Pool Connection
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall (m :: * -> *) a b.
MonadBaseControl IO m =>
Pool a -> (a -> m b) -> m b
withResource Pool Connection
p ((Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64))
-> (Connection -> IO (Either PGBackendError Int64))
-> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ \Connection
c -> do
IO Int64 -> IO (Either PGBackendError Int64)
forall a. IO a -> IO (Either PGBackendError a)
tryDBErr (IO Int64 -> IO (Either PGBackendError Int64))
-> IO Int64 -> IO (Either PGBackendError Int64)
forall a b. (a -> b) -> a -> b
$ Connection -> Query -> IO Int64
PG.execute_ Connection
c Query
removeExpired IO Int64 -> [Handler Int64] -> IO Int64
forall a. IO a -> [Handler a] -> IO a
`catches` [Handler Int64]
forall a. [Handler a]
sqlHandlers
case Either PGBackendError Int64
res of
Left PGBackendError
_ -> Int -> IO ()
threadDelay Int
d10s
Right Int64
n -> Int64 -> IO ()
forall a. (Num a, Ord a) => a -> IO ()
delay Int64
n
where
d10s :: Int
d10s = Int
10_000_000
d1s :: Int
d1s = Int
1_000_000
d100ms :: Int
d100ms = Int
100_000
delay :: a -> IO ()
delay a
n
| a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
5000 = Int -> IO ()
threadDelay Int
d100ms
| a
n a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 = Int -> IO ()
threadDelay Int
d1s
| Bool
otherwise = Int -> IO ()
threadDelay Int
d10s
tryDBErr :: IO a -> IO (Either PGBackendError a)
tryDBErr :: IO a -> IO (Either PGBackendError a)
tryDBErr IO a
a = IO a -> IO (Either PGBackendError a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
a
removeExpired :: Query
removeExpired =
String -> Query
forall a. IsString a => String -> a
fromString (String -> Query) -> String -> Query
forall a b. (a -> b) -> a -> b
$
Text -> String
forall a. ToString a => a -> String
toString (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$
Text -> [Text] -> Text
T.intercalate
Text
" "
[ Text
"DELETE FROM",
Text
tableName,
Text
"WHERE key IN (SELECT key FROM",
Text
tableName,
Text
"WHERE expires_at < CURRENT_TIMESTAMP LIMIT 5000)"
]
postgresBackend :: Pool PG.Connection -> Text -> IO (Backend ByteString PGBackendError)
postgresBackend :: Pool Connection -> Text -> IO (Backend ByteString PGBackendError)
postgresBackend Pool Connection
p Text
tableName = do
Pool Connection -> Text -> IO ()
initPostgresBackend Pool Connection
p Text
tableName
Pool Connection -> Text -> IO ()
pgBackendCleanup Pool Connection
p Text
tableName
Backend ByteString PGBackendError
-> IO (Backend ByteString PGBackendError)
forall (m :: * -> *) a. Monad m => a -> m a
return (Backend ByteString PGBackendError
-> IO (Backend ByteString PGBackendError))
-> Backend ByteString PGBackendError
-> IO (Backend ByteString PGBackendError)
forall a b. (a -> b) -> a -> b
$
MkBackend :: forall key err.
(key -> IO (Either err Integer))
-> (key -> Integer -> IO (Either err Integer))
-> (key -> Integer -> IO (Either err ()))
-> Backend key err
MkBackend
{ backendGetUsage :: ByteString -> IO (Either PGBackendError Integer)
backendGetUsage = Pool Connection
-> Text -> ByteString -> IO (Either PGBackendError Integer)
pgBackendGetUsage Pool Connection
p Text
tableName,
backendIncAndGetUsage :: ByteString -> Integer -> IO (Either PGBackendError Integer)
backendIncAndGetUsage = Pool Connection
-> Text
-> ByteString
-> Integer
-> IO (Either PGBackendError Integer)
pgBackendIncAndGetUsage Pool Connection
p Text
tableName,
backendExpireIn :: ByteString -> Integer -> IO (Either PGBackendError ())
backendExpireIn = Pool Connection
-> Text -> ByteString -> Integer -> IO (Either PGBackendError ())
pgBackendExpireIn Pool Connection
p Text
tableName
}