{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif

{-|
Module      : MongoDB TLS
Copyright   : (c)	Yuras Shumovich, 2016
License     : Apache 2.0
Maintainer  : Victor Denisov denisovenator@gmail.com
Stability   : experimental
Portability : POSIX

This module is for connecting to TLS enabled mongodb servers.
ATTENTION!!! Be aware that this module is highly experimental and is
barely tested. The current implementation doesn't verify server's identity.
It only allows you to connect to a mongodb server using TLS protocol.
-}

module Database.MongoDB.Transport.Tls
( connect
, connectWithTlsParams
)
where

import Data.IORef
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Lazy as Lazy.ByteString
import Data.Default.Class (def)
import Control.Exception (bracketOnError)
import Control.Monad (when, unless)
import System.IO
import Database.MongoDB.Internal.Protocol (Pipe, newPipeWith)
import Database.MongoDB.Transport (Transport(Transport))
import qualified Database.MongoDB.Transport as T
import System.IO.Error (mkIOError, eofErrorType)
import Database.MongoDB.Internal.Network (connectTo, HostName, PortID)
import qualified Network.TLS as TLS
import qualified Network.TLS.Extra.Cipher as TLS
import Database.MongoDB.Query (access, slaveOk, retrieveServerData)

-- | Connect to mongodb using TLS
connect :: HostName -> PortID -> IO Pipe
connect :: HostName -> PortID -> IO Pipe
connect HostName
host PortID
port = ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
params HostName
host PortID
port
  where
    params :: ClientParams
params = (HostName -> ByteString -> ClientParams
TLS.defaultParamsClient HostName
host ByteString
"")
        { clientSupported :: Supported
TLS.clientSupported = forall a. Default a => a
def
            { supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_default }
        , clientHooks :: ClientHooks
TLS.clientHooks = forall a. Default a => a
def
            { onServerCertificate :: OnServerCertificate
TLS.onServerCertificate = \CertificateStore
_ ValidationCache
_ ServiceID
_ CertificateChain
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return [] }
        }

-- | Connect to mongodb using TLS using provided TLS client parameters
connectWithTlsParams :: TLS.ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams :: ClientParams -> HostName -> PortID -> IO Pipe
connectWithTlsParams ClientParams
clientParams HostName
host PortID
port = forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (HostName -> PortID -> IO Handle
connectTo HostName
host PortID
port) Handle -> IO ()
hClose forall a b. (a -> b) -> a -> b
$ \Handle
handle -> do
  Context
context <- forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
TLS.contextNew Handle
handle ClientParams
clientParams
  forall (m :: * -> *). MonadIO m => Context -> m ()
TLS.handshake Context
context

  Transport
conn <- Context -> IO Transport
tlsConnection Context
context
  rec
    Pipe
p <- ServerData -> Transport -> IO Pipe
newPipeWith ServerData
sd Transport
conn
    ServerData
sd <- forall (m :: * -> *) a.
MonadIO m =>
Pipe -> AccessMode -> Database -> Action m a -> m a
access Pipe
p AccessMode
slaveOk Database
"admin" forall (m :: * -> *). MonadIO m => Action m ServerData
retrieveServerData
  forall (m :: * -> *) a. Monad m => a -> m a
return Pipe
p

tlsConnection :: TLS.Context -> IO Transport
tlsConnection :: Context -> IO Transport
tlsConnection Context
ctx = do
  IORef ByteString
restRef <- forall a. a -> IO (IORef a)
newIORef forall a. Monoid a => a
mempty
  forall (m :: * -> *) a. Monad m => a -> m a
return Transport
    { read :: Int -> IO ByteString
T.read = \Int
count -> let
          readSome :: IO ByteString
readSome = do
            ByteString
rest <- forall a. IORef a -> IO a
readIORef IORef ByteString
restRef
            forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
restRef forall a. Monoid a => a
mempty
            if ByteString -> Bool
ByteString.null ByteString
rest
              then forall (m :: * -> *). MonadIO m => Context -> m ByteString
TLS.recvData Context
ctx
              else forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
          unread :: ByteString -> IO ()
unread = \ByteString
rest ->
            forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef ByteString
restRef (ByteString
rest forall a. Semigroup a => a -> a -> a
<>)
          go :: ByteString -> Int -> IO ByteString
go ByteString
acc Int
n = do
            -- read until get enough bytes
            ByteString
chunk <- IO ByteString
readSome
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Bool
ByteString.null ByteString
chunk) forall a b. (a -> b) -> a -> b
$
              forall a. IOError -> IO a
ioError IOError
eof
            let len :: Int
len = ByteString -> Int
ByteString.length ByteString
chunk
            if Int
len forall a. Ord a => a -> a -> Bool
>= Int
n
              then do
                let (ByteString
res, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
ByteString.splitAt Int
n ByteString
chunk
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
ByteString.null ByteString
rest) forall a b. (a -> b) -> a -> b
$
                  ByteString -> IO ()
unread ByteString
rest
                forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
acc forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
res)
              else ByteString -> Int -> IO ByteString
go (ByteString
acc forall a. Semigroup a => a -> a -> a
<> ByteString -> ByteString
Lazy.ByteString.fromStrict ByteString
chunk) (Int
n forall a. Num a => a -> a -> a
- Int
len)
          eof :: IOError
eof = IOErrorType
-> HostName -> Maybe Handle -> Maybe HostName -> IOError
mkIOError IOErrorType
eofErrorType HostName
"Database.MongoDB.Transport"
                forall a. Maybe a
Nothing forall a. Maybe a
Nothing
       in ByteString -> ByteString
Lazy.ByteString.toStrict forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Int -> IO ByteString
go forall a. Monoid a => a
mempty Int
count
    , write :: ByteString -> IO ()
T.write = forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
TLS.sendData Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
Lazy.ByteString.fromStrict
    , flush :: IO ()
T.flush = Context -> IO ()
TLS.contextFlush Context
ctx
    , close :: IO ()
T.close = Context -> IO ()
TLS.contextClose Context
ctx
    }