module Database.PostgreSQL where import Control.Monad.Error import Control.Monad.State import Control.Monad.Trans import Data.List import Data.Maybe import Foreign import Foreign.C #include -- Generic stuff newtype DatabaseT m a = DatabaseT (StateT DatabaseHandle m a) deriving (Monad, MonadIO, MonadTrans) class (MonadIO m, Error e, MonadError e m) => MonadDatabase e m where getConnection :: m DatabaseHandle instance (MonadIO m, Error e, MonadError e m) => MonadDatabase e (DatabaseT m) where getConnection = DatabaseT get instance (MonadIO (t m), MonadError e (t m), MonadTrans t, MonadDatabase e m, Monad (t m)) => MonadDatabase e (t m) where getConnection = lift getConnection instance MonadError e m => MonadError e (DatabaseT m) where throwError x = DatabaseT (throwError x) catchError (DatabaseT f) g = DatabaseT $ catchError f (\x -> case g x of DatabaseT y -> y) -- Postgres stuff data PGconn newtype DatabaseHandle = DatabaseHandle (Ptr PGconn) foreign import ccall unsafe "static libpq-fe.h PQconnectdb" pqConnectDB :: CString -> IO (Ptr PGconn) type ConnStatusType = #type ConnStatusType -- XXX Incomplete list: connection_OK :: ConnStatusType connection_OK = #const CONNECTION_OK connection_bad :: ConnStatusType connection_bad = #const CONNECTION_BAD foreign import ccall unsafe "static libpq-fe.h PQstatus" pqStatus :: Ptr PGconn -> IO ConnStatusType foreign import ccall unsafe "static libpq-fe.h PQerrorMessage" pqErrorMessage :: Ptr PGconn -> IO CString data PGresult withPGResults :: MonadDatabase e m => m (Ptr PGresult) -> (Ptr PGresult -> m a) -> m a withPGResults x f = do p <- x res <- f p clear p return res withoutPGResults :: MonadDatabase e m => m (Ptr PGresult) -> m () withoutPGResults x = x >>= clear foreign import ccall unsafe "static libpq-fe.h PQexec" pqExec :: Ptr PGconn -> CString -> IO (Ptr PGresult) exec :: MonadDatabase e m => String -> m (Ptr PGresult) exec sql = do DatabaseHandle dbh <- getConnection checkResultStatus "execute" $ withCString sql $ pqExec dbh withExec :: MonadDatabase e m => String -> (Ptr PGresult -> m a) -> m a withExec sql f = withPGResults (exec sql) f withoutExec :: MonadDatabase e m => String -> m () withoutExec sql = withoutPGResults (exec sql) type Oid = #type Oid withCStrings :: [String] -> (Ptr CString -> IO a) -> IO a withCStrings all_xs f = go [] all_xs where go acc [] = withArray (reverse acc) f go acc (x:xs) = withCString x $ \s -> go (s:acc) xs foreign import ccall unsafe "static libpq-fe.h PQexecParams" pqExecParams :: Ptr PGconn -- Connection -> CString -- command -> CInt -- nParams -> Ptr Oid -- paramTypes -> Ptr CString -- paramValues -> Ptr CInt -- paramLengths -> Ptr CInt -- paramFormats -> CInt -- resultFormat -> IO (Ptr PGresult) execParams :: MonadDatabase e m => String -> [String] -> m (Ptr PGresult) execParams sql params = do let nparams = genericLength params -- We don't currently let the user tell us which Oid they want -- to use oids = nullPtr -- XXX We should really use binary mode, in which case we'd -- need to give lengths, but for now we are using text mode, -- so we don't lengths = nullPtr DatabaseHandle dbh <- getConnection checkResultStatus "execParams" $ withCString sql $ \sql' -> withCStrings params $ \params' -> -- XXX For now we use text mode (0), but we really ought to -- use binary withArray (genericReplicate nparams 0) $ \formats -> -- XXX Again, should use binary (1) rather than text (0) pqExecParams dbh sql' nparams oids params' lengths formats 0 withExecParams :: MonadDatabase e m => String -> [String] -> (Ptr PGresult -> m a) -> m a withExecParams sql params f = withPGResults (execParams sql params) f withoutExecParams :: MonadDatabase e m => String -> [String] -> m () withoutExecParams sql params = withoutPGResults (execParams sql params) -- Docs in http://www.postgresql.org/docs/7.4/interactive/libpq-exec.html -- PGresult *PQexecPrepared(PGconn *conn, -- const char *stmtName, -- int nParams, -- const char * const *paramValues, -- const int *paramLengths, -- const int *paramFormats, -- int resultFormat); -- Uses statements created with an SQL "PREPARE" command type ExecStatusType = #type ExecStatusType -- XXX Incomplete list: pgres_empty_query :: ExecStatusType pgres_empty_query = #const PGRES_EMPTY_QUERY pgres_command_OK :: ExecStatusType pgres_command_OK = #const PGRES_COMMAND_OK -- can be 0 rows: pgres_tuples_OK :: ExecStatusType pgres_tuples_OK = #const PGRES_TUPLES_OK foreign import ccall unsafe "static libpq-fe.h PQresultStatus" pqResultStatus :: Ptr PGresult -> IO ExecStatusType -- PGRES_COPY_OUT -- PGRES_COPY_IN -- PGRES_BAD_RESPONSE -- PGRES_NONFATAL_ERROR (not returned directly XXX) -- PGRES_FATAL_ERROR (null equiv to this) foreign import ccall unsafe "static libpq-fe.h PQresStatus" pqResStatus :: ExecStatusType -> IO CString foreign import ccall unsafe "static libpq-fe.h PQresultErrorMessage" pqResultErrorMessage :: Ptr PGresult -> IO CString checkResultStatus :: MonadDatabase e m => String -> IO (Ptr PGresult) -> m (Ptr PGresult) checkResultStatus s f = do res <- liftIO f res' <- liftIO $ pqResultStatus res when (res' `notElem` [pgres_command_OK, pgres_tuples_OK]) $ do err_msg <- liftIO $ pqResultErrorMessage res >>= peekCString err_code <- liftIO $ pqResStatus res' >>= peekCString let err = s ++ " failed (" ++ err_code ++ "): " ++ err_msg throwError $ strMsg err return res foreign import ccall unsafe "static libpq-fe.h PQclear" pqClear :: Ptr PGresult -> IO () clear :: MonadDatabase e m => Ptr PGresult -> m () clear res = liftIO $ pqClear res foreign import ccall unsafe "static libpq-fe.h PQntuples" pqNTuples :: Ptr PGresult -> IO CInt nTuples :: MonadDatabase e m => Ptr PGresult -> m CInt nTuples res = liftIO $ pqNTuples res foreign import ccall unsafe "static libpq-fe.h PQnfields" pqNFields :: Ptr PGresult -> IO CInt nFields :: MonadDatabase e m => Ptr PGresult -> m CInt nFields res = liftIO $ pqNFields res foreign import ccall unsafe "static libpq-fe.h PQgetvalue" pqGetValue :: Ptr PGresult -> CInt -> CInt -> IO CString getValue :: MonadDatabase e m => Ptr PGresult -> CInt -> CInt -> m String getValue res row col = do cstr <- liftIO $ pqGetValue res row col liftIO $ peekCString cstr -- int PQgetisnull(const PGresult *res, int row_number, int column_number); -- int PQgetlength(const PGresult *res, int row_number, int column_number); -- XXX Other functions for async requests -- http://www.postgresql.org/docs/7.4/interactive/libpq-async.html foreign import ccall unsafe "static libpq-fe.h PQfinish" pqFinish :: Ptr PGconn -> IO () -- XXX Should use bracket or somesuch, but we have the old IO/MonadIO problem withDatabaseRaw :: MonadIO m => String -> DatabaseT m a -> m a withDatabaseRaw conninfo (DatabaseT f) = do dbh <- liftIO $ withCString conninfo pqConnectDB if dbh == nullPtr then error "XXX dbh was NULL - can't happen?" else do stat <- liftIO $ pqStatus dbh if stat /= connection_OK then do err <- liftIO $ pqErrorMessage dbh >>= peekCString error err -- XXX else do res <- evalStateT f (DatabaseHandle dbh) liftIO $ pqFinish dbh return res data ConnectionInfo = ConnectionInfo { host :: Maybe String, hostaddr :: Maybe String, port :: Maybe String, dbname :: Maybe String, user :: Maybe String, password :: Maybe String, connect_timeout :: Maybe String, options :: Maybe String, sslmode :: Maybe String, service :: Maybe String } defaultConnectionInfo :: ConnectionInfo defaultConnectionInfo = ConnectionInfo { host = Nothing, hostaddr = Nothing, port = Nothing, dbname = Nothing, user = Nothing, password = Nothing, connect_timeout = Nothing, options = Nothing, sslmode = Nothing, service = Nothing } withDatabase :: MonadIO m => ConnectionInfo -> DatabaseT m a -> m a withDatabase conninfo f = withDatabaseRaw conninfo' f where conninfo' = concat $ intersperse " " $ catMaybes [ mkSetting "host" host, mkSetting "hostaddr" hostaddr, mkSetting "port" port, mkSetting "dbname" dbname, mkSetting "user" user, mkSetting "password" password, mkSetting "connect_timeout" connect_timeout, mkSetting "options" options, mkSetting "sslmode" sslmode, mkSetting "service" service] mkSetting name extract = case extract conninfo of Just val -> Just (name ++ "='" ++ escape val ++ "'") Nothing -> Nothing escape ('\'':cs) = '\\':'\'':escape cs escape ('\\':cs) = '\\':'\\':escape cs escape (c:cs) = c:escape cs escape "" = ""