{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Database.PostgreSQL.Tmp
( defaultDB
, DBInfo(..)
, withTmpDB
, withTmpDB'
, createTmpDB
, dropTmpDB
, newRole
, dropRole
, newDB
, dropDB
) where
import Control.Applicative (pure)
import Control.Exception
import Data.ByteString (ByteString)
import Data.Coerce
import Data.Int
import Data.Monoid
import qualified Data.Text as T
import Database.PostgreSQL.Simple
import Database.PostgreSQL.Simple.Types
defaultDB :: ByteString
defaultDB = "dbname='postgres' user='postgres'"
data DBInfo =
DBInfo {dbName :: T.Text
,roleName :: T.Text} deriving (Show,Read,Eq,Ord)
withTmpDB :: (DBInfo -> IO a) -> IO a
withTmpDB = withTmpDB' defaultDB
withTmpDB' :: ByteString -> (DBInfo -> IO a) -> IO a
withTmpDB' conStr f = bracket (createTmpDB conStr) dropTmpDB (\(_,dbInfo) -> f dbInfo)
createTmpDB :: ByteString -> IO (Connection, DBInfo)
createTmpDB conStr = do
conn <- connectPostgreSQL conStr
role <- newRole conn
db <- newDB conn role
pure (conn, DBInfo {dbName = db,roleName = role})
dropTmpDB :: (Connection, DBInfo) -> IO ()
dropTmpDB (conn, DBInfo db role) = do
_ <- dropDB conn db
_ <- dropRole conn role
close conn
newRole :: Connection -> IO T.Text
newRole conn =
do (roles :: [Only T.Text]) <- query_ conn "SELECT rolname FROM pg_roles"
let newName = freshName "tmp" (coerce roles)
_ <- execute conn "CREATE USER ? WITH CREATEDB" (Only (Identifier newName))
pure newName
dropRole :: Connection -> T.Text -> IO Int64
dropRole conn name = execute conn "DROP ROLE ?" (Only (Identifier name))
newDB :: Connection -> T.Text -> IO T.Text
newDB conn role =
do (dbNames :: [Only T.Text]) <- query_ conn "SELECT datname FROM pg_database"
let newName = freshName "tmp" (coerce dbNames)
_ <- execute conn "CREATE DATABASE ? OWNER ?" (Identifier newName,Identifier role)
pure newName
dropDB :: Connection -> T.Text -> IO Int64
dropDB conn name =
execute conn "DROP DATABASE ?" (Only (Identifier name))
freshName :: T.Text -> [T.Text] -> T.Text
freshName template existingNames = loop 0
where loop :: Int -> T.Text
loop i =
if (template <> T.pack (show i)) `elem` existingNames
then loop (i + 1)
else (template <> T.pack (show i))