{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.TLS (
    clientHandshaker,
    serverHandshaker,
) where

import Control.Applicative ((<|>))
import Data.Default.Class
import Network.TLS hiding (Version)
import Network.TLS.QUIC
import System.X509

import Network.QUIC.Config
import Network.QUIC.Parameters
import Network.QUIC.Types

sessionManager :: SessionEstablish -> SessionManager
sessionManager :: SessionEstablish -> SessionManager
sessionManager SessionEstablish
establish = SessionManager
noSessionManager{sessionEstablish = establish}

clientHandshaker
    :: QUICCallbacks
    -> ClientConfig
    -> Version
    -> AuthCIDs
    -> SessionEstablish
    -> Bool
    -> IO ()
clientHandshaker :: QUICCallbacks
-> ClientConfig
-> Version
-> AuthCIDs
-> SessionEstablish
-> Bool
-> IO ()
clientHandshaker QUICCallbacks
callbacks ClientConfig{Bool
ServiceName
[Group]
[Cipher]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ClientHooks
Version
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [ByteString])
ccVersion :: Version
ccVersions :: [Version]
ccCiphers :: [Cipher]
ccGroups :: [Group]
ccParameters :: Parameters
ccKeyLog :: ServiceName -> IO ()
ccQLog :: Maybe ServiceName
ccCredentials :: Credentials
ccHooks :: Hooks
ccTlsHooks :: ClientHooks
ccUse0RTT :: Bool
ccServerName :: ServiceName
ccPortName :: ServiceName
ccALPN :: Version -> IO (Maybe [ByteString])
ccValidate :: Bool
ccResumption :: ResumptionInfo
ccPacketSize :: Maybe Int
ccDebugLog :: Bool
ccAutoMigration :: Bool
ccVersion :: ClientConfig -> Version
ccVersions :: ClientConfig -> [Version]
ccCiphers :: ClientConfig -> [Cipher]
ccGroups :: ClientConfig -> [Group]
ccParameters :: ClientConfig -> Parameters
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccQLog :: ClientConfig -> Maybe ServiceName
ccCredentials :: ClientConfig -> Credentials
ccHooks :: ClientConfig -> Hooks
ccTlsHooks :: ClientConfig -> ClientHooks
ccUse0RTT :: ClientConfig -> Bool
ccServerName :: ClientConfig -> ServiceName
ccPortName :: ClientConfig -> ServiceName
ccALPN :: ClientConfig -> Version -> IO (Maybe [ByteString])
ccValidate :: ClientConfig -> Bool
ccResumption :: ClientConfig -> ResumptionInfo
ccPacketSize :: ClientConfig -> Maybe Int
ccDebugLog :: ClientConfig -> Bool
ccAutoMigration :: ClientConfig -> Bool
..} Version
ver AuthCIDs
myAuthCIDs SessionEstablish
establish Bool
use0RTT = do
    CertificateStore
caStore <- if Bool
ccValidate then IO CertificateStore
getSystemCertificateStore else CertificateStore -> IO CertificateStore
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return CertificateStore
forall a. Monoid a => a
mempty
    ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient (CertificateStore -> ClientParams
cparams CertificateStore
caStore) QUICCallbacks
callbacks
  where
    cparams :: CertificateStore -> ClientParams
cparams CertificateStore
caStore =
        (ServiceName -> ByteString -> ClientParams
defaultParamsClient ServiceName
ccServerName ByteString
"")
            { clientShared = cshared caStore
            , clientHooks = hook
            , clientSupported = supported
            , clientDebug = debug
            , clientWantSessionResume = resumptionSession ccResumption
            , clientUseEarlyData = use0RTT
            }
    convTP :: Parameters -> Parameters
convTP = Hooks -> Parameters -> Parameters
onTransportParametersCreated Hooks
ccHooks
    params :: Parameters
params = Parameters -> Parameters
convTP (Parameters -> Parameters) -> Parameters -> Parameters
forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Parameters -> Parameters
setCIDsToParameters AuthCIDs
myAuthCIDs Parameters
ccParameters
    convExt :: [ExtensionRaw] -> [ExtensionRaw]
convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw]
onTLSExtensionCreated Hooks
ccHooks
    skipValidation :: ValidationCache
skipValidation = ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
ValidationCache (\ServiceID
_ Fingerprint
_ Certificate
_ -> ValidationCacheResult -> IO ValidationCacheResult
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
ValidationCachePass) (\ServiceID
_ Fingerprint
_ Certificate
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
    cshared :: CertificateStore -> Shared
cshared CertificateStore
caStore =
        Shared
forall a. Default a => a
def
            { sharedValidationCache = if ccValidate then def else skipValidation
            , sharedCAStore = caStore
            , sharedHelloExtensions = convExt $ parametersToExtensionRaw ver params
            , sharedSessionManager = sessionManager establish
            }
    hook :: ClientHooks
hook =
        ClientHooks
ccTlsHooks
            { onSuggestALPN = (<|>) <$> ccALPN ver <*> onSuggestALPN ccTlsHooks
            }
    supported :: Supported
supported =
        Supported
defaultSupported
            { supportedCiphers = ccCiphers
            , supportedGroups = ccGroups
            }
    debug :: DebugParams
debug =
        DebugParams
forall a. Default a => a
def
            { debugKeyLogger = ccKeyLog
            }

parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw Version
ver Parameters
params = [ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw ExtensionID
tpId ByteString
eParams]
  where
    tpId :: ExtensionID
tpId = Version -> ExtensionID
extensionIDForTtransportParameter Version
ver
    eParams :: ByteString
eParams = Parameters -> ByteString
encodeParameters Parameters
params

serverHandshaker
    :: QUICCallbacks
    -> ServerConfig
    -> Version
    -> IO Parameters
    -> IO ()
serverHandshaker :: QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO ()
serverHandshaker QUICCallbacks
callbacks ServerConfig{Bool
Int
[(IP, PortNumber)]
[Group]
[Cipher]
[Version]
Maybe ServiceName
Maybe (Version -> [ByteString] -> IO ByteString)
SessionManager
Credentials
ServerHooks
Parameters
Hooks
ServiceName -> IO ()
scVersions :: [Version]
scCiphers :: [Cipher]
scGroups :: [Group]
scParameters :: Parameters
scKeyLog :: ServiceName -> IO ()
scQLog :: Maybe ServiceName
scCredentials :: Credentials
scHooks :: Hooks
scTlsHooks :: ServerHooks
scUse0RTT :: Bool
scAddresses :: [(IP, PortNumber)]
scALPN :: Maybe (Version -> [ByteString] -> IO ByteString)
scRequireRetry :: Bool
scSessionManager :: SessionManager
scDebugLog :: Maybe ServiceName
scTicketLifetime :: Int
scVersions :: ServerConfig -> [Version]
scCiphers :: ServerConfig -> [Cipher]
scGroups :: ServerConfig -> [Group]
scParameters :: ServerConfig -> Parameters
scKeyLog :: ServerConfig -> ServiceName -> IO ()
scQLog :: ServerConfig -> Maybe ServiceName
scCredentials :: ServerConfig -> Credentials
scHooks :: ServerConfig -> Hooks
scTlsHooks :: ServerConfig -> ServerHooks
scUse0RTT :: ServerConfig -> Bool
scAddresses :: ServerConfig -> [(IP, PortNumber)]
scALPN :: ServerConfig -> Maybe (Version -> [ByteString] -> IO ByteString)
scRequireRetry :: ServerConfig -> Bool
scSessionManager :: ServerConfig -> SessionManager
scDebugLog :: ServerConfig -> Maybe ServiceName
scTicketLifetime :: ServerConfig -> Int
..} Version
ver IO Parameters
getParams =
    ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer ServerParams
sparams QUICCallbacks
callbacks
  where
    sparams :: ServerParams
sparams =
        ServerParams
forall a. Default a => a
def
            { serverShared = sshared
            , serverHooks = hook
            , serverSupported = supported
            , serverDebug = debug
            , serverEarlyDataSize = if scUse0RTT then quicMaxEarlyDataSize else 0
            , serverTicketLifetime = scTicketLifetime
            }
    convTP :: Parameters -> Parameters
convTP = Hooks -> Parameters -> Parameters
onTransportParametersCreated Hooks
scHooks
    convExt :: [ExtensionRaw] -> [ExtensionRaw]
convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw]
onTLSExtensionCreated Hooks
scHooks
    sshared :: Shared
sshared =
        Shared
forall a. Default a => a
def
            { sharedCredentials = scCredentials
            , sharedSessionManager = scSessionManager
            }
    hook :: ServerHooks
hook =
        ServerHooks
scTlsHooks
            { onALPNClientSuggest = case scALPN of
                Maybe (Version -> [ByteString] -> IO ByteString)
Nothing -> ServerHooks -> Maybe ([ByteString] -> IO ByteString)
onALPNClientSuggest ServerHooks
scTlsHooks
                Just Version -> [ByteString] -> IO ByteString
io -> ([ByteString] -> IO ByteString)
-> Maybe ([ByteString] -> IO ByteString)
forall a. a -> Maybe a
Just (([ByteString] -> IO ByteString)
 -> Maybe ([ByteString] -> IO ByteString))
-> ([ByteString] -> IO ByteString)
-> Maybe ([ByteString] -> IO ByteString)
forall a b. (a -> b) -> a -> b
$ Version -> [ByteString] -> IO ByteString
io Version
ver
            , onEncryptedExtensionsCreating = \[ExtensionRaw]
exts0 -> do
                [ExtensionRaw]
exts0' <- ServerHooks -> [ExtensionRaw] -> IO [ExtensionRaw]
onEncryptedExtensionsCreating ServerHooks
scTlsHooks [ExtensionRaw]
exts0
                Parameters
params <- IO Parameters
getParams
                let exts :: [ExtensionRaw]
exts = [ExtensionRaw] -> [ExtensionRaw]
convExt ([ExtensionRaw] -> [ExtensionRaw])
-> [ExtensionRaw] -> [ExtensionRaw]
forall a b. (a -> b) -> a -> b
$ Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw Version
ver (Parameters -> [ExtensionRaw]) -> Parameters -> [ExtensionRaw]
forall a b. (a -> b) -> a -> b
$ Parameters -> Parameters
convTP Parameters
params
                [ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([ExtensionRaw] -> IO [ExtensionRaw])
-> [ExtensionRaw] -> IO [ExtensionRaw]
forall a b. (a -> b) -> a -> b
$ [ExtensionRaw]
exts [ExtensionRaw] -> [ExtensionRaw] -> [ExtensionRaw]
forall a. [a] -> [a] -> [a]
++ [ExtensionRaw]
exts0'
            }
    supported :: Supported
supported =
        Supported
forall a. Default a => a
def
            { supportedVersions = [TLS13]
            , supportedCiphers = scCiphers
            , supportedGroups = scGroups
            }
    debug :: DebugParams
debug =
        DebugParams
forall a. Default a => a
def
            { debugKeyLogger = scKeyLog
            }