module Database.PostgreSQL.Typed.Protocol (
PGDatabase(..)
, defaultPGDatabase
, PGConnection
, PGError(..)
, pgMessageCode
, pgTypeEnv
, pgConnect
, pgDisconnect
, pgReconnect
, pgDescribe
, pgSimpleQuery
, pgPreparedQuery
, pgPreparedLazyQuery
, pgCloseStatement
) where
import Control.Applicative ((<$>), (<$))
import Control.Arrow (second)
import Control.Exception (Exception, throwIO)
import Control.Monad (liftM2, replicateM, when, unless)
#ifdef USE_MD5
import qualified Crypto.Hash as Hash
#endif
import qualified Data.Binary.Get as G
import qualified Data.ByteString as BS
import qualified Data.ByteString.Builder as B
import qualified Data.ByteString.Char8 as BSC
import Data.ByteString.Internal (w2c)
import qualified Data.ByteString.Lazy as BSL
import Data.ByteString.Lazy.Internal (smallChunkSize)
import qualified Data.ByteString.Lazy.UTF8 as BSLU
import qualified Data.ByteString.UTF8 as BSU
import qualified Data.Foldable as Fold
import Data.IORef (IORef, newIORef, writeIORef, readIORef, atomicModifyIORef, atomicModifyIORef', modifyIORef)
import qualified Data.Map.Lazy as Map
import Data.Maybe (fromMaybe)
import Data.Monoid (mempty, (<>))
import qualified Data.Sequence as Seq
import Data.Typeable (Typeable)
import Data.Word (Word32)
import Network (HostName, PortID(..), connectTo)
import System.IO (Handle, hFlush, hClose, stderr, hPutStrLn)
import System.IO.Unsafe (unsafeInterleaveIO)
import Text.Read (readMaybe)
import Database.PostgreSQL.Typed.Types
data PGState
= StateUnknown
| StateIdle
| StateTransaction
| StateTransactionFailed
| StateClosed
deriving (Show, Eq)
data PGDatabase = PGDatabase
{ pgDBHost :: HostName
, pgDBPort :: PortID
, pgDBName :: String
, pgDBUser, pgDBPass :: String
, pgDBDebug :: Bool
, pgDBLogMessage :: MessageFields -> IO ()
}
instance Eq PGDatabase where
PGDatabase h1 s1 n1 u1 p1 _ _ == PGDatabase h2 s2 n2 u2 p2 _ _ =
h1 == h2 && s1 == s2 && n1 == n2 && u1 == u2 && p1 == p2
data PGConnection = PGConnection
{ connHandle :: Handle
, connDatabase :: !PGDatabase
, connPid :: !Word32
, connKey :: !Word32
, connParameters :: Map.Map String String
, connTypeEnv :: PGTypeEnv
, connPreparedStatements :: IORef (Integer, Map.Map (String, [OID]) Integer)
, connState :: IORef PGState
, connInput :: IORef (G.Decoder PGBackendMessage)
}
data ColDescription = ColDescription
{ colName :: String
, colTable :: !OID
, colNumber :: !Int
, colType :: !OID
, colModifier :: !Word32
, colBinary :: !Bool
} deriving (Show)
type MessageFields = Map.Map Char String
data PGFrontendMessage
= StartupMessage [(String, String)]
| CancelRequest !Word32 !Word32
| Bind { statementName :: String, bindParameters :: PGValues, binaryColumns :: [Bool] }
| Close { statementName :: String }
| Describe { statementName :: String }
| Execute !Word32
| Flush
| Parse { statementName :: String, queryString :: String, parseTypes :: [OID] }
| PasswordMessage BS.ByteString
| SimpleQuery { queryString :: String }
| Sync
| Terminate
deriving (Show)
data PGBackendMessage
= AuthenticationOk
| AuthenticationCleartextPassword
| AuthenticationMD5Password BS.ByteString
| BackendKeyData Word32 Word32
| BindComplete
| CloseComplete
| CommandComplete BS.ByteString
| DataRow PGValues
| EmptyQueryResponse
| ErrorResponse { messageFields :: MessageFields }
| NoData
| NoticeResponse { messageFields :: MessageFields }
| ParameterDescription [OID]
| ParameterStatus String String
| ParseComplete
| PortalSuspended
| ReadyForQuery PGState
| RowDescription [ColDescription]
deriving (Show)
data PGError = PGError MessageFields
deriving (Typeable)
instance Show PGError where
show (PGError m) = displayMessage m
instance Exception PGError
displayMessage :: MessageFields -> String
displayMessage m = "PG" ++ f 'S' ++ " [" ++ f 'C' ++ "]: " ++ f 'M' ++ '\n' : f 'D'
where f c = Map.findWithDefault "" c m
makeMessage :: String -> String -> MessageFields
makeMessage m d = Map.fromAscList [('D', d), ('M', m)]
pgMessageCode :: MessageFields -> String
pgMessageCode = Map.findWithDefault "" 'C'
defaultLogMessage :: MessageFields -> IO ()
defaultLogMessage = hPutStrLn stderr . displayMessage
defaultPGDatabase :: PGDatabase
defaultPGDatabase = PGDatabase "localhost" (PortNumber 5432) "postgres" "postgres" "" False defaultLogMessage
connDebug :: PGConnection -> Bool
connDebug = pgDBDebug . connDatabase
connLogMessage :: PGConnection -> MessageFields -> IO ()
connLogMessage = pgDBLogMessage . connDatabase
pgTypeEnv :: PGConnection -> PGTypeEnv
pgTypeEnv = connTypeEnv
#ifdef USE_MD5
md5 :: BS.ByteString -> BS.ByteString
md5 = Hash.digestToHexByteString . (Hash.hash :: BS.ByteString -> Hash.Digest Hash.MD5)
#endif
nul :: B.Builder
nul = B.word8 0
pgString :: String -> B.Builder
pgString s = B.stringUtf8 s <> nul
messageBody :: PGFrontendMessage -> (Maybe Char, B.Builder)
messageBody (StartupMessage kv) = (Nothing, B.word32BE 0x30000
<> Fold.foldMap (\(k, v) -> pgString k <> pgString v) kv <> nul)
messageBody (CancelRequest pid key) = (Nothing, B.word32BE 80877102
<> B.word32BE pid <> B.word32BE key)
messageBody Bind{ statementName = n, bindParameters = p, binaryColumns = bc } = (Just 'B',
nul <> pgString n
<> (if any fmt p
then B.word16BE (fromIntegral $ length p) <> Fold.foldMap (B.word16BE . fromIntegral . fromEnum . fmt) p
else B.word16BE 0)
<> B.word16BE (fromIntegral $ length p) <> Fold.foldMap val p
<> (if or bc
then B.word16BE (fromIntegral $ length bc) <> Fold.foldMap (B.word16BE . fromIntegral . fromEnum) bc
else B.word16BE 0))
where
fmt (PGBinaryValue _) = True
fmt _ = False
val PGNullValue = B.int32BE (1)
val (PGTextValue v) = B.word32BE (fromIntegral $ BS.length v) <> B.byteString v
val (PGBinaryValue v) = B.word32BE (fromIntegral $ BS.length v) <> B.byteString v
messageBody Close{ statementName = n } = (Just 'C',
B.char7 'S' <> pgString n)
messageBody Describe{ statementName = n } = (Just 'D',
B.char7 'S' <> pgString n)
messageBody (Execute r) = (Just 'E',
nul <> B.word32BE r)
messageBody Flush = (Just 'H', mempty)
messageBody Parse{ statementName = n, queryString = s, parseTypes = t } = (Just 'P',
pgString n <> pgString s
<> B.word16BE (fromIntegral $ length t) <> Fold.foldMap B.word32BE t)
messageBody (PasswordMessage s) = (Just 'p',
B.byteString s <> nul)
messageBody SimpleQuery{ queryString = s } = (Just 'Q',
pgString s)
messageBody Sync = (Just 'S', mempty)
messageBody Terminate = (Just 'X', mempty)
pgSend :: PGConnection -> PGFrontendMessage -> IO ()
pgSend c@PGConnection{ connHandle = h, connState = sr } msg = do
writeIORef sr StateUnknown
when (connDebug c) $ putStrLn $ "> " ++ show msg
B.hPutBuilder h $ Fold.foldMap B.char7 t <> B.word32BE (fromIntegral $ 4 + BS.length b)
BS.hPut h b
where (t, b) = second (BSL.toStrict . B.toLazyByteString) $ messageBody msg
pgFlush :: PGConnection -> IO ()
pgFlush = hFlush . connHandle
getPGString :: G.Get String
getPGString = BSLU.toString <$> G.getLazyByteStringNul
getByteStringNul :: G.Get BS.ByteString
getByteStringNul = fmap BSL.toStrict G.getLazyByteStringNul
getMessageFields :: G.Get MessageFields
getMessageFields = g . w2c =<< G.getWord8 where
g '\0' = return Map.empty
g f = liftM2 (Map.insert f . BSU.toString) getByteStringNul getMessageFields
getMessageBody :: Char -> G.Get PGBackendMessage
getMessageBody 'R' = auth =<< G.getWord32be where
auth 0 = return AuthenticationOk
auth 3 = return AuthenticationCleartextPassword
auth 5 = AuthenticationMD5Password <$> G.getByteString 4
auth op = fail $ "pgGetMessage: unsupported authentication type: " ++ show op
getMessageBody 't' = do
numParams <- G.getWord16be
ParameterDescription <$> replicateM (fromIntegral numParams) G.getWord32be
getMessageBody 'T' = do
numFields <- G.getWord16be
RowDescription <$> replicateM (fromIntegral numFields) getField where
getField = do
name <- getPGString
oid <- G.getWord32be
col <- G.getWord16be
typ' <- G.getWord32be
_ <- G.getWord16be
tmod <- G.getWord32be
fmt <- G.getWord16be
return $ ColDescription
{ colName = name
, colTable = oid
, colNumber = fromIntegral col
, colType = typ'
, colModifier = tmod
, colBinary = toEnum (fromIntegral fmt)
}
getMessageBody 'Z' = ReadyForQuery <$> (rs . w2c =<< G.getWord8) where
rs 'I' = return StateIdle
rs 'T' = return StateTransaction
rs 'E' = return StateTransactionFailed
rs s = fail $ "pgGetMessage: unknown ready state: " ++ show s
getMessageBody '1' = return ParseComplete
getMessageBody '2' = return BindComplete
getMessageBody '3' = return CloseComplete
getMessageBody 'C' = CommandComplete <$> getByteStringNul
getMessageBody 'S' = liftM2 ParameterStatus getPGString getPGString
getMessageBody 'D' = do
numFields <- G.getWord16be
DataRow <$> replicateM (fromIntegral numFields) (getField =<< G.getWord32be) where
getField 0xFFFFFFFF = return PGNullValue
getField len = PGTextValue <$> G.getByteString (fromIntegral len)
getMessageBody 'K' = liftM2 BackendKeyData G.getWord32be G.getWord32be
getMessageBody 'E' = ErrorResponse <$> getMessageFields
getMessageBody 'I' = return EmptyQueryResponse
getMessageBody 'n' = return NoData
getMessageBody 's' = return PortalSuspended
getMessageBody 'N' = NoticeResponse <$> getMessageFields
getMessageBody t = fail $ "pgGetMessage: unknown message type: " ++ show t
getMessage :: G.Decoder PGBackendMessage
getMessage = G.runGetIncremental $ do
typ <- G.getWord8
s <- G.bytesRead
len <- G.getWord32be
msg <- getMessageBody (w2c typ)
e <- G.bytesRead
let r = fromIntegral len fromIntegral (e s)
when (r > 0) $ G.skip r
when (r < 0) $ fail "pgReceive: decoder overran message"
return msg
pgRecv :: Bool -> PGConnection -> IO (Maybe PGBackendMessage)
pgRecv block c@PGConnection{ connHandle = h, connInput = dr } =
go =<< readIORef dr where
next = writeIORef dr
state s d = writeIORef (connState c) s >> next d
new = G.pushChunk getMessage
go (G.Done b _ m) = do
when (connDebug c) $ putStrLn $ "< " ++ show m
got (new b) m
go (G.Fail _ _ r) = next (new BS.empty) >> fail r
go d@(G.Partial r) = do
b <- (if block then BS.hGetSome else BS.hGetNonBlocking) h smallChunkSize
if BS.null b
then Nothing <$ next d
else go $ r (Just b)
got :: G.Decoder PGBackendMessage -> PGBackendMessage -> IO (Maybe PGBackendMessage)
got d (NoticeResponse m) = connLogMessage c m >> go d
got d m@(ReadyForQuery s) = Just m <$ state s d
got d m@(ErrorResponse _) = Just m <$ state StateUnknown d
got d m = Just m <$ next d
pgReceive :: PGConnection -> IO PGBackendMessage
pgReceive c = do
r <- pgRecv True c
case r of
Nothing -> do
writeIORef (connState c) StateClosed
fail $ "pgReceive: connection closed"
Just ErrorResponse{ messageFields = m } -> throwIO (PGError m)
Just m -> return m
pgConnect :: PGDatabase -> IO PGConnection
pgConnect db = do
state <- newIORef StateUnknown
prep <- newIORef (0, Map.empty)
input <- newIORef getMessage
h <- connectTo (pgDBHost db) (pgDBPort db)
let c = PGConnection
{ connHandle = h
, connDatabase = db
, connPid = 0
, connKey = 0
, connParameters = Map.empty
, connPreparedStatements = prep
, connState = state
, connTypeEnv = undefined
, connInput = input
}
pgSend c $ StartupMessage
[ ("user", pgDBUser db)
, ("database", pgDBName db)
, ("client_encoding", "UTF8")
, ("standard_conforming_strings", "on")
, ("bytea_output", "hex")
, ("DateStyle", "ISO, YMD")
, ("IntervalStyle", "iso_8601")
]
pgFlush c
conn c
where
conn c = pgReceive c >>= msg c
msg c (ReadyForQuery _) = return c
{ connTypeEnv = PGTypeEnv
{ pgIntegerDatetimes = (connParameters c Map.! "integer_datetimes") == "on"
}
}
msg c (BackendKeyData p k) = conn c{ connPid = p, connKey = k }
msg c (ParameterStatus k v) = conn c{ connParameters = Map.insert k v $ connParameters c }
msg c AuthenticationOk = conn c
msg c AuthenticationCleartextPassword = do
pgSend c $ PasswordMessage $ BSU.fromString $ pgDBPass db
pgFlush c
conn c
#ifdef USE_MD5
msg c (AuthenticationMD5Password salt) = do
pgSend c $ PasswordMessage $ BSC.pack "md5" `BS.append` md5 (md5 (BSU.fromString (pgDBPass db ++ pgDBUser db)) `BS.append` salt)
pgFlush c
conn c
#endif
msg _ m = fail $ "pgConnect: unexpected response: " ++ show m
pgDisconnect :: PGConnection
-> IO ()
pgDisconnect c@PGConnection{ connHandle = h, connState = s } = do
pgSend c Terminate
writeIORef s StateClosed
hClose h
pgReconnect :: PGConnection -> PGDatabase -> IO PGConnection
pgReconnect c@PGConnection{ connDatabase = cd, connState = cs } d = do
s <- readIORef cs
if cd == d && s /= StateClosed
then return c{ connDatabase = d }
else do
when (s /= StateClosed) $ pgDisconnect c
pgConnect d
pgSync :: PGConnection -> IO ()
pgSync c@PGConnection{ connState = sr } = do
s <- readIORef sr
when (s == StateClosed) $ fail "pgSync: operation on closed connection"
when (s == StateUnknown) $ wait False where
wait s = do
r <- pgRecv s c
case r of
Nothing -> do
pgSend c Sync
pgFlush c
wait True
(Just (ErrorResponse{ messageFields = m })) -> do
connLogMessage c m
wait s
(Just (ReadyForQuery _)) -> return ()
(Just m) -> do
connLogMessage c $ makeMessage ("Unexpected server message: " ++ show m) "Each statement should only contain a single query"
wait s
pgDescribe :: PGConnection -> String
-> [OID]
-> Bool
-> IO ([OID], [(String, OID, Bool)])
pgDescribe h sql types nulls = do
pgSync h
pgSend h $ Parse{ queryString = sql, statementName = "", parseTypes = types }
pgSend h $ Describe ""
pgSend h Flush
pgSend h Sync
pgFlush h
ParseComplete <- pgReceive h
ParameterDescription ps <- pgReceive h
m <- pgReceive h
(,) ps <$> case m of
NoData -> return []
RowDescription r -> mapM desc r
_ -> fail $ "describeStatement: unexpected response: " ++ show m
where
desc (ColDescription{ colName = name, colTable = tab, colNumber = col, colType = typ }) = do
n <- nullable tab col
return (name, typ, n)
nullable oid col
| nulls && oid /= 0 = do
(_, r) <- pgSimpleQuery h ("SELECT attnotnull FROM pg_catalog.pg_attribute WHERE attrelid = " ++ show oid ++ " AND attnum = " ++ show col)
case Fold.toList r of
[[PGTextValue s]] -> return $ not $ pgDecode (PGTypeProxy :: PGTypeName "boolean") s
[] -> return True
_ -> fail $ "Failed to determine nullability of column #" ++ show col
| otherwise = return True
rowsAffected :: BS.ByteString -> Int
rowsAffected = ra . BSC.words where
ra [] = 1
ra l = fromMaybe (1) $ readMaybe $ BSC.unpack $ last l
fixBinary :: [Bool] -> PGValues -> PGValues
fixBinary (False:b) (PGBinaryValue x:r) = PGTextValue x : fixBinary b r
fixBinary (True :b) (PGTextValue x:r) = PGBinaryValue x : fixBinary b r
fixBinary (_:b) (x:r) = x : fixBinary b r
fixBinary _ l = l
pgSimpleQuery :: PGConnection -> String
-> IO (Int, Seq.Seq PGValues)
pgSimpleQuery h sql = do
pgSync h
pgSend h $ SimpleQuery sql
pgFlush h
go start where
go = (pgReceive h >>=)
start (RowDescription rd) = go $ row (map colBinary rd) Seq.empty
start (CommandComplete c) = got c Seq.empty
start EmptyQueryResponse = return (0, Seq.empty)
start m = fail $ "pgSimpleQuery: unexpected response: " ++ show m
row bc s (DataRow fs) = go $ row bc (s Seq.|> fixBinary bc fs)
row _ s (CommandComplete c) = got c s
row _ _ m = fail $ "pgSimpleQuery: unexpected row: " ++ show m
got c s = return (rowsAffected c, s)
pgPreparedBind :: PGConnection -> String -> [OID] -> PGValues -> [Bool] -> IO (IO ())
pgPreparedBind c@PGConnection{ connPreparedStatements = psr } sql types bind bc = do
pgSync c
(p, n) <- atomicModifyIORef' psr $ \(i, m) ->
maybe ((succ i, m), (False, i)) ((,) (i, m) . (,) True) $ Map.lookup key m
let sn = show n
unless p $
pgSend c $ Parse{ queryString = sql, statementName = sn, parseTypes = types }
pgSend c $ Bind{ statementName = sn, bindParameters = bind, binaryColumns = bc }
let
go = pgReceive c >>= start
start ParseComplete = do
modifyIORef psr $ \(i, m) ->
(i, Map.insert key n m)
go
start BindComplete = return ()
start m = fail $ "pgPrepared: unexpected response: " ++ show m
return go
where key = (sql, types)
pgPreparedQuery :: PGConnection -> String
-> [OID]
-> PGValues
-> [Bool]
-> IO (Int, Seq.Seq PGValues)
pgPreparedQuery c sql types bind bc = do
start <- pgPreparedBind c sql types bind bc
pgSend c $ Execute 0
pgSend c Flush
pgSend c Sync
pgFlush c
start
go Seq.empty
where
go = (pgReceive c >>=) . row
row s (DataRow fs) = go (s Seq.|> fixBinary bc fs)
row s (CommandComplete r) = return (rowsAffected r, s)
row s EmptyQueryResponse = return (0, s)
row _ m = fail $ "pgPreparedQuery: unexpected row: " ++ show m
pgPreparedLazyQuery :: PGConnection -> String -> [OID] -> PGValues -> [Bool] -> Word32
-> IO [PGValues]
pgPreparedLazyQuery c sql types bind bc count = do
start <- pgPreparedBind c sql types bind bc
unsafeInterleaveIO $ do
execute
start
go Seq.empty
where
execute = do
pgSend c $ Execute count
pgSend c $ Flush
pgFlush c
go = (pgReceive c >>=) . row
row s (DataRow fs) = go (s Seq.|> fixBinary bc fs)
row s PortalSuspended = (Fold.toList s ++) <$> unsafeInterleaveIO (execute >> go Seq.empty)
row s (CommandComplete _) = return $ Fold.toList s
row s EmptyQueryResponse = return $ Fold.toList s
row _ m = fail $ "pgPreparedLazyQuery: unexpected row: " ++ show m
pgCloseStatement :: PGConnection -> String -> [OID] -> IO ()
pgCloseStatement c@PGConnection{ connPreparedStatements = psr } sql types = do
mn <- atomicModifyIORef psr $ \(i, m) ->
let (n, m') = Map.updateLookupWithKey (\_ _ -> Nothing) (sql, types) m in ((i, m'), n)
Fold.forM_ mn $ \n -> do
pgSync c
pgSend c $ Close{ statementName = show n }
pgFlush c
CloseComplete <- pgReceive c
return ()