{-# LANGUAGE FlexibleContexts #-}
module Network.TLS.Record.Disengage
( disengageRecord
) where
import Control.Monad.State.Strict
import Crypto.Cipher.Types (AuthTag(..))
import Network.TLS.Struct
import Network.TLS.ErrT
import Network.TLS.Cap
import Network.TLS.Record.State
import Network.TLS.Record.Types
import Network.TLS.Cipher
import Network.TLS.Crypto
import Network.TLS.Compression
import Network.TLS.Util
import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.Imports
import qualified Data.ByteString as B
import qualified Data.ByteArray as B (convert, xor)
disengageRecord :: Record Ciphertext -> RecordM (Record Plaintext)
disengageRecord :: Record Ciphertext -> RecordM (Record Plaintext)
disengageRecord = Record Ciphertext -> RecordM (Record Compressed)
decryptRecord forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Record Compressed -> RecordM (Record Plaintext)
uncompressRecord
uncompressRecord :: Record Compressed -> RecordM (Record Plaintext)
uncompressRecord :: Record Compressed -> RecordM (Record Plaintext)
uncompressRecord Record Compressed
record = forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Compressed
record forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Compressed -> RecordM (Fragment Plaintext)
fragmentUncompress forall a b. (a -> b) -> a -> b
$ \ByteString
bytes ->
forall a. (Compression -> (Compression, a)) -> RecordM a
withCompression forall a b. (a -> b) -> a -> b
$ ByteString -> Compression -> (Compression, ByteString)
compressionInflate ByteString
bytes
decryptRecord :: Record Ciphertext -> RecordM (Record Compressed)
decryptRecord :: Record Ciphertext -> RecordM (Record Compressed)
decryptRecord record :: Record Ciphertext
record@(Record ProtocolType
ct Version
ver Fragment Ciphertext
fragment) = do
RecordState
st <- forall s (m :: * -> *). MonadState s m => m s
get
case RecordState -> Maybe Cipher
stCipher RecordState
st of
Maybe Cipher
Nothing -> RecordM (Record Compressed)
noDecryption
Maybe Cipher
_ -> do
RecordOptions
recOpts <- RecordM RecordOptions
getRecordOptions
let mver :: Version
mver = RecordOptions -> Version
recordVersion RecordOptions
recOpts
if RecordOptions -> Bool
recordTLS13 RecordOptions
recOpts
then Version -> ByteString -> RecordState -> RecordM (Record Compressed)
decryptData13 Version
mver (forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Ciphertext
fragment) RecordState
st
else forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Ciphertext
record forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Ciphertext -> RecordM (Fragment Compressed)
fragmentUncipher forall a b. (a -> b) -> a -> b
$ \ByteString
e ->
Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> RecordM ByteString
decryptData Version
mver Record Ciphertext
record ByteString
e RecordState
st
where
noDecryption :: RecordM (Record Compressed)
noDecryption = forall a b.
Record a
-> (Fragment a -> RecordM (Fragment b)) -> RecordM (Record b)
onRecordFragment Record Ciphertext
record forall a b. (a -> b) -> a -> b
$ (ByteString -> RecordM ByteString)
-> Fragment Ciphertext -> RecordM (Fragment Compressed)
fragmentUncipher forall (m :: * -> *) a. Monad m => a -> m a
return
decryptData13 :: Version -> ByteString -> RecordState -> RecordM (Record Compressed)
decryptData13 Version
mver ByteString
e RecordState
st = case ProtocolType
ct of
ProtocolType
ProtocolType_AppData -> do
ByteString
inner <- Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> RecordM ByteString
decryptData Version
mver Record Ciphertext
record ByteString
e RecordState
st
case ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext ByteString
inner of
Left String
message -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
message AlertDescription
UnexpectedMessage
Right (ProtocolType
ct', ByteString
d) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
ct' Version
ver (ByteString -> Fragment Compressed
fragmentCompressed ByteString
d)
ProtocolType
ProtocolType_ChangeCipherSpec -> RecordM (Record Compressed)
noDecryption
ProtocolType
ProtocolType_Alert -> RecordM (Record Compressed)
noDecryption
ProtocolType
_ -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"illegal plain text" AlertDescription
UnexpectedMessage
unInnerPlaintext :: ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext :: ByteString -> Either String (ProtocolType, ByteString)
unInnerPlaintext ByteString
inner =
case ByteString -> Maybe (ByteString, Word8)
B.unsnoc ByteString
dc of
Maybe (ByteString, Word8)
Nothing -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall {a}. Show a => a -> String
unknownContentType13 (Word8
0 :: Word8)
Just (ByteString
bytes,Word8
c) ->
case forall a. TypeValuable a => Word8 -> Maybe a
valToType Word8
c of
Maybe ProtocolType
Nothing -> forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ forall {a}. Show a => a -> String
unknownContentType13 Word8
c
Just ProtocolType
ct
| ByteString -> Bool
B.null ByteString
bytes Bool -> Bool -> Bool
&& ProtocolType
ct forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [ProtocolType]
nonEmptyContentTypes ->
forall a b. a -> Either a b
Left (String
"empty " forall a. [a] -> [a] -> [a]
++ forall {a}. Show a => a -> String
show ProtocolType
ct forall a. [a] -> [a] -> [a]
++ String
" record disallowed")
| Bool
otherwise -> forall a b. b -> Either a b
Right (ProtocolType
ct, ByteString
bytes)
where
(ByteString
dc,ByteString
_pad) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
B.spanEnd (forall a. Eq a => a -> a -> Bool
== Word8
0) ByteString
inner
nonEmptyContentTypes :: [ProtocolType]
nonEmptyContentTypes = [ ProtocolType
ProtocolType_Handshake, ProtocolType
ProtocolType_Alert ]
unknownContentType13 :: a -> String
unknownContentType13 a
c = String
"unknown TLS 1.3 content type: " forall a. [a] -> [a] -> [a]
++ forall {a}. Show a => a -> String
show a
c
getCipherData :: Record a -> CipherData -> RecordM ByteString
getCipherData :: forall a. Record a -> CipherData -> RecordM ByteString
getCipherData (Record ProtocolType
pt Version
ver Fragment a
_) CipherData
cdata = do
Bool
macValid <- case CipherData -> Maybe ByteString
cipherDataMAC CipherData
cdata of
Maybe ByteString
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
Just ByteString
digest -> do
let new_hdr :: Header
new_hdr = ProtocolType -> Version -> Word16 -> Header
Header ProtocolType
pt Version
ver (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherDataContent CipherData
cdata)
ByteString
expected_digest <- Header -> ByteString -> RecordM ByteString
makeDigest Header
new_hdr forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherDataContent CipherData
cdata
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
expected_digest ByteString -> ByteString -> Bool
`bytesEq` ByteString
digest)
Bool
paddingValid <- case CipherData -> Maybe (ByteString, Int)
cipherDataPadding CipherData
cdata of
Maybe (ByteString, Int)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
Just (ByteString
pad, Int
blksz) -> do
Version
cver <- RecordM Version
getRecordVersion
let b :: Int
b = ByteString -> Int
B.length ByteString
pad forall a. Num a => a -> a -> a
- Int
1
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ if Version
cver forall a. Ord a => a -> a -> Bool
< Version
TLS10
then Int
b forall a. Ord a => a -> a -> Bool
< Int
blksz
else Int -> Word8 -> ByteString
B.replicate (ByteString -> Int
B.length ByteString
pad) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
b) ByteString -> ByteString -> Bool
`bytesEq` ByteString
pad
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Bool
macValid Bool -> Bool -> Bool
&&! Bool
paddingValid) forall a b. (a -> b) -> a -> b
$
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"bad record mac" AlertDescription
BadRecordMac
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ CipherData -> ByteString
cipherDataContent CipherData
cdata
decryptData :: Version -> Record Ciphertext -> ByteString -> RecordState -> RecordM ByteString
decryptData :: Version
-> Record Ciphertext
-> ByteString
-> RecordState
-> RecordM ByteString
decryptData Version
ver Record Ciphertext
record ByteString
econtent RecordState
tst = BulkState -> RecordM ByteString
decryptOf (CryptState -> BulkState
cstKey CryptState
cst)
where cipher :: Cipher
cipher = forall a. String -> Maybe a -> a
fromJust String
"cipher" forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tst
bulk :: Bulk
bulk = Cipher -> Bulk
cipherBulk Cipher
cipher
cst :: CryptState
cst = RecordState -> CryptState
stCryptState RecordState
tst
macSize :: Int
macSize = Hash -> Int
hashDigestSize forall a b. (a -> b) -> a -> b
$ Cipher -> Hash
cipherHash Cipher
cipher
blockSize :: Int
blockSize = Bulk -> Int
bulkBlockSize Bulk
bulk
econtentLen :: Int
econtentLen = ByteString -> Int
B.length ByteString
econtent
explicitIV :: Bool
explicitIV = Version -> Bool
hasExplicitBlockIV Version
ver
sanityCheckError :: RecordM a
sanityCheckError = forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (String -> TLSError
Error_Packet String
"encrypted content too small for encryption parameters")
decryptOf :: BulkState -> RecordM ByteString
decryptOf :: BulkState -> RecordM ByteString
decryptOf (BulkStateBlock BulkBlock
decryptF) = do
let minContent :: Int
minContent = (if Bool
explicitIV then Bulk -> Int
bulkIVSize Bulk
bulk else Int
0) forall a. Num a => a -> a -> a
+ forall a. Ord a => a -> a -> a
max (Int
macSize forall a. Num a => a -> a -> a
+ Int
1) Int
blockSize
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Int
econtentLen forall a. Integral a => a -> a -> a
`mod` Int
blockSize) forall a. Eq a => a -> a -> Bool
/= Int
0 Bool -> Bool -> Bool
|| Int
econtentLen forall a. Ord a => a -> a -> Bool
< Int
minContent) forall {a}. RecordM a
sanityCheckError
(ByteString
iv, ByteString
econtent') <- if Bool
explicitIV
then forall {m :: * -> *}.
MonadError TLSError m =>
ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2o ByteString
econtent (Bulk -> Int
bulkIVSize Bulk
bulk, Int
econtentLen forall a. Num a => a -> a -> a
- Bulk -> Int
bulkIVSize Bulk
bulk)
else forall (m :: * -> *) a. Monad m => a -> m a
return (CryptState -> ByteString
cstIV CryptState
cst, ByteString
econtent)
let (ByteString
content', ByteString
iv') = BulkBlock
decryptF ByteString
iv ByteString
econtent'
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RecordState
txs -> RecordState
txs { stCryptState :: CryptState
stCryptState = CryptState
cst { cstIV :: ByteString
cstIV = ByteString
iv' } }
let paddinglength :: Int
paddinglength = forall a b. (Integral a, Num b) => a -> b
fromIntegral (HasCallStack => ByteString -> Word8
B.last ByteString
content') forall a. Num a => a -> a -> a
+ Int
1
let contentlen :: Int
contentlen = ByteString -> Int
B.length ByteString
content' forall a. Num a => a -> a -> a
- Int
paddinglength forall a. Num a => a -> a -> a
- Int
macSize
(ByteString
content, ByteString
mac, ByteString
padding) <- forall {m :: * -> *}.
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3i ByteString
content' (Int
contentlen, Int
macSize, Int
paddinglength)
forall a. Record a -> CipherData -> RecordM ByteString
getCipherData Record Ciphertext
record CipherData
{ cipherDataContent :: ByteString
cipherDataContent = ByteString
content
, cipherDataMAC :: Maybe ByteString
cipherDataMAC = forall a. a -> Maybe a
Just ByteString
mac
, cipherDataPadding :: Maybe (ByteString, Int)
cipherDataPadding = forall a. a -> Maybe a
Just (ByteString
padding, Int
blockSize)
}
decryptOf (BulkStateStream (BulkStream ByteString -> (ByteString, BulkStream)
decryptF)) = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
econtentLen forall a. Ord a => a -> a -> Bool
< Int
macSize) forall {a}. RecordM a
sanityCheckError
let (ByteString
content', BulkStream
bulkStream') = ByteString -> (ByteString, BulkStream)
decryptF ByteString
econtent
let contentlen :: Int
contentlen = ByteString -> Int
B.length ByteString
content' forall a. Num a => a -> a -> a
- Int
macSize
(ByteString
content, ByteString
mac) <- forall {m :: * -> *}.
MonadError TLSError m =>
ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2i ByteString
content' (Int
contentlen, Int
macSize)
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \RecordState
txs -> RecordState
txs { stCryptState :: CryptState
stCryptState = CryptState
cst { cstKey :: BulkState
cstKey = BulkStream -> BulkState
BulkStateStream BulkStream
bulkStream' } }
forall a. Record a -> CipherData -> RecordM ByteString
getCipherData Record Ciphertext
record CipherData
{ cipherDataContent :: ByteString
cipherDataContent = ByteString
content
, cipherDataMAC :: Maybe ByteString
cipherDataMAC = forall a. a -> Maybe a
Just ByteString
mac
, cipherDataPadding :: Maybe (ByteString, Int)
cipherDataPadding = forall a. Maybe a
Nothing
}
decryptOf (BulkStateAEAD BulkAEAD
decryptF) = do
let authTagLen :: Int
authTagLen = Bulk -> Int
bulkAuthTagLen Bulk
bulk
nonceExpLen :: Int
nonceExpLen = Bulk -> Int
bulkExplicitIV Bulk
bulk
cipherLen :: Int
cipherLen = Int
econtentLen forall a. Num a => a -> a -> a
- Int
authTagLen forall a. Num a => a -> a -> a
- Int
nonceExpLen
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
econtentLen forall a. Ord a => a -> a -> Bool
< (Int
authTagLen forall a. Num a => a -> a -> a
+ Int
nonceExpLen)) forall {a}. RecordM a
sanityCheckError
(ByteString
enonce, ByteString
econtent', ByteString
authTag) <- forall {m :: * -> *}.
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3o ByteString
econtent (Int
nonceExpLen, Int
cipherLen, Int
authTagLen)
let encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tst
iv :: ByteString
iv = CryptState -> ByteString
cstIV (RecordState -> CryptState
stCryptState RecordState
tst)
ivlen :: Int
ivlen = ByteString -> Int
B.length ByteString
iv
Header ProtocolType
typ Version
v Word16
_ = forall a. Record a -> Header
recordToHeader Record Ciphertext
record
hdrLen :: Int
hdrLen = if Version
ver forall a. Ord a => a -> a -> Bool
>= Version
TLS13 then Int
econtentLen else Int
cipherLen
hdr :: Header
hdr = ProtocolType -> Version -> Word16 -> Header
Header ProtocolType
typ Version
v forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
hdrLen
ad :: ByteString
ad | Version
ver forall a. Ord a => a -> a -> Bool
>= Version
TLS13 = Header -> ByteString
encodeHeader Header
hdr
| Bool
otherwise = [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr ]
sqnc :: ByteString
sqnc = Int -> Word8 -> ByteString
B.replicate (Int
ivlen forall a. Num a => a -> a -> a
- Int
8) Word8
0 ByteString -> ByteString -> ByteString
`B.append` ByteString
encodedSeq
nonce :: ByteString
nonce | Int
nonceExpLen forall a. Eq a => a -> a -> Bool
== Int
0 = forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
B.xor ByteString
iv ByteString
sqnc
| Bool
otherwise = ByteString
iv ByteString -> ByteString -> ByteString
`B.append` ByteString
enonce
(ByteString
content, AuthTag
authTag2) = BulkAEAD
decryptF ByteString
nonce ByteString
econtent' ByteString
ad
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bytes -> AuthTag
AuthTag (forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert ByteString
authTag) forall a. Eq a => a -> a -> Bool
/= AuthTag
authTag2) forall a b. (a -> b) -> a -> b
$
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"bad record mac" AlertDescription
BadRecordMac
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify RecordState -> RecordState
incrRecordState
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
content
decryptOf BulkState
BulkStateUninitialized =
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"decrypt state uninitialized" AlertDescription
InternalError
get3o :: ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3o ByteString
s (Int, Int, Int)
ls = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> TLSError
Error_Packet String
"record bad format") forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
-> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 ByteString
s (Int, Int, Int)
ls
get2o :: ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2o ByteString
s (Int
d1,Int
d2) = forall {m :: * -> *}.
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3o ByteString
s (Int
d1,Int
d2,Int
0) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
r1,ByteString
r2,ByteString
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
r1,ByteString
r2)
get3i :: ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3i ByteString
s (Int, Int, Int)
ls = forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError forall a b. (a -> b) -> a -> b
$ String -> AlertDescription -> TLSError
Error_Protocol String
"record bad format" AlertDescription
BadRecordMac) forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString
-> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 ByteString
s (Int, Int, Int)
ls
get2i :: ByteString -> (Int, Int) -> m (ByteString, ByteString)
get2i ByteString
s (Int
d1,Int
d2) = forall {m :: * -> *}.
MonadError TLSError m =>
ByteString
-> (Int, Int, Int) -> m (ByteString, ByteString, ByteString)
get3i ByteString
s (Int
d1,Int
d2,Int
0) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(ByteString
r1,ByteString
r2,ByteString
_) -> forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
r1,ByteString
r2)