{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE CPP #-}
module Network.TLS.Record.State
    ( CryptState(..)
    , MacState(..)
    , RecordOptions(..)
    , RecordState(..)
    , newRecordState
    , incrRecordState
    , RecordM
    , runRecordM
    , getRecordOptions
    , getRecordVersion
    , setRecordIV
    , withCompression
    , computeDigest
    , makeDigest
    , getBulk
    , getMacSequence
    ) where
import Control.Monad.State.Strict
import Network.TLS.Compression
import Network.TLS.Cipher
import Network.TLS.ErrT
import Network.TLS.Struct
import Network.TLS.Wire
import Network.TLS.Packet
import Network.TLS.MAC
import Network.TLS.Util
import Network.TLS.Imports
import qualified Data.ByteString as B
data CryptState = CryptState
    { cstKey        :: !BulkState
    , cstIV         :: !ByteString
    
    
    , cstMacSecret  :: !ByteString
    } deriving (Show)
newtype MacState = MacState
    { msSequence :: Word64
    } deriving (Show)
data RecordOptions = RecordOptions
    { recordVersion :: Version                
    , recordTLS13 :: Bool                     
    }
data RecordState = RecordState
    { stCipher      :: Maybe Cipher
    , stCompression :: Compression
    , stCryptState  :: !CryptState
    , stMacState    :: !MacState
    } deriving (Show)
newtype RecordM a = RecordM { runRecordM :: RecordOptions
                                         -> RecordState
                                         -> Either TLSError (a, RecordState) }
instance Applicative RecordM where
    pure = return
    (<*>) = ap
instance Monad RecordM where
    return a  = RecordM $ \_ st  -> Right (a, st)
    m1 >>= m2 = RecordM $ \opt st ->
                    case runRecordM m1 opt st of
                        Left err       -> Left err
                        Right (a, st2) -> runRecordM (m2 a) opt st2
instance Functor RecordM where
    fmap f m = RecordM $ \opt st ->
                case runRecordM m opt st of
                    Left err       -> Left err
                    Right (a, st2) -> Right (f a, st2)
getRecordOptions :: RecordM RecordOptions
getRecordOptions = RecordM $ \opt st -> Right (opt, st)
getRecordVersion :: RecordM Version
getRecordVersion = recordVersion <$> getRecordOptions
instance MonadState RecordState RecordM where
    put x = RecordM $ \_  _  -> Right ((), x)
    get   = RecordM $ \_  st -> Right (st, st)
#if MIN_VERSION_mtl(2,1,0)
    state f = RecordM $ \_ st -> Right (f st)
#endif
instance MonadError TLSError RecordM where
    throwError e   = RecordM $ \_ _ -> Left e
    catchError m f = RecordM $ \opt st ->
                        case runRecordM m opt st of
                            Left err -> runRecordM (f err) opt st
                            r        -> r
newRecordState :: RecordState
newRecordState = RecordState
    { stCipher      = Nothing
    , stCompression = nullCompression
    , stCryptState  = CryptState BulkStateUninitialized B.empty B.empty
    , stMacState    = MacState 0
    }
incrRecordState :: RecordState -> RecordState
incrRecordState ts = ts { stMacState = MacState (ms + 1) }
  where (MacState ms) = stMacState ts
setRecordIV :: ByteString -> RecordState -> RecordState
setRecordIV iv st = st { stCryptState = (stCryptState st) { cstIV = iv } }
withCompression :: (Compression -> (Compression, a)) -> RecordM a
withCompression f = do
    st <- get
    let (nc, a) = f $ stCompression st
    put $ st { stCompression = nc }
    return a
computeDigest :: Version -> RecordState -> Header -> ByteString -> (ByteString, RecordState)
computeDigest ver tstate hdr content = (digest, incrRecordState tstate)
  where digest = macF (cstMacSecret cst) msg
        cst    = stCryptState tstate
        cipher = fromJust "cipher" $ stCipher tstate
        hashA  = cipherHash cipher
        encodedSeq = encodeWord64 $ msSequence $ stMacState tstate
        (macF, msg)
            | ver < TLS10 = (macSSL hashA, B.concat [ encodedSeq, encodeHeaderNoVer hdr, content ])
            | otherwise   = (hmac hashA, B.concat [ encodedSeq, encodeHeader hdr, content ])
makeDigest :: Header -> ByteString -> RecordM ByteString
makeDigest hdr content = do
    ver <- getRecordVersion
    st <- get
    let (digest, nstate) = computeDigest ver st hdr content
    put nstate
    return digest
getBulk :: RecordM Bulk
getBulk = cipherBulk . fromJust "cipher" . stCipher <$> get
getMacSequence :: RecordM Word64
getMacSequence = msSequence . stMacState <$> get