{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.State13 (
CryptLevel (
CryptEarlySecret,
CryptHandshakeSecret,
CryptApplicationSecret
),
TrafficSecret,
getTxRecordState,
getRxRecordState,
setTxRecordState,
setRxRecordState,
getTxLevel,
getRxLevel,
clearTxRecordState,
clearRxRecordState,
setHelloParameters13,
transcriptHash,
wrapAsMessageHash13,
PendingRecvAction (..),
setPendingRecvActions,
popPendingRecvAction,
) where
import Control.Concurrent.MVar
import Control.Monad.State
import qualified Data.ByteString as B
import Data.IORef
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Context.Internal
import Network.TLS.Crypto
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.KeySchedule (hkdfExpandLabel)
import Network.TLS.Record.State
import Network.TLS.Struct
import Network.TLS.Types
getTxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getTxRecordState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxTxRecordState
getRxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState :: Context -> IO (Hash, Cipher, CryptLevel, ByteString)
getRxRecordState Context
ctx = Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
ctxRxRecordState
getXState
:: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState :: Context
-> (Context -> MVar RecordState)
-> IO (Hash, Cipher, CryptLevel, ByteString)
getXState Context
ctx Context -> MVar RecordState
func = do
RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
let usedCipher :: Cipher
usedCipher = Maybe Cipher -> Cipher
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Cipher -> Cipher) -> Maybe Cipher -> Cipher
forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tx
usedHash :: Hash
usedHash = Cipher -> Hash
cipherHash Cipher
usedCipher
level :: CryptLevel
level = RecordState -> CryptLevel
stCryptLevel RecordState
tx
secret :: ByteString
secret = CryptState -> ByteString
cstMacSecret (CryptState -> ByteString) -> CryptState -> ByteString
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptState
stCryptState RecordState
tx
(Hash, Cipher, CryptLevel, ByteString)
-> IO (Hash, Cipher, CryptLevel, ByteString)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Hash
usedHash, Cipher
usedCipher, CryptLevel
level, ByteString
secret)
getTxLevel :: Context -> IO CryptLevel
getTxLevel :: Context -> IO CryptLevel
getTxLevel Context
ctx = Context -> (Context -> MVar RecordState) -> IO CryptLevel
getXLevel Context
ctx Context -> MVar RecordState
ctxTxRecordState
getRxLevel :: Context -> IO CryptLevel
getRxLevel :: Context -> IO CryptLevel
getRxLevel Context
ctx = Context -> (Context -> MVar RecordState) -> IO CryptLevel
getXLevel Context
ctx Context -> MVar RecordState
ctxRxRecordState
getXLevel
:: Context
-> (Context -> MVar RecordState)
-> IO CryptLevel
getXLevel :: Context -> (Context -> MVar RecordState) -> IO CryptLevel
getXLevel Context
ctx Context -> MVar RecordState
func = do
RecordState
tx <- MVar RecordState -> IO RecordState
forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
func Context
ctx)
CryptLevel -> IO CryptLevel
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (CryptLevel -> IO CryptLevel) -> CryptLevel -> IO CryptLevel
forall a b. (a -> b) -> a -> b
$ RecordState -> CryptLevel
stCryptLevel RecordState
tx
class TrafficSecret ty where
fromTrafficSecret :: ty -> (CryptLevel, ByteString)
instance HasCryptLevel a => TrafficSecret (AnyTrafficSecret a) where
fromTrafficSecret :: AnyTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: AnyTrafficSecret a
prx@(AnyTrafficSecret ByteString
s) = (AnyTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel AnyTrafficSecret a
prx, ByteString
s)
instance HasCryptLevel a => TrafficSecret (ClientTrafficSecret a) where
fromTrafficSecret :: ClientTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ClientTrafficSecret a
prx@(ClientTrafficSecret ByteString
s) = (ClientTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel ClientTrafficSecret a
prx, ByteString
s)
instance HasCryptLevel a => TrafficSecret (ServerTrafficSecret a) where
fromTrafficSecret :: ServerTrafficSecret a -> (CryptLevel, ByteString)
fromTrafficSecret prx :: ServerTrafficSecret a
prx@(ServerTrafficSecret ByteString
s) = (ServerTrafficSecret a -> CryptLevel
forall a (proxy :: * -> *).
HasCryptLevel a =>
proxy a -> CryptLevel
forall (proxy :: * -> *). proxy a -> CryptLevel
getCryptLevel ServerTrafficSecret a
prx, ByteString
s)
setTxRecordState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setTxRecordState :: forall ty.
TrafficSecret ty =>
Context -> Hash -> Cipher -> ty -> IO ()
setTxRecordState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxTxRecordState BulkDirection
BulkEncrypt
setRxRecordState :: TrafficSecret ty => Context -> Hash -> Cipher -> ty -> IO ()
setRxRecordState :: forall ty.
TrafficSecret ty =>
Context -> Hash -> Cipher -> ty -> IO ()
setRxRecordState = (Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
ctxRxRecordState BulkDirection
BulkDecrypt
setXState
:: TrafficSecret ty
=> (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> ty
-> IO ()
setXState :: forall ty.
TrafficSecret ty =>
(Context -> MVar RecordState)
-> BulkDirection -> Context -> Hash -> Cipher -> ty -> IO ()
setXState Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher ty
ts =
let (CryptLevel
lvl, ByteString
secret) = ty -> (CryptLevel, ByteString)
forall ty. TrafficSecret ty => ty -> (CryptLevel, ByteString)
fromTrafficSecret ty
ts
in (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret
setXState'
:: (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' :: (Context -> MVar RecordState)
-> BulkDirection
-> Context
-> Hash
-> Cipher
-> CryptLevel
-> ByteString
-> IO ()
setXState' Context -> MVar RecordState
func BulkDirection
encOrDec Context
ctx Hash
h Cipher
cipher CryptLevel
lvl ByteString
secret =
MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
_ -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt)
where
bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
keySize :: Int
keySize = Bulk -> Int
bulkKeySize Bulk
bulk
ivSize :: Int
ivSize = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
8 (Bulk -> Int
bulkIVSize Bulk
bulk Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Bulk -> Int
bulkExplicitIV Bulk
bulk)
key :: ByteString
key = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"key" ByteString
"" Int
keySize
iv :: ByteString
iv = Hash -> ByteString -> ByteString -> ByteString -> Int -> ByteString
hkdfExpandLabel Hash
h ByteString
secret ByteString
"iv" ByteString
"" Int
ivSize
cst :: CryptState
cst =
CryptState
{ cstKey :: BulkState
cstKey = Bulk -> BulkDirection -> ByteString -> BulkState
bulkInit Bulk
bulk BulkDirection
encOrDec ByteString
key
, cstIV :: ByteString
cstIV = ByteString
iv
, cstMacSecret :: ByteString
cstMacSecret = ByteString
secret
}
rt :: RecordState
rt =
RecordState
{ stCryptState :: CryptState
stCryptState = CryptState
cst
, stMacState :: MacState
stMacState = MacState{msSequence :: Word64
msSequence = Word64
0}
, stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
lvl
, stCipher :: Maybe Cipher
stCipher = Cipher -> Maybe Cipher
forall a. a -> Maybe a
Just Cipher
cipher
, stCompression :: Compression
stCompression = Compression
nullCompression
}
clearTxRecordState :: Context -> IO ()
clearTxRecordState :: Context -> IO ()
clearTxRecordState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxTxRecordState
clearRxRecordState :: Context -> IO ()
clearRxRecordState :: Context -> IO ()
clearRxRecordState = (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
ctxRxRecordState
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState :: (Context -> MVar RecordState) -> Context -> IO ()
clearXState Context -> MVar RecordState
func Context
ctx =
MVar RecordState -> (RecordState -> IO RecordState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
func Context
ctx) (\RecordState
rt -> RecordState -> IO RecordState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
rt{stCipher = Nothing})
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 :: Cipher -> HandshakeM (Either TLSError ())
setHelloParameters13 Cipher
cipher = do
HandshakeState
hst <- HandshakeM HandshakeState
forall s (m :: * -> *). MonadState s m => m s
get
case HandshakeState -> Maybe Cipher
hstPendingCipher HandshakeState
hst of
Maybe Cipher
Nothing -> do
HandshakeState -> HandshakeM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
HandshakeState
hst
{ hstPendingCipher = Just cipher
, hstPendingCompression = nullCompression
, hstHandshakeDigest = updateDigest $ hstHandshakeDigest hst
}
Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
Just Cipher
oldcipher
| Cipher
cipher Cipher -> Cipher -> Bool
forall a. Eq a => a -> a -> Bool
== Cipher
oldcipher -> Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$ () -> Either TLSError ()
forall a b. b -> Either a b
Right ()
| Bool
otherwise ->
Either TLSError () -> HandshakeM (Either TLSError ())
forall a. a -> HandshakeM a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError () -> HandshakeM (Either TLSError ()))
-> Either TLSError () -> HandshakeM (Either TLSError ())
forall a b. (a -> b) -> a -> b
$
TLSError -> Either TLSError ()
forall a b. a -> Either a b
Left (TLSError -> Either TLSError ()) -> TLSError -> Either TLSError ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"TLS 1.3 cipher changed after hello retry" AlertDescription
IllegalParameter
where
hashAlg :: Hash
hashAlg = Cipher -> Hash
cipherHash Cipher
cipher
updateDigest :: HandshakeDigest -> HandshakeDigest
updateDigest (HandshakeMessages [ByteString]
bytes) = HashCtx -> HandshakeDigest
HandshakeDigestContext (HashCtx -> HandshakeDigest) -> HashCtx -> HandshakeDigest
forall a b. (a -> b) -> a -> b
$ (HashCtx -> ByteString -> HashCtx)
-> HashCtx -> [ByteString] -> HashCtx
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl HashCtx -> ByteString -> HashCtx
hashUpdate (Hash -> HashCtx
hashInit Hash
hashAlg) ([ByteString] -> HashCtx) -> [ByteString] -> HashCtx
forall a b. (a -> b) -> a -> b
$ [ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
bytes
updateDigest (HandshakeDigestContext HashCtx
_) = String -> HandshakeDigest
forall a. HasCallStack => String -> a
error String
"cannot initialize digest with another digest"
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 :: HandshakeM ()
wrapAsMessageHash13 = do
Cipher
cipher <- HandshakeM Cipher
getPendingCipher
Hash -> (ByteString -> ByteString) -> HandshakeM ()
foldHandshakeDigest (Cipher -> Hash
cipherHash Cipher
cipher) ByteString -> ByteString
foldFunc
where
foldFunc :: ByteString -> ByteString
foldFunc ByteString
dig =
[ByteString] -> ByteString
B.concat
[ ByteString
"\254\0\0"
, Word8 -> ByteString
B.singleton (Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word8) -> Int -> Word8
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
dig)
, ByteString
dig
]
transcriptHash :: MonadIO m => Context -> m ByteString
transcriptHash :: forall (m :: * -> *). MonadIO m => Context -> m ByteString
transcriptHash Context
ctx = do
HandshakeState
hst <- Maybe HandshakeState -> HandshakeState
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe HandshakeState -> HandshakeState)
-> m (Maybe HandshakeState) -> m HandshakeState
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> m (Maybe HandshakeState)
forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
case HandshakeState -> HandshakeDigest
hstHandshakeDigest HandshakeState
hst of
HandshakeDigestContext HashCtx
hashCtx -> ByteString -> m ByteString
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> m ByteString) -> ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$ HashCtx -> ByteString
hashFinal HashCtx
hashCtx
HandshakeMessages [ByteString]
_ -> String -> m ByteString
forall a. HasCallStack => String -> a
error String
"un-initialized handshake digest"
setPendingRecvActions :: Context -> [PendingRecvAction] -> IO ()
setPendingRecvActions :: Context -> [PendingRecvAction] -> IO ()
setPendingRecvActions Context
ctx = IORef [PendingRecvAction] -> [PendingRecvAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef [PendingRecvAction]
ctxPendingRecvActions Context
ctx)
popPendingRecvAction :: Context -> IO (Maybe PendingRecvAction)
popPendingRecvAction :: Context -> IO (Maybe PendingRecvAction)
popPendingRecvAction Context
ctx = do
let ref :: IORef [PendingRecvAction]
ref = Context -> IORef [PendingRecvAction]
ctxPendingRecvActions Context
ctx
[PendingRecvAction]
actions <- IORef [PendingRecvAction] -> IO [PendingRecvAction]
forall a. IORef a -> IO a
readIORef IORef [PendingRecvAction]
ref
case [PendingRecvAction]
actions of
PendingRecvAction
bs : [PendingRecvAction]
bss -> IORef [PendingRecvAction] -> [PendingRecvAction] -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef [PendingRecvAction]
ref [PendingRecvAction]
bss IO ()
-> IO (Maybe PendingRecvAction) -> IO (Maybe PendingRecvAction)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Maybe PendingRecvAction -> IO (Maybe PendingRecvAction)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (PendingRecvAction -> Maybe PendingRecvAction
forall a. a -> Maybe a
Just PendingRecvAction
bs)
[] -> Maybe PendingRecvAction -> IO (Maybe PendingRecvAction)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe PendingRecvAction
forall a. Maybe a
Nothing