module Database.PostgreSQL.PQTypes.Internal.Query (
runQueryIO
) where
import Control.Applicative
import Foreign.ForeignPtr
import Foreign.Ptr
import Prelude
import qualified Control.Exception as E
import qualified Data.ByteString.Char8 as BS
import Database.PostgreSQL.PQTypes.Internal.C.Interface
import Database.PostgreSQL.PQTypes.Internal.C.Types
import Database.PostgreSQL.PQTypes.Internal.Connection
import Database.PostgreSQL.PQTypes.Internal.Error
import Database.PostgreSQL.PQTypes.Internal.Error.Code
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.Internal.QueryResult
import Database.PostgreSQL.PQTypes.Internal.State
import Database.PostgreSQL.PQTypes.Internal.Utils
import Database.PostgreSQL.PQTypes.SQL.Class
import Database.PostgreSQL.PQTypes.ToSQL
runQueryIO :: IsSQL sql => sql -> DBState m -> IO (Int, DBState m)
runQueryIO sql st = do
(affected, res) <- withConnectionData (dbConnection st) "runQueryIO" $ \cd -> do
let ConnectionData{..} = cd
allocParam = ParamAllocator $ withPGparam cdPtr
(paramCount, res) <- withSQL sql allocParam $ \param query -> (,)
<$> (fromIntegral <$> c_PQparamCount param)
<*> c_PQparamExec cdPtr nullPtr param query c_RESULT_BINARY
affected <- withForeignPtr res $ verifyResult cdPtr
stats' <- case affected of
Left _ -> return cdStats {
statsQueries = statsQueries cdStats + 1
, statsParams = statsParams cdStats + paramCount
}
Right rows -> do
columns <- fromIntegral <$> withForeignPtr res c_PQnfields
return ConnectionStats {
statsQueries = statsQueries cdStats + 1
, statsRows = statsRows cdStats + rows
, statsValues = statsValues cdStats + (rows * columns)
, statsParams = statsParams cdStats + paramCount
}
stats' `seq` return (cd { cdStats = stats' }, (either id id affected, res))
return (affected, st {
dbLastQuery = SomeSQL sql
, dbQueryResult = Just QueryResult {
qrSQL = SomeSQL sql
, qrResult = res
, qrFromRow = id
}
})
where
verifyResult :: Ptr PGconn -> Ptr PGresult -> IO (Either Int Int)
verifyResult conn res = do
rst <- c_PQresultStatus res
case rst of
_ | rst == c_PGRES_COMMAND_OK -> do
sn <- c_PQcmdTuples res >>= BS.packCString
case BS.readInt sn of
Nothing
| BS.null sn -> return . Left $ 0
| otherwise -> throwParseError sn
Just (n, rest)
| rest /= BS.empty -> throwParseError sn
| otherwise -> return . Left $ n
_ | rst == c_PGRES_TUPLES_OK -> Right . fromIntegral <$> c_PQntuples res
_ | rst == c_PGRES_FATAL_ERROR -> throwSQLError
_ | rst == c_PGRES_BAD_RESPONSE -> throwSQLError
_ | otherwise -> return . Left $ 0
where
throwSQLError = rethrowWithContext sql =<< if res == nullPtr
then return . E.toException . QueryError
=<< safePeekCString' =<< c_PQerrorMessage conn
else E.toException <$> (DetailedQueryError
<$> field c_PG_DIAG_SEVERITY
<*> (stringToErrorCode <$> field c_PG_DIAG_SQLSTATE)
<*> field c_PG_DIAG_MESSAGE_PRIMARY
<*> mfield c_PG_DIAG_MESSAGE_DETAIL
<*> mfield c_PG_DIAG_MESSAGE_HINT
<*> ((mread =<<) <$> mfield c_PG_DIAG_STATEMENT_POSITION)
<*> ((mread =<<) <$> mfield c_PG_DIAG_INTERNAL_POSITION)
<*> mfield c_PG_DIAG_INTERNAL_QUERY
<*> mfield c_PG_DIAG_CONTEXT
<*> mfield c_PG_DIAG_SOURCE_FILE
<*> ((mread =<<) <$> mfield c_PG_DIAG_SOURCE_LINE)
<*> mfield c_PG_DIAG_SOURCE_FUNCTION)
where
field f = maybe "" id <$> mfield f
mfield f = safePeekCString =<< c_PQresultErrorField res f
throwParseError sn = E.throwIO DBException {
dbeQueryContext = sql
, dbeError = HPQTypesError ("runQuery.verifyResult: string returned by PQcmdTuples is not a valid number: " ++ show sn)
}