{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE CPP #-}
module Network.TLS.Record.State
( CryptState(..)
, MacState(..)
, RecordOptions(..)
, RecordState(..)
, newRecordState
, incrRecordState
, RecordM
, runRecordM
, getRecordOptions
, getRecordVersion
, setRecordIV
, withCompression
, computeDigest
, makeDigest
, getBulk
, getMacSequence
) where
import Control.Monad.State.Strict
import Network.TLS.Compression
import Network.TLS.Cipher
import Network.TLS.ErrT
import Network.TLS.Struct
import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Imports
import qualified Data.ByteString as B
data CryptState = CryptState
{ cstKey :: !BulkState
, cstIV :: !ByteString
, cstMacSecret :: !ByteString
} deriving (Show)
newtype MacState = MacState
{ msSequence :: Word64
} deriving (Show)
data RecordOptions = RecordOptions
{ recordVersion :: Version
, recordTLS13 :: Bool
}
data RecordState = RecordState
{ stCipher :: Maybe Cipher
, stCompression :: Compression
, stCryptState :: !CryptState
, stMacState :: !MacState
} deriving (Show)
newtype RecordM a = RecordM { runRecordM :: RecordOptions
-> RecordState
-> Either TLSError (a, RecordState) }
instance Applicative RecordM where
pure = return
(<*>) = ap
instance Monad RecordM where
return a = RecordM $ \_ st -> Right (a, st)
m1 >>= m2 = RecordM $ \opt st ->
case runRecordM m1 opt st of
Left err -> Left err
Right (a, st2) -> runRecordM (m2 a) opt st2
instance Functor RecordM where
fmap f m = RecordM $ \opt st ->
case runRecordM m opt st of
Left err -> Left err
Right (a, st2) -> Right (f a, st2)
getRecordOptions :: RecordM RecordOptions
getRecordOptions = RecordM $ \opt st -> Right (opt, st)
getRecordVersion :: RecordM Version
getRecordVersion = recordVersion <$> getRecordOptions
instance MonadState RecordState RecordM where
put x = RecordM $ \_ _ -> Right ((), x)
get = RecordM $ \_ st -> Right (st, st)
#if MIN_VERSION_mtl(2,1,0)
state f = RecordM $ \_ st -> Right (f st)
#endif
instance MonadError TLSError RecordM where
throwError e = RecordM $ \_ _ -> Left e
catchError m f = RecordM $ \opt st ->
case runRecordM m opt st of
Left err -> runRecordM (f err) opt st
r -> r
newRecordState :: RecordState
newRecordState = RecordState
{ stCipher = Nothing
, stCompression = nullCompression
, stCryptState = CryptState BulkStateUninitialized B.empty B.empty
, stMacState = MacState 0
}
incrRecordState :: RecordState -> RecordState
incrRecordState ts = ts { stMacState = MacState (ms + 1) }
where (MacState ms) = stMacState ts
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV iv st = st { stCryptState = (stCryptState st) { cstIV = iv } }
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression f = do
st <- get
let (nc, a) = f $ stCompression st
put $ st { stCompression = nc }
return a
computeDigest :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest ver tstate hdr content = (digest, incrRecordState tstate)
where digest = macF (cstMacSecret cst) msg
cst = stCryptState tstate
cipher = fromJust "cipher" $ stCipher tstate
hashA = cipherHash cipher
encodedSeq = encodeWord64 $ msSequence $ stMacState tstate
(macF, msg)
| ver < TLS10 = (macSSL hashA, B.concat [ encodedSeq, encodeHeaderNoVer hdr, content ])
| otherwise = (hmac hashA, B.concat [ encodedSeq, encodeHeader hdr, content ])
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest hdr content = do
ver <- getRecordVersion
st <- get
let (digest, nstate) = computeDigest ver st hdr content
put nstate
return digest
getBulk :: RecordM Bulk
getBulk = cipherBulk . fromJust "cipher" . stCipher <$> get
getMacSequence :: RecordM Word64
getMacSequence = msSequence . stMacState <$> get