-- |
-- Module      : Network.TLS.Sending
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Sending module contains calls related to marshalling packets according
-- to the TLS state
--
module Network.TLS.Sending (
    encodePacket
  , encodePacket13
  , updateHandshake
  , updateHandshake13
  ) where

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.Handshake.Random
import Network.TLS.Handshake.State
import Network.TLS.Handshake.State13
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Parameters
import Network.TLS.Record
import Network.TLS.Record.Layer
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types (Role(..))
import Network.TLS.Util

import Control.Concurrent.MVar
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef

-- | encodePacket transform a packet into marshalled data related to current state
-- and updating state on the go
encodePacket :: Monoid bytes
             => Context -> RecordLayer bytes -> Packet -> IO (Either TLSError bytes)
encodePacket :: forall bytes.
Monoid bytes =>
Context
-> RecordLayer bytes -> Packet -> IO (Either TLSError bytes)
encodePacket Context
ctx RecordLayer bytes
recordLayer Packet
pkt = do
    (Version
ver, Bool
_) <- Context -> IO (Version, Bool)
decideRecordVersion Context
ctx
    let pt :: ProtocolType
pt = Packet -> ProtocolType
packetType Packet
pkt
        mkRecord :: ByteString -> Record Plaintext
mkRecord ByteString
bs = forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
ver (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
        len :: Maybe Int
len = Context -> Maybe Int
ctxFragmentSize Context
ctx
    [Record Plaintext]
records <- forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments Context
ctx Maybe Int
len Packet
pkt
    Either TLSError bytes
bs <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (forall bytes.
RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
recordEncode RecordLayer bytes
recordLayer)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Packet
pkt forall a. Eq a => a -> a -> Bool
== Packet
ChangeCipherSpec) forall a b. (a -> b) -> a -> b
$ Context -> IO ()
switchTxEncryption Context
ctx
    forall (m :: * -> *) a. Monad m => a -> m a
return Either TLSError bytes
bs

-- Decompose handshake packets into fragments of the specified length.  AppData
-- packets are not fragmented here but by callers of sendPacket, so that the
-- empty-packet countermeasure may be applied to each fragment independently.
packetToFragments :: Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments :: Context -> Maybe Int -> Packet -> IO [ByteString]
packetToFragments Context
ctx Maybe Int
len (Handshake [Handshake]
hss)  =
    Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
len forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
ClientRole) [Handshake]
hss
packetToFragments Context
_   Maybe Int
_   (Alert [(AlertLevel, AlertDescription)]
a)        = forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments Context
_   Maybe Int
_   Packet
ChangeCipherSpec = forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]
packetToFragments Context
_   Maybe Int
_   (AppData ByteString
x)      = forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]

switchTxEncryption :: Context -> IO ()
switchTxEncryption :: Context -> IO ()
switchTxEncryption Context
ctx = do
    RecordState
tx  <- forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (forall a. String -> Maybe a -> a
fromJust String
"tx-state" forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingTxState)
    (Version
ver, Role
cc) <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do Version
v <- TLSSt Version
getVersion
                                      Role
c <- TLSSt Role
isClientContext
                                      forall (m :: * -> *) a. Monad m => a -> m a
return (Version
v, Role
c)
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxTxState Context
ctx) (\RecordState
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return RecordState
tx)
    -- set empty packet counter measure if condition are met
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Version
ver forall a. Ord a => a -> a -> Bool
<= Version
TLS10 Bool -> Bool -> Bool
&& Role
cc forall a. Eq a => a -> a -> Bool
== Role
ClientRole Bool -> Bool -> Bool
&& RecordState -> Bool
isCBC RecordState
tx Bool -> Bool -> Bool
&& Supported -> Bool
supportedEmptyPacket (Context -> Supported
ctxSupported Context
ctx)) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Bool
ctxNeedEmptyPacket Context
ctx) Bool
True
  where isCBC :: RecordState -> Bool
isCBC RecordState
tx = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False (\Cipher
c -> Bulk -> Int
bulkBlockSize (Cipher -> Bulk
cipherBulk Cipher
c) forall a. Ord a => a -> a -> Bool
> Int
0) (RecordState -> Maybe Cipher
stCipher RecordState
tx)

updateHandshake :: Context -> Role -> Handshake -> IO ByteString
updateHandshake :: Context -> Role -> Handshake -> IO ByteString
updateHandshake Context
ctx Role
role Handshake
hs = do
    case Handshake
hs of
        Finished ByteString
fdata -> forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ Role -> ByteString -> TLSSt ()
updateVerifiedData Role
role ByteString
fdata
        Handshake
_              -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
    forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake -> Bool
certVerifyHandshakeMaterial Handshake
hs) forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (HandshakeType -> Bool
finishHandshakeTypeMaterial forall a b. (a -> b) -> a -> b
$ Handshake -> HandshakeType
typeOfHandshake Handshake
hs) forall a b. (a -> b) -> a -> b
$ ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
encoded
    forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
  where
    encoded :: ByteString
encoded = Handshake -> ByteString
encodeHandshake Handshake
hs

----------------------------------------------------------------

encodePacket13 :: Monoid bytes
               => Context -> RecordLayer bytes -> Packet13 -> IO (Either TLSError bytes)
encodePacket13 :: forall bytes.
Monoid bytes =>
Context
-> RecordLayer bytes -> Packet13 -> IO (Either TLSError bytes)
encodePacket13 Context
ctx RecordLayer bytes
recordLayer Packet13
pkt = do
    let pt :: ProtocolType
pt = Packet13 -> ProtocolType
contentType Packet13
pkt
        mkRecord :: ByteString -> Record Plaintext
mkRecord ByteString
bs = forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
pt Version
TLS12 (ByteString -> Fragment Plaintext
fragmentPlaintext ByteString
bs)
        len :: Maybe Int
len = Context -> Maybe Int
ctxFragmentSize Context
ctx
    [Record Plaintext]
records <- forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Record Plaintext
mkRecord forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 Context
ctx Maybe Int
len Packet13
pkt
    forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [Record Plaintext]
records (forall bytes.
RecordLayer bytes -> Record Plaintext -> IO (Either TLSError bytes)
recordEncode13 RecordLayer bytes
recordLayer)

packetToFragments13 :: Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 :: Context -> Maybe Int -> Packet13 -> IO [ByteString]
packetToFragments13 Context
ctx Maybe Int
len (Handshake13 [Handshake13]
hss)  =
    Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
len forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
B.concat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Context -> Handshake13 -> IO ByteString
updateHandshake13 Context
ctx) [Handshake13]
hss
packetToFragments13 Context
_   Maybe Int
_   (Alert13 [(AlertLevel, AlertDescription)]
a)        = forall (m :: * -> *) a. Monad m => a -> m a
return [[(AlertLevel, AlertDescription)] -> ByteString
encodeAlerts [(AlertLevel, AlertDescription)]
a]
packetToFragments13 Context
_   Maybe Int
_   (AppData13 ByteString
x)      = forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
x]
packetToFragments13 Context
_   Maybe Int
_   Packet13
ChangeCipherSpec13 = forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
encodeChangeCipherSpec]

updateHandshake13 :: Context -> Handshake13 -> IO ByteString
updateHandshake13 :: Context -> Handshake13 -> IO ByteString
updateHandshake13 Context
ctx Handshake13
hs
    | Handshake13 -> Bool
isIgnored Handshake13
hs = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
    | Bool
otherwise    = forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx forall a b. (a -> b) -> a -> b
$ do
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Handshake13 -> Bool
isHRR Handshake13
hs) HandshakeM ()
wrapAsMessageHash13
        ByteString -> HandshakeM ()
updateHandshakeDigest ByteString
encoded
        ByteString -> HandshakeM ()
addHandshakeMessage ByteString
encoded
        forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
encoded
  where
    encoded :: ByteString
encoded = Handshake13 -> ByteString
encodeHandshake13 Handshake13
hs

    isHRR :: Handshake13 -> Bool
isHRR (ServerHello13 ServerRandom
srand Session
_ CipherID
_ [ExtensionRaw]
_) = ServerRandom -> Bool
isHelloRetryRequest ServerRandom
srand
    isHRR Handshake13
_                           = Bool
False

    isIgnored :: Handshake13 -> Bool
isIgnored NewSessionTicket13{} = Bool
True
    isIgnored KeyUpdate13{}        = Bool
True
    isIgnored Handshake13
_                    = Bool
False