{-# LANGUAGE OverloadedStrings #-}

module Network.TLS.Handshake.Client (
    handshakeClient,
    handshakeClientWith,
    postHandshakeAuthClientWith,
) where

import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Client.ClientHello
import Network.TLS.Handshake.Client.ServerHello
import Network.TLS.Handshake.Client.TLS12
import Network.TLS.Handshake.Client.TLS13
import Network.TLS.Handshake.Common13
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.IO
import Network.TLS.Imports
import Network.TLS.Measurement
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct

----------------------------------------------------------------

handshakeClientWith :: ClientParams -> Context -> Handshake -> IO ()
handshakeClientWith :: ClientParams -> Context -> Handshake -> IO ()
handshakeClientWith ClientParams
cparams Context
ctx Handshake
HelloRequest = ClientParams -> Context -> IO ()
handshakeClient ClientParams
cparams Context
ctx
handshakeClientWith ClientParams
_ Context
_ Handshake
_ =
    TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
        String -> AlertDescription -> TLSError
Error_Protocol
            String
"unexpected handshake message received in handshakeClientWith"
            AlertDescription
HandshakeFailure

-- client part of handshake. send a bunch of handshake of client
-- values intertwined with response from the server.
handshakeClient :: ClientParams -> Context -> IO ()
handshakeClient :: ClientParams -> Context -> IO ()
handshakeClient ClientParams
cparams Context
ctx = ClientParams
-> Context
-> [Group]
-> Maybe (ClientRandom, Session, Version)
-> IO ()
handshake ClientParams
cparams Context
ctx [Group]
groups Maybe (ClientRandom, Session, Version)
forall a. Maybe a
Nothing
  where
    groupsSupported :: [Group]
groupsSupported = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)
    groups :: [Group]
groups = case ClientParams -> Maybe (SessionID, SessionData)
clientWantSessionResume ClientParams
cparams of
        Maybe (SessionID, SessionData)
Nothing -> [Group]
groupsSupported
        Just (SessionID
_, SessionData
sdata) -> case SessionData -> Maybe Group
sessionGroup SessionData
sdata of
            Maybe Group
Nothing -> [] -- TLS 1.2 or earlier
            Just Group
grp -> Group
grp Group -> [Group] -> [Group]
forall a. a -> [a] -> [a]
: (Group -> Bool) -> [Group] -> [Group]
forall a. (a -> Bool) -> [a] -> [a]
filter (Group -> Group -> Bool
forall a. Eq a => a -> a -> Bool
/= Group
grp) [Group]
groupsSupported

-- https://tools.ietf.org/html/rfc8446#section-4.1.2 says:
-- "The client will also send a
--  ClientHello when the server has responded to its ClientHello with a
--  HelloRetryRequest.  In that case, the client MUST send the same
--  ClientHello without modification, except as follows:"
--
-- So, the ClientRandom in the first client hello is necessary.
handshake
    :: ClientParams
    -> Context
    -> [Group]
    -> Maybe (ClientRandom, Session, Version)
    -> IO ()
handshake :: ClientParams
-> Context
-> [Group]
-> Maybe (ClientRandom, Session, Version)
-> IO ()
handshake ClientParams
cparams Context
ctx [Group]
groups Maybe (ClientRandom, Session, Version)
mparams = do
    --------------------------------
    -- Sending ClientHello
    pskinfo :: PreSharedKeyInfo
pskinfo@(Maybe (SessionID, SessionData, CipherChoice, Second)
_, Maybe CipherChoice
_, Bool
rtt0) <- ClientParams -> Context -> IO PreSharedKeyInfo
getPreSharedKeyInfo ClientParams
cparams Context
ctx
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
rtt0 (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> (TLS13State -> TLS13State) -> IO ()
modifyTLS13State Context
ctx ((TLS13State -> TLS13State) -> IO ())
-> (TLS13State -> TLS13State) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TLS13State
st -> TLS13State
st{tls13st0RTT = True}
    let async :: Bool
async = Bool
rtt0 Bool -> Bool -> Bool
&& Bool -> Bool
not (Context -> Bool
ctxQUICMode Context
ctx)
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
async (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Millisecond
chSentTime <- IO Millisecond
getCurrentTimeFromBase
        ClientParams -> Context -> Maybe Group -> Millisecond -> IO ()
asyncServerHello13 ClientParams
cparams Context
ctx Maybe Group
groupToSend Millisecond
chSentTime
    Context -> (Measurement -> Measurement) -> IO ()
updateMeasure Context
ctx Measurement -> Measurement
incrementNbHandshakes
    ClientRandom
crand <- ClientParams
-> Context
-> [Group]
-> Maybe (ClientRandom, Session, Version)
-> PreSharedKeyInfo
-> IO ClientRandom
sendClientHello ClientParams
cparams Context
ctx [Group]
groups Maybe (ClientRandom, Session, Version)
mparams PreSharedKeyInfo
pskinfo
    --------------------------------
    -- Receiving ServerHello
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
async (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        (Version
ver, [Handshake]
hss, Bool
hrr) <- ClientParams
-> Context
-> Maybe (ClientRandom, Session, Version)
-> IO (Version, [Handshake], Bool)
receiveServerHello ClientParams
cparams Context
ctx Maybe (ClientRandom, Session, Version)
mparams
        --------------------------------
        -- Switching to HRR, TLS 1.2 or TLS 1.3
        case Version
ver of
            Version
TLS13
                | Bool
hrr ->
                    ClientParams
-> Context
-> Maybe (ClientRandom, Session, Version)
-> Version
-> ClientRandom
-> [Group]
-> IO ()
forall a.
ClientParams
-> Context
-> Maybe a
-> Version
-> ClientRandom
-> [Group]
-> IO ()
helloRetry ClientParams
cparams Context
ctx Maybe (ClientRandom, Session, Version)
mparams Version
ver ClientRandom
crand ([Group] -> IO ()) -> [Group] -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> [Group] -> [Group]
forall a. Int -> [a] -> [a]
drop Int
1 [Group]
groups
                | Bool
otherwise -> do
                    ClientParams -> Context -> Maybe Group -> IO ()
recvServerSecondFlight13 ClientParams
cparams Context
ctx Maybe Group
groupToSend
                    ClientParams -> Context -> IO ()
sendClientSecondFlight13 ClientParams
cparams Context
ctx
            Version
_
                | Bool
rtt0 ->
                    TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                        String -> AlertDescription -> TLSError
Error_Protocol
                            String
"server denied TLS 1.3 when connecting with early data"
                            AlertDescription
HandshakeFailure
                | Bool
otherwise -> do
                    ClientParams -> Context -> [Handshake] -> IO ()
recvServerFirstFlight12 ClientParams
cparams Context
ctx [Handshake]
hss
                    ClientParams -> Context -> IO ()
sendClientSecondFlight12 ClientParams
cparams Context
ctx
                    Context -> IO ()
recvServerSecondFlight12 Context
ctx
  where
    groupToSend :: Maybe Group
groupToSend = [Group] -> Maybe Group
forall a. [a] -> Maybe a
listToMaybe [Group]
groups

receiveServerHello
    :: ClientParams
    -> Context
    -> Maybe (ClientRandom, Session, Version)
    -> IO (Version, [Handshake], Bool)
receiveServerHello :: ClientParams
-> Context
-> Maybe (ClientRandom, Session, Version)
-> IO (Version, [Handshake], Bool)
receiveServerHello ClientParams
cparams Context
ctx Maybe (ClientRandom, Session, Version)
mparams = do
    Millisecond
chSentTime <- IO Millisecond
getCurrentTimeFromBase
    [Handshake]
hss <- ClientParams -> Context -> IO [Handshake]
recvServerHello ClientParams
cparams Context
ctx
    Context -> Millisecond -> IO ()
setRTT Context
ctx Millisecond
chSentTime
    Version
ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Bool
-> ((ClientRandom, Session, Version) -> Bool)
-> Maybe (ClientRandom, Session, Version)
-> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (\(ClientRandom
_, Session
_, Version
v) -> Version
v Version -> Version -> Bool
forall a. Eq a => a -> a -> Bool
== Version
ver) Maybe (ClientRandom, Session, Version)
mparams) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"version changed after hello retry" AlertDescription
IllegalParameter
    -- recvServerHello sets TLS13HRR according to the server random.
    -- For 1st server hello, getTLS13HR returns True if it is HRR and
    -- False otherwise.  For 2nd server hello, getTLS13HR returns
    -- False since it is NOT HRR.
    Bool
hrr <- Context -> TLSSt Bool -> IO Bool
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
    (Version, [Handshake], Bool) -> IO (Version, [Handshake], Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Version
ver, [Handshake]
hss, Bool
hrr)

----------------------------------------------------------------

helloRetry
    :: ClientParams
    -> Context
    -> Maybe a
    -> Version
    -> ClientRandom
    -> [Group]
    -> IO ()
helloRetry :: forall a.
ClientParams
-> Context
-> Maybe a
-> Version
-> ClientRandom
-> [Group]
-> IO ()
helloRetry ClientParams
cparams Context
ctx Maybe a
mparams Version
ver ClientRandom
crand [Group]
groups = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([Group] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Group]
groups) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"group is exhausted in the client side" AlertDescription
IllegalParameter
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe a -> Bool
forall a. Maybe a -> Bool
isJust Maybe a
mparams) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"server sent too many hello retries" AlertDescription
UnexpectedMessage
    Maybe KeyShare
mks <- Context -> TLSSt (Maybe KeyShare) -> IO (Maybe KeyShare)
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt (Maybe KeyShare)
getTLS13KeyShare
    case Maybe KeyShare
mks of
        Just (KeyShareHRR Group
selectedGroup)
            | Group
selectedGroup Group -> [Group] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Group]
groups -> do
                Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode HandshakeMode13
HelloRetryRequest
                Context -> IO ()
clearTxRecordState Context
ctx
                let cparams' :: ClientParams
cparams' = ClientParams
cparams{clientUseEarlyData = False}
                Context -> (forall {b}. Monoid b => PacketFlightM b ()) -> IO ()
forall a.
Context -> (forall b. Monoid b => PacketFlightM b a) -> IO a
runPacketFlight Context
ctx ((forall {b}. Monoid b => PacketFlightM b ()) -> IO ())
-> (forall {b}. Monoid b => PacketFlightM b ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Context -> PacketFlightM b ()
forall b. Monoid b => Context -> PacketFlightM b ()
sendChangeCipherSpec13 Context
ctx
                Session
clientSession <- TLS13State -> Session
tls13stSession (TLS13State -> Session) -> IO TLS13State -> IO Session
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO TLS13State
getTLS13State Context
ctx
                ClientParams
-> Context
-> [Group]
-> Maybe (ClientRandom, Session, Version)
-> IO ()
handshake ClientParams
cparams' Context
ctx [Group
selectedGroup] ((ClientRandom, Session, Version)
-> Maybe (ClientRandom, Session, Version)
forall a. a -> Maybe a
Just (ClientRandom
crand, Session
clientSession, Version
ver))
            | Bool
otherwise ->
                TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                    String -> AlertDescription -> TLSError
Error_Protocol String
"server-selected group is not supported" AlertDescription
IllegalParameter
        Just KeyShare
_ -> String -> IO ()
forall a. HasCallStack => String -> a
error String
"handshake: invalid KeyShare value"
        Maybe KeyShare
Nothing ->
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    String
"key exchange not implemented in HRR, expected key_share extension"
                    AlertDescription
HandshakeFailure