{-# LANGUAGE MultiParamTypeClasses #-}
module Network.TLS.Record.State
( CryptState(..)
, CryptLevel(..)
, HasCryptLevel(..)
, MacState(..)
, RecordOptions(..)
, RecordState(..)
, newRecordState
, incrRecordState
, RecordM
, runRecordM
, getRecordOptions
, getRecordVersion
, setRecordIV
, withCompression
, computeDigest
, makeDigest
, getBulk
, getMacSequence
) where
import Control.Monad.State.Strict
import Network.TLS.Compression
import Network.TLS.Cipher
import Network.TLS.ErrT
import Network.TLS.Struct
import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Imports
import Network.TLS.Types
import qualified Data.ByteString as B
data CryptState = CryptState
{ CryptState -> BulkState
cstKey :: !BulkState
, CryptState -> ByteString
cstIV :: !ByteString
, CryptState -> ByteString
cstMacSecret :: !ByteString
} deriving (Int -> CryptState -> ShowS
[CryptState] -> ShowS
CryptState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptState] -> ShowS
$cshowList :: [CryptState] -> ShowS
show :: CryptState -> String
$cshow :: CryptState -> String
showsPrec :: Int -> CryptState -> ShowS
$cshowsPrec :: Int -> CryptState -> ShowS
Show)
newtype MacState = MacState
{ MacState -> Word64
msSequence :: Word64
} deriving (Int -> MacState -> ShowS
[MacState] -> ShowS
MacState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MacState] -> ShowS
$cshowList :: [MacState] -> ShowS
show :: MacState -> String
$cshow :: MacState -> String
showsPrec :: Int -> MacState -> ShowS
$cshowsPrec :: Int -> MacState -> ShowS
Show)
data RecordOptions = RecordOptions
{ RecordOptions -> Version
recordVersion :: Version
, RecordOptions -> Bool
recordTLS13 :: Bool
}
data CryptLevel
= CryptInitial
| CryptMasterSecret
| CryptEarlySecret
| CryptHandshakeSecret
| CryptApplicationSecret
deriving (CryptLevel -> CryptLevel -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CryptLevel -> CryptLevel -> Bool
$c/= :: CryptLevel -> CryptLevel -> Bool
== :: CryptLevel -> CryptLevel -> Bool
$c== :: CryptLevel -> CryptLevel -> Bool
Eq,Int -> CryptLevel -> ShowS
[CryptLevel] -> ShowS
CryptLevel -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CryptLevel] -> ShowS
$cshowList :: [CryptLevel] -> ShowS
show :: CryptLevel -> String
$cshow :: CryptLevel -> String
showsPrec :: Int -> CryptLevel -> ShowS
$cshowsPrec :: Int -> CryptLevel -> ShowS
Show)
class HasCryptLevel a where getCryptLevel :: proxy a -> CryptLevel
instance HasCryptLevel EarlySecret where getCryptLevel :: forall (proxy :: * -> *). proxy EarlySecret -> CryptLevel
getCryptLevel proxy EarlySecret
_ = CryptLevel
CryptEarlySecret
instance HasCryptLevel HandshakeSecret where getCryptLevel :: forall (proxy :: * -> *). proxy HandshakeSecret -> CryptLevel
getCryptLevel proxy HandshakeSecret
_ = CryptLevel
CryptHandshakeSecret
instance HasCryptLevel ApplicationSecret where getCryptLevel :: forall (proxy :: * -> *). proxy ApplicationSecret -> CryptLevel
getCryptLevel proxy ApplicationSecret
_ = CryptLevel
CryptApplicationSecret
data RecordState = RecordState
{ RecordState -> Maybe Cipher
stCipher :: Maybe Cipher
, RecordState -> Compression
stCompression :: Compression
, RecordState -> CryptLevel
stCryptLevel :: !CryptLevel
, RecordState -> CryptState
stCryptState :: !CryptState
, RecordState -> MacState
stMacState :: !MacState
} deriving (Int -> RecordState -> ShowS
[RecordState] -> ShowS
RecordState -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RecordState] -> ShowS
$cshowList :: [RecordState] -> ShowS
show :: RecordState -> String
$cshow :: RecordState -> String
showsPrec :: Int -> RecordState -> ShowS
$cshowsPrec :: Int -> RecordState -> ShowS
Show)
newtype RecordM a = RecordM { forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM :: RecordOptions
-> RecordState
-> Either TLSError (a, RecordState) }
instance Applicative RecordM where
pure :: forall a. a -> RecordM a
pure = forall (m :: * -> *) a. Monad m => a -> m a
return
<*> :: forall a b. RecordM (a -> b) -> RecordM a -> RecordM b
(<*>) = forall (m :: * -> *) a b. Monad m => m (a -> b) -> m a -> m b
ap
instance Monad RecordM where
return :: forall a. a -> RecordM a
return a
a = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> forall a b. b -> Either a b
Right (a
a, RecordState
st)
RecordM a
m1 >>= :: forall a b. RecordM a -> (a -> RecordM b) -> RecordM b
>>= a -> RecordM b
m2 = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m1 RecordOptions
opt RecordState
st of
Left TLSError
err -> forall a b. a -> Either a b
Left TLSError
err
Right (a
a, RecordState
st2) -> forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (a -> RecordM b
m2 a
a) RecordOptions
opt RecordState
st2
instance Functor RecordM where
fmap :: forall a b. (a -> b) -> RecordM a -> RecordM b
fmap a -> b
f RecordM a
m = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
Left TLSError
err -> forall a b. a -> Either a b
Left TLSError
err
Right (a
a, RecordState
st2) -> forall a b. b -> Either a b
Right (a -> b
f a
a, RecordState
st2)
getRecordOptions :: RecordM RecordOptions
getRecordOptions :: RecordM RecordOptions
getRecordOptions = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st -> forall a b. b -> Either a b
Right (RecordOptions
opt, RecordState
st)
getRecordVersion :: RecordM Version
getRecordVersion :: RecordM Version
getRecordVersion = RecordOptions -> Version
recordVersion forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> RecordM RecordOptions
getRecordOptions
instance MonadState RecordState RecordM where
put :: RecordState -> RecordM ()
put RecordState
x = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> forall a b. b -> Either a b
Right ((), RecordState
x)
get :: RecordM RecordState
get = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> forall a b. b -> Either a b
Right (RecordState
st, RecordState
st)
state :: forall a. (RecordState -> (a, RecordState)) -> RecordM a
state RecordState -> (a, RecordState)
f = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
st -> forall a b. b -> Either a b
Right (RecordState -> (a, RecordState)
f RecordState
st)
instance MonadError TLSError RecordM where
throwError :: forall a. TLSError -> RecordM a
throwError TLSError
e = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
_ RecordState
_ -> forall a b. a -> Either a b
Left TLSError
e
catchError :: forall a. RecordM a -> (TLSError -> RecordM a) -> RecordM a
catchError RecordM a
m TLSError -> RecordM a
f = forall a.
(RecordOptions -> RecordState -> Either TLSError (a, RecordState))
-> RecordM a
RecordM forall a b. (a -> b) -> a -> b
$ \RecordOptions
opt RecordState
st ->
case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
m RecordOptions
opt RecordState
st of
Left TLSError
err -> forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM (TLSError -> RecordM a
f TLSError
err) RecordOptions
opt RecordState
st
Either TLSError (a, RecordState)
r -> Either TLSError (a, RecordState)
r
newRecordState :: RecordState
newRecordState :: RecordState
newRecordState = RecordState
{ stCipher :: Maybe Cipher
stCipher = forall a. Maybe a
Nothing
, stCompression :: Compression
stCompression = Compression
nullCompression
, stCryptLevel :: CryptLevel
stCryptLevel = CryptLevel
CryptInitial
, stCryptState :: CryptState
stCryptState = BulkState -> ByteString -> ByteString -> CryptState
CryptState BulkState
BulkStateUninitialized ByteString
B.empty ByteString
B.empty
, stMacState :: MacState
stMacState = Word64 -> MacState
MacState Word64
0
}
incrRecordState :: RecordState -> RecordState
incrRecordState :: RecordState -> RecordState
incrRecordState RecordState
ts = RecordState
ts { stMacState :: MacState
stMacState = Word64 -> MacState
MacState (Word64
ms forall a. Num a => a -> a -> a
+ Word64
1) }
where (MacState Word64
ms) = RecordState -> MacState
stMacState RecordState
ts
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV ByteString
iv RecordState
st = RecordState
st { stCryptState :: CryptState
stCryptState = (RecordState -> CryptState
stCryptState RecordState
st) { cstIV :: ByteString
cstIV = ByteString
iv } }
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression :: forall a. (Compression -> (Compression, a)) -> RecordM a
withCompression Compression -> (Compression, a)
f = do
RecordState
st <- forall s (m :: * -> *). MonadState s m => m s
get
let (Compression
nc, a
a) = Compression -> (Compression, a)
f forall a b. (a -> b) -> a -> b
$ RecordState -> Compression
stCompression RecordState
st
forall s (m :: * -> *). MonadState s m => s -> m ()
put forall a b. (a -> b) -> a -> b
$ RecordState
st { stCompression :: Compression
stCompression = Compression
nc }
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
computeDigest :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest :: Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
tstate Header
hdr ByteString
content = (ByteString
digest, RecordState -> RecordState
incrRecordState RecordState
tstate)
where digest :: ByteString
digest = HMAC
macF (CryptState -> ByteString
cstMacSecret CryptState
cst) ByteString
msg
cst :: CryptState
cst = RecordState -> CryptState
stCryptState RecordState
tstate
cipher :: Cipher
cipher = forall a. String -> Maybe a -> a
fromJust String
"cipher" forall a b. (a -> b) -> a -> b
$ RecordState -> Maybe Cipher
stCipher RecordState
tstate
hashA :: Hash
hashA = Cipher -> Hash
cipherHash Cipher
cipher
encodedSeq :: ByteString
encodedSeq = Word64 -> ByteString
encodeWord64 forall a b. (a -> b) -> a -> b
$ MacState -> Word64
msSequence forall a b. (a -> b) -> a -> b
$ RecordState -> MacState
stMacState RecordState
tstate
(HMAC
macF, ByteString
msg)
| Version
ver forall a. Ord a => a -> a -> Bool
< Version
TLS10 = (Hash -> HMAC
macSSL Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeaderNoVer Header
hdr, ByteString
content ])
| Bool
otherwise = (Hash -> HMAC
hmac Hash
hashA, [ByteString] -> ByteString
B.concat [ ByteString
encodedSeq, Header -> ByteString
encodeHeader Header
hdr, ByteString
content ])
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest Header
hdr ByteString
content = do
Version
ver <- RecordM Version
getRecordVersion
RecordState
st <- forall s (m :: * -> *). MonadState s m => m s
get
let (ByteString
digest, RecordState
nstate) = Version
-> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest Version
ver RecordState
st Header
hdr ByteString
content
forall s (m :: * -> *). MonadState s m => s -> m ()
put RecordState
nstate
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
digest
getBulk :: RecordM Bulk
getBulk :: RecordM Bulk
getBulk = Cipher -> Bulk
cipherBulk forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. String -> Maybe a -> a
fromJust String
"cipher" forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> Maybe Cipher
stCipher forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *). MonadState s m => m s
get
getMacSequence :: RecordM Word64
getMacSequence :: RecordM Word64
getMacSequence = MacState -> Word64
msSequence forall b c a. (b -> c) -> (a -> b) -> a -> c
. RecordState -> MacState
stMacState forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *). MonadState s m => m s
get