{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.QUIC.TLS (
    clientHandshaker,
    serverHandshaker,
) where

import Data.Default.Class
import Network.TLS hiding (Version)
import Network.TLS.QUIC
import System.X509

import Network.QUIC.Config
import Network.QUIC.Parameters
import Network.QUIC.Types

sessionManager :: SessionEstablish -> SessionManager
sessionManager :: SessionEstablish -> SessionManager
sessionManager SessionEstablish
establish =
    SessionManager
        { sessionEstablish :: SessionEstablish
sessionEstablish = SessionEstablish
establish
        , sessionResume :: SessionID -> IO (Maybe SessionData)
sessionResume = \SessionID
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        , sessionResumeOnlyOnce :: SessionID -> IO (Maybe SessionData)
sessionResumeOnlyOnce = \SessionID
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        , sessionInvalidate :: SessionID -> IO ()
sessionInvalidate = \SessionID
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
        }

clientHandshaker
    :: QUICCallbacks
    -> ClientConfig
    -> Version
    -> AuthCIDs
    -> SessionEstablish
    -> Bool
    -> IO ()
clientHandshaker :: QUICCallbacks
-> ClientConfig
-> Version
-> AuthCIDs
-> SessionEstablish
-> Bool
-> IO ()
clientHandshaker QUICCallbacks
callbacks ClientConfig{Bool
ServiceName
[Cipher]
[Group]
[Version]
Maybe Int
Maybe ServiceName
Credentials
ResumptionInfo
Parameters
Hooks
ServiceName -> IO ()
Version -> IO (Maybe [SessionID])
ccAutoMigration :: ClientConfig -> Bool
ccDebugLog :: ClientConfig -> Bool
ccPacketSize :: ClientConfig -> Maybe Int
ccResumption :: ClientConfig -> ResumptionInfo
ccValidate :: ClientConfig -> Bool
ccALPN :: ClientConfig -> Version -> IO (Maybe [SessionID])
ccPortName :: ClientConfig -> ServiceName
ccServerName :: ClientConfig -> ServiceName
ccUse0RTT :: ClientConfig -> Bool
ccHooks :: ClientConfig -> Hooks
ccCredentials :: ClientConfig -> Credentials
ccQLog :: ClientConfig -> Maybe ServiceName
ccKeyLog :: ClientConfig -> ServiceName -> IO ()
ccParameters :: ClientConfig -> Parameters
ccGroups :: ClientConfig -> [Group]
ccCiphers :: ClientConfig -> [Cipher]
ccVersions :: ClientConfig -> [Version]
ccAutoMigration :: Bool
ccDebugLog :: Bool
ccPacketSize :: Maybe Int
ccResumption :: ResumptionInfo
ccValidate :: Bool
ccALPN :: Version -> IO (Maybe [SessionID])
ccPortName :: ServiceName
ccServerName :: ServiceName
ccUse0RTT :: Bool
ccHooks :: Hooks
ccCredentials :: Credentials
ccQLog :: Maybe ServiceName
ccKeyLog :: ServiceName -> IO ()
ccParameters :: Parameters
ccGroups :: [Group]
ccCiphers :: [Cipher]
ccVersions :: [Version]
..} Version
ver AuthCIDs
myAuthCIDs SessionEstablish
establish Bool
use0RTT = do
    CertificateStore
caStore <- if Bool
ccValidate then IO CertificateStore
getSystemCertificateStore else forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
    ClientParams -> QUICCallbacks -> IO ()
tlsQUICClient (CertificateStore -> ClientParams
cparams CertificateStore
caStore) QUICCallbacks
callbacks
  where
    cparams :: CertificateStore -> ClientParams
cparams CertificateStore
caStore =
        (ServiceName -> SessionID -> ClientParams
defaultParamsClient ServiceName
ccServerName SessionID
"")
            { clientShared :: Shared
clientShared = CertificateStore -> Shared
cshared CertificateStore
caStore
            , clientHooks :: ClientHooks
clientHooks = ClientHooks
hook
            , clientSupported :: Supported
clientSupported = Supported
supported
            , clientDebug :: DebugParams
clientDebug = DebugParams
debug
            , clientWantSessionResume :: Maybe (SessionID, SessionData)
clientWantSessionResume = ResumptionInfo -> Maybe (SessionID, SessionData)
resumptionSession ResumptionInfo
ccResumption
            , clientEarlyData :: Maybe SessionID
clientEarlyData = if Bool
use0RTT then forall a. a -> Maybe a
Just SessionID
"" else forall a. Maybe a
Nothing
            }
    convTP :: Parameters -> Parameters
convTP = Hooks -> Parameters -> Parameters
onTransportParametersCreated Hooks
ccHooks
    params :: Parameters
params = Parameters -> Parameters
convTP forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Parameters -> Parameters
setCIDsToParameters AuthCIDs
myAuthCIDs Parameters
ccParameters
    convExt :: [ExtensionRaw] -> [ExtensionRaw]
convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw]
onTLSExtensionCreated Hooks
ccHooks
    skipValidation :: ValidationCache
skipValidation = ValidationCacheQueryCallback
-> ValidationCacheAddCallback -> ValidationCache
ValidationCache (\ServiceID
_ Fingerprint
_ Certificate
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ValidationCacheResult
ValidationCachePass) (\ServiceID
_ Fingerprint
_ Certificate
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ())
    cshared :: CertificateStore -> Shared
cshared CertificateStore
caStore =
        forall a. Default a => a
def
            { sharedValidationCache :: ValidationCache
sharedValidationCache = if Bool
ccValidate then forall a. Default a => a
def else ValidationCache
skipValidation
            , sharedCAStore :: CertificateStore
sharedCAStore = CertificateStore
caStore
            , sharedHelloExtensions :: [ExtensionRaw]
sharedHelloExtensions = [ExtensionRaw] -> [ExtensionRaw]
convExt forall a b. (a -> b) -> a -> b
$ Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw Version
ver Parameters
params
            , sharedSessionManager :: SessionManager
sharedSessionManager = SessionEstablish -> SessionManager
sessionManager SessionEstablish
establish
            }
    hook :: ClientHooks
hook =
        forall a. Default a => a
def
            { onSuggestALPN :: IO (Maybe [SessionID])
onSuggestALPN = Version -> IO (Maybe [SessionID])
ccALPN Version
ver
            }
    supported :: Supported
supported =
        Supported
defaultSupported
            { supportedCiphers :: [Cipher]
supportedCiphers = [Cipher]
ccCiphers
            , supportedGroups :: [Group]
supportedGroups = [Group]
ccGroups
            }
    debug :: DebugParams
debug =
        forall a. Default a => a
def
            { debugKeyLogger :: ServiceName -> IO ()
debugKeyLogger = ServiceName -> IO ()
ccKeyLog
            }

parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw :: Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw Version
ver Parameters
params = [ExtensionID -> SessionID -> ExtensionRaw
ExtensionRaw ExtensionID
tpId SessionID
eParams]
  where
    tpId :: ExtensionID
tpId = Version -> ExtensionID
extensionIDForTtransportParameter Version
ver
    eParams :: SessionID
eParams = Parameters -> SessionID
encodeParameters Parameters
params

serverHandshaker
    :: QUICCallbacks
    -> ServerConfig
    -> Version
    -> IO Parameters
    -> IO ()
serverHandshaker :: QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO ()
serverHandshaker QUICCallbacks
callbacks ServerConfig{Bool
[(IP, PortNumber)]
[Cipher]
[Group]
[Version]
Maybe ServiceName
Maybe (Version -> [SessionID] -> IO SessionID)
Credentials
SessionManager
Parameters
Hooks
ServiceName -> IO ()
scDebugLog :: ServerConfig -> Maybe ServiceName
scSessionManager :: ServerConfig -> SessionManager
scRequireRetry :: ServerConfig -> Bool
scALPN :: ServerConfig -> Maybe (Version -> [SessionID] -> IO SessionID)
scAddresses :: ServerConfig -> [(IP, PortNumber)]
scUse0RTT :: ServerConfig -> Bool
scHooks :: ServerConfig -> Hooks
scCredentials :: ServerConfig -> Credentials
scQLog :: ServerConfig -> Maybe ServiceName
scKeyLog :: ServerConfig -> ServiceName -> IO ()
scParameters :: ServerConfig -> Parameters
scGroups :: ServerConfig -> [Group]
scCiphers :: ServerConfig -> [Cipher]
scVersions :: ServerConfig -> [Version]
scDebugLog :: Maybe ServiceName
scSessionManager :: SessionManager
scRequireRetry :: Bool
scALPN :: Maybe (Version -> [SessionID] -> IO SessionID)
scAddresses :: [(IP, PortNumber)]
scUse0RTT :: Bool
scHooks :: Hooks
scCredentials :: Credentials
scQLog :: Maybe ServiceName
scKeyLog :: ServiceName -> IO ()
scParameters :: Parameters
scGroups :: [Group]
scCiphers :: [Cipher]
scVersions :: [Version]
..} Version
ver IO Parameters
getParams =
    ServerParams -> QUICCallbacks -> IO ()
tlsQUICServer ServerParams
sparams QUICCallbacks
callbacks
  where
    sparams :: ServerParams
sparams =
        forall a. Default a => a
def
            { serverShared :: Shared
serverShared = Shared
sshared
            , serverHooks :: ServerHooks
serverHooks = ServerHooks
hook
            , serverSupported :: Supported
serverSupported = Supported
supported
            , serverDebug :: DebugParams
serverDebug = DebugParams
debug
            , serverEarlyDataSize :: Int
serverEarlyDataSize = if Bool
scUse0RTT then Int
quicMaxEarlyDataSize else Int
0
            }
    convTP :: Parameters -> Parameters
convTP = Hooks -> Parameters -> Parameters
onTransportParametersCreated Hooks
scHooks
    convExt :: [ExtensionRaw] -> [ExtensionRaw]
convExt = Hooks -> [ExtensionRaw] -> [ExtensionRaw]
onTLSExtensionCreated Hooks
scHooks
    sshared :: Shared
sshared =
        forall a. Default a => a
def
            { sharedCredentials :: Credentials
sharedCredentials = Credentials
scCredentials
            , sharedSessionManager :: SessionManager
sharedSessionManager = SessionManager
scSessionManager
            }
    hook :: ServerHooks
hook =
        forall a. Default a => a
def
            { onALPNClientSuggest :: Maybe ([SessionID] -> IO SessionID)
onALPNClientSuggest = case Maybe (Version -> [SessionID] -> IO SessionID)
scALPN of
                Maybe (Version -> [SessionID] -> IO SessionID)
Nothing -> forall a. Maybe a
Nothing
                Just Version -> [SessionID] -> IO SessionID
io -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Version -> [SessionID] -> IO SessionID
io Version
ver
            , onEncryptedExtensionsCreating :: [ExtensionRaw] -> IO [ExtensionRaw]
onEncryptedExtensionsCreating = \[ExtensionRaw]
exts0 -> do
                Parameters
params <- IO Parameters
getParams
                let exts :: [ExtensionRaw]
exts = [ExtensionRaw] -> [ExtensionRaw]
convExt forall a b. (a -> b) -> a -> b
$ Version -> Parameters -> [ExtensionRaw]
parametersToExtensionRaw Version
ver forall a b. (a -> b) -> a -> b
$ Parameters -> Parameters
convTP Parameters
params
                forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [ExtensionRaw]
exts forall a. [a] -> [a] -> [a]
++ [ExtensionRaw]
exts0
            }
    supported :: Supported
supported =
        forall a. Default a => a
def
            { supportedVersions :: [Version]
supportedVersions = [Version
TLS13]
            , supportedCiphers :: [Cipher]
supportedCiphers = [Cipher]
scCiphers
            , supportedGroups :: [Group]
supportedGroups = [Group]
scGroups
            }
    debug :: DebugParams
debug =
        forall a. Default a => a
def
            { debugKeyLogger :: ServiceName -> IO ()
debugKeyLogger = ServiceName -> IO ()
scKeyLog
            }