{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}

module Network.TLS.Handshake.Server.ClientHello13 (
    processClientHello13,
    sendHRR,
) where

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.Random
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13

-- TLS 1.3 or later
processClientHello13
    :: ServerParams
    -> Context
    -> CH
    -> IO (Maybe KeyShareEntry, (Cipher, Hash, Bool))
processClientHello13 :: ServerParams
-> Context -> CH -> IO (Maybe KeyShareEntry, (Cipher, Hash, Bool))
processClientHello13 ServerParams
sparams Context
ctx CH{[CipherID]
[ExtensionRaw]
Session
chSession :: Session
chCiphers :: [CipherID]
chExtensions :: [ExtensionRaw]
chSession :: CH -> Session
chCiphers :: CH -> [CipherID]
chExtensions :: CH -> [ExtensionRaw]
..} = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when
        ((ExtensionRaw -> Bool) -> [ExtensionRaw] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\(ExtensionRaw ExtensionID
eid ByteString
_) -> ExtensionID
eid ExtensionID -> ExtensionID -> Bool
forall a. Eq a => a -> a -> Bool
== ExtensionID
EID_PreSharedKey) ([ExtensionRaw] -> Bool) -> [ExtensionRaw] -> Bool
forall a b. (a -> b) -> a -> b
$ [ExtensionRaw] -> [ExtensionRaw]
forall a. HasCallStack => [a] -> [a]
init [ExtensionRaw]
chExtensions)
        (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore
        (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"extension pre_shared_key must be last" AlertDescription
IllegalParameter
    -- Deciding cipher.
    -- The shared cipherlist can become empty after filtering for compatible
    -- creds, check now before calling onCipherChoosing, which does not handle
    -- empty lists.
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Cipher] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Cipher]
ciphersFilteredVersion) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"no cipher in common with the TLS 1.3 client" AlertDescription
HandshakeFailure
    let usedCipher :: Cipher
usedCipher = ServerHooks -> Version -> [Cipher] -> Cipher
onCipherChoosing (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) Version
TLS13 [Cipher]
ciphersFilteredVersion
        usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
        rtt0 :: Bool
rtt0 = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_EarlyData [ExtensionRaw]
chExtensions Maybe ByteString
-> (ByteString -> Maybe EarlyDataIndication)
-> Maybe EarlyDataIndication
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe EarlyDataIndication
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
            Just (EarlyDataIndication Maybe Word32
_) -> Bool
True
            Maybe EarlyDataIndication
Nothing -> Bool
False
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
rtt0 (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        -- mark a 0-RTT attempt before a possible HRR, and before updating the
        -- status again if 0-RTT successful
        Context -> Established -> IO ()
setEstablished Context
ctx (Int -> Established
EarlyDataNotAllowed Int
3) -- hardcoding
        -- Deciding key exchange from key shares
    [KeyShareEntry]
keyShares <- case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_KeyShare [ExtensionRaw]
chExtensions of
        Maybe ByteString
Nothing ->
            TLSError -> IO [KeyShareEntry]
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO [KeyShareEntry]) -> TLSError -> IO [KeyShareEntry]
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    String
"key exchange not implemented, expected key_share extension"
                    AlertDescription
MissingExtension
        Just ByteString
kss -> case MessageType -> ByteString -> Maybe KeyShare
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello ByteString
kss of
            Just (KeyShareClientHello [KeyShareEntry]
kses) -> [KeyShareEntry] -> IO [KeyShareEntry]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [KeyShareEntry]
kses
            Just KeyShare
_ ->
                String -> IO [KeyShareEntry]
forall a. HasCallStack => String -> a
error String
"processClientHello13: invalid KeyShare value"
            Maybe KeyShare
_ ->
                TLSError -> IO [KeyShareEntry]
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO [KeyShareEntry]) -> TLSError -> IO [KeyShareEntry]
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"broken key_share" AlertDescription
DecodeError
    Maybe KeyShareEntry
mshare <- [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
keyShares [Group]
serverGroups
    (Maybe KeyShareEntry, (Cipher, Hash, Bool))
-> IO (Maybe KeyShareEntry, (Cipher, Hash, Bool))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry
mshare, (Cipher
usedCipher, Hash
usedHash, Bool
rtt0))
  where
    ciphersFilteredVersion :: [Cipher]
ciphersFilteredVersion = (Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter ((CipherID -> [CipherID] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CipherID]
chCiphers) (CipherID -> Bool) -> (Cipher -> CipherID) -> Cipher -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherID
cipherID) [Cipher]
serverCiphers
    serverCiphers :: [Cipher]
serverCiphers =
        (Cipher -> Bool) -> [Cipher] -> [Cipher]
forall a. (a -> Bool) -> [a] -> [a]
filter
            (Version -> Cipher -> Bool
cipherAllowedForVersion Version
TLS13)
            (Supported -> [Cipher]
supportedCiphers (Supported -> [Cipher]) -> Supported -> [Cipher]
forall a b. (a -> b) -> a -> b
$ ServerParams -> Supported
serverSupported ServerParams
sparams)
    serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)

findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare :: [KeyShareEntry] -> [Group] -> IO (Maybe KeyShareEntry)
findKeyShare [KeyShareEntry]
ks [Group]
ggs = [Group] -> IO (Maybe KeyShareEntry)
forall {m :: * -> *}.
MonadIO m =>
[Group] -> m (Maybe KeyShareEntry)
go [Group]
ggs
  where
    go :: [Group] -> m (Maybe KeyShareEntry)
go [] = Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KeyShareEntry
forall a. Maybe a
Nothing
    go (Group
g : [Group]
gs) = case (KeyShareEntry -> Bool) -> [KeyShareEntry] -> [KeyShareEntry]
forall a. (a -> Bool) -> [a] -> [a]
filter (Group -> KeyShareEntry -> Bool
grpEq Group
g) [KeyShareEntry]
ks of
        [] -> [Group] -> m (Maybe KeyShareEntry)
go [Group]
gs
        [KeyShareEntry
k] -> do
            Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (KeyShareEntry -> Bool
checkKeyShareKeyLength KeyShareEntry
k) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                TLSError -> m ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m ()) -> TLSError -> m ()
forall a b. (a -> b) -> a -> b
$
                    String -> AlertDescription -> TLSError
Error_Protocol String
"broken key_share" AlertDescription
IllegalParameter
            Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KeyShareEntry -> m (Maybe KeyShareEntry))
-> Maybe KeyShareEntry -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ KeyShareEntry -> Maybe KeyShareEntry
forall a. a -> Maybe a
Just KeyShareEntry
k
        [KeyShareEntry]
_ -> TLSError -> m (Maybe KeyShareEntry)
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> m (Maybe KeyShareEntry))
-> TLSError -> m (Maybe KeyShareEntry)
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"duplicated key_share" AlertDescription
IllegalParameter
    grpEq :: Group -> KeyShareEntry -> Bool
grpEq Group
g KeyShareEntry
ent = Group
g Group -> Group -> Bool
forall a. Eq a => a -> a -> Bool
== KeyShareEntry -> Group
keyShareEntryGroup KeyShareEntry
ent

sendHRR :: Context -> (Cipher, a, b) -> CH -> IO ()
sendHRR :: forall a b. Context -> (Cipher, a, b) -> CH -> IO ()
sendHRR Context
ctx (Cipher
usedCipher, a
_, b
_) CH{[CipherID]
[ExtensionRaw]
Session
chSession :: CH -> Session
chCiphers :: CH -> [CipherID]
chExtensions :: CH -> [ExtensionRaw]
chSession :: Session
chCiphers :: [CipherID]
chExtensions :: [ExtensionRaw]
..} = do
    Bool
twice <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
twice (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"Hello retry not allowed again" AlertDescription
HandshakeFailure
    Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ Bool -> TLSSt ()
setTLS13HRR Bool
True
    IO (Either TLSError ()) -> IO ()
forall (m :: * -> *) a. MonadIO m => m (Either TLSError a) -> m a
failOnEitherError (IO (Either TLSError ()) -> IO ())
-> IO (Either TLSError ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Context
-> HandshakeM (Either TLSError ()) -> IO (Either TLSError ())
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM (Either TLSError ()) -> IO (Either TLSError ()))
-> HandshakeM (Either TLSError ()) -> IO (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 Cipher
usedCipher
    let clientGroups :: [Group]
clientGroups = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_SupportedGroups [ExtensionRaw]
chExtensions
            Maybe ByteString
-> (ByteString -> Maybe SupportedGroups) -> Maybe SupportedGroups
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe SupportedGroups
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
            Just (SupportedGroups [Group]
gs) -> [Group]
gs
            Maybe SupportedGroups
Nothing -> []
        possibleGroups :: [Group]
possibleGroups = [Group]
serverGroups [Group] -> [Group] -> [Group]
forall a. Eq a => [a] -> [a] -> [a]
`intersect` [Group]
clientGroups
    case [Group]
possibleGroups of
        [] ->
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol String
"no group in common with the client for HRR" AlertDescription
HandshakeFailure
        Group
g : [Group]
_ -> do
            let serverKeyShare :: ByteString
serverKeyShare = KeyShare -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (KeyShare -> ByteString) -> KeyShare -> ByteString
forall a b. (a -> b) -> a -> b
$ Group -> KeyShare
KeyShareHRR Group
g
                selectedVersion :: ByteString
selectedVersion = SupportedVersions -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (SupportedVersions -> ByteString)
-> SupportedVersions -> ByteString
forall a b. (a -> b) -> a -> b
$ Version -> SupportedVersions
SupportedVersionsServerHello Version
TLS13
                extensions :: [ExtensionRaw]
extensions =
                    [ ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw ExtensionID
EID_KeyShare ByteString
serverKeyShare
                    , ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw ExtensionID
EID_SupportedVersions ByteString
selectedVersion
                    ]
                hrr :: Handshake13
hrr = ServerRandom
-> Session -> CipherID -> [ExtensionRaw] -> Handshake13
ServerHello13 ServerRandom
hrrRandom Session
chSession (Cipher -> CipherID
cipherID Cipher
usedCipher) [ExtensionRaw]
extensions
            Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
HelloRetryRequest
            Context -> (forall {b}. Monoid b => PacketFlightM b ()) -> IO ()
forall a.
Context -> (forall b. Monoid b => PacketFlightM b a) -> IO a
runPacketFlight Context
ctx ((forall {b}. Monoid b => PacketFlightM b ()) -> IO ())
-> (forall {b}. Monoid b => PacketFlightM b ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Context -> Packet13 -> PacketFlightM b ()
forall b. Monoid b => Context -> Packet13 -> PacketFlightM b ()
loadPacket13 Context
ctx (Packet13 -> PacketFlightM b ()) -> Packet13 -> PacketFlightM b ()
forall a b. (a -> b) -> a -> b
$ [Handshake13] -> Packet13
Handshake13 [Handshake13
hrr]
                Context -> PacketFlightM b ()
forall b. Monoid b => Context -> PacketFlightM b ()
sendChangeCipherSpec13 Context
ctx
  where
    serverGroups :: [Group]
serverGroups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)