module Database.PostgreSQL where
import Control.Monad.Error
import Control.Monad.State
import Control.Monad.Trans
import Data.List
import Data.Maybe
import Foreign
import Foreign.C
newtype DatabaseT m a = DatabaseT (StateT DatabaseHandle m a)
deriving (Monad, MonadIO, MonadTrans)
class (MonadIO m, Error e, MonadError e m)
=> MonadDatabase e m where
getConnection :: m DatabaseHandle
instance (MonadIO m, Error e, MonadError e m)
=> MonadDatabase e (DatabaseT m) where
getConnection = DatabaseT get
instance (MonadIO (t m), MonadError e (t m), MonadTrans t,
MonadDatabase e m, Monad (t m))
=> MonadDatabase e (t m) where
getConnection = lift getConnection
instance MonadError e m => MonadError e (DatabaseT m) where
throwError x = DatabaseT (throwError x)
catchError (DatabaseT f) g = DatabaseT
$ catchError f (\x -> case g x of
DatabaseT y -> y)
data PGconn
newtype DatabaseHandle = DatabaseHandle (Ptr PGconn)
foreign import ccall unsafe "static libpq-fe.h PQconnectdb"
pqConnectDB :: CString -> IO (Ptr PGconn)
type ConnStatusType = Word32
connection_OK :: ConnStatusType
connection_OK = 0
connection_bad :: ConnStatusType
connection_bad = 1
foreign import ccall unsafe "static libpq-fe.h PQstatus"
pqStatus :: Ptr PGconn -> IO ConnStatusType
foreign import ccall unsafe "static libpq-fe.h PQerrorMessage"
pqErrorMessage :: Ptr PGconn -> IO CString
data PGresult
withPGResults :: MonadDatabase e m
=> m (Ptr PGresult) -> (Ptr PGresult -> m a) -> m a
withPGResults x f
= do p <- x
res <- f p
clear p
return res
withoutPGResults :: MonadDatabase e m => m (Ptr PGresult) -> m ()
withoutPGResults x = x >>= clear
foreign import ccall unsafe "static libpq-fe.h PQexec"
pqExec :: Ptr PGconn -> CString -> IO (Ptr PGresult)
exec :: MonadDatabase e m => String -> m (Ptr PGresult)
exec sql = do DatabaseHandle dbh <- getConnection
checkResultStatus "execute" $ withCString sql $ pqExec dbh
withExec :: MonadDatabase e m
=> String -> (Ptr PGresult -> m a) -> m a
withExec sql f = withPGResults (exec sql) f
withoutExec :: MonadDatabase e m => String -> m ()
withoutExec sql = withoutPGResults (exec sql)
type Oid = Word32
withCStrings :: [String] -> (Ptr CString -> IO a) -> IO a
withCStrings all_xs f = go [] all_xs
where go acc [] = withArray (reverse acc) f
go acc (x:xs) = withCString x $ \s -> go (s:acc) xs
foreign import ccall unsafe "static libpq-fe.h PQexecParams"
pqExecParams :: Ptr PGconn
-> CString
-> CInt
-> Ptr Oid
-> Ptr CString
-> Ptr CInt
-> Ptr CInt
-> CInt
-> IO (Ptr PGresult)
execParams :: MonadDatabase e m => String -> [String] -> m (Ptr PGresult)
execParams sql params
= do let nparams = genericLength params
oids = nullPtr
lengths = nullPtr
DatabaseHandle dbh <- getConnection
checkResultStatus "execParams" $
withCString sql $ \sql' ->
withCStrings params $ \params' ->
withArray (genericReplicate nparams 0) $ \formats ->
pqExecParams dbh sql' nparams oids params' lengths formats 0
withExecParams :: MonadDatabase e m
=> String -> [String] -> (Ptr PGresult -> m a) -> m a
withExecParams sql params f = withPGResults (execParams sql params) f
withoutExecParams :: MonadDatabase e m => String -> [String] -> m ()
withoutExecParams sql params = withoutPGResults (execParams sql params)
type ExecStatusType = Word32
pgres_empty_query :: ExecStatusType
pgres_empty_query = 0
pgres_command_OK :: ExecStatusType
pgres_command_OK = 1
pgres_tuples_OK :: ExecStatusType
pgres_tuples_OK = 2
foreign import ccall unsafe "static libpq-fe.h PQresultStatus"
pqResultStatus :: Ptr PGresult -> IO ExecStatusType
foreign import ccall unsafe "static libpq-fe.h PQresStatus"
pqResStatus :: ExecStatusType -> IO CString
foreign import ccall unsafe "static libpq-fe.h PQresultErrorMessage"
pqResultErrorMessage :: Ptr PGresult -> IO CString
checkResultStatus :: MonadDatabase e m
=> String -> IO (Ptr PGresult) -> m (Ptr PGresult)
checkResultStatus s f
= do res <- liftIO f
res' <- liftIO $ pqResultStatus res
when (res' `notElem` [pgres_command_OK, pgres_tuples_OK]) $ do
err_msg <- liftIO $ pqResultErrorMessage res >>= peekCString
err_code <- liftIO $ pqResStatus res' >>= peekCString
let err = s ++ " failed (" ++ err_code ++ "): " ++ err_msg
throwError $ strMsg err
return res
foreign import ccall unsafe "static libpq-fe.h PQclear"
pqClear :: Ptr PGresult -> IO ()
clear :: MonadDatabase e m => Ptr PGresult -> m ()
clear res = liftIO $ pqClear res
foreign import ccall unsafe "static libpq-fe.h PQntuples"
pqNTuples :: Ptr PGresult -> IO CInt
nTuples :: MonadDatabase e m => Ptr PGresult -> m CInt
nTuples res = liftIO $ pqNTuples res
foreign import ccall unsafe "static libpq-fe.h PQnfields"
pqNFields :: Ptr PGresult -> IO CInt
nFields :: MonadDatabase e m => Ptr PGresult -> m CInt
nFields res = liftIO $ pqNFields res
foreign import ccall unsafe "static libpq-fe.h PQgetvalue"
pqGetValue :: Ptr PGresult -> CInt -> CInt -> IO CString
getValue :: MonadDatabase e m => Ptr PGresult -> CInt -> CInt -> m String
getValue res row col = do cstr <- liftIO $ pqGetValue res row col
liftIO $ peekCString cstr
foreign import ccall unsafe "static libpq-fe.h PQfinish"
pqFinish :: Ptr PGconn -> IO ()
withDatabaseRaw :: MonadIO m => String -> DatabaseT m a -> m a
withDatabaseRaw conninfo (DatabaseT f)
= do dbh <- liftIO $ withCString conninfo pqConnectDB
if dbh == nullPtr
then error "XXX dbh was NULL - can't happen?"
else do stat <- liftIO $ pqStatus dbh
if stat /= connection_OK
then do err <- liftIO $ pqErrorMessage dbh >>= peekCString
error err
else do res <- evalStateT f (DatabaseHandle dbh)
liftIO $ pqFinish dbh
return res
data ConnectionInfo = ConnectionInfo { host :: Maybe String,
hostaddr :: Maybe String,
port :: Maybe String,
dbname :: Maybe String,
user :: Maybe String,
password :: Maybe String,
connect_timeout :: Maybe String,
options :: Maybe String,
sslmode :: Maybe String,
service :: Maybe String }
defaultConnectionInfo :: ConnectionInfo
defaultConnectionInfo = ConnectionInfo { host = Nothing,
hostaddr = Nothing,
port = Nothing,
dbname = Nothing,
user = Nothing,
password = Nothing,
connect_timeout = Nothing,
options = Nothing,
sslmode = Nothing,
service = Nothing }
withDatabase :: MonadIO m => ConnectionInfo -> DatabaseT m a -> m a
withDatabase conninfo f
= withDatabaseRaw conninfo' f
where conninfo' = concat $ intersperse " " $ catMaybes [
mkSetting "host" host,
mkSetting "hostaddr" hostaddr,
mkSetting "port" port,
mkSetting "dbname" dbname,
mkSetting "user" user,
mkSetting "password" password,
mkSetting "connect_timeout" connect_timeout,
mkSetting "options" options,
mkSetting "sslmode" sslmode,
mkSetting "service" service]
mkSetting name extract
= case extract conninfo of
Just val -> Just (name ++ "='" ++ escape val ++ "'")
Nothing -> Nothing
escape ('\'':cs) = '\\':'\'':escape cs
escape ('\\':cs) = '\\':'\\':escape cs
escape (c:cs) = c:escape cs
escape "" = ""