{-# LANGUAGE LambdaCase #-}

module Network.TLS.Handshake.Client.Common (
    throwMiscErrorOnException,
    doServerKeyExchange,
    doCertificate,
    getLocalHashSigAlg,
    clientChain,
    sigAlgsToCertTypes,
    setALPN,
    contextSync,
) where

import Control.Exception (SomeException)
import Control.Monad.State.Strict
import Data.X509 (ExtKeyUsageFlag (..))

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Credentials
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet hiding (getExtensions)
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util (catchException)
import Network.TLS.X509

----------------------------------------------------------------

throwMiscErrorOnException :: String -> SomeException -> IO a
throwMiscErrorOnException :: forall a. String -> SomeException -> IO a
throwMiscErrorOnException String
msg SomeException
e =
    TLSError -> IO a
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO a) -> TLSError -> IO a
forall a b. (a -> b) -> a -> b
$ String -> TLSError
Error_Misc (String -> TLSError) -> String -> TLSError
forall a b. (a -> b) -> a -> b
$ String
msg String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SomeException -> String
forall a. Show a => a -> String
show SomeException
e

----------------------------------------------------------------

doServerKeyExchange :: Context -> ServerKeyXchgAlgorithmData -> IO ()
doServerKeyExchange :: Context -> ServerKeyXchgAlgorithmData -> IO ()
doServerKeyExchange Context
ctx ServerKeyXchgAlgorithmData
origSkx = do
    Cipher
cipher <- Context -> HandshakeM Cipher -> IO Cipher
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM Cipher
getPendingCipher
    Cipher -> ServerKeyXchgAlgorithmData -> IO ()
processWithCipher Cipher
cipher ServerKeyXchgAlgorithmData
origSkx
  where
    processWithCipher :: Cipher -> ServerKeyXchgAlgorithmData -> IO ()
processWithCipher Cipher
cipher ServerKeyXchgAlgorithmData
skx =
        case (Cipher -> CipherKeyExchangeType
cipherKeyExchange Cipher
cipher, ServerKeyXchgAlgorithmData
skx) of
            (CipherKeyExchangeType
CipherKeyExchange_DHE_RSA, SKX_DHE_RSA ServerDHParams
dhparams DigitallySigned
signature) ->
                ServerDHParams
-> DigitallySigned -> KeyExchangeSignatureAlg -> IO ()
doDHESignature ServerDHParams
dhparams DigitallySigned
signature KeyExchangeSignatureAlg
KX_RSA
            (CipherKeyExchangeType
CipherKeyExchange_DHE_DSA, SKX_DHE_DSA ServerDHParams
dhparams DigitallySigned
signature) ->
                ServerDHParams
-> DigitallySigned -> KeyExchangeSignatureAlg -> IO ()
doDHESignature ServerDHParams
dhparams DigitallySigned
signature KeyExchangeSignatureAlg
KX_DSA
            (CipherKeyExchangeType
CipherKeyExchange_ECDHE_RSA, SKX_ECDHE_RSA ServerECDHParams
ecdhparams DigitallySigned
signature) ->
                ServerECDHParams
-> DigitallySigned -> KeyExchangeSignatureAlg -> IO ()
doECDHESignature ServerECDHParams
ecdhparams DigitallySigned
signature KeyExchangeSignatureAlg
KX_RSA
            (CipherKeyExchangeType
CipherKeyExchange_ECDHE_ECDSA, SKX_ECDHE_ECDSA ServerECDHParams
ecdhparams DigitallySigned
signature) ->
                ServerECDHParams
-> DigitallySigned -> KeyExchangeSignatureAlg -> IO ()
doECDHESignature ServerECDHParams
ecdhparams DigitallySigned
signature KeyExchangeSignatureAlg
KX_ECDSA
            (CipherKeyExchangeType
cke, SKX_Unparsed ByteString
bytes) -> do
                Version
ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
                case Version
-> CipherKeyExchangeType
-> ByteString
-> Either TLSError ServerKeyXchgAlgorithmData
decodeReallyServerKeyXchgAlgorithmData Version
ver CipherKeyExchangeType
cke ByteString
bytes of
                    Left TLSError
_ ->
                        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                            String -> AlertDescription -> TLSError
Error_Protocol
                                (String
"unknown server key exchange received, expecting: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CipherKeyExchangeType -> String
forall a. Show a => a -> String
show CipherKeyExchangeType
cke)
                                AlertDescription
HandshakeFailure
                    Right ServerKeyXchgAlgorithmData
realSkx -> Cipher -> ServerKeyXchgAlgorithmData -> IO ()
processWithCipher Cipher
cipher ServerKeyXchgAlgorithmData
realSkx
            -- we need to resolve the result. and recall processWithCipher ..
            (CipherKeyExchangeType
c, ServerKeyXchgAlgorithmData
_) ->
                TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                    String -> AlertDescription -> TLSError
Error_Protocol
                        (String
"unknown server key exchange received, expecting: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ CipherKeyExchangeType -> String
forall a. Show a => a -> String
show CipherKeyExchangeType
c)
                        AlertDescription
HandshakeFailure
    doDHESignature :: ServerDHParams
-> DigitallySigned -> KeyExchangeSignatureAlg -> IO ()
doDHESignature ServerDHParams
dhparams DigitallySigned
signature KeyExchangeSignatureAlg
kxsAlg = do
        -- FF group selected by the server is verified when generating CKX
        PubKey
publicKey <- KeyExchangeSignatureAlg -> IO PubKey
getSignaturePublicKey KeyExchangeSignatureAlg
kxsAlg
        Bool
verified <- Context -> ServerDHParams -> PubKey -> DigitallySigned -> IO Bool
digitallySignDHParamsVerify Context
ctx ServerDHParams
dhparams PubKey
publicKey DigitallySigned
signature
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
verified (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> m a
decryptError
                (String
"bad " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PubKey -> String
pubkeyType PubKey
publicKey String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" signature for dhparams " String -> String -> String
forall a. [a] -> [a] -> [a]
++ ServerDHParams -> String
forall a. Show a => a -> String
show ServerDHParams
dhparams)
        Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ ServerDHParams -> HandshakeM ()
setServerDHParams ServerDHParams
dhparams

    doECDHESignature :: ServerECDHParams
-> DigitallySigned -> KeyExchangeSignatureAlg -> IO ()
doECDHESignature ServerECDHParams
ecdhparams DigitallySigned
signature KeyExchangeSignatureAlg
kxsAlg = do
        -- EC group selected by the server is verified when generating CKX
        PubKey
publicKey <- KeyExchangeSignatureAlg -> IO PubKey
getSignaturePublicKey KeyExchangeSignatureAlg
kxsAlg
        Bool
verified <- Context -> ServerECDHParams -> PubKey -> DigitallySigned -> IO Bool
digitallySignECDHParamsVerify Context
ctx ServerECDHParams
ecdhparams PubKey
publicKey DigitallySigned
signature
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
verified (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> IO ()
forall (m :: * -> *) a. MonadIO m => String -> m a
decryptError (String
"bad " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PubKey -> String
pubkeyType PubKey
publicKey String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" signature for ecdhparams")
        Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ ServerECDHParams -> HandshakeM ()
setServerECDHParams ServerECDHParams
ecdhparams

    getSignaturePublicKey :: KeyExchangeSignatureAlg -> IO PubKey
getSignaturePublicKey KeyExchangeSignatureAlg
kxsAlg = do
        PubKey
publicKey <- Context -> HandshakeM PubKey -> IO PubKey
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM PubKey
getRemotePublicKey
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (KeyExchangeSignatureAlg -> PubKey -> Bool
isKeyExchangeSignatureKey KeyExchangeSignatureAlg
kxsAlg PubKey
publicKey) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    (String
"server public key algorithm is incompatible with " String -> String -> String
forall a. [a] -> [a] -> [a]
++ KeyExchangeSignatureAlg -> String
forall a. Show a => a -> String
show KeyExchangeSignatureAlg
kxsAlg)
                    AlertDescription
HandshakeFailure
        Version
ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PubKey
publicKey PubKey -> Version -> Bool
`versionCompatible` Version
ver) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    (Version -> String
forall a. Show a => a -> String
show Version
ver String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" has no support for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PubKey -> String
pubkeyType PubKey
publicKey)
                    AlertDescription
IllegalParameter
        let groups :: [Group]
groups = Supported -> [Group]
supportedGroups (Context -> Supported
ctxSupported Context
ctx)
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate (Group -> [Group] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Group]
groups) PubKey
publicKey) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
                String -> AlertDescription -> TLSError
Error_Protocol
                    String
"server public key has unsupported elliptic curve"
                    AlertDescription
IllegalParameter
        PubKey -> IO PubKey
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return PubKey
publicKey

----------------------------------------------------------------

doCertificate :: ClientParams -> Context -> CertificateChain -> IO ()
doCertificate :: ClientParams -> Context -> CertificateChain -> IO ()
doCertificate ClientParams
cparams Context
ctx CertificateChain
certs = do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CertificateChain -> Bool
isNullCertificateChain CertificateChain
certs) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol String
"server certificate missing" AlertDescription
DecodeError
    -- run certificate recv hook
    Context -> (Hooks -> IO ()) -> IO ()
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx (Hooks -> CertificateChain -> IO ()
`hookRecvCertificates` CertificateChain
certs)
    -- then run certificate validation
    CertificateUsage
usage <- IO CertificateUsage
-> (SomeException -> IO CertificateUsage) -> IO CertificateUsage
forall a. IO a -> (SomeException -> IO a) -> IO a
catchException ([FailedReason] -> CertificateUsage
wrapCertificateChecks ([FailedReason] -> CertificateUsage)
-> IO [FailedReason] -> IO CertificateUsage
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO [FailedReason]
checkCert) SomeException -> IO CertificateUsage
rejectOnException
    case CertificateUsage
usage of
        CertificateUsage
CertificateUsageAccept -> IO ()
checkLeafCertificateKeyUsage
        CertificateUsageReject CertificateRejectReason
reason -> CertificateRejectReason -> IO ()
forall (m :: * -> *) a. MonadIO m => CertificateRejectReason -> m a
certificateRejected CertificateRejectReason
reason
  where
    shared :: Shared
shared = ClientParams -> Shared
clientShared ClientParams
cparams
    checkCert :: IO [FailedReason]
checkCert =
        ClientHooks -> OnServerCertificate
onServerCertificate
            (ClientParams -> ClientHooks
clientHooks ClientParams
cparams)
            (Shared -> CertificateStore
sharedCAStore Shared
shared)
            (Shared -> ValidationCache
sharedValidationCache Shared
shared)
            (ClientParams -> (String, ByteString)
clientServerIdentification ClientParams
cparams)
            CertificateChain
certs
    -- also verify that the certificate optional key usage is compatible
    -- with the intended key-exchange.  This check is not delegated to
    -- x509-validation 'checkLeafKeyUsage' because it depends on negotiated
    -- cipher, which is not available from onServerCertificate parameters.
    -- Additionally, with only one shared ValidationCache, x509-validation
    -- would cache validation result based on a key usage and reuse it with
    -- another key usage.
    checkLeafCertificateKeyUsage :: IO ()
checkLeafCertificateKeyUsage = do
        Cipher
cipher <- Context -> HandshakeM Cipher -> IO Cipher
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM Cipher
getPendingCipher
        case Cipher -> [ExtKeyUsageFlag]
requiredCertKeyUsage Cipher
cipher of
            [] -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
            [ExtKeyUsageFlag]
flags -> [ExtKeyUsageFlag] -> CertificateChain -> IO ()
forall (m :: * -> *).
MonadIO m =>
[ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage [ExtKeyUsageFlag]
flags CertificateChain
certs

-- Unless result is empty, server certificate must be allowed for at least one
-- of the returned values.  Constraints for RSA-based key exchange are relaxed
-- to avoid rejecting certificates having incomplete extension.
requiredCertKeyUsage :: Cipher -> [ExtKeyUsageFlag]
requiredCertKeyUsage :: Cipher -> [ExtKeyUsageFlag]
requiredCertKeyUsage Cipher
cipher =
    case Cipher -> CipherKeyExchangeType
cipherKeyExchange Cipher
cipher of
        CipherKeyExchangeType
CipherKeyExchange_RSA -> [ExtKeyUsageFlag]
rsaCompatibility
        CipherKeyExchangeType
CipherKeyExchange_DH_Anon -> [] -- unrestricted
        CipherKeyExchangeType
CipherKeyExchange_DHE_RSA -> [ExtKeyUsageFlag]
rsaCompatibility
        CipherKeyExchangeType
CipherKeyExchange_ECDHE_RSA -> [ExtKeyUsageFlag]
rsaCompatibility
        CipherKeyExchangeType
CipherKeyExchange_DHE_DSA -> [ExtKeyUsageFlag
KeyUsage_digitalSignature]
        CipherKeyExchangeType
CipherKeyExchange_DH_DSA -> [ExtKeyUsageFlag
KeyUsage_keyAgreement]
        CipherKeyExchangeType
CipherKeyExchange_DH_RSA -> [ExtKeyUsageFlag]
rsaCompatibility
        CipherKeyExchangeType
CipherKeyExchange_ECDH_ECDSA -> [ExtKeyUsageFlag
KeyUsage_keyAgreement]
        CipherKeyExchangeType
CipherKeyExchange_ECDH_RSA -> [ExtKeyUsageFlag]
rsaCompatibility
        CipherKeyExchangeType
CipherKeyExchange_ECDHE_ECDSA -> [ExtKeyUsageFlag
KeyUsage_digitalSignature]
        CipherKeyExchangeType
CipherKeyExchange_TLS13 -> [ExtKeyUsageFlag
KeyUsage_digitalSignature]
  where
    rsaCompatibility :: [ExtKeyUsageFlag]
rsaCompatibility =
        [ ExtKeyUsageFlag
KeyUsage_digitalSignature
        , ExtKeyUsageFlag
KeyUsage_keyEncipherment
        , ExtKeyUsageFlag
KeyUsage_keyAgreement
        ]

----------------------------------------------------------------

-- | Return the supported 'CertificateType' values that are
-- compatible with at least one supported signature algorithm.
supportedCtypes
    :: [HashAndSignatureAlgorithm]
    -> [CertificateType]
supportedCtypes :: [HashAndSignatureAlgorithm] -> [CertificateType]
supportedCtypes [HashAndSignatureAlgorithm]
hashAlgs =
    [CertificateType] -> [CertificateType]
forall a. Eq a => [a] -> [a]
nub ([CertificateType] -> [CertificateType])
-> [CertificateType] -> [CertificateType]
forall a b. (a -> b) -> a -> b
$ (HashAndSignatureAlgorithm
 -> [CertificateType] -> [CertificateType])
-> [CertificateType]
-> [HashAndSignatureAlgorithm]
-> [CertificateType]
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr HashAndSignatureAlgorithm -> [CertificateType] -> [CertificateType]
ctfilter [] [HashAndSignatureAlgorithm]
hashAlgs
  where
    ctfilter :: HashAndSignatureAlgorithm -> [CertificateType] -> [CertificateType]
ctfilter HashAndSignatureAlgorithm
x [CertificateType]
acc = case HashAndSignatureAlgorithm -> Maybe CertificateType
hashSigToCertType HashAndSignatureAlgorithm
x of
        Just CertificateType
cType
            | CertificateType
cType CertificateType -> CertificateType -> Bool
forall a. Ord a => a -> a -> Bool
<= CertificateType
lastSupportedCertificateType ->
                CertificateType
cType CertificateType -> [CertificateType] -> [CertificateType]
forall a. a -> [a] -> [a]
: [CertificateType]
acc
        Maybe CertificateType
_ -> [CertificateType]
acc

clientSupportedCtypes
    :: Context
    -> [CertificateType]
clientSupportedCtypes :: Context -> [CertificateType]
clientSupportedCtypes Context
ctx =
    [HashAndSignatureAlgorithm] -> [CertificateType]
supportedCtypes ([HashAndSignatureAlgorithm] -> [CertificateType])
-> [HashAndSignatureAlgorithm] -> [CertificateType]
forall a b. (a -> b) -> a -> b
$ Supported -> [HashAndSignatureAlgorithm]
supportedHashSignatures (Supported -> [HashAndSignatureAlgorithm])
-> Supported -> [HashAndSignatureAlgorithm]
forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx

sigAlgsToCertTypes
    :: Context
    -> [HashAndSignatureAlgorithm]
    -> [CertificateType]
sigAlgsToCertTypes :: Context -> [HashAndSignatureAlgorithm] -> [CertificateType]
sigAlgsToCertTypes Context
ctx [HashAndSignatureAlgorithm]
hashSigs =
    (CertificateType -> Bool) -> [CertificateType] -> [CertificateType]
forall a. (a -> Bool) -> [a] -> [a]
filter (CertificateType -> [CertificateType] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [HashAndSignatureAlgorithm] -> [CertificateType]
supportedCtypes [HashAndSignatureAlgorithm]
hashSigs) ([CertificateType] -> [CertificateType])
-> [CertificateType] -> [CertificateType]
forall a b. (a -> b) -> a -> b
$ Context -> [CertificateType]
clientSupportedCtypes Context
ctx

----------------------------------------------------------------

-- | When the server requests a client certificate, we try to
-- obtain a suitable certificate chain and private key via the
-- callback in the client parameters.  It is OK for the callback
-- to return an empty chain, in many cases the client certificate
-- is optional.  If the client wishes to abort the handshake for
-- lack of a suitable certificate, it can throw an exception in
-- the callback.
--
-- The return value is 'Nothing' when no @CertificateRequest@ was
-- received and no @Certificate@ message needs to be sent. An empty
-- chain means that an empty @Certificate@ message needs to be sent
-- to the server, naturally without a @CertificateVerify@.  A non-empty
-- 'CertificateChain' is the chain to send to the server along with
-- a corresponding 'CertificateVerify'.
--
-- With TLS < 1.2 the server's @CertificateRequest@ does not carry
-- a signature algorithm list.  It has a list of supported public
-- key signing algorithms in the @certificate_types@ field.  The
-- hash is implicit.  It is 'SHA1' for DSA and 'SHA1_MD5' for RSA.
--
-- With TLS == 1.2 the server's @CertificateRequest@ always has a
-- @supported_signature_algorithms@ list, as a fixed component of
-- the structure.  This list is (wrongly) overloaded to also limit
-- X.509 signatures in the client's certificate chain.  The BCP
-- strategy is to find a compatible chain if possible, but else
-- ignore the constraint, and let the server verify the chain as it
-- sees fit.  The @supported_signature_algorithms@ field is only
-- obligatory with respect to signatures on TLS messages, in this
-- case the @CertificateVerify@ message.  The @certificate_types@
-- field is still included.
--
-- With TLS 1.3 the server's @CertificateRequest@ has a mandatory
-- @signature_algorithms@ extension, the @signature_algorithms_cert@
-- extension, which is optional, carries a list of algorithms the
-- server promises to support in verifying the certificate chain.
-- As with TLS 1.2, the client's makes a /best-effort/ to deliver
-- a compatible certificate chain where all the CA signatures are
-- known to be supported, but it should not abort the connection
-- just because the chain might not work out, just send the best
-- chain you have and let the server worry about the rest.  The
-- supported public key algorithms are now inferred from the
-- @signature_algorithms@ extension and @certificate_types@ is
-- gone.
--
-- With TLS 1.3, we synthesize and store a @certificate_types@
-- field at the time that the server's @CertificateRequest@
-- message is received.  This is then present across all the
-- protocol versions, and can be used to determine whether
-- a @CertificateRequest@ was received or not.
--
-- If @signature_algorithms@ is 'Nothing', then we're doing
-- TLS 1.0 or 1.1.  The @signature_algorithms_cert@ extension
-- is optional in TLS 1.3, and so the application callback
-- will not be able to distinguish between TLS 1.[01] and
-- TLS 1.3 with no certificate algorithm hints, but this
-- just simplifies the chain selection process, all CA
-- signatures are OK.
clientChain :: ClientParams -> Context -> IO (Maybe CertificateChain)
clientChain :: ClientParams -> Context -> IO (Maybe CertificateChain)
clientChain ClientParams
cparams Context
ctx =
    Context
-> HandshakeM (Maybe CertReqCBdata) -> IO (Maybe CertReqCBdata)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata IO (Maybe CertReqCBdata)
-> (Maybe CertReqCBdata -> IO (Maybe CertificateChain))
-> IO (Maybe CertificateChain)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe CertReqCBdata
Nothing -> Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe CertificateChain
forall a. Maybe a
Nothing
        Just CertReqCBdata
cbdata -> do
            let callback :: OnCertificateRequest
callback = ClientHooks -> OnCertificateRequest
onCertificateRequest (ClientHooks -> OnCertificateRequest)
-> ClientHooks -> OnCertificateRequest
forall a b. (a -> b) -> a -> b
$ ClientParams -> ClientHooks
clientHooks ClientParams
cparams
            Maybe (CertificateChain, PrivKey)
chain <-
                IO (Maybe (CertificateChain, PrivKey))
-> IO (Maybe (CertificateChain, PrivKey))
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe (CertificateChain, PrivKey))
 -> IO (Maybe (CertificateChain, PrivKey)))
-> IO (Maybe (CertificateChain, PrivKey))
-> IO (Maybe (CertificateChain, PrivKey))
forall a b. (a -> b) -> a -> b
$
                    OnCertificateRequest
callback CertReqCBdata
cbdata
                        IO (Maybe (CertificateChain, PrivKey))
-> (SomeException -> IO (Maybe (CertificateChain, PrivKey)))
-> IO (Maybe (CertificateChain, PrivKey))
forall a. IO a -> (SomeException -> IO a) -> IO a
`catchException` String -> SomeException -> IO (Maybe (CertificateChain, PrivKey))
forall a. String -> SomeException -> IO a
throwMiscErrorOnException String
"certificate request callback failed"
            case Maybe (CertificateChain, PrivKey)
chain of
                Maybe (CertificateChain, PrivKey)
Nothing ->
                    Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe CertificateChain -> IO (Maybe CertificateChain))
-> Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a b. (a -> b) -> a -> b
$ CertificateChain -> Maybe CertificateChain
forall a. a -> Maybe a
Just (CertificateChain -> Maybe CertificateChain)
-> CertificateChain -> Maybe CertificateChain
forall a b. (a -> b) -> a -> b
$ [SignedExact Certificate] -> CertificateChain
CertificateChain []
                Just (CertificateChain [], PrivKey
_) ->
                    Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe CertificateChain -> IO (Maybe CertificateChain))
-> Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a b. (a -> b) -> a -> b
$ CertificateChain -> Maybe CertificateChain
forall a. a -> Maybe a
Just (CertificateChain -> Maybe CertificateChain)
-> CertificateChain -> Maybe CertificateChain
forall a b. (a -> b) -> a -> b
$ [SignedExact Certificate] -> CertificateChain
CertificateChain []
                Just cred :: (CertificateChain, PrivKey)
cred@(CertificateChain
cc, PrivKey
_) ->
                    do
                        let ([CertificateType]
cTypes, Maybe [HashAndSignatureAlgorithm]
_, [DistinguishedName]
_) = CertReqCBdata
cbdata
                        Context
-> [CertificateType] -> (CertificateChain, PrivKey) -> IO ()
storePrivInfoClient Context
ctx [CertificateType]
cTypes (CertificateChain, PrivKey)
cred
                        Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe CertificateChain -> IO (Maybe CertificateChain))
-> Maybe CertificateChain -> IO (Maybe CertificateChain)
forall a b. (a -> b) -> a -> b
$ CertificateChain -> Maybe CertificateChain
forall a. a -> Maybe a
Just CertificateChain
cc

-- | Store the keypair and check that it is compatible with the current protocol
-- version and a list of 'CertificateType' values.
storePrivInfoClient
    :: Context
    -> [CertificateType]
    -> Credential
    -> IO ()
storePrivInfoClient :: Context
-> [CertificateType] -> (CertificateChain, PrivKey) -> IO ()
storePrivInfoClient Context
ctx [CertificateType]
cTypes (CertificateChain
cc, PrivKey
privkey) = do
    PubKey
pubkey <- Context -> CertificateChain -> PrivKey -> IO PubKey
forall (m :: * -> *).
MonadIO m =>
Context -> CertificateChain -> PrivKey -> m PubKey
storePrivInfo Context
ctx CertificateChain
cc PrivKey
privkey
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PubKey -> [CertificateType] -> Bool
certificateCompatible PubKey
pubkey [CertificateType]
cTypes) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol
                (PubKey -> String
pubkeyType PubKey
pubkey String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" credential does not match allowed certificate types")
                AlertDescription
InternalError
    Version
ver <- Context -> TLSSt Version -> IO Version
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (PubKey
pubkey PubKey -> Version -> Bool
`versionCompatible` Version
ver) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
            String -> AlertDescription -> TLSError
Error_Protocol
                (PubKey -> String
pubkeyType PubKey
pubkey String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" credential is not supported at version " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Version -> String
forall a. Show a => a -> String
show Version
ver)
                AlertDescription
InternalError

----------------------------------------------------------------

-- | Return a most preferred 'HandAndSignatureAlgorithm' that is compatible with
-- the local key and server's signature algorithms (both already saved).  Must
-- only be called for TLS versions 1.2 and up, with compatibility function
-- 'signatureCompatible' or 'signatureCompatible13' based on version.
--
-- The values in the server's @signature_algorithms@ extension are
-- in descending order of preference.  However here the algorithms
-- are selected by client preference in @cHashSigs@.
getLocalHashSigAlg
    :: Context
    -> (PubKey -> HashAndSignatureAlgorithm -> Bool)
    -> [HashAndSignatureAlgorithm]
    -> PubKey
    -> IO HashAndSignatureAlgorithm
getLocalHashSigAlg :: Context
-> (PubKey -> HashAndSignatureAlgorithm -> Bool)
-> [HashAndSignatureAlgorithm]
-> PubKey
-> IO HashAndSignatureAlgorithm
getLocalHashSigAlg Context
ctx PubKey -> HashAndSignatureAlgorithm -> Bool
isCompatible [HashAndSignatureAlgorithm]
cHashSigs PubKey
pubKey = do
    -- Must be present with TLS 1.2 and up.
    (Just ([CertificateType]
_, Just [HashAndSignatureAlgorithm]
hashSigs, [DistinguishedName]
_)) <- Context
-> HandshakeM (Maybe CertReqCBdata) -> IO (Maybe CertReqCBdata)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata
    let want :: HashAndSignatureAlgorithm -> Bool
want =
            Bool -> Bool -> Bool
(&&)
                (Bool -> Bool -> Bool)
-> (HashAndSignatureAlgorithm -> Bool)
-> HashAndSignatureAlgorithm
-> Bool
-> Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PubKey -> HashAndSignatureAlgorithm -> Bool
isCompatible PubKey
pubKey
                (HashAndSignatureAlgorithm -> Bool -> Bool)
-> (HashAndSignatureAlgorithm -> Bool)
-> HashAndSignatureAlgorithm
-> Bool
forall a b.
(HashAndSignatureAlgorithm -> a -> b)
-> (HashAndSignatureAlgorithm -> a)
-> HashAndSignatureAlgorithm
-> b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (HashAndSignatureAlgorithm -> [HashAndSignatureAlgorithm] -> Bool)
-> [HashAndSignatureAlgorithm] -> HashAndSignatureAlgorithm -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip HashAndSignatureAlgorithm -> [HashAndSignatureAlgorithm] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem [HashAndSignatureAlgorithm]
hashSigs
    case (HashAndSignatureAlgorithm -> Bool)
-> [HashAndSignatureAlgorithm] -> Maybe HashAndSignatureAlgorithm
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find HashAndSignatureAlgorithm -> Bool
want [HashAndSignatureAlgorithm]
cHashSigs of
        Just HashAndSignatureAlgorithm
best -> HashAndSignatureAlgorithm -> IO HashAndSignatureAlgorithm
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return HashAndSignatureAlgorithm
best
        Maybe HashAndSignatureAlgorithm
Nothing -> TLSError -> IO HashAndSignatureAlgorithm
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO HashAndSignatureAlgorithm)
-> TLSError -> IO HashAndSignatureAlgorithm
forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol (PubKey -> String
keyerr PubKey
pubKey) AlertDescription
HandshakeFailure
  where
    keyerr :: PubKey -> String
keyerr PubKey
k = String
"no " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PubKey -> String
pubkeyType PubKey
k String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" hash algorithm in common with the server"

----------------------------------------------------------------

setALPN :: Context -> MessageType -> [ExtensionRaw] -> IO ()
setALPN :: Context -> MessageType -> [ExtensionRaw] -> IO ()
setALPN Context
ctx MessageType
msgt [ExtensionRaw]
exts = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_ApplicationLayerProtocolNegotiation [ExtensionRaw]
exts
    Maybe ByteString
-> (ByteString -> Maybe ApplicationLayerProtocolNegotiation)
-> Maybe ApplicationLayerProtocolNegotiation
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType
-> ByteString -> Maybe ApplicationLayerProtocolNegotiation
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
msgt of
    Just (ApplicationLayerProtocolNegotiation [ByteString
proto]) -> Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Maybe [ByteString]
mprotos <- TLSSt (Maybe [ByteString])
getClientALPNSuggest
        case Maybe [ByteString]
mprotos of
            Just [ByteString]
protos -> Bool -> TLSSt () -> TLSSt ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
proto ByteString -> [ByteString] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ByteString]
protos) (TLSSt () -> TLSSt ()) -> TLSSt () -> TLSSt ()
forall a b. (a -> b) -> a -> b
$ do
                Bool -> TLSSt ()
setExtensionALPN Bool
True
                ByteString -> TLSSt ()
setNegotiatedProtocol ByteString
proto
            Maybe [ByteString]
_ -> () -> TLSSt ()
forall a. a -> TLSSt a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    Maybe ApplicationLayerProtocolNegotiation
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

----------------------------------------------------------------

contextSync :: Context -> ClientState -> IO ()
contextSync :: Context -> ClientState -> IO ()
contextSync Context
ctx ClientState
ctl = case Context -> HandshakeSync
ctxHandshakeSync Context
ctx of
    HandshakeSync Context -> ClientState -> IO ()
sync Context -> ServerState -> IO ()
_ -> Context -> ClientState -> IO ()
sync Context
ctx ClientState
ctl