-- This Source Code Form is subject to the terms of the Mozilla Public -- License, v. 2.0. If a copy of the MPL was not distributed with this -- file, You can obtain one at http://mozilla.org/MPL/2.0/. {-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE TemplateHaskell #-} -- | A thin wrapper of the Network.Socket API. module Database.CQL.IO.Connection.Socket ( Socket , resolve , open , send , recv , close , shutdown -- Re-exports , HostName , PortNumber , ShutdownCmd (..) ) where import Control.Applicative import Control.Monad import Control.Monad.Catch import Data.ByteString (ByteString) import Data.ByteString.Builder import Data.Maybe (isJust) import Data.Monoid import Database.CQL.IO.Cluster.Host import Database.CQL.IO.Exception (ConnectionError (..)) import Database.CQL.IO.Timeouts (Milliseconds (..)) import Foreign.C.Types (CInt (..)) import Network.Socket (HostName, PortNumber, SockAddr (..), ShutdownCmd (..)) import Network.Socket (Family (..), AddrInfo (..), AddrInfoFlag (..)) import Network.Socket.ByteString.Lazy (sendAll) import OpenSSL.Session (SSL, SSLContext) import System.Timeout import Prelude import qualified Data.ByteString as Bytes import qualified Data.ByteString.Lazy as Lazy import qualified Network.Socket as S import qualified Network.Socket.ByteString as NB import qualified OpenSSL.Session as SSL data Socket = Stream !S.Socket | Tls !S.Socket !SSL instance Show Socket where show s = show $ case s of Stream x -> fd x Tls x _ -> fd x where fd x = let CInt n = S.fdSocket x in n resolve :: HostName -> PortNumber -> IO [InetAddr] resolve h p = do ais <- S.getAddrInfo (Just hints) (Just h) (Just (show p)) return $ map (InetAddr . addrAddress) ais where hints = S.defaultHints { addrFlags = [AI_ADDRCONFIG], addrSocketType = S.Stream } open :: Milliseconds -> InetAddr -> Maybe SSLContext -> IO Socket open to a ctx = do bracketOnError (mkSock a) S.close $ \s -> do ok <- timeout (ms to * 1000) (S.connect s (sockAddr a)) unless (isJust ok) $ throwM (ConnectTimeout a) case ctx of Nothing -> return (Stream s) Just set -> do c <- SSL.connection set s SSL.connect c return (Tls s c) mkSock :: InetAddr -> IO S.Socket mkSock (InetAddr a) = S.socket (familyOf a) S.Stream S.defaultProtocol where familyOf (SockAddrInet _ _) = AF_INET familyOf (SockAddrInet6 _ _ _ _) = AF_INET6 familyOf (SockAddrUnix _) = AF_UNIX #if MIN_VERSION_network(2,6,1) && !MIN_VERSION_network(3,0,0) familyOf (SockAddrCan _ ) = AF_CAN #endif close :: Socket -> IO () close (Stream s) = S.close s close (Tls s c) = SSL.shutdown c SSL.Bidirectional >> S.close s shutdown :: Socket -> ShutdownCmd -> IO () shutdown (Stream s) cmd = S.shutdown s cmd shutdown _ _ = return () recv :: Int -> InetAddr -> Socket -> Int -> IO Lazy.ByteString recv x a (Stream s) n = receive x a (NB.recv s) n recv x a (Tls _ c) n = receive x a (SSL.read c) n receive :: Int -> InetAddr -> (Int -> IO ByteString) -> Int -> IO Lazy.ByteString receive _ _ _ 0 = return Lazy.empty receive x i f n = toLazyByteString <$> go n mempty where go !k !bb = do a <- f (k `min` x) when (Bytes.null a) $ throwM (ConnectionClosed i) let b = bb <> byteString a let m = k - Bytes.length a if m > 0 then go m b else return b send :: Socket -> Lazy.ByteString -> IO () send (Stream s) b = sendAll s b send (Tls _ c) b = mapM_ (SSL.write c) (Lazy.toChunks b)