{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
module Network.TLS.Handshake.State
    ( HandshakeState(..)
    , HandshakeDigest(..)
    , HandshakeMode13(..)
    , RTT0Status(..)
    , CertReqCBdata
    , HandshakeM
    , newEmptyHandshake
    , runHandshake
    
    , setPublicKey
    , setPublicPrivateKeys
    , getLocalPublicPrivateKeys
    , getRemotePublicKey
    , setServerDHParams
    , getServerDHParams
    , setServerECDHParams
    , getServerECDHParams
    , setDHPrivate
    , getDHPrivate
    , setGroupPrivate
    , getGroupPrivate
    
    , setClientCertSent
    , getClientCertSent
    , setCertReqSent
    , getCertReqSent
    , setClientCertChain
    , getClientCertChain
    , setCertReqToken
    , getCertReqToken
    , setCertReqCBdata
    , getCertReqCBdata
    , setCertReqSigAlgsCert
    , getCertReqSigAlgsCert
    
    , addHandshakeMessage
    , updateHandshakeDigest
    , getHandshakeMessages
    , getHandshakeMessagesRev
    , getHandshakeDigest
    , foldHandshakeDigest
    
    , setMasterSecret
    , setMasterSecretFromPre
    
    , getPendingCipher
    , setServerHelloParameters
    , setExtendedMasterSec
    , getExtendedMasterSec
    , setNegotiatedGroup
    , getNegotiatedGroup
    , setTLS13HandshakeMode
    , getTLS13HandshakeMode
    , setTLS13RTT0Status
    , getTLS13RTT0Status
    , setTLS13EarlySecret
    , getTLS13EarlySecret
    , setTLS13ResumptionSecret
    , getTLS13ResumptionSecret
    , setCCS13Sent
    , getCCS13Sent
    ) where
import Network.TLS.Util
import Network.TLS.Struct
import Network.TLS.Record.State
import Network.TLS.Packet
import Network.TLS.Crypto
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Types
import Network.TLS.Imports
import Control.Monad.State.Strict
import Data.X509 (CertificateChain)
import Data.ByteArray (ByteArrayAccess)
data HandshakeKeyState = HandshakeKeyState
    { hksRemotePublicKey :: !(Maybe PubKey)
    , hksLocalPublicPrivateKeys :: !(Maybe (PubKey, PrivKey))
    } deriving (Show)
data HandshakeDigest = HandshakeMessages [ByteString]
                     | HandshakeDigestContext HashCtx
                     deriving (Show)
data HandshakeState = HandshakeState
    { hstClientVersion       :: !Version
    , hstClientRandom        :: !ClientRandom
    , hstServerRandom        :: !(Maybe ServerRandom)
    , hstMasterSecret        :: !(Maybe ByteString)
    , hstKeyState            :: !HandshakeKeyState
    , hstServerDHParams      :: !(Maybe ServerDHParams)
    , hstDHPrivate           :: !(Maybe DHPrivate)
    , hstServerECDHParams    :: !(Maybe ServerECDHParams)
    , hstGroupPrivate        :: !(Maybe GroupPrivate)
    , hstHandshakeDigest     :: !HandshakeDigest
    , hstHandshakeMessages   :: [ByteString]
    , hstCertReqToken        :: !(Maybe ByteString)
        
    , hstCertReqCBdata       :: !(Maybe CertReqCBdata)
        
    , hstCertReqSigAlgsCert  :: !(Maybe [HashAndSignatureAlgorithm])
        
        
        
        
    , hstClientCertSent      :: !Bool
        
    , hstCertReqSent         :: !Bool
        
        
    , hstClientCertChain     :: !(Maybe CertificateChain)
    , hstPendingTxState      :: Maybe RecordState
    , hstPendingRxState      :: Maybe RecordState
    , hstPendingCipher       :: Maybe Cipher
    , hstPendingCompression  :: Compression
    , hstExtendedMasterSec   :: Bool
    , hstNegotiatedGroup     :: Maybe Group
    , hstTLS13HandshakeMode  :: HandshakeMode13
    , hstTLS13RTT0Status     :: !RTT0Status
    , hstTLS13EarlySecret    :: Maybe (BaseSecret EarlySecret)
    , hstTLS13ResumptionSecret :: Maybe (BaseSecret ResumptionSecret)
    , hstCCS13Sent           :: !Bool
    } deriving (Show)
type CertReqCBdata =
     ( [CertificateType]
     , Maybe [HashAndSignatureAlgorithm]
     , [DistinguishedName] )
newtype HandshakeM a = HandshakeM { runHandshakeM :: State HandshakeState a }
    deriving (Functor, Applicative, Monad)
instance MonadState HandshakeState HandshakeM where
    put x = HandshakeM (put x)
    get   = HandshakeM get
#if MIN_VERSION_mtl(2,1,0)
    state f = HandshakeM (state f)
#endif
newEmptyHandshake :: Version -> ClientRandom -> HandshakeState
newEmptyHandshake ver crand = HandshakeState
    { hstClientVersion       = ver
    , hstClientRandom        = crand
    , hstServerRandom        = Nothing
    , hstMasterSecret        = Nothing
    , hstKeyState            = HandshakeKeyState Nothing Nothing
    , hstServerDHParams      = Nothing
    , hstDHPrivate           = Nothing
    , hstServerECDHParams    = Nothing
    , hstGroupPrivate        = Nothing
    , hstHandshakeDigest     = HandshakeMessages []
    , hstHandshakeMessages   = []
    , hstCertReqToken        = Nothing
    , hstCertReqCBdata       = Nothing
    , hstCertReqSigAlgsCert  = Nothing
    , hstClientCertSent      = False
    , hstCertReqSent         = False
    , hstClientCertChain     = Nothing
    , hstPendingTxState      = Nothing
    , hstPendingRxState      = Nothing
    , hstPendingCipher       = Nothing
    , hstPendingCompression  = nullCompression
    , hstExtendedMasterSec   = False
    , hstNegotiatedGroup     = Nothing
    , hstTLS13HandshakeMode  = FullHandshake
    , hstTLS13RTT0Status     = RTT0None
    , hstTLS13EarlySecret    = Nothing
    , hstTLS13ResumptionSecret = Nothing
    , hstCCS13Sent           = False
    }
runHandshake :: HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake hst f = runState (runHandshakeM f) hst
setPublicKey :: PubKey -> HandshakeM ()
setPublicKey pk = modify (\hst -> hst { hstKeyState = setPK (hstKeyState hst) })
  where setPK hks = hks { hksRemotePublicKey = Just pk }
setPublicPrivateKeys :: (PubKey, PrivKey) -> HandshakeM ()
setPublicPrivateKeys keys = modify (\hst -> hst { hstKeyState = setKeys (hstKeyState hst) })
  where setKeys hks = hks { hksLocalPublicPrivateKeys = Just keys }
getRemotePublicKey :: HandshakeM PubKey
getRemotePublicKey = fromJust "remote public key" <$> gets (hksRemotePublicKey . hstKeyState)
getLocalPublicPrivateKeys :: HandshakeM (PubKey, PrivKey)
getLocalPublicPrivateKeys = fromJust "local public/private key" <$> gets (hksLocalPublicPrivateKeys . hstKeyState)
setServerDHParams :: ServerDHParams -> HandshakeM ()
setServerDHParams shp = modify (\hst -> hst { hstServerDHParams = Just shp })
getServerDHParams :: HandshakeM ServerDHParams
getServerDHParams = fromJust "server DH params" <$> gets hstServerDHParams
setServerECDHParams :: ServerECDHParams -> HandshakeM ()
setServerECDHParams shp = modify (\hst -> hst { hstServerECDHParams = Just shp })
getServerECDHParams :: HandshakeM ServerECDHParams
getServerECDHParams = fromJust "server ECDH params" <$> gets hstServerECDHParams
setDHPrivate :: DHPrivate -> HandshakeM ()
setDHPrivate shp = modify (\hst -> hst { hstDHPrivate = Just shp })
getDHPrivate :: HandshakeM DHPrivate
getDHPrivate = fromJust "server DH private" <$> gets hstDHPrivate
getGroupPrivate :: HandshakeM GroupPrivate
getGroupPrivate = fromJust "server ECDH private" <$> gets hstGroupPrivate
setGroupPrivate :: GroupPrivate -> HandshakeM ()
setGroupPrivate shp = modify (\hst -> hst { hstGroupPrivate = Just shp })
setExtendedMasterSec :: Bool -> HandshakeM ()
setExtendedMasterSec b = modify (\hst -> hst { hstExtendedMasterSec = b })
getExtendedMasterSec :: HandshakeM Bool
getExtendedMasterSec = gets hstExtendedMasterSec
setNegotiatedGroup :: Group -> HandshakeM ()
setNegotiatedGroup g = modify (\hst -> hst { hstNegotiatedGroup = Just g })
getNegotiatedGroup :: HandshakeM (Maybe Group)
getNegotiatedGroup = gets hstNegotiatedGroup
data HandshakeMode13 =
      
      FullHandshake
      
    | HelloRetryRequest
      
    | PreSharedKey
      
    | RTT0
    deriving (Show,Eq)
setTLS13HandshakeMode :: HandshakeMode13 -> HandshakeM ()
setTLS13HandshakeMode s = modify (\hst -> hst { hstTLS13HandshakeMode = s })
getTLS13HandshakeMode :: HandshakeM HandshakeMode13
getTLS13HandshakeMode = gets hstTLS13HandshakeMode
data RTT0Status = RTT0None
                | RTT0Sent
                | RTT0Accepted
                | RTT0Rejected
                deriving (Show,Eq)
setTLS13RTT0Status :: RTT0Status -> HandshakeM ()
setTLS13RTT0Status s = modify (\hst -> hst { hstTLS13RTT0Status = s })
getTLS13RTT0Status :: HandshakeM RTT0Status
getTLS13RTT0Status = gets hstTLS13RTT0Status
setTLS13EarlySecret :: BaseSecret EarlySecret -> HandshakeM ()
setTLS13EarlySecret secret = modify (\hst -> hst { hstTLS13EarlySecret = Just secret })
getTLS13EarlySecret :: HandshakeM (Maybe (BaseSecret EarlySecret))
getTLS13EarlySecret = gets hstTLS13EarlySecret
setTLS13ResumptionSecret :: BaseSecret ResumptionSecret -> HandshakeM ()
setTLS13ResumptionSecret secret = modify (\hst -> hst { hstTLS13ResumptionSecret = Just secret })
getTLS13ResumptionSecret :: HandshakeM (Maybe (BaseSecret ResumptionSecret))
getTLS13ResumptionSecret = gets hstTLS13ResumptionSecret
setCCS13Sent :: Bool -> HandshakeM ()
setCCS13Sent sent = modify (\hst -> hst { hstCCS13Sent = sent })
getCCS13Sent :: HandshakeM Bool
getCCS13Sent = gets hstCCS13Sent
setCertReqSent :: Bool -> HandshakeM ()
setCertReqSent b = modify (\hst -> hst { hstCertReqSent = b })
getCertReqSent :: HandshakeM Bool
getCertReqSent = gets hstCertReqSent
setClientCertSent :: Bool -> HandshakeM ()
setClientCertSent b = modify (\hst -> hst { hstClientCertSent = b })
getClientCertSent :: HandshakeM Bool
getClientCertSent = gets hstClientCertSent
setClientCertChain :: CertificateChain -> HandshakeM ()
setClientCertChain b = modify (\hst -> hst { hstClientCertChain = Just b })
getClientCertChain :: HandshakeM (Maybe CertificateChain)
getClientCertChain = gets hstClientCertChain
setCertReqToken :: Maybe ByteString -> HandshakeM ()
setCertReqToken token = modify $ \hst -> hst { hstCertReqToken = token }
getCertReqToken :: HandshakeM (Maybe ByteString)
getCertReqToken = gets hstCertReqToken
setCertReqCBdata :: Maybe CertReqCBdata -> HandshakeM ()
setCertReqCBdata d = modify (\hst -> hst { hstCertReqCBdata = d })
getCertReqCBdata :: HandshakeM (Maybe CertReqCBdata)
getCertReqCBdata = gets hstCertReqCBdata
setCertReqSigAlgsCert :: Maybe [HashAndSignatureAlgorithm] -> HandshakeM ()
setCertReqSigAlgsCert as = modify $ \hst -> hst { hstCertReqSigAlgsCert = as }
getCertReqSigAlgsCert :: HandshakeM (Maybe [HashAndSignatureAlgorithm])
getCertReqSigAlgsCert = gets hstCertReqSigAlgsCert
getPendingCipher :: HandshakeM Cipher
getPendingCipher = fromJust "pending cipher" <$> gets hstPendingCipher
addHandshakeMessage :: ByteString -> HandshakeM ()
addHandshakeMessage content = modify $ \hs -> hs { hstHandshakeMessages = content : hstHandshakeMessages hs}
getHandshakeMessages :: HandshakeM [ByteString]
getHandshakeMessages = gets (reverse . hstHandshakeMessages)
getHandshakeMessagesRev :: HandshakeM [ByteString]
getHandshakeMessagesRev = gets hstHandshakeMessages
updateHandshakeDigest :: ByteString -> HandshakeM ()
updateHandshakeDigest content = modify $ \hs -> hs
    { hstHandshakeDigest = case hstHandshakeDigest hs of
        HandshakeMessages bytes        -> HandshakeMessages (content:bytes)
        HandshakeDigestContext hashCtx -> HandshakeDigestContext $ hashUpdate hashCtx content }
foldHandshakeDigest :: Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest hashAlg f = modify $ \hs ->
    case hstHandshakeDigest hs of
        HandshakeMessages bytes ->
            let hashCtx  = foldl hashUpdate (hashInit hashAlg) $ reverse bytes
                !folded  = f (hashFinal hashCtx)
             in hs { hstHandshakeDigest   = HandshakeMessages [folded]
                   , hstHandshakeMessages = [folded]
                   }
        HandshakeDigestContext hashCtx ->
            let !folded  = f (hashFinal hashCtx)
                hashCtx' = hashUpdate (hashInit hashAlg) folded
             in hs { hstHandshakeDigest   = HandshakeDigestContext hashCtx'
                   , hstHandshakeMessages = [folded]
                   }
getSessionHash :: HandshakeM ByteString
getSessionHash = gets $ \hst ->
    case hstHandshakeDigest hst of
        HandshakeDigestContext hashCtx -> hashFinal hashCtx
        HandshakeMessages _ -> error "un-initialized session hash"
getHandshakeDigest :: Version -> Role -> HandshakeM ByteString
getHandshakeDigest ver role = gets gen
  where gen hst = case hstHandshakeDigest hst of
                      HandshakeDigestContext hashCtx ->
                         let msecret = fromJust "master secret" $ hstMasterSecret hst
                             cipher  = fromJust "cipher" $ hstPendingCipher hst
                          in generateFinish ver cipher msecret hashCtx
                      HandshakeMessages _        ->
                         error "un-initialized handshake digest"
        generateFinish | role == ClientRole = generateClientFinished
                       | otherwise          = generateServerFinished
setMasterSecretFromPre :: ByteArrayAccess preMaster
                       => Version   
                       -> Role      
                       -> preMaster 
                       -> HandshakeM ByteString
setMasterSecretFromPre ver role premasterSecret = do
    ems <- getExtendedMasterSec
    secret <- if ems then get >>= genExtendedSecret else genSecret <$> get
    setMasterSecret ver role secret
    return secret
  where genSecret hst =
            generateMasterSecret ver (fromJust "cipher" $ hstPendingCipher hst)
                                 premasterSecret
                                 (hstClientRandom hst)
                                 (fromJust "server random" $ hstServerRandom hst)
        genExtendedSecret hst =
            generateExtendedMasterSec ver (fromJust "cipher" $ hstPendingCipher hst)
                                      premasterSecret
                <$> getSessionHash
setMasterSecret :: Version -> Role -> ByteString -> HandshakeM ()
setMasterSecret ver role masterSecret = modify $ \hst ->
    let (pendingTx, pendingRx) = computeKeyBlock hst masterSecret ver role
     in hst { hstMasterSecret   = Just masterSecret
            , hstPendingTxState = Just pendingTx
            , hstPendingRxState = Just pendingRx }
computeKeyBlock :: HandshakeState -> ByteString -> Version -> Role -> (RecordState, RecordState)
computeKeyBlock hst masterSecret ver cc = (pendingTx, pendingRx)
  where cipher       = fromJust "cipher" $ hstPendingCipher hst
        keyblockSize = cipherKeyBlockSize cipher
        bulk         = cipherBulk cipher
        digestSize   = if hasMAC (bulkF bulk) then hashDigestSize (cipherHash cipher)
                                              else 0
        keySize      = bulkKeySize bulk
        ivSize       = bulkIVSize bulk
        kb           = generateKeyBlock ver cipher (hstClientRandom hst)
                                        (fromJust "server random" $ hstServerRandom hst)
                                        masterSecret keyblockSize
        (cMACSecret, sMACSecret, cWriteKey, sWriteKey, cWriteIV, sWriteIV) =
                    fromJust "p6" $ partition6 kb (digestSize, digestSize, keySize, keySize, ivSize, ivSize)
        cstClient = CryptState { cstKey        = bulkInit bulk (BulkEncrypt `orOnServer` BulkDecrypt) cWriteKey
                               , cstIV         = cWriteIV
                               , cstMacSecret  = cMACSecret }
        cstServer = CryptState { cstKey        = bulkInit bulk (BulkDecrypt `orOnServer` BulkEncrypt) sWriteKey
                               , cstIV         = sWriteIV
                               , cstMacSecret  = sMACSecret }
        msClient = MacState { msSequence = 0 }
        msServer = MacState { msSequence = 0 }
        pendingTx = RecordState
                  { stCryptState  = if cc == ClientRole then cstClient else cstServer
                  , stMacState    = if cc == ClientRole then msClient else msServer
                  , stCipher      = Just cipher
                  , stCompression = hstPendingCompression hst
                  }
        pendingRx = RecordState
                  { stCryptState  = if cc == ClientRole then cstServer else cstClient
                  , stMacState    = if cc == ClientRole then msServer else msClient
                  , stCipher      = Just cipher
                  , stCompression = hstPendingCompression hst
                  }
        orOnServer f g = if cc == ClientRole then f else g
setServerHelloParameters :: Version      
                         -> ServerRandom
                         -> Cipher
                         -> Compression
                         -> HandshakeM ()
setServerHelloParameters ver sran cipher compression = do
    modify $ \hst -> hst
                { hstServerRandom       = Just sran
                , hstPendingCipher      = Just cipher
                , hstPendingCompression = compression
                , hstHandshakeDigest    = updateDigest $ hstHandshakeDigest hst
                }
  where hashAlg = getHash ver cipher
        updateDigest (HandshakeMessages bytes)  = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
        updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"
getHash :: Version -> Cipher -> Hash
getHash ver ciph
    | ver < TLS12                              = SHA1_MD5
    | maybe True (< TLS12) (cipherMinVer ciph) = SHA256
    | otherwise                                = cipherHash ciph