module Database.PostgreSQL.Typed.Protocol (
PGDatabase(..)
, defaultPGDatabase
, PGConnection
, PGError(..)
, pgErrorCode
, pgConnectionDatabase
, pgTypeEnv
, pgConnect
, pgDisconnect
, pgReconnect
, pgDescribe
, pgSimpleQuery
, pgSimpleQueries_
, pgPreparedQuery
, pgPreparedLazyQuery
, pgCloseStatement
, pgBegin
, pgCommit
, pgRollback
, pgCommitAll
, pgRollbackAll
, pgTransaction
, pgDisconnectOnce
, pgRun
, PGPreparedStatement
, pgPrepare
, pgClose
, PGColDescription(..)
, PGRowDescription
, pgBind
, pgFetch
, PGNotification(..)
, pgGetNotifications
, pgGetNotification
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>), (<$))
#endif
import Control.Arrow ((&&&), first, second)
import Control.Exception (Exception, throwIO, onException, finally)
import Control.Monad (void, liftM2, replicateM, when, unless)
#ifdef VERSION_cryptonite
import qualified Crypto.Hash as Hash
import qualified Data.ByteArray.Encoding as BA
#endif
import qualified Data.Binary.Get as G
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Char8 as BSC
import Data.ByteString.Internal (w2c)
import qualified Data.ByteString.Lazy as BSL
import qualified Data.ByteString.Lazy.Char8 as BSLC
import Data.ByteString.Lazy.Internal (smallChunkSize)
import qualified Data.Foldable as Fold
import Data.IORef (IORef, newIORef, writeIORef, readIORef, atomicModifyIORef, atomicModifyIORef', modifyIORef')
import Data.Int (Int32, Int16)
import qualified Data.Map.Lazy as Map
import Data.Maybe (fromMaybe)
import Data.Monoid ((<>))
#if !MIN_VERSION_base(4,8,0)
import Data.Monoid (mempty)
#endif
import Data.Tuple (swap)
import Data.Typeable (Typeable)
#if !MIN_VERSION_base(4,8,0)
import Data.Word (Word)
#endif
import Data.Word (Word32)
import Network (HostName, PortID(..), connectTo)
import System.IO (Handle, hFlush, hClose, stderr, hPutStrLn, hSetBuffering, BufferMode(BlockBuffering))
import System.IO.Error (mkIOError, eofErrorType, ioError)
import System.IO.Unsafe (unsafeInterleaveIO)
import Text.Read (readMaybe)
import Database.PostgreSQL.Typed.Types
import Database.PostgreSQL.Typed.Dynamic
data PGState
= StateUnsync
| StatePending
| StateIdle
| StateTransaction
| StateTransactionFailed
| StateClosed
deriving (Show, Eq)
data PGDatabase = PGDatabase
{ pgDBHost :: HostName
, pgDBPort :: PortID
, pgDBName :: BS.ByteString
, pgDBUser, pgDBPass :: BS.ByteString
, pgDBParams :: [(BS.ByteString, BS.ByteString)]
, pgDBDebug :: Bool
, pgDBLogMessage :: MessageFields -> IO ()
}
instance Eq PGDatabase where
PGDatabase h1 s1 n1 u1 p1 l1 _ _ == PGDatabase h2 s2 n2 u2 p2 l2 _ _ =
h1 == h2 && s1 == s2 && n1 == n2 && u1 == u2 && p1 == p2 && l1 == l2
newtype PGPreparedStatement = PGPreparedStatement Integer
deriving (Eq, Show)
preparedStatementName :: PGPreparedStatement -> BS.ByteString
preparedStatementName (PGPreparedStatement n) = BSC.pack $ show n
data PGConnection = PGConnection
{ connHandle :: Handle
, connDatabase :: !PGDatabase
, connPid :: !Word32
, connKey :: !Word32
, connTypeEnv :: PGTypeEnv
, connParameters :: IORef (Map.Map BS.ByteString BS.ByteString)
, connPreparedStatementCount :: IORef Integer
, connPreparedStatementMap :: IORef (Map.Map (BS.ByteString, [OID]) PGPreparedStatement)
, connState :: IORef PGState
, connInput :: IORef (G.Decoder PGBackendMessage)
, connTransaction :: IORef Word
, connNotifications :: IORef (Queue PGNotification)
}
data PGColDescription = PGColDescription
{ pgColName :: BS.ByteString
, pgColTable :: !OID
, pgColNumber :: !Int16
, pgColType :: !OID
, pgColSize :: !Int16
, pgColModifier :: !Int32
, pgColBinary :: !Bool
} deriving (Show)
type PGRowDescription = [PGColDescription]
type MessageFields = Map.Map Char BS.ByteString
data PGNotification = PGNotification
{ pgNotificationPid :: !Word32
, pgNotificationChannel :: !BS.ByteString
, pgNotificationPayload :: BSL.ByteString
} deriving (Show)
data Queue a = Queue [a] [a]
emptyQueue :: Queue a
emptyQueue = Queue [] []
enQueue :: a -> Queue a -> Queue a
enQueue a (Queue e d) = Queue (a:e) d
deQueue :: Queue a -> (Queue a, Maybe a)
deQueue (Queue e (x:d)) = (Queue e d, Just x)
deQueue (Queue (reverse -> x:d) []) = (Queue [] d, Just x)
deQueue q = (q, Nothing)
queueToList :: Queue a -> [a]
queueToList (Queue e d) = d ++ reverse e
data PGFrontendMessage
= StartupMessage [(BS.ByteString, BS.ByteString)]
| CancelRequest !Word32 !Word32
| Bind { portalName :: BS.ByteString, statementName :: BS.ByteString, bindParameters :: PGValues, binaryColumns :: [Bool] }
| CloseStatement { statementName :: BS.ByteString }
| ClosePortal { portalName :: BS.ByteString }
| DescribeStatement { statementName :: BS.ByteString }
| DescribePortal { portalName :: BS.ByteString }
| Execute { portalName :: BS.ByteString, executeRows :: !Word32 }
| Flush
| Parse { statementName :: BS.ByteString, queryString :: BSL.ByteString, parseTypes :: [OID] }
| PasswordMessage BS.ByteString
| SimpleQuery { queryString :: BSL.ByteString }
| Sync
| Terminate
deriving (Show)
data PGBackendMessage
= AuthenticationOk
| AuthenticationCleartextPassword
| AuthenticationMD5Password BS.ByteString
| BackendKeyData Word32 Word32
| BindComplete
| CloseComplete
| CommandComplete BS.ByteString
| DataRow PGValues
| EmptyQueryResponse
| ErrorResponse { messageFields :: MessageFields }
| NoData
| NoticeResponse { messageFields :: MessageFields }
| NotificationResponse PGNotification
| ParameterDescription [OID]
| ParameterStatus BS.ByteString BS.ByteString
| ParseComplete
| PortalSuspended
| ReadyForQuery PGState
| RowDescription PGRowDescription
deriving (Show)
newtype PGError = PGError { pgErrorFields :: MessageFields }
deriving (Typeable)
instance Show PGError where
show (PGError m) = displayMessage m
instance Exception PGError
displayMessage :: MessageFields -> String
displayMessage m = "PG" ++ f 'S' ++ (if null fC then ": " else " [" ++ fC ++ "]: ") ++ f 'M' ++ (if null fD then fD else '\n' : fD)
where
fC = f 'C'
fD = f 'D'
f c = BSC.unpack $ Map.findWithDefault BS.empty c m
makeMessage :: BS.ByteString -> BS.ByteString -> MessageFields
makeMessage m d = Map.fromAscList [('D', d), ('M', m)]
pgErrorCode :: PGError -> BS.ByteString
pgErrorCode (PGError e) = Map.findWithDefault BS.empty 'C' e
defaultLogMessage :: MessageFields -> IO ()
defaultLogMessage = hPutStrLn stderr . displayMessage
defaultPGDatabase :: PGDatabase
defaultPGDatabase = PGDatabase
{ pgDBHost = "localhost"
, pgDBPort = PortNumber 5432
, pgDBName = "postgres"
, pgDBUser = "postgres"
, pgDBPass = BS.empty
, pgDBParams = []
, pgDBDebug = False
, pgDBLogMessage = defaultLogMessage
}
connDebug :: PGConnection -> Bool
connDebug = pgDBDebug . connDatabase
connLogMessage :: PGConnection -> MessageFields -> IO ()
connLogMessage = pgDBLogMessage . connDatabase
pgConnectionDatabase :: PGConnection -> PGDatabase
pgConnectionDatabase = connDatabase
pgTypeEnv :: PGConnection -> PGTypeEnv
pgTypeEnv = connTypeEnv
#ifdef VERSION_cryptonite
md5 :: BS.ByteString -> BS.ByteString
md5 = BA.convertToBase BA.Base16 . (Hash.hash :: BS.ByteString -> Hash.Digest Hash.MD5)
#endif
nul :: B.Builder
nul = B.word8 0
byteStringNul :: BS.ByteString -> B.Builder
byteStringNul s = B.byteString s <> nul
lazyByteStringNul :: BSL.ByteString -> B.Builder
lazyByteStringNul s = B.lazyByteString s <> nul
messageBody :: PGFrontendMessage -> (Maybe Char, B.Builder)
messageBody (StartupMessage kv) = (Nothing, B.word32BE 0x30000
<> Fold.foldMap (\(k, v) -> byteStringNul k <> byteStringNul v) kv <> nul)
messageBody (CancelRequest pid key) = (Nothing, B.word32BE 80877102
<> B.word32BE pid <> B.word32BE key)
messageBody Bind{ portalName = d, statementName = n, bindParameters = p, binaryColumns = bc } = (Just 'B',
byteStringNul d
<> byteStringNul n
<> (if any fmt p
then B.word16BE (fromIntegral $ length p) <> Fold.foldMap (B.word16BE . fromIntegral . fromEnum . fmt) p
else B.word16BE 0)
<> B.word16BE (fromIntegral $ length p) <> Fold.foldMap val p
<> (if or bc
then B.word16BE (fromIntegral $ length bc) <> Fold.foldMap (B.word16BE . fromIntegral . fromEnum) bc
else B.word16BE 0))
where
fmt (PGBinaryValue _) = True
fmt _ = False
val PGNullValue = B.int32BE (1)
val (PGTextValue v) = B.word32BE (fromIntegral $ BS.length v) <> B.byteString v
val (PGBinaryValue v) = B.word32BE (fromIntegral $ BS.length v) <> B.byteString v
messageBody CloseStatement{ statementName = n } = (Just 'C',
B.char7 'S' <> byteStringNul n)
messageBody ClosePortal{ portalName = n } = (Just 'C',
B.char7 'P' <> byteStringNul n)
messageBody DescribeStatement{ statementName = n } = (Just 'D',
B.char7 'S' <> byteStringNul n)
messageBody DescribePortal{ portalName = n } = (Just 'D',
B.char7 'P' <> byteStringNul n)
messageBody Execute{ portalName = n, executeRows = r } = (Just 'E',
byteStringNul n <> B.word32BE r)
messageBody Flush = (Just 'H', mempty)
messageBody Parse{ statementName = n, queryString = s, parseTypes = t } = (Just 'P',
byteStringNul n <> lazyByteStringNul s
<> B.word16BE (fromIntegral $ length t) <> Fold.foldMap B.word32BE t)
messageBody (PasswordMessage s) = (Just 'p',
B.byteString s <> nul)
messageBody SimpleQuery{ queryString = s } = (Just 'Q',
lazyByteStringNul s)
messageBody Sync = (Just 'S', mempty)
messageBody Terminate = (Just 'X', mempty)
pgSend :: PGConnection -> PGFrontendMessage -> IO ()
pgSend c@PGConnection{ connHandle = h, connState = sr } msg = do
modifyIORef' sr $ state msg
when (connDebug c) $ putStrLn $ "> " ++ show msg
B.hPutBuilder h $ Fold.foldMap B.char7 t <> B.word32BE (fromIntegral $ 4 + BS.length b)
BS.hPut h b
where
(t, b) = second (BSL.toStrict . B.toLazyByteString) $ messageBody msg
state _ StateClosed = StateClosed
state Sync _ = StatePending
state SimpleQuery{} _ = StatePending
state Terminate _ = StateClosed
state _ _ = StateUnsync
pgFlush :: PGConnection -> IO ()
pgFlush = hFlush . connHandle
getByteStringNul :: G.Get BS.ByteString
getByteStringNul = fmap BSL.toStrict G.getLazyByteStringNul
getMessageFields :: G.Get MessageFields
getMessageFields = g . w2c =<< G.getWord8 where
g '\0' = return Map.empty
g f = liftM2 (Map.insert f) getByteStringNul getMessageFields
getMessageBody :: Char -> G.Get PGBackendMessage
getMessageBody 'R' = auth =<< G.getWord32be where
auth 0 = return AuthenticationOk
auth 3 = return AuthenticationCleartextPassword
auth 5 = AuthenticationMD5Password <$> G.getByteString 4
auth op = fail $ "pgGetMessage: unsupported authentication type: " ++ show op
getMessageBody 't' = do
numParams <- G.getWord16be
ParameterDescription <$> replicateM (fromIntegral numParams) G.getWord32be
getMessageBody 'T' = do
numFields <- G.getWord16be
RowDescription <$> replicateM (fromIntegral numFields) getField where
getField = do
name <- getByteStringNul
oid <- G.getWord32be
col <- G.getWord16be
typ' <- G.getWord32be
siz <- G.getWord16be
tmod <- G.getWord32be
fmt <- G.getWord16be
return $ PGColDescription
{ pgColName = name
, pgColTable = oid
, pgColNumber = fromIntegral col
, pgColType = typ'
, pgColSize = fromIntegral siz
, pgColModifier = fromIntegral tmod
, pgColBinary = toEnum (fromIntegral fmt)
}
getMessageBody 'Z' = ReadyForQuery <$> (rs . w2c =<< G.getWord8) where
rs 'I' = return StateIdle
rs 'T' = return StateTransaction
rs 'E' = return StateTransactionFailed
rs s = fail $ "pgGetMessage: unknown ready state: " ++ show s
getMessageBody '1' = return ParseComplete
getMessageBody '2' = return BindComplete
getMessageBody '3' = return CloseComplete
getMessageBody 'C' = CommandComplete <$> getByteStringNul
getMessageBody 'S' = liftM2 ParameterStatus getByteStringNul getByteStringNul
getMessageBody 'D' = do
numFields <- G.getWord16be
DataRow <$> replicateM (fromIntegral numFields) (getField =<< G.getWord32be) where
getField 0xFFFFFFFF = return PGNullValue
getField len = PGTextValue <$> G.getByteString (fromIntegral len)
getMessageBody 'K' = liftM2 BackendKeyData G.getWord32be G.getWord32be
getMessageBody 'E' = ErrorResponse <$> getMessageFields
getMessageBody 'I' = return EmptyQueryResponse
getMessageBody 'n' = return NoData
getMessageBody 's' = return PortalSuspended
getMessageBody 'N' = NoticeResponse <$> getMessageFields
getMessageBody 'A' = NotificationResponse <$> do
PGNotification
<$> G.getWord32be
<*> getByteStringNul
<*> G.getLazyByteStringNul
getMessageBody t = fail $ "pgGetMessage: unknown message type: " ++ show t
getMessage :: G.Decoder PGBackendMessage
getMessage = G.runGetIncremental $ do
typ <- G.getWord8
len <- G.getWord32be
G.isolate (fromIntegral len 4) $ getMessageBody (w2c typ)
class Show m => RecvMsg m where
recvMsgData :: PGConnection -> IO (Either m BS.ByteString)
recvMsgData c = do
r <- BS.hGetSome (connHandle c) smallChunkSize
if BS.null r
then do
writeIORef (connState c) StateClosed
hClose (connHandle c)
ioError $ mkIOError eofErrorType "PGConnection" (Just (connHandle c)) Nothing
else
return (Right r)
recvMsgSync :: Maybe m
recvMsgSync = Nothing
recvMsgNotif :: PGConnection -> PGNotification -> IO (Maybe m)
recvMsgNotif c n = Nothing <$
modifyIORef' (connNotifications c) (enQueue n)
recvMsgErr :: PGConnection -> MessageFields -> IO (Maybe m)
recvMsgErr c m = Nothing <$
connLogMessage c m
recvMsg :: PGConnection -> PGBackendMessage -> IO (Maybe m)
recvMsg c m = Nothing <$
connLogMessage c (makeMessage (BSC.pack $ "Unexpected server message: " ++ show m) "Each statement should only contain a single query")
data RecvNonBlock = RecvNonBlock deriving (Show)
instance RecvMsg RecvNonBlock where
recvMsgData c = do
r <- BS.hGetNonBlocking (connHandle c) smallChunkSize
if BS.null r
then return (Left RecvNonBlock)
else return (Right r)
data RecvSync = RecvSync deriving (Show)
instance RecvMsg RecvSync where
recvMsgSync = Just RecvSync
instance RecvMsg PGNotification where
recvMsgNotif _ = return . Just
instance RecvMsg PGBackendMessage where
recvMsgErr _ = throwIO . PGError
recvMsg _ = return . Just
instance RecvMsg (Either PGBackendMessage RecvSync) where
recvMsgSync = Just $ Right RecvSync
recvMsgErr _ = throwIO . PGError
recvMsg _ = return . Just . Left
pgRecv :: RecvMsg m => PGConnection -> IO m
pgRecv c@PGConnection{ connInput = dr, connState = sr } =
rcv =<< readIORef dr where
next = writeIORef dr
new = G.pushChunk getMessage
rcv (G.Done b _ m) = do
when (connDebug c) $ putStrLn $ "< " ++ show m
got (new b) m
rcv (G.Fail _ _ r) = next (new BS.empty) >> fail r
rcv d@(G.Partial r) = recvMsgData c `onException` next d >>=
either (<$ next d) (rcv . r . Just)
msg (ParameterStatus k v) = Nothing <$
modifyIORef' (connParameters c) (Map.insert k v)
msg (NoticeResponse m) = Nothing <$
connLogMessage c m
msg (ErrorResponse m) =
recvMsgErr c m
msg m@(ReadyForQuery s) = do
s' <- atomicModifyIORef' sr (s, )
if s' == StatePending
then return recvMsgSync
else recvMsg c m
msg (NotificationResponse n) =
recvMsgNotif c n
msg m@AuthenticationOk = do
writeIORef sr StatePending
recvMsg c m
msg m = recvMsg c m
got d m = msg m `onException` next d >>=
maybe (rcv d) (<$ next d)
pgConnect :: PGDatabase -> IO PGConnection
pgConnect db = do
param <- newIORef Map.empty
state <- newIORef StateUnsync
prepc <- newIORef 0
prepm <- newIORef Map.empty
input <- newIORef getMessage
tr <- newIORef 0
notif <- newIORef emptyQueue
h <- connectTo (pgDBHost db) (pgDBPort db)
hSetBuffering h (BlockBuffering Nothing)
let c = PGConnection
{ connHandle = h
, connDatabase = db
, connPid = 0
, connKey = 0
, connParameters = param
, connPreparedStatementCount = prepc
, connPreparedStatementMap = prepm
, connState = state
, connTypeEnv = unknownPGTypeEnv
, connInput = input
, connTransaction = tr
, connNotifications = notif
}
pgSend c $ StartupMessage $
[ ("user", pgDBUser db)
, ("database", pgDBName db)
, ("client_encoding", "UTF8")
, ("standard_conforming_strings", "on")
, ("bytea_output", "hex")
, ("DateStyle", "ISO, YMD")
, ("IntervalStyle", "iso_8601")
] ++ pgDBParams db
pgFlush c
conn c
where
conn c = pgRecv c >>= msg c
msg c (Right RecvSync) = do
cp <- readIORef (connParameters c)
return c
{ connTypeEnv = PGTypeEnv
{ pgIntegerDatetimes = fmap ("on" ==) $ Map.lookup "integer_datetimes" cp
, pgServerVersion = Map.lookup "server_version" cp
}
}
msg c (Left (BackendKeyData p k)) = conn c{ connPid = p, connKey = k }
msg c (Left AuthenticationOk) = conn c
msg c (Left AuthenticationCleartextPassword) = do
pgSend c $ PasswordMessage $ pgDBPass db
pgFlush c
conn c
#ifdef VERSION_cryptonite
msg c (Left (AuthenticationMD5Password salt)) = do
pgSend c $ PasswordMessage $ "md5" `BS.append` md5 (md5 (pgDBPass db <> pgDBUser db) `BS.append` salt)
pgFlush c
conn c
#endif
msg _ (Left m) = fail $ "pgConnect: unexpected response: " ++ show m
pgDisconnect :: PGConnection
-> IO ()
pgDisconnect c@PGConnection{ connHandle = h } =
pgSend c Terminate `finally` hClose h
pgDisconnectOnce :: PGConnection
-> IO ()
pgDisconnectOnce c@PGConnection{ connState = cs } = do
s <- readIORef cs
unless (s == StateClosed) $
pgDisconnect c
pgReconnect :: PGConnection -> PGDatabase -> IO PGConnection
pgReconnect c@PGConnection{ connDatabase = cd, connState = cs } d = do
s <- readIORef cs
if cd == d && s /= StateClosed
then return c{ connDatabase = d }
else do
pgDisconnectOnce c
pgConnect d
pgSync :: PGConnection -> IO ()
pgSync c@PGConnection{ connState = sr } = do
s <- readIORef sr
case s of
StateClosed -> fail "pgSync: operation on closed connection"
StatePending -> wait
StateUnsync -> do
pgSend c Sync
pgFlush c
wait
_ -> return ()
where
wait = do
RecvSync <- pgRecv c
return ()
rowDescription :: PGBackendMessage -> PGRowDescription
rowDescription (RowDescription d) = d
rowDescription NoData = []
rowDescription m = error $ "describe: unexpected response: " ++ show m
pgDescribe :: PGConnection -> BSL.ByteString
-> [OID]
-> Bool
-> IO ([OID], [(BS.ByteString, OID, Bool)])
pgDescribe h sql types nulls = do
pgSync h
pgSend h Parse{ queryString = sql, statementName = BS.empty, parseTypes = types }
pgSend h DescribeStatement{ statementName = BS.empty }
pgSend h Sync
pgFlush h
ParseComplete <- pgRecv h
ParameterDescription ps <- pgRecv h
(,) ps <$> (mapM desc . rowDescription =<< pgRecv h)
where
desc (PGColDescription{ pgColName = name, pgColTable = tab, pgColNumber = col, pgColType = typ }) = do
n <- nullable tab col
return (name, typ, n)
nullable oid col
| nulls && oid /= 0 = do
(_, r) <- pgPreparedQuery h "SELECT attnotnull FROM pg_catalog.pg_attribute WHERE attrelid = $1 AND attnum = $2" [26, 21] [pgEncodeRep (oid :: OID), pgEncodeRep (col :: Int16)] []
case r of
[[s]] -> return $ not $ pgDecodeRep s
[] -> return True
_ -> fail $ "Failed to determine nullability of column #" ++ show col
| otherwise = return True
rowsAffected :: (Integral i, Read i) => BS.ByteString -> i
rowsAffected = ra . BSC.words where
ra [] = 1
ra l = fromMaybe (1) $ readMaybe $ BSC.unpack $ last l
fixBinary :: [Bool] -> PGValues -> PGValues
fixBinary (False:b) (PGBinaryValue x:r) = PGTextValue x : fixBinary b r
fixBinary (True :b) (PGTextValue x:r) = PGBinaryValue x : fixBinary b r
fixBinary (_:b) (x:r) = x : fixBinary b r
fixBinary _ l = l
pgSimpleQuery :: PGConnection -> BSL.ByteString
-> IO (Int, [PGValues])
pgSimpleQuery h sql = do
pgSync h
pgSend h $ SimpleQuery sql
pgFlush h
go start where
go = (pgRecv h >>=)
start (RowDescription rd) = go $ row (map pgColBinary rd) id
start (CommandComplete c) = got c []
start EmptyQueryResponse = return (0, [])
start m = fail $ "pgSimpleQuery: unexpected response: " ++ show m
row bc r (DataRow fs) = go $ row bc (r . (fixBinary bc fs :))
row _ r (CommandComplete c) = got c (r [])
row _ _ m = fail $ "pgSimpleQuery: unexpected row: " ++ show m
got c r = return (rowsAffected c, r)
pgSimpleQueries_ :: PGConnection -> BSL.ByteString
-> IO ()
pgSimpleQueries_ h sql = do
pgSync h
pgSend h $ SimpleQuery sql
pgFlush h
go where
go = pgRecv h >>= res
res (Left (RowDescription _)) = go
res (Left (CommandComplete _)) = go
res (Left EmptyQueryResponse) = go
res (Left (DataRow _)) = go
res (Right RecvSync) = return ()
res m = fail $ "pgSimpleQueries_: unexpected response: " ++ show m
pgPreparedBind :: PGConnection -> BS.ByteString -> [OID] -> PGValues -> [Bool] -> IO (IO ())
pgPreparedBind c sql types bind bc = do
pgSync c
m <- readIORef (connPreparedStatementMap c)
(p, n) <- maybe
(atomicModifyIORef' (connPreparedStatementCount c) (succ &&& (,) False . PGPreparedStatement))
(return . (,) True) $ Map.lookup key m
unless p $
pgSend c Parse{ queryString = BSL.fromStrict sql, statementName = preparedStatementName n, parseTypes = types }
pgSend c Bind{ portalName = BS.empty, statementName = preparedStatementName n, bindParameters = bind, binaryColumns = bc }
let
go = pgRecv c >>= start
start ParseComplete = do
modifyIORef' (connPreparedStatementMap c) $
Map.insert key n
go
start BindComplete = return ()
start r = fail $ "pgPrepared: unexpected response: " ++ show r
return go
where key = (sql, types)
pgPreparedQuery :: PGConnection -> BS.ByteString
-> [OID]
-> PGValues
-> [Bool]
-> IO (Int, [PGValues])
pgPreparedQuery c sql types bind bc = do
start <- pgPreparedBind c sql types bind bc
pgSend c Execute{ portalName = BS.empty, executeRows = 0 }
pgSend c Sync
pgFlush c
start
go id
where
go r = pgRecv c >>= row r
row r (DataRow fs) = go (r . (fixBinary bc fs :))
row r (CommandComplete d) = return (rowsAffected d, r [])
row r EmptyQueryResponse = return (0, r [])
row _ m = fail $ "pgPreparedQuery: unexpected row: " ++ show m
pgPreparedLazyQuery :: PGConnection -> BS.ByteString -> [OID] -> PGValues -> [Bool] -> Word32
-> IO [PGValues]
pgPreparedLazyQuery c sql types bind bc count = do
start <- pgPreparedBind c sql types bind bc
unsafeInterleaveIO $ do
execute
start
go id
where
execute = do
pgSend c Execute{ portalName = BS.empty, executeRows = count }
pgSend c Flush
pgFlush c
go r = pgRecv c >>= row r
row r (DataRow fs) = go (r . (fixBinary bc fs :))
row r PortalSuspended = r <$> unsafeInterleaveIO (execute >> go id)
row r (CommandComplete _) = return (r [])
row r EmptyQueryResponse = return (r [])
row _ m = fail $ "pgPreparedLazyQuery: unexpected row: " ++ show m
pgCloseStatement :: PGConnection -> BS.ByteString -> [OID] -> IO ()
pgCloseStatement c sql types = do
mn <- atomicModifyIORef (connPreparedStatementMap c) $
swap . Map.updateLookupWithKey (\_ _ -> Nothing) (sql, types)
Fold.mapM_ (pgClose c) mn
pgBegin :: PGConnection -> IO ()
pgBegin c@PGConnection{ connTransaction = tr } = do
t <- atomicModifyIORef' tr (succ &&& id)
void $ pgSimpleQuery c $ BSLC.pack $ if t == 0 then "BEGIN" else "SAVEPOINT pgt" ++ show t
predTransaction :: Word -> (Word, Word)
predTransaction 0 = (0, error "pgTransaction: no transactions")
predTransaction x = (x', x') where x' = pred x
pgRollback :: PGConnection -> IO ()
pgRollback c@PGConnection{ connTransaction = tr } = do
t <- atomicModifyIORef' tr predTransaction
void $ pgSimpleQuery c $ BSLC.pack $ if t == 0 then "ROLLBACK" else "ROLLBACK TO SAVEPOINT pgt" ++ show t
pgCommit :: PGConnection -> IO ()
pgCommit c@PGConnection{ connTransaction = tr } = do
t <- atomicModifyIORef' tr predTransaction
void $ pgSimpleQuery c $ BSLC.pack $ if t == 0 then "COMMIT" else "RELEASE SAVEPOINT pgt" ++ show t
pgRollbackAll :: PGConnection -> IO ()
pgRollbackAll c@PGConnection{ connTransaction = tr } = do
writeIORef tr 0
void $ pgSimpleQuery c $ BSLC.pack "ROLLBACK"
pgCommitAll :: PGConnection -> IO ()
pgCommitAll c@PGConnection{ connTransaction = tr } = do
writeIORef tr 0
void $ pgSimpleQuery c $ BSLC.pack "COMMIT"
pgTransaction :: PGConnection -> IO a -> IO a
pgTransaction c f = do
pgBegin c
onException (do
r <- f
pgCommit c
return r)
(pgRollback c)
pgRun :: PGConnection -> BSL.ByteString -> [OID] -> PGValues -> IO (Maybe Integer)
pgRun c sql types bind = do
pgSync c
pgSend c Parse{ queryString = sql, statementName = BS.empty, parseTypes = types }
pgSend c Bind{ portalName = BS.empty, statementName = BS.empty, bindParameters = bind, binaryColumns = [] }
pgSend c Execute{ portalName = BS.empty, executeRows = 1 }
pgSend c Sync
pgFlush c
go where
go = pgRecv c >>= res
res ParseComplete = go
res BindComplete = go
res (DataRow _) = go
res PortalSuspended = return Nothing
res (CommandComplete d) = return (Just $ rowsAffected d)
res EmptyQueryResponse = return (Just 0)
res m = fail $ "pgRun: unexpected response: " ++ show m
pgPrepare :: PGConnection -> BSL.ByteString -> [OID] -> IO PGPreparedStatement
pgPrepare c sql types = do
n <- atomicModifyIORef' (connPreparedStatementCount c) (succ &&& PGPreparedStatement)
pgSync c
pgSend c Parse{ queryString = sql, statementName = preparedStatementName n, parseTypes = types }
pgSend c Sync
pgFlush c
ParseComplete <- pgRecv c
return n
pgClose :: PGConnection -> PGPreparedStatement -> IO ()
pgClose c n = do
pgSync c
pgSend c ClosePortal{ portalName = preparedStatementName n }
pgSend c CloseStatement{ statementName = preparedStatementName n }
pgSend c Sync
pgFlush c
CloseComplete <- pgRecv c
CloseComplete <- pgRecv c
return ()
pgBind :: PGConnection -> PGPreparedStatement -> PGValues -> IO PGRowDescription
pgBind c n bind = do
pgSync c
pgSend c ClosePortal{ portalName = sn }
pgSend c Bind{ portalName = sn, statementName = sn, bindParameters = bind, binaryColumns = [] }
pgSend c DescribePortal{ portalName = sn }
pgSend c Sync
pgFlush c
CloseComplete <- pgRecv c
BindComplete <- pgRecv c
rowDescription <$> pgRecv c
where sn = preparedStatementName n
pgFetch :: PGConnection -> PGPreparedStatement -> Word32
-> IO ([PGValues], Maybe Integer)
pgFetch c n count = do
pgSync c
pgSend c Execute{ portalName = preparedStatementName n, executeRows = count }
pgSend c Sync
pgFlush c
go where
go = pgRecv c >>= res
res (DataRow v) = first (v :) <$> go
res PortalSuspended = return ([], Nothing)
res (CommandComplete d) = do
pgSync c
pgSend c ClosePortal{ portalName = preparedStatementName n }
pgSend c Sync
pgFlush c
CloseComplete <- pgRecv c
return ([], Just $ rowsAffected d)
res EmptyQueryResponse = return ([], Just 0)
res m = fail $ "pgFetch: unexpected response: " ++ show m
pgGetNotifications :: PGConnection -> IO [PGNotification]
pgGetNotifications c = do
RecvNonBlock <- pgRecv c
queueToList <$> atomicModifyIORef' (connNotifications c) (emptyQueue, )
pgGetNotification :: PGConnection -> IO PGNotification
pgGetNotification c =
maybe (pgRecv c) return
=<< atomicModifyIORef' (connNotifications c) deQueue