{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Metro.TP.TLS
( TLS
, module Metro.TP.TLSSetting
, tlsConfig
) where
import Control.Exception (SomeException, bracketOnError, catch)
import qualified Data.ByteString.Char8 as B (append, length, null)
import qualified Data.ByteString.Lazy as BL (fromStrict)
import Metro.Class (Transport (..))
import Metro.TP.TLSSetting
import Network.TLS (Context, TLSParams)
import qualified Network.TLS as TLS
newtype TLS = TLS Context
instance Transport TLS where
data TransportConfig TLS = forall params tp. (Transport tp, TLSParams params) => TLSConfig params (TransportConfig tp)
newTransport (TLSConfig params config) = do
transport <- newTransport config
bracketOnError (TLS.contextNew (transportBackend transport) params) closeTLS $ \ctx -> do
TLS.handshake ctx
return $ TLS ctx
recvData (TLS ctx) = const $ TLS.recvData ctx
sendData (TLS ctx) = TLS.sendData ctx . BL.fromStrict
closeTransport (TLS ctx) = closeTLS ctx
transportBackend :: Transport tp => tp -> TLS.Backend
transportBackend transport = TLS.Backend
{ TLS.backendFlush = return ()
, TLS.backendClose = closeTransport transport
, TLS.backendSend = sendData transport
, TLS.backendRecv = recvData'
}
where recvData' nbytes = do
s <- recvData transport nbytes
if loadMore nbytes s then do
s' <- recvData' (nbytes - B.length s)
return $ s `B.append` s'
else return s
loadMore nbytes bs | B.null bs = False
| B.length bs < nbytes = True
| otherwise = False
closeTLS :: Context -> IO ()
closeTLS ctx = (TLS.bye ctx >> TLS.contextClose ctx)
`catch` (\(_::SomeException) -> return ())
tlsConfig :: (Transport tp, TLSParams params) => params -> TransportConfig tp -> TransportConfig TLS
tlsConfig = TLSConfig