-- |
-- Module      : Network.TLS.Receiving
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Receiving module contains calls related to unmarshalling packets according
-- to the TLS state
--
{-# LANGUAGE FlexibleContexts #-}

module Network.TLS.Receiving
    ( processPacket
    , processPacket13
    ) where

import Network.TLS.Cipher
import Network.TLS.Context.Internal
import Network.TLS.ErrT
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Packet
import Network.TLS.Packet13
import Network.TLS.Record
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Util
import Network.TLS.Wire

import Control.Concurrent.MVar
import Control.Monad.State.Strict

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)

processPacket :: Context -> Record Plaintext -> IO (Either TLSError Packet)
processPacket Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ ByteString -> Packet
AppData forall a b. (a -> b) -> a -> b
$ forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment

processPacket Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet
Alert forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))

processPacket Context
ctx (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError ()
decodeChangeCipherSpec forall a b. (a -> b) -> a -> b
$ forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left TLSError
err
        Right ()
_  -> do Context -> IO ()
switchRxEncryption Context
ctx
                       forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right Packet
ChangeCipherSpec

processPacket Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
ver Fragment Plaintext
fragment) = do
    Maybe CipherKeyExchangeType
keyxchg <- forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe HandshakeState
hs -> forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe HandshakeState
hs forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HandshakeState -> Maybe Cipher
hstPendingCipher forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. Cipher -> CipherKeyExchangeType
cipherKeyExchange)
    forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx forall a b. (a -> b) -> a -> b
$ do
        let currentParams :: CurrentParams
currentParams = CurrentParams
                            { cParamsVersion :: Version
cParamsVersion     = Version
ver
                            , cParamsKeyXchgType :: Maybe CipherKeyExchangeType
cParamsKeyXchgType = Maybe CipherKeyExchangeType
keyxchg
                            }
        -- get back the optional continuation, and parse as many handshake record as possible.
        Maybe (GetContinuation (HandshakeType, ByteString))
mCont <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont
        forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = forall a. Maybe a
Nothing })
        [Handshake]
hss   <- forall {m :: * -> *}.
(MonadError TLSError m, MonadState TLSState m) =>
CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont (forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake]
hss
  where parseMany :: CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs =
            case forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType, ByteString)
decodeHandshakeRecord Maybe (GetContinuation (HandshakeType, ByteString))
mCont ByteString
bs of
                GotError TLSError
err                -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                GotPartial GetContinuation (HandshakeType, ByteString)
cont             -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
stHandshakeRecordCont = forall a. a -> Maybe a
Just GetContinuation (HandshakeType, ByteString)
cont }) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return []
                GotSuccess (HandshakeType
ty,ByteString
content)     ->
                    forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[])) forall a b. (a -> b) -> a -> b
$ CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content
                GotSuccessRemaining (HandshakeType
ty,ByteString
content) ByteString
left ->
                    case CurrentParams
-> HandshakeType -> ByteString -> Either TLSError Handshake
decodeHandshake CurrentParams
currentParams HandshakeType
ty ByteString
content of
                        Left TLSError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                        Right Handshake
hh -> (Handshake
hhforall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CurrentParams
-> Maybe (GetContinuation (HandshakeType, ByteString))
-> ByteString
-> m [Handshake]
parseMany CurrentParams
currentParams forall a. Maybe a
Nothing ByteString
left

processPacket Context
_ (Record ProtocolType
ProtocolType_DeprecatedHandshake Version
_ Fragment Plaintext
fragment) =
    case ByteString -> Either TLSError Handshake
decodeDeprecatedHandshake forall a b. (a -> b) -> a -> b
$ forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment of
        Left TLSError
err -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. a -> Either a b
Left TLSError
err
        Right Handshake
hs -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ [Handshake] -> Packet
Handshake [Handshake
hs]

switchRxEncryption :: Context -> IO ()
switchRxEncryption :: Context -> IO ()
switchRxEncryption Context
ctx =
    forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets HandshakeState -> Maybe RecordState
hstPendingRxState) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe RecordState
rx ->
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (Context -> MVar RecordState
ctxRxState Context
ctx) (\RecordState
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. String -> Maybe a -> a
fromJust String
"rx-state" Maybe RecordState
rx)

----------------------------------------------------------------

processPacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 :: Context -> Record Plaintext -> IO (Either TLSError Packet13)
processPacket13 Context
_ (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
_) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right Packet13
ChangeCipherSpec13
processPacket13 Context
_ (Record ProtocolType
ProtocolType_AppData Version
_ Fragment Plaintext
fragment) = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ ByteString -> Packet13
AppData13 forall a b. (a -> b) -> a -> b
$ forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment
processPacket13 Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
fragment) = forall (m :: * -> *) a. Monad m => a -> m a
return ([(AlertLevel, AlertDescription)] -> Packet13
Alert13 forall a b l. (a -> b) -> Either l a -> Either l b
`fmapEither` ByteString -> Either TLSError [(AlertLevel, AlertDescription)]
decodeAlerts (forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment))
processPacket13 Context
ctx (Record ProtocolType
ProtocolType_Handshake Version
_ Fragment Plaintext
fragment) = forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx forall a b. (a -> b) -> a -> b
$ do
    Maybe (GetContinuation (HandshakeType13, ByteString))
mCont <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe (GetContinuation (HandshakeType13, ByteString))
stHandshakeRecordCont13
    forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont13 :: Maybe (GetContinuation (HandshakeType13, ByteString))
stHandshakeRecordCont13 = forall a. Maybe a
Nothing })
    [Handshake13]
hss <- forall {m :: * -> *}.
(MonadError TLSError m, MonadState TLSState m) =>
Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType13, ByteString))
mCont (forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
fragment)
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Handshake13] -> Packet13
Handshake13 [Handshake13]
hss
  where parseMany :: Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> m [Handshake13]
parseMany Maybe (GetContinuation (HandshakeType13, ByteString))
mCont ByteString
bs =
            case forall a. a -> Maybe a -> a
fromMaybe GetContinuation (HandshakeType13, ByteString)
decodeHandshakeRecord13 Maybe (GetContinuation (HandshakeType13, ByteString))
mCont ByteString
bs of
                GotError TLSError
err                -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                GotPartial GetContinuation (HandshakeType13, ByteString)
cont             -> forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\TLSState
st -> TLSState
st { stHandshakeRecordCont13 :: Maybe (GetContinuation (HandshakeType13, ByteString))
stHandshakeRecordCont13 = forall a. a -> Maybe a
Just GetContinuation (HandshakeType13, ByteString)
cont }) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return []
                GotSuccess (HandshakeType13
ty,ByteString
content)     ->
                    forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. a -> [a] -> [a]
:[])) forall a b. (a -> b) -> a -> b
$ HandshakeType13 -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType13
ty ByteString
content
                GotSuccessRemaining (HandshakeType13
ty,ByteString
content) ByteString
left ->
                    case HandshakeType13 -> ByteString -> Either TLSError Handshake13
decodeHandshake13 HandshakeType13
ty ByteString
content of
                        Left TLSError
err -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError TLSError
err
                        Right Handshake13
hh -> (Handshake13
hhforall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (GetContinuation (HandshakeType13, ByteString))
-> ByteString -> m [Handshake13]
parseMany forall a. Maybe a
Nothing ByteString
left
processPacket13 Context
_ (Record ProtocolType
ProtocolType_DeprecatedHandshake Version
_ Fragment Plaintext
_) =
    forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String -> TLSError
Error_Packet String
"deprecated handshake packet 1.3")