{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} module Network.QUIC.TLS ( clientHandshaker , serverHandshaker ) where import Data.Default.Class import Network.TLS hiding (Version) import Network.TLS.QUIC import System.X509 import Network.QUIC.Config import Network.QUIC.Parameters import Network.QUIC.Types sessionManager :: SessionEstablish -> SessionManager sessionManager :: SessionEstablish -> SessionManager sessionManager SessionEstablish establish = SessionManager { sessionEstablish :: SessionEstablish sessionEstablish = SessionEstablish establish , sessionResume :: SessionID -> IO (Maybe SessionData) sessionResume = \SessionID _ -> forall (m :: * -> *) a. Monad m => a -> m a return forall a. Maybe a Nothing , sessionResumeOnlyOnce :: SessionID -> IO (Maybe SessionData) sessionResumeOnlyOnce = \SessionID _ -> forall (m :: * -> *) a. Monad m => a -> m a return forall a. Maybe a Nothing , sessionInvalidate :: SessionID -> IO () sessionInvalidate = \SessionID _ -> forall (m :: * -> *) a. Monad m => a -> m a return () } clientHandshaker :: QUICCallbacks -> ClientConfig -> Version -> AuthCIDs -> SessionEstablish -> Bool -> IO () clientHandshaker :: QUICCallbacks -> ClientConfig -> Version -> AuthCIDs -> SessionEstablish -> Bool -> IO () clientHandshaker QUICCallbacks callbacks ClientConfig{Bool ServiceName [Cipher] [Group] [Version] Maybe Int Maybe ServiceName Credentials ResumptionInfo Parameters Hooks ServiceName -> IO () Version -> IO (Maybe [SessionID]) ccAutoMigration :: ClientConfig -> Bool ccDebugLog :: ClientConfig -> Bool ccPacketSize :: ClientConfig -> Maybe Int ccResumption :: ClientConfig -> ResumptionInfo ccValidate :: ClientConfig -> Bool ccALPN :: ClientConfig -> Version -> IO (Maybe [SessionID]) ccPortName :: ClientConfig -> ServiceName ccServerName :: ClientConfig -> ServiceName ccUse0RTT :: ClientConfig -> Bool ccHooks :: ClientConfig -> Hooks ccCredentials :: ClientConfig -> Credentials ccQLog :: ClientConfig -> Maybe ServiceName ccKeyLog :: ClientConfig -> ServiceName -> IO () ccParameters :: ClientConfig -> Parameters ccGroups :: ClientConfig -> [Group] ccCiphers :: ClientConfig -> [Cipher] ccVersions :: ClientConfig -> [Version] ccAutoMigration :: Bool ccDebugLog :: Bool ccPacketSize :: Maybe Int ccResumption :: ResumptionInfo ccValidate :: Bool ccALPN :: Version -> IO (Maybe [SessionID]) ccPortName :: ServiceName ccServerName :: ServiceName ccUse0RTT :: Bool ccHooks :: Hooks ccCredentials :: Credentials ccQLog :: Maybe ServiceName ccKeyLog :: ServiceName -> IO () ccParameters :: Parameters ccGroups :: [Group] ccCiphers :: [Cipher] ccVersions :: [Version] ..} Version ver AuthCIDs myAuthCIDs SessionEstablish establish Bool use0RTT = do CertificateStore caStore <- if Bool ccValidate then IO CertificateStore getSystemCertificateStore else forall (m :: * -> *) a. Monad m => a -> m a return forall a. Monoid a => a mempty ClientParams -> QUICCallbacks -> IO () tlsQUICClient (CertificateStore -> ClientParams cparams CertificateStore caStore) QUICCallbacks callbacks where cparams :: CertificateStore -> ClientParams cparams CertificateStore caStore = (ServiceName -> SessionID -> ClientParams defaultParamsClient ServiceName ccServerName SessionID "") { clientShared :: Shared clientShared = CertificateStore -> Shared cshared CertificateStore caStore , clientHooks :: ClientHooks clientHooks = ClientHooks hook , clientSupported :: Supported clientSupported = Supported supported , clientDebug :: DebugParams clientDebug = DebugParams debug , clientWantSessionResume :: Maybe (SessionID, SessionData) clientWantSessionResume = ResumptionInfo -> Maybe (SessionID, SessionData) resumptionSession ResumptionInfo ccResumption , clientEarlyData :: Maybe SessionID clientEarlyData = if Bool use0RTT then forall a. a -> Maybe a Just SessionID "" else forall a. Maybe a Nothing } convTP :: Parameters -> Parameters convTP = Hooks -> Parameters -> Parameters onTransportParametersCreated Hooks ccHooks params :: Parameters params = Parameters -> Parameters convTP forall a b. (a -> b) -> a -> b $ AuthCIDs -> Parameters -> Parameters setCIDsToParameters AuthCIDs myAuthCIDs Parameters ccParameters convExt :: [ExtensionRaw] -> [ExtensionRaw] convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw] onTLSExtensionCreated Hooks ccHooks skipValidation :: ValidationCache skipValidation = ValidationCacheQueryCallback -> ValidationCacheAddCallback -> ValidationCache ValidationCache (\ServiceID _ Fingerprint _ Certificate _ -> forall (m :: * -> *) a. Monad m => a -> m a return ValidationCacheResult ValidationCachePass) (\ServiceID _ Fingerprint _ Certificate _ -> forall (m :: * -> *) a. Monad m => a -> m a return ()) cshared :: CertificateStore -> Shared cshared CertificateStore caStore = forall a. Default a => a def { sharedValidationCache :: ValidationCache sharedValidationCache = if Bool ccValidate then forall a. Default a => a def else ValidationCache skipValidation , sharedCAStore :: CertificateStore sharedCAStore = CertificateStore caStore , sharedHelloExtensions :: [ExtensionRaw] sharedHelloExtensions = [ExtensionRaw] -> [ExtensionRaw] convExt forall a b. (a -> b) -> a -> b $ Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw Version ver Parameters params , sharedSessionManager :: SessionManager sharedSessionManager = SessionEstablish -> SessionManager sessionManager SessionEstablish establish } hook :: ClientHooks hook = forall a. Default a => a def { onSuggestALPN :: IO (Maybe [SessionID]) onSuggestALPN = Version -> IO (Maybe [SessionID]) ccALPN Version ver } supported :: Supported supported = Supported defaultSupported { supportedCiphers :: [Cipher] supportedCiphers = [Cipher] ccCiphers , supportedGroups :: [Group] supportedGroups = [Group] ccGroups } debug :: DebugParams debug = forall a. Default a => a def { debugKeyLogger :: ServiceName -> IO () debugKeyLogger = ServiceName -> IO () ccKeyLog } parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw Version ver Parameters params = [ExtensionID -> SessionID -> ExtensionRaw ExtensionRaw ExtensionID tpId SessionID eParams] where tpId :: ExtensionID tpId = Version -> ExtensionID extensionIDForTtransportParameter Version ver eParams :: SessionID eParams = Parameters -> SessionID encodeParameters Parameters params serverHandshaker :: QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO () serverHandshaker :: QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO () serverHandshaker QUICCallbacks callbacks ServerConfig{Bool [(IP, PortNumber)] [Cipher] [Group] [Version] Maybe ServiceName Maybe (Version -> [SessionID] -> IO SessionID) Credentials SessionManager Parameters Hooks ServiceName -> IO () scDebugLog :: ServerConfig -> Maybe ServiceName scSessionManager :: ServerConfig -> SessionManager scRequireRetry :: ServerConfig -> Bool scALPN :: ServerConfig -> Maybe (Version -> [SessionID] -> IO SessionID) scAddresses :: ServerConfig -> [(IP, PortNumber)] scUse0RTT :: ServerConfig -> Bool scHooks :: ServerConfig -> Hooks scCredentials :: ServerConfig -> Credentials scQLog :: ServerConfig -> Maybe ServiceName scKeyLog :: ServerConfig -> ServiceName -> IO () scParameters :: ServerConfig -> Parameters scGroups :: ServerConfig -> [Group] scCiphers :: ServerConfig -> [Cipher] scVersions :: ServerConfig -> [Version] scDebugLog :: Maybe ServiceName scSessionManager :: SessionManager scRequireRetry :: Bool scALPN :: Maybe (Version -> [SessionID] -> IO SessionID) scAddresses :: [(IP, PortNumber)] scUse0RTT :: Bool scHooks :: Hooks scCredentials :: Credentials scQLog :: Maybe ServiceName scKeyLog :: ServiceName -> IO () scParameters :: Parameters scGroups :: [Group] scCiphers :: [Cipher] scVersions :: [Version] ..} Version ver IO Parameters getParams = ServerParams -> QUICCallbacks -> IO () tlsQUICServer ServerParams sparams QUICCallbacks callbacks where sparams :: ServerParams sparams = forall a. Default a => a def { serverShared :: Shared serverShared = Shared sshared , serverHooks :: ServerHooks serverHooks = ServerHooks hook , serverSupported :: Supported serverSupported = Supported supported , serverDebug :: DebugParams serverDebug = DebugParams debug , serverEarlyDataSize :: Int serverEarlyDataSize = if Bool scUse0RTT then Int quicMaxEarlyDataSize else Int 0 } convTP :: Parameters -> Parameters convTP = Hooks -> Parameters -> Parameters onTransportParametersCreated Hooks scHooks convExt :: [ExtensionRaw] -> [ExtensionRaw] convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw] onTLSExtensionCreated Hooks scHooks sshared :: Shared sshared = forall a. Default a => a def { sharedCredentials :: Credentials sharedCredentials = Credentials scCredentials , sharedSessionManager :: SessionManager sharedSessionManager = SessionManager scSessionManager } hook :: ServerHooks hook = forall a. Default a => a def { onALPNClientSuggest :: Maybe ([SessionID] -> IO SessionID) onALPNClientSuggest = case Maybe (Version -> [SessionID] -> IO SessionID) scALPN of Maybe (Version -> [SessionID] -> IO SessionID) Nothing -> forall a. Maybe a Nothing Just Version -> [SessionID] -> IO SessionID io -> forall a. a -> Maybe a Just forall a b. (a -> b) -> a -> b $ Version -> [SessionID] -> IO SessionID io Version ver , onEncryptedExtensionsCreating :: [ExtensionRaw] -> IO [ExtensionRaw] onEncryptedExtensionsCreating = \[ExtensionRaw] exts0 -> do Parameters params <- IO Parameters getParams let exts :: [ExtensionRaw] exts = [ExtensionRaw] -> [ExtensionRaw] convExt forall a b. (a -> b) -> a -> b $ Version -> Parameters -> [ExtensionRaw] parametersToExtensionRaw Version ver forall a b. (a -> b) -> a -> b $ Parameters -> Parameters convTP Parameters params forall (m :: * -> *) a. Monad m => a -> m a return forall a b. (a -> b) -> a -> b $ [ExtensionRaw] exts forall a. [a] -> [a] -> [a] ++ [ExtensionRaw] exts0 } supported :: Supported supported = forall a. Default a => a def { supportedVersions :: [Version] supportedVersions = [Version TLS13] , supportedCiphers :: [Cipher] supportedCiphers = [Cipher] scCiphers , supportedGroups :: [Group] supportedGroups = [Group] scGroups } debug :: DebugParams debug = forall a. Default a => a def { debugKeyLogger :: ServiceName -> IO () debugKeyLogger = ServiceName -> IO () scKeyLog }