{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Transport.Tls
(connect)
where
import Data.IORef
import Data.Monoid
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def)
import Control.Applicative ((<$>))
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)
import Database.MongoDB.Internal.Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
import Database.MongoDB.Query (access, slaveOk, retrieveServerData)
connect :: HostName -> PortID -> IO Pipe
connect :: HostName -> PortID -> IO Pipe
connect HostName
host PortID
port = IO Handle -> (Handle -> IO ()) -> (Handle -> IO Pipe) -> IO Pipe
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (HostName -> PortID -> IO Handle
connectTo HostName
host PortID
port) Handle -> IO ()
hClose ((Handle -> IO Pipe) -> IO Pipe) -> (Handle -> IO Pipe) -> IO Pipe
forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
let params :: ClientParams
params = (HostName -> ByteString -> ClientParams
TLS.defaultParamsClient HostName
host ByteString
"")
{ clientSupported :: Supported
TLS.clientSupported = Supported
forall a. Default a => a
def
{ supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_default}
, clientHooks :: ClientHooks
TLS.clientHooks = ClientHooks
forall a. Default a => a
def
{ onServerCertificate :: OnServerCertificate
TLS.onServerCertificate = \CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> [FailedReason] -> IO [FailedReason]
forall (m :: * -> *) a. Monad m => a -> m a
return []}
}
Context
context <- Handle -> ClientParams -> IO Context
forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
handle ClientParams
params
Context -> IO ()
forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
context
Transport
conn <- Context -> IO Transport
tlsConnection Context
context
rec
Pipe
p <- ServerData -> Transport -> IO Pipe
newPipeWith ServerData
sd Transport
conn
ServerData
sd <- Pipe
-> AccessMode -> Database -> Action IO ServerData -> IO ServerData
forall (m :: * -> *) a.
MonadIO m =>
Pipe -> AccessMode -> Database -> Action m a -> m a
access Pipe
p AccessMode
slaveOk Database
"admin" Action IO ServerData
forall (m :: * -> *). MonadIO m => Action m ServerData
retrieveServerData
Pipe -> IO Pipe
forall (m :: * -> *) a. Monad m => a -> m a
return Pipe
p
tlsConnection :: TLS.Context -> IO Transport
tlsConnection :: Context -> IO Transport
tlsConnection Context
ctx = do
IORef ByteString
restRef <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
forall a. Monoid a => a
mempty
Transport -> IO Transport
forall (m :: * -> *) a. Monad m => a -> m a
return Transport :: (Int -> IO ByteString)
-> (ByteString -> IO ()) -> IO () -> IO () -> Transport
Transport
{ read :: Int -> IO ByteString
T.read = \Int
count -> let
readSome :: IO ByteString
readSome = do
ByteString
rest <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
restRef
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
restRef ByteString
forall a. Monoid a => a
mempty
if ByteString -> Bool
ByteString.null ByteString
rest
then Context -> IO ByteString
forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
else ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
unread :: ByteString -> IO ()
unread = \ByteString
rest ->
IORef ByteString -> (ByteString -> ByteString) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef ByteString
restRef (ByteString
rest ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<>)
go :: ByteString -> Int -> IO ByteString
go ByteString
acc Int
n = do
ByteString
chunk <- IO ByteString
readSome
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
ByteString.null ByteString
chunk) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
IOError -> IO ()
forall a. IOError -> IO a
ioError IOError
eof
let len :: Int
len = ByteString -> Int
ByteString.length ByteString
chunk
if Int
len Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
n
then do
let (ByteString
res, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
ByteString.splitAt Int
n ByteString
chunk
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
ByteString.null ByteString
rest) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
ByteString -> IO ()
unread ByteString
rest
ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
res)
else ByteString -> Int -> IO ByteString
go (ByteString
acc ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
chunk) (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
len)
eof :: IOError
eof = IOErrorType
-> HostName -> Maybe Handle -> Maybe HostName -> IOError
mkIOError IOErrorType
eofErrorType HostName
"Database.MongoDB.Transport"
Maybe Handle
forall a. Maybe a
Nothing Maybe HostName
forall a. Maybe a
Nothing
in ByteString -> ByteString
Lazy.ByteString.toStrict (ByteString -> ByteString) -> IO ByteString -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Int -> IO ByteString
go ByteString
forall a. Monoid a => a
mempty Int
count
, write :: ByteString -> IO ()
T.write = Context -> ByteString -> IO ()
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx (ByteString -> IO ())
-> (ByteString -> ByteString) -> ByteString -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Lazy.ByteString.fromStrict
, flush :: IO ()
T.flush = Context -> IO ()
TLS.contextFlush Context
ctx
, close :: IO ()
T.close = Context -> IO ()
TLS.contextClose Context
ctx
}