{-# LANGUAGE CPP #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE PackageImports #-} -- for doctest module Database.PostgreSQL.Pure.Internal.Connection ( connect , disconnect , withConnection ) where import qualified Database.PostgreSQL.Pure.Internal.Builder as Builder import Database.PostgreSQL.Pure.Internal.Data (Address (AddressNotResolved, AddressResolved), AuthenticationMD5Password (AuthenticationMD5Password), AuthenticationResponse (AuthenticationCleartextPasswordResponse, AuthenticationMD5PasswordResponse, AuthenticationOkResponse), BackendKey, BackendKeyData (BackendKeyData), BackendParameters, Buffer (Buffer), Config (Config, address, database, password, receptionBufferSize, sendingBufferSize, user), Connection (Connection, config, receptionBuffer, sendingBuffer, socket), ParameterStatus (ParameterStatus), Pid, ReadyForQuery (ReadyForQuery), Salt, TransactionState (Idle)) import qualified Database.PostgreSQL.Pure.Internal.Exception as Exception import qualified Database.PostgreSQL.Pure.Internal.Parser as Parser import Database.PostgreSQL.Pure.Internal.SocketIO (SocketIO, buildAndSend, receive, runSocketIO, send) import Control.Exception.Safe (assert, bracket) import Control.Monad (void) import Control.Monad.Reader (ask) import qualified Data.Attoparsec.ByteString as AP import qualified Data.ByteString as BS import qualified Data.ByteString.Base16 as B16 import qualified Data.ByteString.Internal as BSI import qualified Data.ByteString.UTF8 as BSU import qualified Data.Map.Strict as Map import qualified Network.Socket as NS #ifdef PURE_MD5 import qualified Data.Digest.Pure.MD5 as MD5 #else import qualified "cryptohash-md5" Crypto.Hash.MD5 as MD5 #endif -- | Bracket function for a connection. withConnection :: Config -> (Connection -> IO a) -> IO a withConnection config@Config { address } f = Exception.convert $ do addr <- case address of AddressResolved a -> pure $ addrInfo a AddressNotResolved h s -> resolve h s bracket (open addr) NS.close $ \sock -> do conn <- connect' sock config f conn -- | To connect to the server. connect :: Config -> IO Connection connect config@Config { address } = Exception.convert $ do addr <- case address of AddressResolved a -> pure $ addrInfo a AddressNotResolved h s -> resolve h s sock <- open addr connect' sock config connect' :: NS.Socket -> Config -> IO Connection connect' sock config@Config { sendingBufferSize, receptionBufferSize } = do sBuff <- flip Buffer sendingBufferSize <$> BSI.mallocByteString sendingBufferSize rBuff <- flip Buffer receptionBufferSize <$> BSI.mallocByteString receptionBufferSize runSocketIO sock sBuff rBuff config initializeConnection -- | To disconnect to the server. disconnect :: Connection -> IO () disconnect Connection { socket, sendingBuffer, receptionBuffer, config } = Exception.convert $ do runSocketIO socket sendingBuffer receptionBuffer config terminate NS.close socket addrInfoHints :: NS.AddrInfo addrInfoHints = NS.defaultHints { NS.addrSocketType = NS.Stream , NS.addrProtocol = 6 -- TCP , NS.addrFlags = [NS.AI_ADDRCONFIG] } addrInfo :: NS.SockAddr -> NS.AddrInfo addrInfo address = addrInfoHints { NS.addrAddress = address , NS.addrFamily = case address of NS.SockAddrInet {} -> NS.AF_INET NS.SockAddrInet6 {} -> NS.AF_INET6 NS.SockAddrUnix {} -> NS.AF_UNIX #if !MIN_VERSION_network(3,0,0) _ -> NS.AF_UNSPEC #endif } resolve :: NS.HostName -> NS.ServiceName -> IO NS.AddrInfo resolve host service = do addrs <- NS.getAddrInfo (Just addrInfoHints) (Just host) (Just service) case addrs of addr:_ -> return addr [] -> Exception.cantReachHere open :: NS.AddrInfo -> IO NS.Socket open addr = do sock <- NS.socket (NS.addrFamily addr) (NS.addrSocketType addr) (NS.addrProtocol addr) NS.connect sock $ NS.addrAddress addr return sock initializeConnection :: SocketIO Connection initializeConnection = do response <- startup (bps, pid, bk) <- authenticate response (sock, sBuff, rBuff, config) <- ask pure $ Connection sock pid bk bps sBuff rBuff config startup :: SocketIO AuthenticationResponse startup = do (_, _, _, Config { user, database }) <- ask buildAndSend $ Builder.startup user database receive Parser.authentication authenticate :: AuthenticationResponse -> SocketIO (BackendParameters, Pid, BackendKey) authenticate response = do (_, _, _, Config { user, password }) <- ask case response of AuthenticationOkResponse -> pure () AuthenticationCleartextPasswordResponse -> auth $ BSU.fromString password AuthenticationMD5PasswordResponse (AuthenticationMD5Password salt) -> auth $ hashMD5 user password salt (bps, pid, bk) <- receive $ do bps <- Map.fromList . ((\(ParameterStatus k v) -> (k, v)) <$>) <$> AP.many' Parser.parameterStatus BackendKeyData pid bk <- Parser.backendKeyData ReadyForQuery ts <- Parser.readyForQuery assert (ts == Idle) $ pure (bps, pid, bk) pure (bps, pid, bk) where auth pw = do buildAndSend $ Builder.password pw void $ receive Parser.authenticationOk terminate :: SocketIO () terminate = send Builder.terminate hashMD5 :: String -> String -> Salt -> BS.ByteString hashMD5 user password salt = let user' = BSU.fromString user password' = BSU.fromString password #ifdef PURE_MD5 hash = B16.encode . MD5.md5DigestBytes . MD5.hash' #else hash = B16.encode . MD5.hash #endif in "md5" <> hash (hash (password' <> user') <> salt)