{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.State13
( getTxState
, getRxState
, setTxState
, setRxState
, clearTxState
, clearRxState
, setHelloParameters13
, transcriptHash
, wrapAsMessageHash13
, PendingAction(..)
, setPendingActions
, popPendingAction
) where
import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Util
getTxState :: Context -> IO (Hash, Cipher, ByteString)
getTxState ctx = getXState ctx ctxTxState
getRxState :: Context -> IO (Hash, Cipher, ByteString)
getRxState ctx = getXState ctx ctxRxState
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, ByteString)
getXState ctx func = do
tx <- readMVar (func ctx)
let Just usedCipher = stCipher tx
usedHash = cipherHash usedCipher
secret = cstMacSecret $ stCryptState tx
return (usedHash, usedCipher, secret)
setTxState :: Context -> Hash -> Cipher -> ByteString -> IO ()
setTxState = setXState ctxTxState BulkEncrypt
setRxState :: Context -> Hash -> Cipher -> ByteString -> IO ()
setRxState = setXState ctxRxState BulkDecrypt
setXState :: (Context -> MVar RecordState) -> BulkDirection
-> Context -> Hash -> Cipher -> ByteString
-> IO ()
setXState func encOrDec ctx h cipher secret =
modifyMVar_ (func ctx) (\_ -> return rt)
where
bulk = cipherBulk cipher
keySize = bulkKeySize bulk
ivSize = max 8 (bulkIVSize bulk + bulkExplicitIV bulk)
key = hkdfExpandLabel h secret "key" "" keySize
iv = hkdfExpandLabel h secret "iv" "" ivSize
cst = CryptState {
cstKey = bulkInit bulk encOrDec key
, cstIV = iv
, cstMacSecret = secret
}
rt = RecordState {
stCryptState = cst
, stMacState = MacState { msSequence = 0 }
, stCipher = Just cipher
, stCompression = nullCompression
}
clearTxState :: Context -> IO ()
clearTxState = clearXState ctxTxState
clearRxState :: Context -> IO ()
clearRxState = clearXState ctxRxState
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState func ctx =
modifyMVar_ (func ctx) (\rt -> return rt { stCipher = Nothing })
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 cipher = do
hst <- get
case hstPendingCipher hst of
Nothing -> do
put hst {
hstPendingCipher = Just cipher
, hstPendingCompression = nullCompression
, hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
}
return $ Right ()
Just oldcipher
| cipher == oldcipher -> return $ Right ()
| otherwise -> return $ Left $ Error_Protocol ("TLS 1.3 cipher changed after hello retry", True, IllegalParameter)
where
hashAlg = cipherHash cipher
updateDigest (HandshakeMessages bytes) = HandshakeDigestContext $ foldl hashUpdate (hashInit hashAlg) $ reverse bytes
updateDigest (HandshakeDigestContext _) = error "cannot initialize digest with another digest"
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
cipher <- getPendingCipher
foldHandshakeDigest (cipherHash cipher) foldFunc
where
foldFunc dig = B.concat [ "\254\0\0"
, B.singleton (fromIntegral $ B.length dig)
, dig
]
transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash ctx = do
hst <- fromJust "HState" <$> getHState ctx
case hstHandshakeDigest hst of
HandshakeDigestContext hashCtx -> return $ hashFinal hashCtx
HandshakeMessages _ -> error "un-initialized handshake digest"
setPendingActions :: Context -> [PendingAction] -> IO ()
setPendingActions ctx = writeIORef (ctxPendingActions ctx)
popPendingAction :: Context -> IO (Maybe PendingAction)
popPendingAction ctx = do
let ref = ctxPendingActions ctx
actions <- readIORef ref
case actions of
bs:bss -> writeIORef ref bss >> return (Just bs)
[] -> return Nothing