{-# LANGUAGE
FlexibleContexts
, GADTs
, LambdaCase
, OverloadedStrings
, ScopedTypeVariables
, TypeApplications
#-}
module Squeal.PostgreSQL.Session.Result
( Result (..)
, getRow
, firstRow
, getRows
, nextRow
, cmdStatus
, cmdTuples
, ntuples
, nfields
, resultStatus
, okResult
, resultErrorMessage
, resultErrorCode
, liftResult
) where
import Control.Exception (throw)
import Control.Monad (when, (<=<))
import Control.Monad.IO.Class
import Data.ByteString (ByteString)
import Data.Text (Text)
import Data.Traversable (for)
import Text.Read (readMaybe)
import UnliftIO (throwIO)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Char8 as Char8
import qualified Data.Text.Encoding as Text
import qualified Database.PostgreSQL.LibPQ as LibPQ
import qualified Generics.SOP as SOP
import Squeal.PostgreSQL.Session.Decode
import Squeal.PostgreSQL.Session.Exception
data Result y where
Result
:: SOP.SListI row
=> DecodeRow row y
-> LibPQ.Result
-> Result y
instance Functor Result where
fmap f (Result decode result) = Result (fmap f decode) result
getRow :: MonadIO io => LibPQ.Row -> Result y -> io y
getRow r (Result decode result) = liftIO $ do
numRows <- LibPQ.ntuples result
numCols <- LibPQ.nfields result
when (numRows < r) $ throw $ RowsException "getRow" r numRows
row' <- traverse (LibPQ.getvalue result r) [0 .. numCols - 1]
case SOP.fromList row' of
Nothing -> throw $ ColumnsException "getRow" numCols
Just row -> case execDecodeRow decode row of
Left parseError -> throw $ DecodingException "getRow" parseError
Right y -> return y
nextRow
:: MonadIO io
=> LibPQ.Row
-> Result y
-> LibPQ.Row
-> io (Maybe (LibPQ.Row, y))
nextRow total (Result decode result) r
= liftIO $ if r >= total then return Nothing else do
numCols <- LibPQ.nfields result
row' <- traverse (LibPQ.getvalue result r) [0 .. numCols - 1]
case SOP.fromList row' of
Nothing -> throw $ ColumnsException "nextRow" numCols
Just row -> case execDecodeRow decode row of
Left parseError -> throw $ DecodingException "nextRow" parseError
Right y -> return $ Just (r+1, y)
getRows :: MonadIO io => Result y -> io [y]
getRows (Result decode result) = liftIO $ do
numCols <- LibPQ.nfields result
numRows <- LibPQ.ntuples result
for [0 .. numRows - 1] $ \ r -> do
row' <- traverse (LibPQ.getvalue result r) [0 .. numCols - 1]
case SOP.fromList row' of
Nothing -> throw $ ColumnsException "getRows" numCols
Just row -> case execDecodeRow decode row of
Left parseError -> throw $ DecodingException "getRows" parseError
Right y -> return y
firstRow :: MonadIO io => Result y -> io (Maybe y)
firstRow (Result decode result) = liftIO $ do
numRows <- LibPQ.ntuples result
numCols <- LibPQ.nfields result
if numRows <= 0 then return Nothing else do
row' <- traverse (LibPQ.getvalue result 0) [0 .. numCols - 1]
case SOP.fromList row' of
Nothing -> throw $ ColumnsException "firstRow" numCols
Just row -> case execDecodeRow decode row of
Left parseError -> throw $ DecodingException "firstRow" parseError
Right y -> return $ Just y
liftResult
:: MonadIO io
=> (LibPQ.Result -> IO x)
-> Result y -> io x
liftResult f (Result _ result) = liftIO $ f result
ntuples :: MonadIO io => Result y -> io LibPQ.Row
ntuples = liftResult LibPQ.ntuples
nfields :: MonadIO io => Result y -> io LibPQ.Column
nfields = liftResult LibPQ.nfields
resultStatus :: MonadIO io => Result y -> io LibPQ.ExecStatus
resultStatus = liftResult LibPQ.resultStatus
cmdStatus :: MonadIO io => Result y -> io Text
cmdStatus = liftResult (getCmdStatus <=< LibPQ.cmdStatus)
where
getCmdStatus = \case
Nothing -> throwIO $ ConnectionException "LibPQ.cmdStatus"
Just bytes -> return $ Text.decodeUtf8 bytes
cmdTuples :: MonadIO io => Result y -> io (Maybe LibPQ.Row)
cmdTuples = liftResult (getCmdTuples <=< LibPQ.cmdTuples)
where
getCmdTuples = \case
Nothing -> throwIO $ ConnectionException "LibPQ.cmdTuples"
Just bytes -> return $
if ByteString.null bytes
then Nothing
else fromInteger <$> readMaybe (Char8.unpack bytes)
okResult_ :: MonadIO io => LibPQ.Result -> io ()
okResult_ result = liftIO $ do
status <- LibPQ.resultStatus result
case status of
LibPQ.CommandOk -> return ()
LibPQ.TuplesOk -> return ()
_ -> do
stateCodeMaybe <- LibPQ.resultErrorField result LibPQ.DiagSqlstate
case stateCodeMaybe of
Nothing -> throw $ ConnectionException "LibPQ.resultErrorField"
Just stateCode -> do
msgMaybe <- LibPQ.resultErrorMessage result
case msgMaybe of
Nothing -> throw $ ConnectionException "LibPQ.resultErrorMessage"
Just msg -> throw . SQLException $ SQLState status stateCode msg
okResult :: MonadIO io => Result y -> io ()
okResult = liftResult okResult_
resultErrorMessage
:: MonadIO io => Result y -> io (Maybe ByteString)
resultErrorMessage = liftResult LibPQ.resultErrorMessage
resultErrorCode
:: MonadIO io
=> Result y
-> io (Maybe ByteString)
resultErrorCode = liftResult (flip LibPQ.resultErrorField LibPQ.DiagSqlstate)
execDecodeRow
:: DecodeRow row y
-> SOP.NP (SOP.K (Maybe ByteString)) row
-> Either Text y
execDecodeRow decode = runDecodeRow decode