{-# LANGUAGE OverloadedStrings, RecordWildCards, GADTs, CPP #-}
module Database.Selda.PostgreSQL
( PG, PGConnectInfo (..)
, withPostgreSQL, on, auth
, pgOpen, pgOpen', seldaClose
, pgConnString, pgPPConfig
) where
#if !MIN_VERSION_base(4, 11, 0)
import Data.Monoid
#endif
import Data.ByteString (ByteString)
import qualified Data.Text as T
import Database.Selda.Backend hiding (toText)
import Database.Selda.JSON
import Database.Selda.Unsafe as Selda (cast, operator)
import Control.Monad.Catch
import Control.Monad.IO.Class
#ifndef __HASTE__
import Control.Monad (void)
import qualified Data.ByteString as BS (foldl')
import qualified Data.ByteString.Char8 as BS (pack, unpack)
import Data.Dynamic
import Data.Foldable (for_)
import Data.Text.Encoding
import Database.Selda.PostgreSQL.Encoding
import Database.PostgreSQL.LibPQ hiding (user, pass, db, host)
#endif
data PG
instance JSONBackend PG where
(~>) = operator "->"
jsonToText = Selda.cast
data PGConnectInfo = PGConnectInfo
{
pgHost :: T.Text
, pgPort :: Int
, pgDatabase :: T.Text
, pgSchema :: Maybe T.Text
, pgUsername :: Maybe T.Text
, pgPassword :: Maybe T.Text
}
on :: T.Text -> T.Text -> PGConnectInfo
on db host = PGConnectInfo
{ pgHost = host
, pgPort = 5432
, pgDatabase = db
, pgSchema = Nothing
, pgUsername = Nothing
, pgPassword = Nothing
}
infixl 7 `on`
auth :: PGConnectInfo -> (T.Text, T.Text) -> PGConnectInfo
auth ci (user, pass) = ci
{ pgUsername = Just user
, pgPassword = Just pass
}
infixl 4 `auth`
pgConnString :: PGConnectInfo -> ByteString
#ifdef __HASTE__
pgConnString PGConnectInfo{..} = error "pgConnString called in JS context"
#else
pgConnString PGConnectInfo{..} = mconcat
[ "host=", encodeUtf8 pgHost, " "
, "port=", BS.pack (show pgPort), " "
, "dbname=", encodeUtf8 pgDatabase, " "
, case pgUsername of
Just user -> "user=" <> encodeUtf8 user <> " "
_ -> ""
, case pgPassword of
Just pass -> "password=" <> encodeUtf8 pass <> " "
_ -> ""
, "connect_timeout=10", " "
, "client_encoding=UTF8"
]
#endif
withPostgreSQL :: (MonadIO m, MonadMask m)
=> PGConnectInfo
-> SeldaT PG m a
-> m a
#ifdef __HASTE__
withPostgreSQL _ _ = return $ error "withPostgreSQL called in JS context"
#else
withPostgreSQL ci m = bracket (pgOpen ci) seldaClose (runSeldaT m)
#endif
pgOpen :: (MonadIO m, MonadMask m) => PGConnectInfo -> m (SeldaConnection PG)
pgOpen ci = pgOpen' (pgSchema ci) (pgConnString ci)
pgPPConfig :: PPConfig
pgOpen' :: (MonadIO m, MonadMask m)
=> Maybe T.Text
-> ByteString
-> m (SeldaConnection PG)
#ifdef __HASTE__
pgOpen' _ _ = return $ error "pgOpen' called in JS context"
pgPPConfig = error "pgPPConfig evaluated in JS context"
#else
pgOpen' schema connStr =
bracketOnError (liftIO $ connectdb connStr) (liftIO . finish) $ \conn -> do
st <- liftIO $ status conn
case st of
ConnectionOk -> do
let backend = pgBackend conn
_ <- liftIO $ runStmt backend "SET client_min_messages TO WARNING;" []
for_ schema $ \schema' ->
liftIO $ runStmt backend ("SET search_path TO '" <> schema' <> "';") []
newConnection backend (decodeUtf8 connStr)
nope -> do
connFailed nope
where
connFailed f = throwM $ DbError $ unwords
[ "unable to connect to postgres server: " ++ show f
]
pgPPConfig = defPPConfig
{ ppType = pgColType defPPConfig
, ppTypeHook = pgTypeHook
, ppTypePK = pgColTypePK defPPConfig
, ppAutoIncInsert = "DEFAULT"
, ppColAttrs = T.unwords . map pgColAttr
, ppColAttrsHook = pgColAttrsHook
, ppIndexMethodHook = (" USING " <>) . compileIndexMethod
}
where
compileIndexMethod BTreeIndex = "btree"
compileIndexMethod HashIndex = "hash"
pgTypeHook :: SqlTypeRep -> [ColAttr] -> (SqlTypeRep -> T.Text) -> T.Text
pgTypeHook ty attrs fun
| isGenericIntPrimaryKey ty attrs = pgColTypePK pgPPConfig TRowID
| otherwise = pgTypeRenameHook fun ty
pgTypeRenameHook _ TDateTime = "timestamp with time zone"
pgTypeRenameHook _ TTime = "time with time zone"
pgTypeRenameHook f ty = f ty
pgColAttrsHook :: SqlTypeRep -> [ColAttr] -> ([ColAttr] -> T.Text) -> T.Text
pgColAttrsHook ty attrs fun
| isGenericIntPrimaryKey ty attrs = fun [AutoPrimary Strong]
| otherwise = fun attrs
bigserialQue :: [ColAttr]
bigserialQue = [AutoPrimary Strong, Required]
isGenericIntPrimaryKey :: SqlTypeRep -> [ColAttr] -> Bool
isGenericIntPrimaryKey ty attrs = ty == TInt && and ((`elem` attrs) <$> bigserialQue)
pgBackend :: Connection
-> SeldaBackend PG
pgBackend c = SeldaBackend
{ runStmt = \q ps -> right <$> pgQueryRunner c False q ps
, runStmtWithPK = \q ps -> left <$> pgQueryRunner c True q ps
, prepareStmt = pgPrepare c
, runPrepared = pgRun c
, getTableInfo = pgGetTableInfo c . rawTableName
, backendId = PostgreSQL
, ppConfig = pgPPConfig
, closeConnection = \_ -> finish c
, disableForeignKeys = disableFKs c
}
where
left (Left x) = x
left _ = error "impossible"
right (Right x) = x
right _ = error "impossible"
disableFKs :: Connection -> Bool -> IO ()
disableFKs c True = do
void $ pgQueryRunner c False "BEGIN TRANSACTION;" []
void $ pgQueryRunner c False create []
void $ pgQueryRunner c False dropTbl []
where
create = mconcat
[ "create table if not exists __selda_dropped_fks ("
, " seq bigserial primary key,"
, " sql text"
, ");"
]
dropTbl = mconcat
[ "do $$ declare t record;"
, "begin"
, " for t in select conrelid::regclass::varchar table_name, conname constraint_name,"
, " pg_catalog.pg_get_constraintdef(r.oid, true) constraint_definition"
, " from pg_catalog.pg_constraint r"
, " where r.contype = 'f'"
, " and r.connamespace = (select n.oid from pg_namespace n where n.nspname = current_schema())"
, " loop"
, " insert into __selda_dropped_fks (sql) values ("
, " format('alter table if exists %s add constraint %s %s',"
, " quote_ident(t.table_name), quote_ident(t.constraint_name), t.constraint_definition));"
, " execute format('alter table %s drop constraint %s', quote_ident(t.table_name), quote_ident(t.constraint_name));"
, " end loop;"
, "end $$;"
]
disableFKs c False = do
void $ pgQueryRunner c False restore []
void $ pgQueryRunner c False "DROP TABLE __selda_dropped_fks;" []
void $ pgQueryRunner c False "COMMIT;" []
where
restore = mconcat
[ "do $$ declare t record;"
, "begin"
, " for t in select * from __selda_dropped_fks order by seq loop"
, " execute t.sql;"
, " delete from __selda_dropped_fks where seq = t.seq;"
, " end loop;"
, "end $$;"
]
pgGetTableInfo :: Connection -> T.Text -> IO TableInfo
pgGetTableInfo c tbl = do
Right (_, vals) <- pgQueryRunner c False tableinfo []
if null vals
then do
pure $ TableInfo [] [] []
else do
Right (_, pkInfo) <- pgQueryRunner c False pkquery []
Right (_, us) <- pgQueryRunner c False uniquequery []
let uniques = map splitNames us
Right (_, fks) <- pgQueryRunner c False fkquery []
Right (_, ixs) <- pgQueryRunner c False ixquery []
colInfos <- mapM (describe fks (map toText ixs)) vals
x <- pure $ TableInfo
{ tableColumnInfos = colInfos
, tableUniqueGroups = map (map mkColName) uniques
, tablePrimaryKey = [mkColName pk | [SqlString pk] <- pkInfo]
}
pure x
where
splitNames = breakNames . toText
breakNames s =
case T.break (== '"') s of
(n, ns) | T.null n -> []
| T.null ns -> [n]
| otherwise -> n : breakNames (T.tail ns)
toText [SqlString s] = s
toText _ = error "toText: unreachable"
tableinfo = mconcat
[ "SELECT column_name, data_type, is_nullable, column_default LIKE 'nextval(%' "
, "FROM information_schema.columns "
, "WHERE table_name = '", tbl, "';"
]
pkquery = mconcat
[ "SELECT a.attname "
, "FROM pg_index i "
, "JOIN pg_attribute a ON a.attrelid = i.indrelid "
, " AND a.attnum = ANY(i.indkey) "
, "WHERE i.indrelid = '\"", tbl, "\"'::regclass "
, " AND i.indisprimary;"
]
uniquequery = mconcat
[ "SELECT string_agg(a.attname, '\"') "
, "FROM pg_index i "
, "JOIN pg_attribute a ON a.attrelid = i.indrelid "
, " AND a.attnum = ANY(i.indkey) "
, "WHERE i.indrelid = '\"", tbl, "\"'::regclass "
, " AND i.indisunique "
, " AND NOT i.indisprimary "
, "GROUP BY i.indkey;"
]
fkquery = mconcat
[ "SELECT kcu.column_name, ccu.table_name, ccu.column_name "
, "FROM information_schema.table_constraints AS tc "
, "JOIN information_schema.key_column_usage AS kcu "
, " ON tc.constraint_name = kcu.constraint_name "
, "JOIN information_schema.constraint_column_usage AS ccu "
, " ON ccu.constraint_name = tc.constraint_name "
, "WHERE constraint_type = 'FOREIGN KEY' AND tc.table_name='", tbl, "';"
]
ixquery = mconcat
[ "select a.attname as column_name "
, "from pg_class t, pg_class i, pg_index ix, pg_attribute a "
, "where "
, "t.oid = ix.indrelid "
, "and i.oid = ix.indexrelid "
, "and a.attrelid = t.oid "
, "and a.attnum = ANY(ix.indkey) "
, "and t.relkind = 'r' "
, "and not ix.indisunique "
, "and not ix.indisprimary "
, "and t.relkind = 'r' "
, "and t.relname = '", tbl , "';"
]
describe fks ixs [SqlString name, SqlString ty, SqlString nullable, auto] =
return $ ColumnInfo
{ colName = mkColName name
, colType = mkTypeRep ty'
, colIsAutoPrimary = isAuto auto
, colIsNullable = readBool nullable
, colHasIndex = name `elem` ixs
, colFKs =
[ (mkTableName tblname, mkColName col)
| [SqlString cname, SqlString tblname, SqlString col] <- fks
, name == cname
]
}
where
ty' = T.toLower ty
isAuto (SqlBool x) = x
isAuto _ = False
describe _ _ results =
throwM $ SqlError $ "bad result from table info query: " ++ show results
pgQueryRunner :: Connection -> Bool -> T.Text -> [Param] -> IO (Either Int (Int, [[SqlValue]]))
pgQueryRunner c return_lastid q ps = do
mres <- execParams c (encodeUtf8 q') [fromSqlValue p | Param p <- ps] Binary
unlessError c errmsg mres $ \res -> do
if return_lastid
then Left <$> getLastId res
else Right <$> getRows res
where
errmsg = "error executing query `" ++ T.unpack q' ++ "'"
q' | return_lastid = q <> " RETURNING LASTVAL();"
| otherwise = q
getLastId res = (maybe 0 id . fmap readInt) <$> getvalue res 0 0
pgRun :: Connection -> Dynamic -> [Param] -> IO (Int, [[SqlValue]])
pgRun c hdl ps = do
let Just sid = fromDynamic hdl :: Maybe StmtID
mres <- execPrepared c (BS.pack $ show sid) (map mkParam ps) Binary
unlessError c errmsg mres $ getRows
where
errmsg = "error executing prepared statement"
mkParam (Param p) = case fromSqlValue p of
Just (_, val, fmt) -> Just (val, fmt)
Nothing -> Nothing
getRows :: Result -> IO (Int, [[SqlValue]])
getRows res = do
rows <- ntuples res
cols <- nfields res
types <- mapM (ftype res) [0..cols-1]
affected <- cmdTuples res
result <- mapM (getRow res types cols) [0..rows-1]
pure $ case affected of
Just "" -> (0, result)
Just s -> (bsToPositiveInt s, result)
_ -> (0, result)
where
bsToPositiveInt = BS.foldl' (\a x -> a*10+fromIntegral x-48) 0
getRow :: Result -> [Oid] -> Column -> Row -> IO [SqlValue]
getRow res types cols row = do
sequence $ zipWith (getCol res row) [0..cols-1] types
getCol :: Result -> Row -> Column -> Oid -> IO SqlValue
getCol res row col t = do
mval <- getvalue res row col
case mval of
Just val -> pure $ toSqlValue t val
_ -> pure SqlNull
pgPrepare :: Connection -> StmtID -> [SqlTypeRep] -> T.Text -> IO Dynamic
pgPrepare c sid types q = do
mres <- prepare c (BS.pack $ show sid) (encodeUtf8 q) (Just types')
unlessError c errmsg mres $ \_ -> return (toDyn sid)
where
types' = map fromSqlType types
errmsg = "error preparing query `" ++ T.unpack q ++ "'"
unlessError :: Connection -> String -> Maybe Result -> (Result -> IO a) -> IO a
unlessError c msg mres m = do
case mres of
Just res -> do
st <- resultStatus res
case st of
BadResponse -> doError c msg
FatalError -> doError c msg
NonfatalError -> doError c msg
_ -> m res
Nothing -> throwM $ DbError "unable to submit query to server"
doError :: Connection -> String -> IO a
doError c msg = do
me <- errorMessage c
throwM $ SqlError $ concat
[ msg
, maybe "" ((": " ++) . BS.unpack) me
]
mkTypeRep :: T.Text -> Either T.Text SqlTypeRep
mkTypeRep "bigserial" = Right TRowID
mkTypeRep "int8" = Right TInt
mkTypeRep "bigint" = Right TInt
mkTypeRep "float8" = Right TFloat
mkTypeRep "double precision" = Right TFloat
mkTypeRep "timestamp with time zone" = Right TDateTime
mkTypeRep "bytea" = Right TBlob
mkTypeRep "text" = Right TText
mkTypeRep "boolean" = Right TBool
mkTypeRep "date" = Right TDate
mkTypeRep "time with time zone" = Right TTime
mkTypeRep "uuid" = Right TUUID
mkTypeRep "jsonb" = Right TJSON
mkTypeRep typ = Left typ
pgColType :: PPConfig -> SqlTypeRep -> T.Text
pgColType _ TRowID = "BIGINT"
pgColType _ TInt = "INT8"
pgColType _ TFloat = "FLOAT8"
pgColType _ TDateTime = "TIMESTAMP"
pgColType _ TBlob = "BYTEA"
pgColType _ TUUID = "UUID"
pgColType _ TJSON = "JSONB"
pgColType cfg t = ppType cfg t
pgColAttr :: ColAttr -> T.Text
pgColAttr Primary = ""
pgColAttr (AutoPrimary _) = "PRIMARY KEY"
pgColAttr Required = "NOT NULL"
pgColAttr Optional = "NULL"
pgColAttr Unique = "UNIQUE"
pgColAttr (Indexed _) = ""
pgColTypePK :: PPConfig -> SqlTypeRep -> T.Text
pgColTypePK _ TRowID = "BIGSERIAL"
pgColTypePK cfg t = pgColType cfg t
#endif