{-# LANGUAGE CPP #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.PostgreSQL.Simple.Util
( existsTable
, withTransactionRolledBack
) where
import Control.Exception (finally)
import Database.PostgreSQL.Simple (Connection, Only (..), begin,
query, rollback)
import GHC.Int (Int64)
existsTable :: Connection -> String -> IO Bool
existsTable con table =
fmap checkRowCount (query con q (Only table) :: IO [[Int64]])
where
q = "select count(relname) from pg_class where relname = ?"
checkRowCount :: [[Int64]] -> Bool
checkRowCount ((1:_):_) = True
checkRowCount _ = False
withTransactionRolledBack :: Connection -> IO a -> IO a
withTransactionRolledBack con f =
begin con >> finally f (rollback con)