-- |
-- Module      : Network.TLS.Handshake.Process
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- process handshake message received
--
module Network.TLS.Handshake.Process
    ( processHandshake
    , processHandshake13
    , startHandshake
    ) where

import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.ErrT
import Network.TLS.Extension
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.Random
import Network.TLS.Handshake.Signature
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Parameters
import Network.TLS.Sending
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types (Role(..), invertRole, MasterSecret(..))
import Network.TLS.Util

import Control.Concurrent.MVar
import Control.Monad.IO.Class (liftIO)
import Control.Monad.State.Strict (gets)
import Data.X509 (CertificateChain(..), Certificate(..), getCertificate)
import Data.IORef (writeIORef)

processHandshake :: Context -> Handshake -> IO ()
processHandshake :: Context -> Handshake -> IO ()
processHandshake Context
ctx Handshake
hs = do
    Role
role <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
isClientContext
    case Handshake
hs of
        ClientHello Version
cver ClientRandom
ran Session
_ [CipherID]
cids [CompressionID]
_ [ExtensionRaw]
ex Maybe ByteString
_ -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Role
role forall a. Eq a => a -> a -> Bool
== Role
ServerRole) forall a b. (a -> b) -> a -> b
$ do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtensionRaw -> TLSSt ()
processClientExtension) [ExtensionRaw]
ex
            -- RFC 5746: secure renegotiation
            -- TLS_EMPTY_RENEGOTIATION_INFO_SCSV: {0x00, 0xFF}
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
secureRenegotiation Bool -> Bool -> Bool
&& (CipherID
0xff forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [CipherID]
cids)) forall a b. (a -> b) -> a -> b
$
                forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ Bool -> TLSSt ()
setSecureRenegotiation Bool
True
            Bool
hrr <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Bool
getTLS13HRR
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
hrr forall a b. (a -> b) -> a -> b
$ Context -> Version -> ClientRandom -> IO ()
startHandshake Context
ctx Version
cver ClientRandom
ran
        Certificates CertificateChain
certs            -> Role -> CertificateChain -> IO ()
processCertificates Role
role CertificateChain
certs
        Finished ByteString
fdata                -> Context -> ByteString -> IO ()
processClientFinished Context
ctx ByteString
fdata
        Handshake
_                             -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
isHRR Handshake
hs) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ()
wrapAsMessageHash13
    forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
ServerRole Handshake
hs
    case Handshake
hs of
        ClientKeyXchg ClientKeyXchgAlgorithmData
content  -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Role
role forall a. Eq a => a -> a -> Bool
== Role
ServerRole) forall a b. (a -> b) -> a -> b
$
            Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg Context
ctx ClientKeyXchgAlgorithmData
content
        Handshake
_                      -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where secureRenegotiation :: Bool
secureRenegotiation = Supported -> Bool
supportedSecureRenegotiation forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx
        -- RFC5746: secure renegotiation
        -- the renegotiation_info extension: 0xff01
        processClientExtension :: ExtensionRaw -> TLSSt ()
processClientExtension (ExtensionRaw CipherID
0xff01 ByteString
content) | Bool
secureRenegotiation = do
            ByteString
v <- Role -> TLSSt ByteString
getVerifiedData Role
ClientRole
            let bs :: ByteString
bs = forall a. Extension a => a -> ByteString
extensionEncode (ByteString -> Maybe ByteString -> SecureRenegotiation
SecureRenegotiation ByteString
v forall a. Maybe a
Nothing)
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString
bs ByteString -> ByteString -> Bool
`bytesEq` ByteString
content) forall a b. (a -> b) -> a -> b
$ forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"client verified data not matching: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ByteString
v forall a. [a] -> [a] -> [a]
++ String
":" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show ByteString
content, Bool
True, AlertDescription
HandshakeFailure)

            Bool -> TLSSt ()
setSecureRenegotiation Bool
True
        -- unknown extensions
        processClientExtension ExtensionRaw
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()

        processCertificates :: Role -> CertificateChain -> IO ()
        processCertificates :: Role -> CertificateChain -> IO ()
processCertificates Role
ServerRole (CertificateChain []) = forall (m :: * -> *) a. Monad m => a -> m a
return ()
        processCertificates Role
ClientRole (CertificateChain []) =
            forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"server certificate missing", Bool
True, AlertDescription
HandshakeFailure)
        processCertificates Role
_ (CertificateChain (SignedExact Certificate
c:[SignedExact Certificate]
_)) =
            forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ PubKey -> HandshakeM ()
setPublicKey PubKey
pubkey
          where pubkey :: PubKey
pubkey = Certificate -> PubKey
certPubKey forall a b. (a -> b) -> a -> b
$ SignedExact Certificate -> Certificate
getCertificate SignedExact Certificate
c

        isHRR :: Handshake -> Bool
isHRR (ServerHello Version
TLS12 ServerRandom
srand Session
_ CipherID
_ CompressionID
_ [ExtensionRaw]
_) = ServerRandom -> Bool
isHelloRetryRequest ServerRandom
srand
        isHRR Handshake
_                                 = Bool
False

processHandshake13 :: Context -> Handshake13 -> IO ()
processHandshake13 :: Context -> Handshake13 -> IO ()
processHandshake13 Context
ctx = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> Handshake13 -> IO ByteString
updateHandshake13 Context
ctx

-- process the client key exchange message. the protocol expects the initial
-- client version received in ClientHello, not the negotiated version.
-- in case the version mismatch, generate a random master secret
processClientKeyXchg :: Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg :: Context -> ClientKeyXchgAlgorithmData -> IO ()
processClientKeyXchg Context
ctx (CKX_RSA ByteString
encryptedPremaster) = do
    (Version
rver, Role
role, ByteString
random) <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do
        (,,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TLSSt Version
getVersion forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TLSSt Role
isClientContext forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> TLSSt ByteString
genRandom Int
48
    Either KxError ByteString
ePremaster <- Context -> ByteString -> IO (Either KxError ByteString)
decryptRSA Context
ctx ByteString
encryptedPremaster
    ByteString
masterSecret <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ do
        Version
expectedVer <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Version
hstClientVersion
        case Either KxError ByteString
ePremaster of
            Left KxError
_          -> forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
random
            Right ByteString
premaster -> case ByteString -> Either TLSError (Version, ByteString)
decodePreMasterSecret ByteString
premaster of
                Left TLSError
_                   -> forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
random
                Right (Version
ver, ByteString
_)
                    | Version
ver forall a. Eq a => a -> a -> Bool
/= Version
expectedVer -> forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
random
                    | Bool
otherwise          -> forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role ByteString
premaster
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MasterSecret
MasterSecret ByteString
masterSecret)

processClientKeyXchg Context
ctx (CKX_DH DHPublic
clientDHValue) = do
    Version
rver <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    Role
role <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
isClientContext

    ServerDHParams
serverParams <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ServerDHParams
getServerDHParams
    let params :: DHParams
params = ServerDHParams -> DHParams
serverDHParamsToParams ServerDHParams
serverParams
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (DHParams -> Integer -> Bool
dhValid DHParams
params forall a b. (a -> b) -> a -> b
$ DHPublic -> Integer
dhUnwrapPublic DHPublic
clientDHValue) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"invalid client public key", Bool
True, AlertDescription
IllegalParameter)

    DHPrivate
dhpriv       <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM DHPrivate
getDHPrivate
    let premaster :: DHKey
premaster = DHParams -> DHPrivate -> DHPublic -> DHKey
dhGetShared DHParams
params DHPrivate
dhpriv DHPublic
clientDHValue
    ByteString
masterSecret <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role DHKey
premaster
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MasterSecret
MasterSecret ByteString
masterSecret)

processClientKeyXchg Context
ctx (CKX_ECDH ByteString
bytes) = do
    ServerECDHParams Group
grp GroupPublic
_ <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM ServerECDHParams
getServerECDHParams
    case Group -> ByteString -> Either CryptoError GroupPublic
decodeGroupPublic Group
grp ByteString
bytes of
      Left CryptoError
_ -> forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"client public key cannot be decoded", Bool
True, AlertDescription
IllegalParameter)
      Right GroupPublic
clipub -> do
          GroupPrivate
srvpri <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM GroupPrivate
getGroupPrivate
          case GroupPublic -> GroupPrivate -> Maybe GroupKey
groupGetShared GroupPublic
clipub GroupPrivate
srvpri of
              Just GroupKey
premaster -> do
                  Version
rver <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
                  Role
role <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Role
isClientContext
                  ByteString
masterSecret <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ forall preMaster.
ByteArrayAccess preMaster =>
Version -> Role -> preMaster -> HandshakeM ByteString
setMasterSecretFromPre Version
rver Role
role GroupKey
premaster
                  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. LogLabel a => Context -> a -> IO ()
logKey Context
ctx (ByteString -> MasterSecret
MasterSecret ByteString
masterSecret)
              Maybe GroupKey
Nothing -> forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore forall a b. (a -> b) -> a -> b
$ (String, Bool, AlertDescription) -> TLSError
Error_Protocol (String
"cannot generate a shared secret on ECDH", Bool
True, AlertDescription
IllegalParameter)

processClientFinished :: Context -> FinishedData -> IO ()
processClientFinished :: Context -> ByteString -> IO ()
processClientFinished Context
ctx ByteString
fdata = do
    (Role
cc,Version
ver) <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ (,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TLSSt Role
isClientContext forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TLSSt Version
getVersion
    ByteString
expected <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ Version -> Role -> HandshakeM ByteString
getHandshakeDigest Version
ver forall a b. (a -> b) -> a -> b
$ Role -> Role
invertRole Role
cc
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
expected forall a. Eq a => a -> a -> Bool
/= ByteString
fdata) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => String -> m a
decryptError String
"cannot verify finished"
    forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef (Maybe ByteString)
ctxPeerFinished Context
ctx) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just ByteString
fdata

-- initialize a new Handshake context (initial handshake or renegotiations)
startHandshake :: Context -> Version -> ClientRandom -> IO ()
startHandshake :: Context -> Version -> ClientRandom -> IO ()
startHandshake Context
ctx Version
ver ClientRandom
crand =
    let hs :: Maybe HandshakeState
hs = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Version -> ClientRandom -> HandshakeState
newEmptyHandshake Version
ver ClientRandom
crand
    in forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO a
swapMVar (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx) Maybe HandshakeState
hs