{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.TLS.Handshake.State
( HandshakeState(..)
, HandshakeDigest(..)
, HandshakeMode13(..)
, RTT0Status(..)
, CertReqCBdata
, HandshakeM
, newEmptyHandshake
, runHandshake
, setPublicKey
, setPublicPrivateKeys
, getLocalPublicPrivateKeys
, getRemotePublicKey
, setServerDHParams
, getServerDHParams
, setServerECDHParams
, getServerECDHParams
, setDHPrivate
, getDHPrivate
, setGroupPrivate
, getGroupPrivate
, setClientCertSent
, getClientCertSent
, setCertReqSent
, getCertReqSent
, setClientCertChain
, getClientCertChain
, setCertReqToken
, getCertReqToken
, setCertReqCBdata
, getCertReqCBdata
, setCertReqSigAlgsCert
, getCertReqSigAlgsCert
, addHandshakeMessage
, updateHandshakeDigest
, getHandshakeMessages
, getHandshakeMessagesRev
, getHandshakeDigest
, foldHandshakeDigest
, setMasterSecret
, setMasterSecretFromPre
, getPendingCipher
, setServerHelloParameters
, setNegotiatedGroup
, getNegotiatedGroup
, setTLS13HandshakeMode
, getTLS13HandshakeMode
, setTLS13RTT0Status
, getTLS13RTT0Status
, setTLS13EarlySecret
, getTLS13EarlySecret
, setTLS13ResumptionSecret
, getTLS13ResumptionSecret
, setCCS13Sent
, getCCS13Sent
) where
import Network.TLS.Util
import Network.TLS.Struct
import Network.TLS.Record.State
import Network.TLS.Packet
import Network.TLS.Crypto
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Types
import Network.TLS.Imports
import Control.Monad.State.Strict
import Data.X509 (CertificateChain)
import Data.ByteArray (ByteArrayAccess)
data HandshakeKeyState = HandshakeKeyState
{ hksRemotePublicKey :: !(Maybe PubKey)
, hksLocalPublicPrivateKeys :: !(Maybe (PubKey, PrivKey))
} deriving (Show)
data HandshakeDigest = HandshakeMessages [ByteString]
| HandshakeDigestContext HashCtx
deriving (Show)
data HandshakeState = HandshakeState
{ hstClientVersion :: !Version
, hstClientRandom :: !ClientRandom
, hstServerRandom :: !(Maybe ServerRandom)
, hstMasterSecret :: !(Maybe ByteString)
, hstKeyState :: !HandshakeKeyState
, hstServerDHParams :: !(Maybe ServerDHParams)
, hstDHPrivate :: !(Maybe DHPrivate)
, hstServerECDHParams :: !(Maybe ServerECDHParams)
, hstGroupPrivate :: !(Maybe GroupPrivate)
, hstHandshakeDigest :: !HandshakeDigest
, hstHandshakeMessages :: [ByteString]
, hstCertReqToken :: !(Maybe ByteString)
, hstCertReqCBdata :: !(Maybe CertReqCBdata)
, hstCertReqSigAlgsCert :: !(Maybe [HashAndSignatureAlgorithm])
, hstClientCertSent :: !Bool
, hstCertReqSent :: !Bool
, hstClientCertChain :: !(Maybe CertificateChain)
, hstPendingTxState :: Maybe RecordState
, hstPendingRxState :: Maybe RecordState
, hstPendingCipher :: Maybe Cipher
, hstPendingCompression :: Compression
, hstNegotiatedGroup :: Maybe Group
, hstTLS13HandshakeMode :: HandshakeMode13
, hstTLS13RTT0Status :: !RTT0Status
, hstTLS13EarlySecret :: Maybe (BaseSecret EarlySecret)
, hstTLS13ResumptionSecret :: Maybe (BaseSecret ResumptionSecret)
, hstCCS13Sent :: !Bool
} deriving (Show)
type CertReqCBdata =
( [CertificateType]
, Maybe [HashAndSignatureAlgorithm]
, [DistinguishedName] )
newtype HandshakeM a = HandshakeM { runHandshakeM :: State HandshakeState a }
deriving (Functor, Applicative, Monad)
instance MonadState HandshakeState HandshakeM where
put x = HandshakeM (put x)
get = HandshakeM get
#if MIN_VERSION_mtl(2,1,0)
state f = HandshakeM (state f)
#endif
newEmptyHandshake :: Version -> ClientRandom -> HandshakeState
newEmptyHandshake ver crand = HandshakeState
{ hstClientVersion = ver
, hstClientRandom = crand
, hstServerRandom = Nothing
, hstMasterSecret = Nothing
, hstKeyState = HandshakeKeyState Nothing Nothing
, hstServerDHParams = Nothing
, hstDHPrivate = Nothing
, hstServerECDHParams = Nothing
, hstGroupPrivate = Nothing
, hstHandshakeDigest = HandshakeMessages []
, hstHandshakeMessages = []
, hstCertReqToken = Nothing
, hstCertReqCBdata = Nothing
, hstCertReqSigAlgsCert = Nothing
, hstClientCertSent = False
, hstCertReqSent = False
, hstClientCertChain = Nothing
, hstPendingTxState = Nothing
, hstPendingRxState = Nothing
, hstPendingCipher = Nothing
, hstPendingCompression = nullCompression
, hstNegotiatedGroup = Nothing
, hstTLS13HandshakeMode = FullHandshake
, hstTLS13RTT0Status = RTT0None
, hstTLS13EarlySecret = Nothing
, hstTLS13ResumptionSecret = Nothing
, hstCCS13Sent = False
}
runHandshake :: HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake hst f = runState (runHandshakeM f) hst
setPublicKey :: PubKey -> HandshakeM ()
setPublicKey pk = modify (\hst -> hst { hstKeyState = setPK (hstKeyState hst) })
where setPK hks = hks { hksRemotePublicKey = Just pk }
setPublicPrivateKeys :: (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys keys = modify (\hst -> hst { hstKeyState = setKeys (hstKeyState hst) })
where setKeys hks = hks { hksLocalPublicPrivateKeys = Just keys }
getRemotePublicKey :: HandshakeM PubKey
getRemotePublicKey = fromJust "remote public key" <$> gets (hksRemotePublicKey . hstKeyState)
getLocalPublicPrivateKeys :: HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys = fromJust "local public/private key" <$> gets (hksLocalPublicPrivateKeys . hstKeyState)
setServerDHParams :: ServerDHParams -> HandshakeM ()
setServerDHParams shp = modify (\hst -> hst { hstServerDHParams = Just shp })
getServerDHParams :: HandshakeM ServerDHParams
getServerDHParams = fromJust "server DH params" <$> gets hstServerDHParams
setServerECDHParams :: ServerECDHParams -> HandshakeM ()
setServerECDHParams shp = modify (\hst -> hst { hstServerECDHParams = Just shp })
getServerECDHParams :: HandshakeM ServerECDHParams
getServerECDHParams = fromJust "server ECDH params" <$> gets hstServerECDHParams
setDHPrivate :: DHPrivate -> HandshakeM ()
setDHPrivate shp = modify (\hst -> hst { hstDHPrivate = Just shp })
getDHPrivate :: HandshakeM DHPrivate
getDHPrivate = fromJust "server DH private" <$> gets hstDHPrivate
getGroupPrivate :: HandshakeM GroupPrivate
getGroupPrivate = fromJust "server ECDH private" <$> gets hstGroupPrivate
setGroupPrivate :: GroupPrivate -> HandshakeM ()
setGroupPrivate shp = modify (\hst -> hst { hstGroupPrivate = Just shp })
setNegotiatedGroup :: Group -> HandshakeM ()
setNegotiatedGroup g = modify (\hst -> hst { hstNegotiatedGroup = Just g })
getNegotiatedGroup :: HandshakeM (Maybe Group)
getNegotiatedGroup = gets hstNegotiatedGroup
data HandshakeMode13 =
FullHandshake
| HelloRetryRequest
| PreSharedKey
| RTT0
deriving (Show,Eq)
setTLS13HandshakeMode :: HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode s = modify (\hst -> hst { hstTLS13HandshakeMode = s })
getTLS13HandshakeMode :: HandshakeM HandshakeMode13
getTLS13HandshakeMode = gets hstTLS13HandshakeMode
data RTT0Status = RTT0None
| RTT0Sent
| RTT0Accepted
| RTT0Rejected
deriving (Show,Eq)
setTLS13RTT0Status :: RTT0Status -> HandshakeM ()
setTLS13RTT0Status s = modify (\hst -> hst { hstTLS13RTT0Status = s })
getTLS13RTT0Status :: HandshakeM RTT0Status
getTLS13RTT0Status = gets hstTLS13RTT0Status
setTLS13EarlySecret :: BaseSecret EarlySecret -> HandshakeM ()
setTLS13EarlySecret secret = modify (\hst -> hst { hstTLS13EarlySecret = Just secret })
getTLS13EarlySecret :: HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret = gets hstTLS13EarlySecret
setTLS13ResumptionSecret :: BaseSecret ResumptionSecret -> HandshakeM ()
setTLS13ResumptionSecret secret = modify (\hst -> hst { hstTLS13ResumptionSecret = Just secret })
getTLS13ResumptionSecret :: HandshakeM (Maybe (BaseSecret ResumptionSecret))
getTLS13ResumptionSecret = gets hstTLS13ResumptionSecret
setCCS13Sent :: Bool -> HandshakeM ()
setCCS13Sent sent = modify (\hst -> hst { hstCCS13Sent = sent })
getCCS13Sent :: HandshakeM Bool
getCCS13Sent = gets hstCCS13Sent
setCertReqSent :: Bool -> HandshakeM ()
setCertReqSent b = modify (\hst -> hst { hstCertReqSent = b })
getCertReqSent :: HandshakeM Bool
getCertReqSent = gets hstCertReqSent
setClientCertSent :: Bool -> HandshakeM ()
setClientCertSent b = modify (\hst -> hst { hstClientCertSent = b })
getClientCertSent :: HandshakeM Bool
getClientCertSent = gets hstClientCertSent
setClientCertChain :: CertificateChain -> HandshakeM ()
setClientCertChain b = modify (\hst -> hst { hstClientCertChain = Just b })
getClientCertChain :: HandshakeM (Maybe CertificateChain)
getClientCertChain = gets hstClientCertChain
setCertReqToken :: Maybe ByteString -> HandshakeM ()
setCertReqToken token = modify $ \hst -> hst { hstCertReqToken = token }
getCertReqToken :: HandshakeM (Maybe ByteString)
getCertReqToken = gets hstCertReqToken
setCertReqCBdata :: Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata d = modify (\hst -> hst { hstCertReqCBdata = d })
getCertReqCBdata :: HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata = gets hstCertReqCBdata
setCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm] -> HandshakeM ()
setCertReqSigAlgsCert as = modify $ \hst -> hst { hstCertReqSigAlgsCert = as }
getCertReqSigAlgsCert :: HandshakeM (Maybe [HashAndSignatureAlgorithm])
getCertReqSigAlgsCert = gets hstCertReqSigAlgsCert
getPendingCipher :: HandshakeM Cipher
getPendingCipher = fromJust "pending cipher" <$> gets hstPendingCipher
addHandshakeMessage :: ByteString -> HandshakeM ()
addHandshakeMessage content = modify $ \hs -> hs { hstHandshakeMessages = content : hstHandshakeMessages hs}
getHandshakeMessages :: HandshakeM [ByteString]
getHandshakeMessages = gets (reverse . hstHandshakeMessages)
getHandshakeMessagesRev :: HandshakeM [ByteString]
getHandshakeMessagesRev = gets hstHandshakeMessages
updateHandshakeDigest :: ByteString -> HandshakeM ()
updateHandshakeDigest content = modify $ \hs -> hs
{ hstHandshakeDigest = case hstHandshakeDigest hs of
HandshakeMessages bytes -> HandshakeMessages (content:bytes)
HandshakeDigestContext hashCtx -> HandshakeDigestContext $ hashUpdate hashCtx content }
foldHandshakeDigest :: Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest hashAlg f = modify $ \hs ->
case hstHandshakeDigest hs of
HandshakeMessages bytes ->
let hashCtx = foldl hashUpdate (hashInit hashAlg) $ reverse bytes
!folded = f (hashFinal hashCtx)
in hs { hstHandshakeDigest = HandshakeMessages [folded]
, hstHandshakeMessages = [folded]
}
HandshakeDigestContext hashCtx ->
let !folded = f (hashFinal hashCtx)
hashCtx' = hashUpdate (hashInit hashAlg) folded
in hs { hstHandshakeDigest = HandshakeDigestContext hashCtx'
, hstHandshakeMessages = [folded]
}
getHandshakeDigest :: Version -> Role -> HandshakeM ByteString
getHandshakeDigest ver role = gets gen
where gen hst = case hstHandshakeDigest hst of
HandshakeDigestContext hashCtx ->
let msecret = fromJust "master secret" $ hstMasterSecret hst
cipher = fromJust "cipher" $ hstPendingCipher hst
in generateFinish ver cipher msecret hashCtx
HandshakeMessages _ ->
error "un-initialized handshake digest"
generateFinish | role == ClientRole = generateClientFinished
| otherwise = generateServerFinished
setMasterSecretFromPre :: ByteArrayAccess preMaster
=> Version
-> Role
-> preMaster
-> HandshakeM ByteString
setMasterSecretFromPre ver role premasterSecret = do
secret <- genSecret <$> get
setMasterSecret ver role secret
return secret
where genSecret hst =
generateMasterSecret ver (fromJust "cipher" $ hstPendingCipher hst)
premasterSecret
(hstClientRandom hst)
(fromJust "server random" $ hstServerRandom hst)
setMasterSecret :: Version -> Role -> ByteString -> HandshakeM ()
setMasterSecret ver role masterSecret = modify $ \hst ->
let (pendingTx, pendingRx) = computeKeyBlock hst masterSecret ver role
in hst { hstMasterSecret = Just masterSecret
, hstPendingTxState = Just pendingTx
, hstPendingRxState = Just pendingRx }
computeKeyBlock :: HandshakeState -> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
where cipher = fromJust "cipher" $ hstPendingCipher hst
keyblockSize = cipherKeyBlockSize cipher
bulk = cipherBulk cipher
digestSize = if hasMAC (bulkF bulk) then hashDigestSize (cipherHash cipher)
else 0
keySize = bulkKeySize bulk
ivSize = bulkIVSize bulk
kb = generateKeyBlock ver cipher (hstClientRandom hst)
(fromJust "server random" $ hstServerRandom hst)
masterSecret keyblockSize
(cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) =
fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize)
cstClient = CryptState { cstKey = bulkInit bulk (BulkEncrypt `orOnServer` BulkDecrypt) cWriteKey
, cstIV = cWriteIV
, cstMacSecret = cMACSecret }
cstServer = CryptState { cstKey = bulkInit bulk (BulkDecrypt `orOnServer` BulkEncrypt) sWriteKey
, cstIV = sWriteIV
, cstMacSecret = sMACSecret }
msClient = MacState { msSequence = 0 }
msServer = MacState { msSequence = 0 }
pendingTx = RecordState
{ stCryptState = if cc == ClientRole then cstClient else cstServer
, stMacState = if cc == ClientRole then msClient else msServer
, stCipher = Just cipher
, stCompression = hstPendingCompression hst
}
pendingRx = RecordState
{ stCryptState = if cc == ClientRole then cstServer else cstClient
, stMacState = if cc == ClientRole then msServer else msClient
, stCipher = Just cipher
, stCompression = hstPendingCompression hst
}
orOnServer f g = if cc == ClientRole then f else g
setServerHelloParameters :: Version
-> ServerRandom
-> Cipher
-> Compression
-> HandshakeM ()
setServerHelloParameters ver sran cipher compression = do
modify $ \hst -> hst
{ hstServerRandom = Just sran
, hstPendingCipher = Just cipher
, hstPendingCompression = compression
, hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
}
where hashAlg = getHash ver cipher
updateDigest (HandshakeMessages bytes) = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"
getHash :: Version -> Cipher -> Hash
getHash ver ciph
| ver < TLS12 = SHA1_MD5
| maybe True (< TLS12) (cipherMinVer ciph) = SHA256
| otherwise = cipherHash ciph