{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Transport.Tls
( connect
, connectWithTlsParams
)
where
import Data.IORef
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def)
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)
import Database.MongoDB.Internal.Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
import Database.MongoDB.Query (access, slaveOk, retrieveServerData)
connect :: HostName -> PortID -> IO Pipe
connect :: HostName -> PortID -> IO Pipe
connect HostName
host PortID
port = ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
params HostName
host PortID
port
where
params :: ClientParams
params = (HostName -> ByteString -> ClientParams
TLS.defaultParamsClient HostName
host ByteString
"")
{ clientSupported :: Supported
TLS.clientSupported = forall a. Default a => a
def
{ supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_default }
, clientHooks :: ClientHooks
TLS.clientHooks = forall a. Default a => a
def
{ onServerCertificate :: OnServerCertificate
TLS.onServerCertificate = \CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return [] }
}
connectWithTlsParams :: TLS.ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams :: ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
clientParams HostName
host PortID
port = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (HostName -> PortID -> IO Handle
connectTo HostName
host PortID
port) Handle -> IO ()
hClose forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
Context
context <- forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
handle ClientParams
clientParams
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
context
Transport
conn <- Context -> IO Transport
tlsConnection Context
context
rec
Pipe
p <- ServerData -> Transport -> IO Pipe
newPipeWith ServerData
sd Transport
conn
ServerData
sd <- forall (m :: * -> *) a.
MonadIO m =>
Pipe -> AccessMode -> Database -> Action m a -> m a
access Pipe
p AccessMode
slaveOk Database
"admin" forall (m :: * -> *). MonadIO m => Action m ServerData
retrieveServerData
forall (m :: * -> *) a. Monad m => a -> m a
return Pipe
p
tlsConnection :: TLS.Context -> IO Transport
tlsConnection :: Context -> IO Transport
tlsConnection Context
ctx = do
IORef ByteString
restRef <- forall a. a -> IO (IORef a)
newIORef forall a. Monoid a => a
mempty
forall (m :: * -> *) a. Monad m => a -> m a
return Transport
{ read :: Int -> IO ByteString
T.read = \Int
count -> let
readSome :: IO ByteString
readSome = do
ByteString
rest <- forall a. IORef a -> IO a
readIORef IORef ByteString
restRef
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
restRef forall a. Monoid a => a
mempty
if ByteString -> Bool
ByteString.null ByteString
rest
then forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
else forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
unread :: ByteString -> IO ()
unread = \ByteString
rest ->
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef ByteString
restRef (ByteString
rest forall a. Semigroup a => a -> a -> a
<>)
go :: ByteString -> Int -> IO ByteString
go ByteString
acc Int
n = do
ByteString
chunk <- IO ByteString
readSome
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
ByteString.null ByteString
chunk) forall a b. (a -> b) -> a -> b
$
forall a. IOError -> IO a
ioError IOError
eof
let len :: Int
len = ByteString -> Int
ByteString.length ByteString
chunk
if Int
len forall a. Ord a => a -> a -> Bool
>= Int
n
then do
let (ByteString
res, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
ByteString.splitAt Int
n ByteString
chunk
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
ByteString.null ByteString
rest) forall a b. (a -> b) -> a -> b
$
ByteString -> IO ()
unread ByteString
rest
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
acc forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
res)
else ByteString -> Int -> IO ByteString
go (ByteString
acc forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
chunk) (Int
n forall a. Num a => a -> a -> a
- Int
len)
eof :: IOError
eof = IOErrorType
-> HostName -> Maybe Handle -> Maybe HostName -> IOError
mkIOError IOErrorType
eofErrorType HostName
"Database.MongoDB.Transport"
forall a. Maybe a
Nothing forall a. Maybe a
Nothing
in ByteString -> ByteString
Lazy.ByteString.toStrict forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Int -> IO ByteString
go forall a. Monoid a => a
mempty Int
count
, write :: ByteString -> IO ()
T.write = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Lazy.ByteString.fromStrict
, flush :: IO ()
T.flush = Context -> IO ()
TLS.contextFlush Context
ctx
, close :: IO ()
T.close = Context -> IO ()
TLS.contextClose Context
ctx
}