{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Network.QUIC.Handshake where

import Data.List (intersect)
import qualified Network.TLS as TLS
import Network.TLS.QUIC
import qualified UnliftIO.Exception as E

import Network.QUIC.Config
import Network.QUIC.Connection
import Network.QUIC.Connector
import Network.QUIC.Crypto
import Network.QUIC.Imports
import Network.QUIC.Info
import Network.QUIC.Logger
import Network.QUIC.Parameters
import Network.QUIC.Qlog
import Network.QUIC.Recovery
import Network.QUIC.TLS
import Network.QUIC.Types

----------------------------------------------------------------

newtype HndState = HndState
    { HndState -> Int
hsRecvCnt :: Int  -- number of 'recv' calls since last 'send'
    }

newHndStateRef :: IO (IORef HndState)
newHndStateRef :: IO (IORef HndState)
newHndStateRef = forall a. a -> IO (IORef a)
newIORef HndState { hsRecvCnt :: Int
hsRecvCnt = Int
0 }

sendCompleted :: IORef HndState -> IO ()
sendCompleted :: IORef HndState -> IO ()
sendCompleted IORef HndState
hsr = forall a. IORef a -> (a -> a) -> IO ()
atomicModifyIORef'' IORef HndState
hsr forall a b. (a -> b) -> a -> b
$ \HndState
hs -> HndState
hs { hsRecvCnt :: Int
hsRecvCnt = Int
0 }

recvCompleted :: IORef HndState -> IO Int
recvCompleted :: IORef HndState -> IO Int
recvCompleted IORef HndState
hsr = forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef HndState
hsr forall a b. (a -> b) -> a -> b
$ \HndState
hs ->
    let cnt :: Int
cnt = HndState -> Int
hsRecvCnt HndState
hs in (HndState
hs { hsRecvCnt :: Int
hsRecvCnt = Int
cnt forall a. Num a => a -> a -> a
+ Int
1 }, Int
cnt)

rxLevelChanged :: IORef HndState -> IO ()
rxLevelChanged :: IORef HndState -> IO ()
rxLevelChanged = IORef HndState -> IO ()
sendCompleted

----------------------------------------------------------------

sendCryptoData :: Connection -> Output -> IO ()
sendCryptoData :: Connection -> Output -> IO ()
sendCryptoData = Connection -> Output -> IO ()
putOutput

recvCryptoData :: Connection -> IO Crypto
recvCryptoData :: Connection -> IO Crypto
recvCryptoData = Connection -> IO Crypto
takeCrypto

recvTLS :: Connection -> IORef HndState -> CryptLevel -> IO (Either TLS.TLSError ByteString)
recvTLS :: Connection
-> IORef HndState -> CryptLevel -> IO (Either TLSError ByteString)
recvTLS Connection
conn IORef HndState
hsr CryptLevel
level =
    case CryptLevel
level of
            CryptLevel
CryptInitial           -> EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
InitialLevel
            CryptLevel
CryptMasterSecret      -> forall {b}. String -> IO (Either TLSError b)
failure String
"QUIC does not receive data < TLS 1.3"
            CryptLevel
CryptEarlySecret       -> forall {b}. String -> IO (Either TLSError b)
failure String
"QUIC does not send early data with TLS library"
            CryptLevel
CryptHandshakeSecret   -> EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
HandshakeLevel
            CryptLevel
CryptApplicationSecret -> EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
RTT1Level
  where
    failure :: String -> IO (Either TLSError b)
failure = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> TLSError
internalError

    go :: EncryptionLevel -> IO (Either TLSError ByteString)
go EncryptionLevel
expected = do
        InpHandshake EncryptionLevel
actual ByteString
bs <- Connection -> IO Crypto
recvCryptoData Connection
conn
        if ByteString
bs forall a. Eq a => a -> a -> Bool
== ByteString
"" then
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left TLSError
TLS.Error_EOF
          else if EncryptionLevel
actual forall a. Eq a => a -> a -> Bool
/= EncryptionLevel
expected then
            forall {b}. String -> IO (Either TLSError b)
failure forall a b. (a -> b) -> a -> b
$ String
"encryption level mismatch: expected " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show EncryptionLevel
expected forall a. [a] -> [a] -> [a]
++ String
" but got " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show EncryptionLevel
actual
          else do
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Connector a => a -> Bool
isClient Connection
conn) forall a b. (a -> b) -> a -> b
$ do
                Int
n <- IORef HndState -> IO Int
recvCompleted IORef HndState
hsr
                -- Sending ACKs for three times rule
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int
n forall a. Integral a => a -> a -> a
`mod` Int
3) forall a. Eq a => a -> a -> Bool
== Int
1) forall a b. (a -> b) -> a -> b
$
                    Connection -> Output -> IO ()
sendCryptoData Connection
conn forall a b. (a -> b) -> a -> b
$ EncryptionLevel -> [Frame] -> IO () -> Output
OutControl EncryptionLevel
HandshakeLevel [] forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return ()
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right ByteString
bs

sendTLS :: Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS :: Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS Connection
conn IORef HndState
hsr [(CryptLevel, ByteString)]
x = do
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {b}. (CryptLevel, b) -> IO (EncryptionLevel, b)
convertLevel [(CryptLevel, ByteString)]
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> Output -> IO ()
sendCryptoData Connection
conn forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(EncryptionLevel, ByteString)] -> Output
OutHandshake
    IORef HndState -> IO ()
sendCompleted IORef HndState
hsr
  where
    convertLevel :: (CryptLevel, b) -> IO (EncryptionLevel, b)
convertLevel (CryptLevel
CryptInitial, b
bs) = forall (m :: * -> *) a. Monad m => a -> m a
return (EncryptionLevel
InitialLevel, b
bs)
    convertLevel (CryptLevel
CryptMasterSecret, b
_) = forall a. String -> IO a
errorTLS String
"QUIC does not send data < TLS 1.3"
    convertLevel (CryptLevel
CryptEarlySecret, b
_) = forall a. String -> IO a
errorTLS String
"QUIC does not receive early data with TLS library"
    convertLevel (CryptLevel
CryptHandshakeSecret, b
bs) = forall (m :: * -> *) a. Monad m => a -> m a
return (EncryptionLevel
HandshakeLevel, b
bs)
    convertLevel (CryptLevel
CryptApplicationSecret, b
bs) = forall (m :: * -> *) a. Monad m => a -> m a
return (EncryptionLevel
RTT1Level, b
bs)

internalError :: String -> TLS.TLSError
#if MIN_VERSION_tls(1,9,0)
internalError :: String -> TLSError
internalError String
msg     = String -> AlertDescription -> TLSError
TLS.Error_Protocol String
msg AlertDescription
TLS.InternalError
#else
internalError msg     = TLS.Error_Protocol (msg, True, TLS.InternalError)
#endif

----------------------------------------------------------------

handshakeClient :: ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient :: ClientConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeClient ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs = do
    forall q. KeepQlog q => q -> (Parameters, String) -> IO ()
qlogParamsSet Connection
conn (ClientConfig -> Parameters
ccParameters ClientConfig
conf, String
"local") -- fixme
    ClientConfig
-> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeClient' ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Version
getVersion Connection
conn forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (IORef HndState)
newHndStateRef

handshakeClient' :: ClientConfig -> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeClient' :: ClientConfig
-> Connection -> AuthCIDs -> Version -> IORef HndState -> IO ()
handshakeClient' ClientConfig
conf Connection
conn AuthCIDs
myAuthCIDs Version
ver IORef HndState
hsr = IO ()
handshaker
  where
    handshaker :: IO ()
handshaker = QUICCallbacks
-> ClientConfig
-> Version
-> AuthCIDs
-> SessionEstablish
-> Bool
-> IO ()
clientHandshaker QUICCallbacks
qc ClientConfig
conf Version
ver AuthCIDs
myAuthCIDs SessionEstablish
setter Bool
use0RTT forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` TLSException -> IO ()
sendCCTLSError
    qc :: QUICCallbacks
qc = QUICCallbacks { quicSend :: [(CryptLevel, ByteString)] -> IO ()
quicSend = Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS Connection
conn IORef HndState
hsr
                       , quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
quicRecv = Connection
-> IORef HndState -> CryptLevel -> IO (Either TLSError ByteString)
recvTLS Connection
conn IORef HndState
hsr
                       , quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
quicInstallKeys = Context -> KeyScheduleEvent -> IO ()
installKeysClient
                       , quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions = Connection -> Context -> [ExtensionRaw] -> IO ()
setPeerParams Connection
conn
                       , quicDone :: Context -> IO ()
quicDone = forall {p}. p -> IO ()
done
                       }
    setter :: SessionEstablish
setter = Connection -> SessionEstablish
setResumptionSession Connection
conn
    installKeysClient :: Context -> KeyScheduleEvent -> IO ()
installKeysClient Context
_ctx (InstallEarlyKeys Maybe EarlySecretInfo
Nothing) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    installKeysClient Context
_ctx (InstallEarlyKeys (Just (EarlySecretInfo Cipher
cphr ClientTrafficSecret EarlySecret
cts))) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT0Level Cipher
cphr
        forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
RTT0Level (ClientTrafficSecret EarlySecret
cts, forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret ByteString
"")
        Connection -> IO ()
setConnection0RTTReady Connection
conn
    installKeysClient Context
_ctx (InstallHandshakeKeys (HandshakeSecretInfo Cipher
cphr TrafficSecrets HandshakeSecret
tss)) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
HandshakeLevel Cipher
cphr
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT1Level Cipher
cphr
        forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
HandshakeLevel TrafficSecrets HandshakeSecret
tss
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
HandshakeLevel
        IORef HndState -> IO ()
rxLevelChanged IORef HndState
hsr
    installKeysClient Context
ctx (InstallApplicationKeys appSecInf :: ApplicationSecretInfo
appSecInf@(ApplicationSecretInfo TrafficSecrets ApplicationSecret
tss)) = do
        Connection -> Context -> ApplicationSecretInfo -> IO ()
storeNegotiated Connection
conn Context
ctx ApplicationSecretInfo
appSecInf
        Connection -> TrafficSecrets ApplicationSecret -> IO ()
initializeCoder1RTT Connection
conn TrafficSecrets ApplicationSecret
tss
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
RTT1Level
        IORef HndState -> IO ()
rxLevelChanged IORef HndState
hsr
        Connection -> IO ()
setConnection1RTTReady Connection
conn
        CIDInfo
cidInfo <- Connection -> IO CIDInfo
getNewMyCID Connection
conn
        Connection -> Output -> IO ()
putOutput Connection
conn forall a b. (a -> b) -> a -> b
$ [(EncryptionLevel, ByteString)] -> Output
OutHandshake [] -- for h3spec testing
        Connection -> EncryptionLevel -> [Frame] -> IO ()
sendFrames Connection
conn EncryptionLevel
RTT1Level [CIDInfo -> Int -> Frame
NewConnectionID CIDInfo
cidInfo Int
0]
    done :: p -> IO ()
done p
_ctx = do
        -- Validating Chosen Version
        Maybe VersionInfo
mPeerVerInfo <- Parameters -> Maybe VersionInfo
versionInformation forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Parameters
getPeerParameters Connection
conn
        case Maybe VersionInfo
mPeerVerInfo of
          Maybe VersionInfo
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
          Just VersionInfo
peerVerInfo -> do
              Version
hdrVer <- Connection -> IO Version
getVersion Connection
conn
              forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
hdrVer forall a. Eq a => a -> a -> Bool
/= VersionInfo -> Version
chosenVersion VersionInfo
peerVerInfo) IO ()
sendCCVNError
        ConnectionInfo
info <- Connection -> IO ConnectionInfo
getConnectionInfo Connection
conn
        Connection -> DebugLogger
connDebugLog Connection
conn forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> Builder
bhow ConnectionInfo
info
    use0RTT :: Bool
use0RTT = ClientConfig -> Bool
ccUse0RTT ClientConfig
conf

----------------------------------------------------------------

handshakeServer :: ServerConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeServer :: ServerConfig -> Connection -> AuthCIDs -> IO (IO ())
handshakeServer ServerConfig
conf Connection
conn AuthCIDs
myAuthCIDs =
    ServerConfig
-> Connection
-> Version
-> IORef HndState
-> IORef Parameters
-> IO ()
handshakeServer' ServerConfig
conf Connection
conn forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Version
getVersion Connection
conn
                               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO (IORef HndState)
newHndStateRef
                               forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (IORef a)
newIORef Parameters
params
  where
    params :: Parameters
params = AuthCIDs -> Parameters -> Parameters
setCIDsToParameters AuthCIDs
myAuthCIDs forall a b. (a -> b) -> a -> b
$ ServerConfig -> Parameters
scParameters ServerConfig
conf

handshakeServer' :: ServerConfig -> Connection -> Version -> IORef HndState -> IORef Parameters -> IO ()
handshakeServer' :: ServerConfig
-> Connection
-> Version
-> IORef HndState
-> IORef Parameters
-> IO ()
handshakeServer' ServerConfig
conf Connection
conn Version
ver IORef HndState
hsRef IORef Parameters
paramRef = IO ()
handshaker
  where
    handshaker :: IO ()
handshaker = QUICCallbacks -> ServerConfig -> Version -> IO Parameters -> IO ()
serverHandshaker QUICCallbacks
qc ServerConfig
conf Version
ver IO Parameters
getParams forall (m :: * -> *) e a.
(MonadUnliftIO m, Exception e) =>
m a -> (e -> m a) -> m a
`E.catch` TLSException -> IO ()
sendCCTLSError
    qc :: QUICCallbacks
qc = QUICCallbacks { quicSend :: [(CryptLevel, ByteString)] -> IO ()
quicSend = Connection -> IORef HndState -> [(CryptLevel, ByteString)] -> IO ()
sendTLS Connection
conn IORef HndState
hsRef
                       , quicRecv :: CryptLevel -> IO (Either TLSError ByteString)
quicRecv = Connection
-> IORef HndState -> CryptLevel -> IO (Either TLSError ByteString)
recvTLS Connection
conn IORef HndState
hsRef
                       , quicInstallKeys :: Context -> KeyScheduleEvent -> IO ()
quicInstallKeys = Context -> KeyScheduleEvent -> IO ()
installKeysServer
                       , quicNotifyExtensions :: Context -> [ExtensionRaw] -> IO ()
quicNotifyExtensions = Connection -> Context -> [ExtensionRaw] -> IO ()
setPeerParams Connection
conn
                       , quicDone :: Context -> IO ()
quicDone = Context -> IO ()
done
                       }
    installKeysServer :: Context -> KeyScheduleEvent -> IO ()
installKeysServer Context
_ctx (InstallEarlyKeys Maybe EarlySecretInfo
Nothing) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    installKeysServer Context
_ctx (InstallEarlyKeys (Just (EarlySecretInfo Cipher
cphr ClientTrafficSecret EarlySecret
cts))) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT0Level Cipher
cphr
        forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
RTT0Level (ClientTrafficSecret EarlySecret
cts, forall a. ByteString -> ServerTrafficSecret a
ServerTrafficSecret ByteString
"")
        Connection -> IO ()
setConnection0RTTReady Connection
conn
    installKeysServer Context
_ctx (InstallHandshakeKeys (HandshakeSecretInfo Cipher
cphr TrafficSecrets HandshakeSecret
tss)) = do
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
HandshakeLevel Cipher
cphr
        Connection -> EncryptionLevel -> Cipher -> IO ()
setCipher Connection
conn EncryptionLevel
RTT1Level Cipher
cphr
        forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
HandshakeLevel TrafficSecrets HandshakeSecret
tss
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
HandshakeLevel
        IORef HndState -> IO ()
rxLevelChanged IORef HndState
hsRef
    installKeysServer Context
ctx (InstallApplicationKeys appSecInf :: ApplicationSecretInfo
appSecInf@(ApplicationSecretInfo TrafficSecrets ApplicationSecret
tss)) = do
        Connection -> Context -> ApplicationSecretInfo -> IO ()
storeNegotiated Connection
conn Context
ctx ApplicationSecretInfo
appSecInf
        Connection -> TrafficSecrets ApplicationSecret -> IO ()
initializeCoder1RTT Connection
conn TrafficSecrets ApplicationSecret
tss
        -- will switch to RTT1Level after client Finished
        -- is received and verified
    done :: Context -> IO ()
done Context
ctx = do
        Connection -> EncryptionLevel -> IO ()
setEncryptionLevel Connection
conn EncryptionLevel
RTT1Level
        Context -> IO (Maybe CertificateChain)
TLS.getClientCertificateChain Context
ctx forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Connection -> Maybe CertificateChain -> IO ()
setCertificateChain Connection
conn
        Connection -> Microseconds -> IO () -> IO ()
fire Connection
conn (Int -> Microseconds
Microseconds Int
100000) forall a b. (a -> b) -> a -> b
$ do
            let ldcc :: LDCC
ldcc = Connection -> LDCC
connLDCC Connection
conn
            Bool
discarded0 <- LDCC -> EncryptionLevel -> IO Bool
getAndSetPacketNumberSpaceDiscarded LDCC
ldcc EncryptionLevel
RTT0Level
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
discarded0 forall a b. (a -> b) -> a -> b
$ Connection -> EncryptionLevel -> IO ()
dropSecrets Connection
conn EncryptionLevel
RTT0Level
            Bool
discarded1 <- LDCC -> EncryptionLevel -> IO Bool
getAndSetPacketNumberSpaceDiscarded LDCC
ldcc EncryptionLevel
HandshakeLevel
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
discarded1 forall a b. (a -> b) -> a -> b
$ do
                Connection -> EncryptionLevel -> IO ()
dropSecrets Connection
conn EncryptionLevel
HandshakeLevel
                LDCC -> EncryptionLevel -> IO ()
onPacketNumberSpaceDiscarded (Connection -> LDCC
connLDCC Connection
conn) EncryptionLevel
HandshakeLevel
            Connection -> EncryptionLevel -> IO ()
clearCryptoStream Connection
conn EncryptionLevel
HandshakeLevel
            Connection -> EncryptionLevel -> IO ()
clearCryptoStream Connection
conn EncryptionLevel
RTT1Level
        Connection -> IO ()
setConnection1RTTReady Connection
conn
        Connection -> IO ()
setConnectionEstablished Connection
conn
--        sendFrames conn RTT1Level [HandshakeDone]
        --
        ConnectionInfo
info <- Connection -> IO ConnectionInfo
getConnectionInfo Connection
conn
        Connection -> DebugLogger
connDebugLog Connection
conn forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> Builder
bhow ConnectionInfo
info
    getParams :: IO Parameters
getParams = do
        Parameters
params <- forall a. IORef a -> IO a
readIORef IORef Parameters
paramRef
        VersionInfo
verInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
        forall (m :: * -> *) a. Monad m => a -> m a
return Parameters
params { versionInformation :: Maybe VersionInfo
versionInformation = forall a. a -> Maybe a
Just VersionInfo
verInfo }

----------------------------------------------------------------

setPeerParams :: Connection -> TLS.Context -> [ExtensionRaw] -> IO ()
setPeerParams :: Connection -> Context -> [ExtensionRaw] -> IO ()
setPeerParams Connection
conn Context
_ctx [ExtensionRaw]
peerExts = do
    ExtensionID
tpId <- Version -> ExtensionID
extensionIDForTtransportParameter forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection -> IO Version
getVersion Connection
conn
    case forall {t :: * -> *}.
Foldable t =>
ExtensionID -> t ExtensionRaw -> Maybe ExtensionRaw
getTP ExtensionID
tpId [ExtensionRaw]
peerExts of
      Maybe ExtensionRaw
Nothing                  -> AlertDescription -> ReasonPhrase -> IO ()
sendCCTLSAlert AlertDescription
TLS.MissingExtension ReasonPhrase
"QUIC transport parameters are mssing"
      Just (ExtensionRaw ExtensionID
_ ByteString
bs) -> ByteString -> IO ()
setPP ByteString
bs
  where
    getTP :: ExtensionID -> t ExtensionRaw -> Maybe ExtensionRaw
getTP ExtensionID
n = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(ExtensionRaw ExtensionID
extid ByteString
_) -> ExtensionID
extid forall a. Eq a => a -> a -> Bool
== ExtensionID
n)
    setPP :: ByteString -> IO ()
setPP ByteString
bs = case ByteString -> Maybe Parameters
decodeParameters ByteString
bs of
      Maybe Parameters
Nothing     -> IO ()
sendCCParamError
      Just Parameters
params -> do
          Parameters -> IO ()
checkAuthCIDs Parameters
params
          Parameters -> IO ()
checkInvalid Parameters
params
          Parameters -> IO ()
setParams Parameters
params
          forall q. KeepQlog q => q -> (Parameters, String) -> IO ()
qlogParamsSet Connection
conn (Parameters
params,String
"remote")
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Connector a => a -> Bool
isServer Connection
conn) forall a b. (a -> b) -> a -> b
$
              Maybe VersionInfo -> IO ()
serverVersionNegotiation forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe VersionInfo
versionInformation Parameters
params

    checkAuthCIDs :: Parameters -> IO ()
checkAuthCIDs Parameters
params = do
        AuthCIDs
peerAuthCIDs <- Connection -> IO AuthCIDs
getPeerAuthCIDs Connection
conn
        forall {a}. Eq a => Maybe a -> Maybe a -> IO ()
ensure (Parameters -> Maybe CID
initialSourceConnectionId Parameters
params) forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
initSrcCID AuthCIDs
peerAuthCIDs
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Connector a => a -> Bool
isClient Connection
conn) forall a b. (a -> b) -> a -> b
$ do
            forall {a}. Eq a => Maybe a -> Maybe a -> IO ()
ensure (Parameters -> Maybe CID
originalDestinationConnectionId Parameters
params) forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
origDstCID AuthCIDs
peerAuthCIDs
            forall {a}. Eq a => Maybe a -> Maybe a -> IO ()
ensure (Parameters -> Maybe CID
retrySourceConnectionId Parameters
params) forall a b. (a -> b) -> a -> b
$ AuthCIDs -> Maybe CID
retrySrcCID AuthCIDs
peerAuthCIDs
    ensure :: Maybe a -> Maybe a -> IO ()
ensure Maybe a
_ Maybe a
Nothing = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    ensure Maybe a
v0 Maybe a
v1
      | Maybe a
v0 forall a. Eq a => a -> a -> Bool
== Maybe a
v1  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
      | Bool
otherwise = IO ()
sendCCParamError
    checkInvalid :: Parameters -> IO ()
checkInvalid Parameters
params = do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Parameters -> Int
maxUdpPayloadSize Parameters
params forall a. Ord a => a -> a -> Bool
< Int
1200) IO ()
sendCCParamError
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Parameters -> Int
ackDelayExponent Parameters
params forall a. Ord a => a -> a -> Bool
> Int
20) IO ()
sendCCParamError
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Parameters -> Milliseconds
maxAckDelay Parameters
params forall a. Ord a => a -> a -> Bool
>= Milliseconds
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
14 :: Int)) IO ()
sendCCParamError
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Connector a => a -> Bool
isServer Connection
conn) forall a b. (a -> b) -> a -> b
$ do
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe CID
originalDestinationConnectionId Parameters
params) IO ()
sendCCParamError
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe ByteString
preferredAddress Parameters
params) IO ()
sendCCParamError
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe CID
retrySourceConnectionId Parameters
params) IO ()
sendCCParamError
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe StatelessResetToken
statelessResetToken Parameters
params) IO ()
sendCCParamError
        let vi :: VersionInfo
vi = case Parameters -> Maybe VersionInfo
versionInformation Parameters
params of
              Maybe VersionInfo
Nothing  -> Version -> [Version] -> VersionInfo
VersionInfo Version
Version1 [Version
Version1]
              Just VersionInfo
vi0 -> VersionInfo
vi0
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (VersionInfo
vi forall a. Eq a => a -> a -> Bool
== VersionInfo
brokenVersionInfo) IO ()
sendCCParamError
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
Negotiation forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` VersionInfo -> [Version]
otherVersions VersionInfo
vi) IO ()
sendCCParamError
        -- Always False for servers
        Bool
isICVN <- Connection -> IO Bool
getIncompatibleVN Connection
conn
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
isICVN forall a b. (a -> b) -> a -> b
$ do
            -- Validating Other Version fields.
            VersionInfo
verInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
            let myVer :: Version
myVer  = VersionInfo -> Version
chosenVersion VersionInfo
verInfo
                myVers :: [Version]
myVers = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) forall a b. (a -> b) -> a -> b
$ VersionInfo -> [Version]
otherVersions VersionInfo
verInfo
                peerVers :: [Version]
peerVers = VersionInfo -> [Version]
otherVersions VersionInfo
vi
            case [Version]
myVers forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
              Version
ver:[Version]
_ | Version
ver forall a. Eq a => a -> a -> Bool
== Version
myVer -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
              [Version]
_                    -> IO ()
sendCCVNError


    setParams :: Parameters -> IO ()
setParams Parameters
params = do
        Connection -> Parameters -> IO ()
setPeerParameters Connection
conn Parameters
params
        forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Connection -> StatelessResetToken -> IO ()
setPeerStatelessResetToken Connection
conn) forall a b. (a -> b) -> a -> b
$ Parameters -> Maybe StatelessResetToken
statelessResetToken Parameters
params
        Connection -> Int -> IO ()
setTxMaxData Connection
conn forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxData Parameters
params
        Connection -> Microseconds -> IO ()
setMinIdleTimeout Connection
conn forall a b. (a -> b) -> a -> b
$ Milliseconds -> Microseconds
milliToMicro forall a b. (a -> b) -> a -> b
$ Parameters -> Milliseconds
maxIdleTimeout Parameters
params
        LDCC -> Microseconds -> IO ()
setMaxAckDaley (Connection -> LDCC
connLDCC Connection
conn) forall a b. (a -> b) -> a -> b
$ Milliseconds -> Microseconds
milliToMicro forall a b. (a -> b) -> a -> b
$ Parameters -> Milliseconds
maxAckDelay Parameters
params
        Connection -> Int -> IO ()
setTxMaxStreams Connection
conn forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsBidi Parameters
params
        Connection -> Int -> IO ()
setTxUniMaxStreams Connection
conn forall a b. (a -> b) -> a -> b
$ Parameters -> Int
initialMaxStreamsUni Parameters
params

    serverVersionNegotiation :: Maybe VersionInfo -> IO ()
serverVersionNegotiation Maybe VersionInfo
Nothing = forall (m :: * -> *) a. Monad m => a -> m a
return ()
    serverVersionNegotiation (Just VersionInfo
peerVerInfo) = do
        VersionInfo
myVerInfo <- Connection -> IO VersionInfo
getVersionInfo Connection
conn
        let clientVer :: Version
clientVer = VersionInfo -> Version
chosenVersion VersionInfo
myVerInfo
            myVers :: [Version]
myVers = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not forall b c a. (b -> c) -> (a -> b) -> a -> c
. Version -> Bool
isGreasingVersion) forall a b. (a -> b) -> a -> b
$ VersionInfo -> [Version]
otherVersions VersionInfo
myVerInfo
            peerVers :: [Version]
peerVers = VersionInfo -> [Version]
otherVersions VersionInfo
peerVerInfo
        -- Server's preference should be preferred.
        case [Version]
myVers forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Version]
peerVers of
          vers :: [Version]
vers@(Version
serverVer:[Version]
_)
            | Version
clientVer forall a. Eq a => a -> a -> Bool
/= Version
serverVer -> do
                Connection -> VersionInfo -> IO ()
setVersionInfo Connection
conn forall a b. (a -> b) -> a -> b
$ Version -> [Version] -> VersionInfo
VersionInfo Version
serverVer [Version]
vers
                CID
dcid <- Connection -> IO CID
getClientDstCID Connection
conn
                forall a.
Connection -> EncryptionLevel -> TrafficSecrets a -> IO ()
initializeCoder Connection
conn EncryptionLevel
InitialLevel forall a b. (a -> b) -> a -> b
$ Version -> CID -> TrafficSecrets InitialSecret
initialSecrets Version
serverVer CID
dcid
          [Version]
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()

storeNegotiated :: Connection -> TLS.Context -> ApplicationSecretInfo -> IO ()
storeNegotiated :: Connection -> Context -> ApplicationSecretInfo -> IO ()
storeNegotiated Connection
conn Context
ctx ApplicationSecretInfo
appSecInf = do
    Maybe ByteString
appPro <- forall (m :: * -> *). MonadIO m => Context -> m (Maybe ByteString)
TLS.getNegotiatedProtocol Context
ctx
    Maybe Information
minfo <- Context -> IO (Maybe Information)
TLS.contextGetInformation Context
ctx
    let mode :: HandshakeMode13
mode = forall a. a -> Maybe a -> a
fromMaybe HandshakeMode13
FullHandshake (Maybe Information
minfo forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Information -> Maybe HandshakeMode13
TLS.infoTLS13HandshakeMode)
    Connection
-> HandshakeMode13
-> Maybe ByteString
-> ApplicationSecretInfo
-> IO ()
setNegotiated Connection
conn HandshakeMode13
mode Maybe ByteString
appPro ApplicationSecretInfo
appSecInf

----------------------------------------------------------------

sendCCParamError :: IO ()
sendCCParamError :: IO ()
sendCCParamError = forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
WrongTransportParameter

sendCCVNError :: IO ()
sendCCVNError :: IO ()
sendCCVNError = forall (m :: * -> *) e a. (MonadIO m, Exception e) => e -> m a
E.throwIO InternalControl
WrongVersionInformation

sendCCTLSError :: TLS.TLSException -> IO ()
sendCCTLSError :: TLSException -> IO ()
sendCCTLSError (TLS.HandshakeFailed (TLS.Error_Misc String
"WrongTransportParameter")) = TransportError -> ReasonPhrase -> IO ()
closeConnection TransportError
TransportParameterError ReasonPhrase
"Transport parameter error"
sendCCTLSError (TLS.HandshakeFailed (TLS.Error_Misc String
"WrongVersionInformation")) = TransportError -> ReasonPhrase -> IO ()
closeConnection TransportError
VersionNegotiationError ReasonPhrase
"Version negotiation error"
sendCCTLSError TLSException
e = TransportError -> ReasonPhrase -> IO ()
closeConnection TransportError
err ReasonPhrase
msg
  where
    tlserr :: TLSError
tlserr = TLSException -> TLSError
getErrorCause TLSException
e
    err :: TransportError
err = AlertDescription -> TransportError
cryptoError forall a b. (a -> b) -> a -> b
$ TLSError -> AlertDescription
errorToAlertDescription TLSError
tlserr
    msg :: ReasonPhrase
msg = String -> ReasonPhrase
shortpack forall a b. (a -> b) -> a -> b
$ TLSError -> String
errorToAlertMessage TLSError
tlserr

sendCCTLSAlert :: TLS.AlertDescription -> ReasonPhrase -> IO ()
sendCCTLSAlert :: AlertDescription -> ReasonPhrase -> IO ()
sendCCTLSAlert AlertDescription
a ReasonPhrase
msg = TransportError -> ReasonPhrase -> IO ()
closeConnection (AlertDescription -> TransportError
cryptoError AlertDescription
a) ReasonPhrase
msg

getErrorCause :: TLS.TLSException -> TLS.TLSError
getErrorCause :: TLSException -> TLSError
getErrorCause (TLS.Terminated Bool
_ String
_ TLSError
e)   = TLSError
e
getErrorCause (TLS.HandshakeFailed TLSError
e)  = TLSError
e
#if MIN_VERSION_tls(1,8,0)
getErrorCause (TLS.PostHandshake TLSError
e)    = TLSError
e
getErrorCause (TLS.Uncontextualized TLSError
e) = TLSError
e
#endif
getErrorCause TLSException
e =
    let msg :: String
msg = String
"unexpected TLS exception: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show TLSException
e
#if MIN_VERSION_tls(1,9,0)
     in String -> AlertDescription -> TLSError
TLS.Error_Protocol String
msg AlertDescription
TLS.InternalError
#else
     in TLS.Error_Protocol (msg, True, TLS.InternalError)
#endif