{-# LANGUAGE OverloadedStrings, GeneralizedNewtypeDeriving, BangPatterns #-}
module Network.TLS.Handshake.Common13
       ( makeFinished
       , checkFinished
       , makeServerKeyShare
       , makeClientKeyShare
       , fromServerKeyShare
       , makeCertVerify
       , checkCertVerify
       , makePSKBinder
       , replacePSKBinder
       , sendChangeCipherSpec13
       , handshakeTerminate13
       , makeCertRequest
       , createTLS13TicketInfo
       , ageToObfuscatedAge
       , isAgeValid
       , getAge
       , checkFreshness
       , getCurrentTimeFromBase
       , getSessionData13
       , ensureNullCompression
       , isHashSignatureValid13
       , safeNonNegative32
       , RecvHandshake13M
       , runRecvHandshake13
       , recvHandshake13
       , recvHandshake13hash
       , CipherChoice(..)
       , makeCipherChoice
       , initEarlySecret
       , calculateEarlySecret
       , calculateHandshakeSecret
       , calculateApplicationSecret
       , calculateResumptionSecret
       , derivePSK
       ) where
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import Data.Hourglass
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Cipher
import Network.TLS.Crypto
import qualified Network.TLS.Crypto.IES as IES
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate (extractCAname)
import Network.TLS.Handshake.Process (processHandshake13)
import Network.TLS.Handshake.Common (unexpected)
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Handshake.Signature
import Network.TLS.Imports
import Network.TLS.KeySchedule
import Network.TLS.MAC
import Network.TLS.Parameters
import Network.TLS.IO
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types
import Network.TLS.Wire
import Time.System
import Control.Concurrent.MVar
import Control.Monad.State.Strict
makeFinished :: MonadIO m => Context -> Hash -> ByteString -> m Handshake13
makeFinished ctx usedHash baseKey =
    Finished13 . makeVerifyData usedHash baseKey <$> transcriptHash ctx
checkFinished :: MonadIO m => Hash -> ByteString -> ByteString -> ByteString -> m ()
checkFinished usedHash baseKey hashValue verifyData = do
    let verifyData' = makeVerifyData usedHash baseKey hashValue
    unless (verifyData' == verifyData) $ decryptError "cannot verify finished"
makeVerifyData :: Hash -> ByteString -> ByteString -> ByteString
makeVerifyData usedHash baseKey = hmac usedHash finishedKey
  where
    hashSize = hashDigestSize usedHash
    finishedKey = hkdfExpandLabel usedHash baseKey "finished" "" hashSize
makeServerKeyShare :: Context -> KeyShareEntry -> IO (ByteString, KeyShareEntry)
makeServerKeyShare ctx (KeyShareEntry grp wcpub) = case ecpub of
  Left  e    -> throwCore $ Error_Protocol (show e, True, IllegalParameter)
  Right cpub -> do
      ecdhePair <- generateECDHEShared ctx cpub
      case ecdhePair of
          Nothing -> throwCore $ Error_Protocol (msgInvalidPublic, True, IllegalParameter)
          Just (spub, share) ->
              let wspub = IES.encodeGroupPublic spub
                  serverKeyShare = KeyShareEntry grp wspub
               in return (BA.convert share, serverKeyShare)
  where
    ecpub = IES.decodeGroupPublic grp wcpub
    msgInvalidPublic = "invalid client " ++ show grp ++ " public key"
makeClientKeyShare :: Context -> Group -> IO (IES.GroupPrivate, KeyShareEntry)
makeClientKeyShare ctx grp = do
    (cpri, cpub) <- generateECDHE ctx grp
    let wcpub = IES.encodeGroupPublic cpub
        clientKeyShare = KeyShareEntry grp wcpub
    return (cpri, clientKeyShare)
fromServerKeyShare :: KeyShareEntry -> IES.GroupPrivate -> IO ByteString
fromServerKeyShare (KeyShareEntry grp wspub) cpri = case espub of
  Left  e    -> throwCore $ Error_Protocol (show e, True, IllegalParameter)
  Right spub -> case IES.groupGetShared spub cpri of
    Just shared -> return $ BA.convert shared
    Nothing     -> throwCore $ Error_Protocol ("cannot generate a shared secret on (EC)DH", True, IllegalParameter)
  where
    espub = IES.decodeGroupPublic grp wspub
serverContextString :: ByteString
serverContextString = "TLS 1.3, server CertificateVerify"
clientContextString :: ByteString
clientContextString = "TLS 1.3, client CertificateVerify"
makeCertVerify :: MonadIO m => Context -> PubKey -> HashAndSignatureAlgorithm -> ByteString -> m Handshake13
makeCertVerify ctx pub hs hashValue = do
    cc <- liftIO $ usingState_ ctx isClientContext
    let ctxStr | cc == ClientRole = clientContextString
               | otherwise        = serverContextString
        target = makeTarget ctxStr hashValue
    CertVerify13 hs <$> sign ctx pub hs target
checkCertVerify :: MonadIO m => Context -> PubKey -> HashAndSignatureAlgorithm -> Signature -> ByteString -> m Bool
checkCertVerify ctx pub hs signature hashValue
    | pub `signatureCompatible13` hs = liftIO $ do
        cc <- usingState_ ctx isClientContext
        let ctxStr | cc == ClientRole = serverContextString 
                | otherwise        = clientContextString
            target = makeTarget ctxStr hashValue
            sigParams = signatureParams pub (Just hs)
        checkHashSignatureValid13 hs
        checkSupportedHashSignature ctx (Just hs)
        verifyPublic ctx sigParams target signature
    | otherwise = return False
makeTarget :: ByteString -> ByteString -> ByteString
makeTarget contextString hashValue = runPut $ do
    putBytes $ B.replicate 64 32
    putBytes contextString
    putWord8 0
    putBytes hashValue
sign :: MonadIO m => Context -> PubKey -> HashAndSignatureAlgorithm -> ByteString -> m Signature
sign ctx pub hs target = liftIO $ do
    cc <- usingState_ ctx isClientContext
    let sigParams = signatureParams pub (Just hs)
    signPrivate ctx cc sigParams target
makePSKBinder :: Context -> BaseSecret EarlySecret -> Hash -> Int -> Maybe ByteString -> IO ByteString
makePSKBinder ctx (BaseSecret sec) usedHash truncLen mch = do
    rmsgs0 <- usingHState ctx getHandshakeMessagesRev 
    let rmsgs = case mch of
          Just ch -> trunc ch : rmsgs0
          Nothing -> trunc (head rmsgs0) : tail rmsgs0
        hChTruncated = hash usedHash $ B.concat $ reverse rmsgs
        binderKey = deriveSecret usedHash sec "res binder" (hash usedHash "")
    return $ makeVerifyData usedHash binderKey hChTruncated
  where
    trunc x = B.take takeLen x
      where
        totalLen = B.length x
        takeLen = totalLen - truncLen
replacePSKBinder :: ByteString -> ByteString -> ByteString
replacePSKBinder pskz binder = identities `B.append` binders
  where
    bindersSize = B.length binder + 3
    identities  = B.take (B.length pskz - bindersSize) pskz
    binders     = runPut $ putOpaque16 $ runPut $ putOpaque8 binder
sendChangeCipherSpec13 :: Context -> PacketFlightM ()
sendChangeCipherSpec13 ctx = do
    sent <- usingHState ctx $ do
                b <- getCCS13Sent
                unless b $ setCCS13Sent True
                return b
    unless sent $ loadPacket13 ctx ChangeCipherSpec13
handshakeTerminate13 :: Context -> IO ()
handshakeTerminate13 ctx = do
    
    liftIO $ modifyMVar_ (ctxHandshake ctx) $ \ mhshake ->
        case mhshake of
            Nothing -> return Nothing
            Just hshake ->
                return $ Just (newEmptyHandshake (hstClientVersion hshake) (hstClientRandom hshake))
                    { hstServerRandom = hstServerRandom hshake
                    , hstMasterSecret = hstMasterSecret hshake
                    , hstNegotiatedGroup = hstNegotiatedGroup hshake
                    , hstHandshakeDigest = hstHandshakeDigest hshake
                    , hstTLS13HandshakeMode = hstTLS13HandshakeMode hshake
                    , hstTLS13RTT0Status = hstTLS13RTT0Status hshake
                    , hstTLS13ResumptionSecret = hstTLS13ResumptionSecret hshake
                    }
    
    usingState_ ctx $ do
        setTLS13KeyShare Nothing
        setTLS13PreSharedKey Nothing
    
    setEstablished ctx Established
makeCertRequest :: ServerParams -> Context -> CertReqContext -> Handshake13
makeCertRequest sparams ctx certReqCtx =
    let sigAlgs = extensionEncode $ SignatureAlgorithms $ supportedHashSignatures $ ctxSupported ctx
        caDns = map extractCAname $ serverCACertificates sparams
        caDnsEncoded = extensionEncode $ CertificateAuthorities caDns
        caExtension
            | null caDns = []
            | otherwise  = [ExtensionRaw extensionID_CertificateAuthorities caDnsEncoded]
        crexts = ExtensionRaw extensionID_SignatureAlgorithms sigAlgs : caExtension
     in CertRequest13 certReqCtx crexts
createTLS13TicketInfo :: Second -> Either Context Second -> Maybe Millisecond -> IO TLS13TicketInfo
createTLS13TicketInfo life ecw mrtt = do
    
    
    bTime <- getCurrentTimeFromBase
    add <- case ecw of
        Left ctx -> B.foldl' (*+) 0 <$> getStateRNG ctx 4
        Right ad -> return ad
    return $ TLS13TicketInfo life add bTime mrtt
  where
    x *+ y = x * 256 + fromIntegral y
ageToObfuscatedAge :: Second -> TLS13TicketInfo -> Second
ageToObfuscatedAge age tinfo = obfage
  where
    !obfage = age + ageAdd tinfo
obfuscatedAgeToAge :: Second -> TLS13TicketInfo -> Second
obfuscatedAgeToAge obfage tinfo = age
  where
    !age = obfage - ageAdd tinfo
isAgeValid :: Second -> TLS13TicketInfo -> Bool
isAgeValid age tinfo = age <= lifetime tinfo * 1000
getAge :: TLS13TicketInfo -> IO Second
getAge tinfo = do
    let clientReceiveTime = txrxTime tinfo
    clientSendTime <- getCurrentTimeFromBase
    return $! fromIntegral (clientSendTime - clientReceiveTime) 
checkFreshness :: TLS13TicketInfo -> Second -> IO Bool
checkFreshness tinfo obfAge = do
    serverReceiveTime <- getCurrentTimeFromBase
    let freshness = if expectedArrivalTime > serverReceiveTime
                    then expectedArrivalTime - serverReceiveTime
                    else serverReceiveTime - expectedArrivalTime
    
    
    let tolerance = max 2000 rtt
        isFresh = freshness < tolerance
    return $ isAlive && isFresh
  where
    serverSendTime = txrxTime tinfo
    Just rtt = estimatedRTT tinfo
    age = obfuscatedAgeToAge obfAge tinfo
    expectedArrivalTime = serverSendTime + rtt + fromIntegral age
    isAlive = isAgeValid age tinfo
getCurrentTimeFromBase :: IO Millisecond
getCurrentTimeFromBase = millisecondsFromBase <$> timeCurrentP
millisecondsFromBase :: ElapsedP -> Millisecond
millisecondsFromBase d = fromIntegral ms
  where
    ElapsedP (Elapsed (Seconds s)) (NanoSeconds ns) = d - timeConvert base
    ms = s * 1000 + ns `div` 1000000
    base = Date 2017 January 1
getSessionData13 :: Context -> Cipher -> TLS13TicketInfo -> Int -> ByteString -> IO SessionData
getSessionData13 ctx usedCipher tinfo maxSize psk = do
    ver   <- usingState_ ctx getVersion
    malpn <- usingState_ ctx getNegotiatedProtocol
    sni   <- usingState_ ctx getClientSNI
    mgrp  <- usingHState ctx getNegotiatedGroup
    return SessionData {
        sessionVersion     = ver
      , sessionCipher      = cipherID usedCipher
      , sessionCompression = 0
      , sessionClientSNI   = sni
      , sessionSecret      = psk
      , sessionGroup       = mgrp
      , sessionTicketInfo  = Just tinfo
      , sessionALPN        = malpn
      , sessionMaxEarlyDataSize = maxSize
      , sessionFlags       = []
      }
ensureNullCompression :: MonadIO m => CompressionID -> m ()
ensureNullCompression compression =
    when (compression /= compressionID nullCompression) $
        throwCore $ Error_Protocol ("compression is not allowed in TLS 1.3", True, IllegalParameter)
safeNonNegative32 :: (Num a, Ord a, FiniteBits a) => a -> a
safeNonNegative32 x
  | x <= 0                = 0
  | finiteBitSize x <= 32 = x
  | otherwise             = x `min` fromIntegral (maxBound :: Word32)
newtype RecvHandshake13M m a = RecvHandshake13M (StateT [Handshake13] m a)
    deriving (Functor, Applicative, Monad, MonadIO)
recvHandshake13 :: MonadIO m
                => Context
                -> (Handshake13 -> RecvHandshake13M m a)
                -> RecvHandshake13M m a
recvHandshake13 ctx f = getHandshake13 ctx >>= f
recvHandshake13hash :: MonadIO m
                    => Context
                    -> (ByteString -> Handshake13 -> RecvHandshake13M m a)
                    -> RecvHandshake13M m a
recvHandshake13hash ctx f = do
    d <- transcriptHash ctx
    getHandshake13 ctx >>= f d
getHandshake13 :: MonadIO m => Context -> RecvHandshake13M m Handshake13
getHandshake13 ctx = RecvHandshake13M $ do
    currentState <- get
    case currentState of
        (h:hs) -> found h hs
        []     -> recvLoop
  where
    found h hs = liftIO (processHandshake13 ctx h) >> put hs >> return h
    recvLoop = do
        epkt <- recvPacket13 ctx
        case epkt of
            Right (Handshake13 [])     -> error "invalid recvPacket13 result"
            Right (Handshake13 (h:hs)) -> found h hs
            Right ChangeCipherSpec13   -> recvLoop
            Right x                    -> unexpected (show x) (Just "handshake 13")
            Left err                   -> throwCore err
runRecvHandshake13 :: MonadIO m => RecvHandshake13M m a -> m a
runRecvHandshake13 (RecvHandshake13M f) = do
    (result, new) <- runStateT f []
    unless (null new) $ unexpected "spurious handshake 13" Nothing
    return result
checkHashSignatureValid13 :: HashAndSignatureAlgorithm -> IO ()
checkHashSignatureValid13 hs =
    unless (isHashSignatureValid13 hs) $
        let msg = "invalid TLS13 hash and signature algorithm: " ++ show hs
         in throwCore $ Error_Protocol (msg, True, IllegalParameter)
isHashSignatureValid13 :: HashAndSignatureAlgorithm -> Bool
isHashSignatureValid13 (HashIntrinsic, s) =
    s `elem` [ SignatureRSApssRSAeSHA256
             , SignatureRSApssRSAeSHA384
             , SignatureRSApssRSAeSHA512
             , SignatureEd25519
             , SignatureEd448
             , SignatureRSApsspssSHA256
             , SignatureRSApsspssSHA384
             , SignatureRSApsspssSHA512
             ]
isHashSignatureValid13 (h, SignatureECDSA) =
    h `elem` [ HashSHA256, HashSHA384, HashSHA512 ]
isHashSignatureValid13 _ = False
data CipherChoice = CipherChoice {
    cVersion :: Version
  , cCipher  :: Cipher
  , cHash    :: Hash
  , cZero    :: !ByteString
  }
makeCipherChoice :: Version -> Cipher -> CipherChoice
makeCipherChoice ver cipher = CipherChoice ver cipher h zero
  where
    h = cipherHash cipher
    zero = B.replicate (hashDigestSize h) 0
calculateEarlySecret :: Context -> CipherChoice
                     -> Either ByteString (BaseSecret EarlySecret)
                     -> Bool -> IO (SecretPair EarlySecret)
calculateEarlySecret ctx choice maux initialized = do
    hCh <- if initialized then
               transcriptHash ctx
             else do
               hmsgs <- usingHState ctx getHandshakeMessages
               return $ hash usedHash $ B.concat hmsgs
    let earlySecret = case maux of
          Right (BaseSecret sec) -> sec
          Left  psk              -> hkdfExtract usedHash zero psk
        clientEarlySecret = deriveSecret usedHash earlySecret "c e traffic" hCh
        cets = ClientTrafficSecret clientEarlySecret :: ClientTrafficSecret EarlySecret
    logKey ctx cets
    return $ SecretPair (BaseSecret earlySecret) cets
  where
    usedHash = cHash choice
    zero = cZero choice
initEarlySecret :: CipherChoice -> Maybe ByteString -> BaseSecret EarlySecret
initEarlySecret choice mpsk = BaseSecret sec
  where
    sec = hkdfExtract usedHash zero zeroOrPSK
    usedHash = cHash choice
    zero = cZero choice
    zeroOrPSK = case mpsk of
      Just psk -> psk
      Nothing  -> zero
calculateHandshakeSecret :: Context -> CipherChoice -> BaseSecret EarlySecret -> ByteString
                         -> IO (SecretTriple HandshakeSecret)
calculateHandshakeSecret ctx choice (BaseSecret sec) ecdhe = do
        hChSh <- transcriptHash ctx
        let handshakeSecret = hkdfExtract usedHash (deriveSecret usedHash sec "derived" (hash usedHash "")) ecdhe
        let clientHandshakeSecret = deriveSecret usedHash handshakeSecret "c hs traffic" hChSh
            serverHandshakeSecret = deriveSecret usedHash handshakeSecret "s hs traffic" hChSh
        let shts = ServerTrafficSecret serverHandshakeSecret :: ServerTrafficSecret HandshakeSecret
            chts = ClientTrafficSecret clientHandshakeSecret :: ClientTrafficSecret HandshakeSecret
        logKey ctx shts
        logKey ctx chts
        return $ SecretTriple (BaseSecret handshakeSecret) chts shts
  where
    usedHash = cHash choice
calculateApplicationSecret :: Context -> CipherChoice -> BaseSecret HandshakeSecret -> ByteString
                           -> IO (SecretTriple ApplicationSecret)
calculateApplicationSecret ctx choice (BaseSecret sec) hChSf = do
    let applicationSecret = hkdfExtract usedHash (deriveSecret usedHash sec "derived" (hash usedHash "")) zero
    let clientApplicationSecret0 = deriveSecret usedHash applicationSecret "c ap traffic" hChSf
        serverApplicationSecret0 = deriveSecret usedHash applicationSecret "s ap traffic" hChSf
        exporterMasterSecret = deriveSecret usedHash applicationSecret "exp master" hChSf
    usingState_ ctx $ setExporterMasterSecret exporterMasterSecret
    let sts0 = ServerTrafficSecret serverApplicationSecret0 :: ServerTrafficSecret ApplicationSecret
    let cts0 = ClientTrafficSecret clientApplicationSecret0 :: ClientTrafficSecret ApplicationSecret
    logKey ctx sts0
    logKey ctx cts0
    return $ SecretTriple (BaseSecret applicationSecret) cts0 sts0
  where
    usedHash = cHash choice
    zero = cZero choice
calculateResumptionSecret :: Context -> CipherChoice -> BaseSecret ApplicationSecret
                          -> IO (BaseSecret ResumptionSecret)
calculateResumptionSecret ctx choice (BaseSecret sec) = do
    hChCf <- transcriptHash ctx
    let resumptionMasterSecret = deriveSecret usedHash sec "res master" hChCf
    return $ BaseSecret resumptionMasterSecret
  where
    usedHash = cHash choice
derivePSK :: CipherChoice -> BaseSecret ResumptionSecret -> ByteString -> ByteString
derivePSK choice (BaseSecret sec) nonce =
    hkdfExpandLabel usedHash sec "resumption" nonce hashSize
  where
    usedHash = cHash choice
    hashSize = hashDigestSize usedHash