{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module Network.QUIC.TLS ( clientHandshaker, serverHandshaker, ) where import Control.Applicative ((<|>)) import Data.Default.Class import Network.TLS hiding (Version) import Network.TLS.QUIC import System.X509 import Network.QUIC.Config import Network.QUIC.Parameters import Network.QUIC.Types sessionManager :: SessionEstablish -> SessionManager sessionManager :: SessionEstablish -> SessionManager sessionManager SessionEstablish establish = SessionManager noSessionManager{sessionEstablish = establish} clientHandshaker :: QUICCallbacks -> ClientConfig -> Version -> AuthCIDs -> SessionEstablish -> Bool -> IO () clientHandshaker :: QUICCallbacks -> ClientConfig -> Version -> AuthCIDs -> SessionEstablish -> Bool -> IO () clientHandshaker QUICCallbacks callbacks ClientConfig{Bool ServiceName [Group] [Cipher] [Version] Maybe Int Maybe ServiceName Credentials ClientHooks Version ResumptionInfo Parameters Hooks ServiceName -> IO () Version -> IO (Maybe [ByteString]) ccVersion :: Version ccVersions :: [Version] ccCiphers :: [Cipher] ccGroups :: [Group] ccParameters :: Parameters ccKeyLog :: ServiceName -> IO () ccQLog :: Maybe ServiceName ccCredentials :: Credentials ccHooks :: Hooks ccTlsHooks :: ClientHooks ccUse0RTT :: Bool ccServerName :: ServiceName ccPortName :: ServiceName ccALPN :: Version -> IO (Maybe [ByteString]) ccValidate :: Bool ccResumption :: ResumptionInfo ccPacketSize :: Maybe Int ccDebugLog :: Bool ccAutoMigration :: Bool ccVersion :: ClientConfig -> Version ccVersions :: ClientConfig -> [Version] ccCiphers :: ClientConfig -> [Cipher] ccGroups :: ClientConfig -> [Group] ccParameters :: ClientConfig -> Parameters ccKeyLog :: ClientConfig -> ServiceName -> IO () ccQLog :: ClientConfig -> Maybe ServiceName ccCredentials :: ClientConfig -> Credentials ccHooks :: ClientConfig -> Hooks ccTlsHooks :: ClientConfig -> ClientHooks ccUse0RTT :: ClientConfig -> Bool ccServerName :: ClientConfig -> ServiceName ccPortName :: ClientConfig -> ServiceName ccALPN :: ClientConfig -> Version -> IO (Maybe [ByteString]) ccValidate :: ClientConfig -> Bool ccResumption :: ClientConfig -> ResumptionInfo ccPacketSize :: ClientConfig -> Maybe Int ccDebugLog :: ClientConfig -> Bool ccAutoMigration :: ClientConfig -> Bool ..} Version ver AuthCIDs myAuthCIDs SessionEstablish establish Bool use0RTT = do CertificateStore caStore <- if Bool ccValidate then IO CertificateStore getSystemCertificateStore else CertificateStore -> IO CertificateStore forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return CertificateStore forall a. Monoid a => a mempty ClientParams -> QUICCallbacks -> IO () tlsQUICClient (CertificateStore -> ClientParams cparams CertificateStore caStore) QUICCallbacks callbacks where cparams :: CertificateStore -> ClientParams cparams CertificateStore caStore = (ServiceName -> ByteString -> ClientParams defaultParamsClient ServiceName ccServerName ByteString "") { clientShared = cshared caStore , clientHooks = hook , clientSupported = supported , clientDebug = debug , clientWantSessionResume = resumptionSession ccResumption , clientUseEarlyData = use0RTT } convTP :: Parameters -> Parameters convTP = Hooks -> Parameters -> Parameters onTransportParametersCreated Hooks ccHooks params :: Parameters params = Parameters -> Parameters convTP (Parameters -> Parameters) -> Parameters -> Parameters forall a b. (a -> b) -> a -> b $ AuthCIDs -> Parameters -> Parameters setCIDsToParameters AuthCIDs myAuthCIDs Parameters ccParameters convExt :: [ExtensionRaw] -> [ExtensionRaw] convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw] onTLSExtensionCreated Hooks ccHooks skipValidation :: ValidationCache skipValidation = ValidationCacheQueryCallback -> ValidationCacheAddCallback -> ValidationCache ValidationCache (\ServiceID _ Fingerprint _ Certificate _ -> ValidationCacheResult -> IO ValidationCacheResult forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return ValidationCacheResult ValidationCachePass) (\ServiceID _ Fingerprint _ Certificate _ -> () -> IO () forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return ()) cshared :: CertificateStore -> Shared cshared CertificateStore caStore = Shared forall a. Default a => a def { sharedValidationCache = if ccValidate then def else skipValidation , sharedCAStore = caStore , sharedHelloExtensions = convExt $ parametersToExtensionRaw ver params , sharedSessionManager = sessionManager establish } hook :: ClientHooks hook = ClientHooks ccTlsHooks { onSuggestALPN = (<|>) <$> ccALPN ver <*> onSuggestALPN ccTlsHooks } supported :: Supported supported = Supported defaultSupported { supportedCiphers = ccCiphers , supportedGroups = ccGroups } debug :: DebugParams debug = DebugParams forall a. Default a => a def { debugKeyLogger = ccKeyLog } parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw Version ver Parameters params = [ExtensionID -> ByteString -> ExtensionRaw ExtensionRaw ExtensionID tpId ByteString eParams] where tpId :: ExtensionID tpId = Version -> ExtensionID extensionIDForTtransportParameter Version ver eParams :: ByteString eParams = Parameters -> ByteString encodeParameters Parameters params serverHandshaker :: QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO () serverHandshaker :: QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO () serverHandshaker QUICCallbacks callbacks ServerConfig{Bool Int [(IP, PortNumber)] [Group] [Cipher] [Version] Maybe ServiceName Maybe (Version -> [ByteString] -> IO ByteString) SessionManager Credentials ServerHooks Parameters Hooks ServiceName -> IO () scVersions :: [Version] scCiphers :: [Cipher] scGroups :: [Group] scParameters :: Parameters scKeyLog :: ServiceName -> IO () scQLog :: Maybe ServiceName scCredentials :: Credentials scHooks :: Hooks scTlsHooks :: ServerHooks scUse0RTT :: Bool scAddresses :: [(IP, PortNumber)] scALPN :: Maybe (Version -> [ByteString] -> IO ByteString) scRequireRetry :: Bool scSessionManager :: SessionManager scDebugLog :: Maybe ServiceName scTicketLifetime :: Int scVersions :: ServerConfig -> [Version] scCiphers :: ServerConfig -> [Cipher] scGroups :: ServerConfig -> [Group] scParameters :: ServerConfig -> Parameters scKeyLog :: ServerConfig -> ServiceName -> IO () scQLog :: ServerConfig -> Maybe ServiceName scCredentials :: ServerConfig -> Credentials scHooks :: ServerConfig -> Hooks scTlsHooks :: ServerConfig -> ServerHooks scUse0RTT :: ServerConfig -> Bool scAddresses :: ServerConfig -> [(IP, PortNumber)] scALPN :: ServerConfig -> Maybe (Version -> [ByteString] -> IO ByteString) scRequireRetry :: ServerConfig -> Bool scSessionManager :: ServerConfig -> SessionManager scDebugLog :: ServerConfig -> Maybe ServiceName scTicketLifetime :: ServerConfig -> Int ..} Version ver IO Parameters getParams = ServerParams -> QUICCallbacks -> IO () tlsQUICServer ServerParams sparams QUICCallbacks callbacks where sparams :: ServerParams sparams = ServerParams forall a. Default a => a def { serverShared = sshared , serverHooks = hook , serverSupported = supported , serverDebug = debug , serverEarlyDataSize = if scUse0RTT then quicMaxEarlyDataSize else 0 , serverTicketLifetime = scTicketLifetime } convTP :: Parameters -> Parameters convTP = Hooks -> Parameters -> Parameters onTransportParametersCreated Hooks scHooks convExt :: [ExtensionRaw] -> [ExtensionRaw] convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw] onTLSExtensionCreated Hooks scHooks sshared :: Shared sshared = Shared forall a. Default a => a def { sharedCredentials = scCredentials , sharedSessionManager = scSessionManager } hook :: ServerHooks hook = ServerHooks scTlsHooks { onALPNClientSuggest = case scALPN of Maybe (Version -> [ByteString] -> IO ByteString) Nothing -> ServerHooks -> Maybe ([ByteString] -> IO ByteString) onALPNClientSuggest ServerHooks scTlsHooks Just Version -> [ByteString] -> IO ByteString io -> ([ByteString] -> IO ByteString) -> Maybe ([ByteString] -> IO ByteString) forall a. a -> Maybe a Just (([ByteString] -> IO ByteString) -> Maybe ([ByteString] -> IO ByteString)) -> ([ByteString] -> IO ByteString) -> Maybe ([ByteString] -> IO ByteString) forall a b. (a -> b) -> a -> b $ Version -> [ByteString] -> IO ByteString io Version ver , onEncryptedExtensionsCreating = \[ExtensionRaw] exts0 -> do [ExtensionRaw] exts0' <- ServerHooks -> [ExtensionRaw] -> IO [ExtensionRaw] onEncryptedExtensionsCreating ServerHooks scTlsHooks [ExtensionRaw] exts0 Parameters params <- IO Parameters getParams let exts :: [ExtensionRaw] exts = [ExtensionRaw] -> [ExtensionRaw] convExt ([ExtensionRaw] -> [ExtensionRaw]) -> [ExtensionRaw] -> [ExtensionRaw] forall a b. (a -> b) -> a -> b $ Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw Version ver (Parameters -> [ExtensionRaw]) -> Parameters -> [ExtensionRaw] forall a b. (a -> b) -> a -> b $ Parameters -> Parameters convTP Parameters params [ExtensionRaw] -> IO [ExtensionRaw] forall a. a -> IO a forall (m :: * -> *) a. Monad m => a -> m a return ([ExtensionRaw] -> IO [ExtensionRaw]) -> [ExtensionRaw] -> IO [ExtensionRaw] forall a b. (a -> b) -> a -> b $ [ExtensionRaw] exts [ExtensionRaw] -> [ExtensionRaw] -> [ExtensionRaw] forall a. [a] -> [a] -> [a] ++ [ExtensionRaw] exts0' } supported :: Supported supported = Supported forall a. Default a => a def { supportedVersions = [TLS13] , supportedCiphers = scCiphers , supportedGroups = scGroups } debug :: DebugParams debug = DebugParams forall a. Default a => a def { debugKeyLogger = scKeyLog }