{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
-- |
-- Module      : Network.TLS.Context.Internal
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Context.Internal
    (
    -- * Context configuration
      ClientParams(..)
    , ServerParams(..)
    , defaultParamsClient
    , SessionID
    , SessionData(..)
    , MaxFragmentEnum(..)
    , Measurement(..)

    -- * Context object and accessor
    , Context(..)
    , Hooks(..)
    , Established(..)
    , PendingAction(..)
    , ctxEOF
    , ctxHasSSLv2ClientHello
    , ctxDisableSSLv2ClientHello
    , ctxEstablished
    , withLog
    , ctxWithHooks
    , contextModifyHooks
    , setEOF
    , setEstablished
    , contextFlush
    , contextClose
    , contextSend
    , contextRecv
    , updateRecordLayer
    , updateMeasure
    , withMeasure
    , withReadLock
    , withWriteLock
    , withStateLock
    , withRWLock

    -- * information
    , Information(..)
    , contextGetInformation

    -- * Using context states
    , throwCore
    , failOnEitherError
    , usingState
    , usingState_
    , runTxState
    , runRxState
    , usingHState
    , getHState
    , saveHState
    , restoreHState
    , getStateRNG
    , tls13orLater
    , addCertRequest13
    , getCertRequest13
    , decideRecordVersion

    -- * Misc
    , HandshakeSync(..)
    ) where

import Network.TLS.Backend
import Network.TLS.Cipher
import Network.TLS.Compression (Compression)
import Network.TLS.Extension
import Network.TLS.Handshake.Control
import Network.TLS.Handshake.State
import Network.TLS.Hooks
import Network.TLS.Imports
import Network.TLS.Measurement
import Network.TLS.Parameters
import Network.TLS.Record.Layer
import Network.TLS.Record.State
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.Types
import Network.TLS.Util

import Control.Concurrent.MVar
import Control.Exception (throwIO)
import Control.Monad.State.Strict
import qualified Data.ByteString as B
import Data.IORef
import Data.Tuple

-- | Information related to a running context, e.g. current cipher
data Information = Information
    { Information -> Version
infoVersion      :: Version
    , Information -> Cipher
infoCipher       :: Cipher
    , Information -> Compression
infoCompression  :: Compression
    , Information -> Maybe ByteString
infoMasterSecret :: Maybe ByteString
    , Information -> Bool
infoExtendedMasterSec   :: Bool
    , Information -> Maybe ClientRandom
infoClientRandom :: Maybe ClientRandom
    , Information -> Maybe ServerRandom
infoServerRandom :: Maybe ServerRandom
    , Information -> Maybe Group
infoNegotiatedGroup     :: Maybe Group
    , Information -> Maybe HandshakeMode13
infoTLS13HandshakeMode  :: Maybe HandshakeMode13
    , Information -> Bool
infoIsEarlyDataAccepted :: Bool
    } deriving (Int -> Information -> ShowS
[Information] -> ShowS
Information -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Information] -> ShowS
$cshowList :: [Information] -> ShowS
show :: Information -> String
$cshow :: Information -> String
showsPrec :: Int -> Information -> ShowS
$cshowsPrec :: Int -> Information -> ShowS
Show,Information -> Information -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Information -> Information -> Bool
$c/= :: Information -> Information -> Bool
== :: Information -> Information -> Bool
$c== :: Information -> Information -> Bool
Eq)

-- | A TLS Context keep tls specific state, parameters and backend information.
data Context = forall bytes . Monoid bytes => Context
    { Context -> Backend
ctxConnection       :: Backend   -- ^ return the backend object associated with this context
    , Context -> Supported
ctxSupported        :: Supported
    , Context -> Shared
ctxShared           :: Shared
    , Context -> MVar TLSState
ctxState            :: MVar TLSState
    , Context -> IORef Measurement
ctxMeasurement      :: IORef Measurement
    , Context -> IORef Bool
ctxEOF_             :: IORef Bool    -- ^ has the handle EOFed or not.
    , Context -> IORef Established
ctxEstablished_     :: IORef Established -- ^ has the handshake been done and been successful.
    , Context -> IORef Bool
ctxNeedEmptyPacket  :: IORef Bool    -- ^ empty packet workaround for CBC guessability.
    , Context -> IORef Bool
ctxSSLv2ClientHello :: IORef Bool    -- ^ enable the reception of compatibility SSLv2 client hello.
                                           -- the flag will be set to false regardless of its initial value
                                           -- after the first packet received.
    , Context -> Maybe Int
ctxFragmentSize     :: Maybe Int        -- ^ maximum size of plaintext fragments
    , Context -> MVar RecordState
ctxTxState          :: MVar RecordState -- ^ current tx state
    , Context -> MVar RecordState
ctxRxState          :: MVar RecordState -- ^ current rx state
    , Context -> MVar (Maybe HandshakeState)
ctxHandshake        :: MVar (Maybe HandshakeState) -- ^ optional handshake state
    , Context -> Context -> IO ()
ctxDoHandshake      :: Context -> IO ()
    , Context -> Context -> Handshake -> IO ()
ctxDoHandshakeWith  :: Context -> Handshake -> IO ()
    , Context -> Context -> IO Bool
ctxDoRequestCertificate :: Context -> IO Bool
    , Context -> Context -> Handshake13 -> IO ()
ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
    , Context -> IORef Hooks
ctxHooks            :: IORef Hooks   -- ^ hooks for this context
    , Context -> MVar ()
ctxLockWrite        :: MVar ()       -- ^ lock to use for writing data (including updating the state)
    , Context -> MVar ()
ctxLockRead         :: MVar ()       -- ^ lock to use for reading data (including updating the state)
    , Context -> MVar ()
ctxLockState        :: MVar ()       -- ^ lock used during read/write when receiving and sending packet.
                                           -- it is usually nested in a write or read lock.
    , Context -> IORef [PendingAction]
ctxPendingActions   :: IORef [PendingAction]
    , Context -> IORef [Handshake13]
ctxCertRequests     :: IORef [Handshake13]  -- ^ pending PHA requests
    , Context -> String -> IO ()
ctxKeyLogger        :: String -> IO ()
    , ()
ctxRecordLayer      :: RecordLayer bytes
    , Context -> HandshakeSync
ctxHandshakeSync    :: HandshakeSync
    , Context -> Bool
ctxQUICMode         :: Bool
    , Context -> IORef (Maybe ByteString)
ctxFinished         :: IORef (Maybe FinishedData)
    , Context -> IORef (Maybe ByteString)
ctxPeerFinished     :: IORef (Maybe FinishedData)
    }

data HandshakeSync = HandshakeSync (Context -> ClientState -> IO ())
                                   (Context -> ServerState -> IO ())

updateRecordLayer :: Monoid bytes => RecordLayer bytes -> Context -> Context
updateRecordLayer :: forall bytes.
Monoid bytes =>
RecordLayer bytes -> Context -> Context
updateRecordLayer RecordLayer bytes
recordLayer Context{Bool
Maybe Int
IORef Bool
IORef [Handshake13]
IORef [PendingAction]
IORef (Maybe ByteString)
IORef Measurement
IORef Hooks
IORef Established
MVar (Maybe HandshakeState)
MVar ()
MVar RecordState
MVar TLSState
Backend
RecordLayer bytes
Shared
Supported
HandshakeSync
String -> IO ()
Context -> IO Bool
Context -> IO ()
Context -> Handshake -> IO ()
Context -> Handshake13 -> IO ()
ctxPeerFinished :: IORef (Maybe ByteString)
ctxFinished :: IORef (Maybe ByteString)
ctxQUICMode :: Bool
ctxHandshakeSync :: HandshakeSync
ctxRecordLayer :: RecordLayer bytes
ctxKeyLogger :: String -> IO ()
ctxCertRequests :: IORef [Handshake13]
ctxPendingActions :: IORef [PendingAction]
ctxLockState :: MVar ()
ctxLockRead :: MVar ()
ctxLockWrite :: MVar ()
ctxHooks :: IORef Hooks
ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
ctxDoRequestCertificate :: Context -> IO Bool
ctxDoHandshakeWith :: Context -> Handshake -> IO ()
ctxDoHandshake :: Context -> IO ()
ctxHandshake :: MVar (Maybe HandshakeState)
ctxRxState :: MVar RecordState
ctxTxState :: MVar RecordState
ctxFragmentSize :: Maybe Int
ctxSSLv2ClientHello :: IORef Bool
ctxNeedEmptyPacket :: IORef Bool
ctxEstablished_ :: IORef Established
ctxEOF_ :: IORef Bool
ctxMeasurement :: IORef Measurement
ctxState :: MVar TLSState
ctxShared :: Shared
ctxSupported :: Supported
ctxConnection :: Backend
ctxPeerFinished :: Context -> IORef (Maybe ByteString)
ctxFinished :: Context -> IORef (Maybe ByteString)
ctxQUICMode :: Context -> Bool
ctxHandshakeSync :: Context -> HandshakeSync
ctxRecordLayer :: ()
ctxKeyLogger :: Context -> String -> IO ()
ctxCertRequests :: Context -> IORef [Handshake13]
ctxPendingActions :: Context -> IORef [PendingAction]
ctxLockState :: Context -> MVar ()
ctxLockRead :: Context -> MVar ()
ctxLockWrite :: Context -> MVar ()
ctxHooks :: Context -> IORef Hooks
ctxDoPostHandshakeAuthWith :: Context -> Context -> Handshake13 -> IO ()
ctxDoRequestCertificate :: Context -> Context -> IO Bool
ctxDoHandshakeWith :: Context -> Context -> Handshake -> IO ()
ctxDoHandshake :: Context -> Context -> IO ()
ctxHandshake :: Context -> MVar (Maybe HandshakeState)
ctxRxState :: Context -> MVar RecordState
ctxTxState :: Context -> MVar RecordState
ctxFragmentSize :: Context -> Maybe Int
ctxSSLv2ClientHello :: Context -> IORef Bool
ctxNeedEmptyPacket :: Context -> IORef Bool
ctxEstablished_ :: Context -> IORef Established
ctxEOF_ :: Context -> IORef Bool
ctxMeasurement :: Context -> IORef Measurement
ctxState :: Context -> MVar TLSState
ctxShared :: Context -> Shared
ctxSupported :: Context -> Supported
ctxConnection :: Context -> Backend
..} =
    Context { ctxRecordLayer :: RecordLayer bytes
ctxRecordLayer = RecordLayer bytes
recordLayer, Bool
Maybe Int
IORef Bool
IORef [Handshake13]
IORef [PendingAction]
IORef (Maybe ByteString)
IORef Measurement
IORef Hooks
IORef Established
MVar (Maybe HandshakeState)
MVar ()
MVar RecordState
MVar TLSState
Backend
Shared
Supported
HandshakeSync
String -> IO ()
Context -> IO Bool
Context -> IO ()
Context -> Handshake -> IO ()
Context -> Handshake13 -> IO ()
ctxPeerFinished :: IORef (Maybe ByteString)
ctxFinished :: IORef (Maybe ByteString)
ctxQUICMode :: Bool
ctxHandshakeSync :: HandshakeSync
ctxKeyLogger :: String -> IO ()
ctxCertRequests :: IORef [Handshake13]
ctxPendingActions :: IORef [PendingAction]
ctxLockState :: MVar ()
ctxLockRead :: MVar ()
ctxLockWrite :: MVar ()
ctxHooks :: IORef Hooks
ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
ctxDoRequestCertificate :: Context -> IO Bool
ctxDoHandshakeWith :: Context -> Handshake -> IO ()
ctxDoHandshake :: Context -> IO ()
ctxHandshake :: MVar (Maybe HandshakeState)
ctxRxState :: MVar RecordState
ctxTxState :: MVar RecordState
ctxFragmentSize :: Maybe Int
ctxSSLv2ClientHello :: IORef Bool
ctxNeedEmptyPacket :: IORef Bool
ctxEstablished_ :: IORef Established
ctxEOF_ :: IORef Bool
ctxMeasurement :: IORef Measurement
ctxState :: MVar TLSState
ctxShared :: Shared
ctxSupported :: Supported
ctxConnection :: Backend
ctxPeerFinished :: IORef (Maybe ByteString)
ctxFinished :: IORef (Maybe ByteString)
ctxQUICMode :: Bool
ctxHandshakeSync :: HandshakeSync
ctxKeyLogger :: String -> IO ()
ctxCertRequests :: IORef [Handshake13]
ctxPendingActions :: IORef [PendingAction]
ctxLockState :: MVar ()
ctxLockRead :: MVar ()
ctxLockWrite :: MVar ()
ctxHooks :: IORef Hooks
ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
ctxDoRequestCertificate :: Context -> IO Bool
ctxDoHandshakeWith :: Context -> Handshake -> IO ()
ctxDoHandshake :: Context -> IO ()
ctxHandshake :: MVar (Maybe HandshakeState)
ctxRxState :: MVar RecordState
ctxTxState :: MVar RecordState
ctxFragmentSize :: Maybe Int
ctxSSLv2ClientHello :: IORef Bool
ctxNeedEmptyPacket :: IORef Bool
ctxEstablished_ :: IORef Established
ctxEOF_ :: IORef Bool
ctxMeasurement :: IORef Measurement
ctxState :: MVar TLSState
ctxShared :: Shared
ctxSupported :: Supported
ctxConnection :: Backend
.. }

data Established = NotEstablished
                 | EarlyDataAllowed Int    -- remaining 0-RTT bytes allowed
                 | EarlyDataNotAllowed Int -- remaining 0-RTT packets allowed to skip
                 | Established
                 deriving (Established -> Established -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Established -> Established -> Bool
$c/= :: Established -> Established -> Bool
== :: Established -> Established -> Bool
$c== :: Established -> Established -> Bool
Eq, Int -> Established -> ShowS
[Established] -> ShowS
Established -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Established] -> ShowS
$cshowList :: [Established] -> ShowS
show :: Established -> String
$cshow :: Established -> String
showsPrec :: Int -> Established -> ShowS
$cshowsPrec :: Int -> Established -> ShowS
Show)

data PendingAction
    = PendingAction Bool (Handshake13 -> IO ())
      -- ^ simple pending action
    | PendingActionHash Bool (ByteString -> Handshake13 -> IO ())
      -- ^ pending action taking transcript hash up to preceding message

updateMeasure :: Context -> (Measurement -> Measurement) -> IO ()
updateMeasure :: Context -> (Measurement -> Measurement) -> IO ()
updateMeasure Context
ctx = forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' (Context -> IORef Measurement
ctxMeasurement Context
ctx)

withMeasure :: Context -> (Measurement -> IO a) -> IO a
withMeasure :: forall a. Context -> (Measurement -> IO a) -> IO a
withMeasure Context
ctx Measurement -> IO a
f = forall a. IORef a -> IO a
readIORef (Context -> IORef Measurement
ctxMeasurement Context
ctx) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Measurement -> IO a
f

-- | A shortcut for 'backendFlush . ctxConnection'.
contextFlush :: Context -> IO ()
contextFlush :: Context -> IO ()
contextFlush = Backend -> IO ()
backendFlush forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> Backend
ctxConnection

-- | A shortcut for 'backendClose . ctxConnection'.
contextClose :: Context -> IO ()
contextClose :: Context -> IO ()
contextClose = Backend -> IO ()
backendClose forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> Backend
ctxConnection

-- | Information about the current context
contextGetInformation :: Context -> IO (Maybe Information)
contextGetInformation :: Context -> IO (Maybe Information)
contextGetInformation Context
ctx = do
    Maybe Version
ver    <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets TLSState -> Maybe Version
stVersion
    Maybe HandshakeState
hstate <- forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx
    let (Maybe ByteString
ms, Bool
ems, Maybe ClientRandom
cr, Maybe ServerRandom
sr, Maybe HandshakeMode13
hm13, Maybe Group
grp) =
            case Maybe HandshakeState
hstate of
                Just HandshakeState
st -> (HandshakeState -> Maybe ByteString
hstMasterSecret HandshakeState
st,
                            HandshakeState -> Bool
hstExtendedMasterSec HandshakeState
st,
                            forall a. a -> Maybe a
Just (HandshakeState -> ClientRandom
hstClientRandom HandshakeState
st),
                            HandshakeState -> Maybe ServerRandom
hstServerRandom HandshakeState
st,
                            if Maybe Version
ver forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Version
TLS13 then forall a. a -> Maybe a
Just (HandshakeState -> HandshakeMode13
hstTLS13HandshakeMode HandshakeState
st) else forall a. Maybe a
Nothing,
                            HandshakeState -> Maybe Group
hstNegotiatedGroup HandshakeState
st)
                Maybe HandshakeState
Nothing -> (forall a. Maybe a
Nothing, Bool
False, forall a. Maybe a
Nothing, forall a. Maybe a
Nothing, forall a. Maybe a
Nothing, forall a. Maybe a
Nothing)
    (Maybe Cipher
cipher,Compression
comp) <- forall a. MVar a -> IO a
readMVar (Context -> MVar RecordState
ctxRxState Context
ctx) forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \RecordState
st -> (RecordState -> Maybe Cipher
stCipher RecordState
st, RecordState -> Compression
stCompression RecordState
st)
    let accepted :: Bool
accepted = case Maybe HandshakeState
hstate of
            Just HandshakeState
st -> HandshakeState -> RTT0Status
hstTLS13RTT0Status HandshakeState
st forall a. Eq a => a -> a -> Bool
== RTT0Status
RTT0Accepted
            Maybe HandshakeState
Nothing -> Bool
False
    case (Maybe Version
ver, Maybe Cipher
cipher) of
        (Just Version
v, Just Cipher
c) -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Version
-> Cipher
-> Compression
-> Maybe ByteString
-> Bool
-> Maybe ClientRandom
-> Maybe ServerRandom
-> Maybe Group
-> Maybe HandshakeMode13
-> Bool
-> Information
Information Version
v Cipher
c Compression
comp Maybe ByteString
ms Bool
ems Maybe ClientRandom
cr Maybe ServerRandom
sr Maybe Group
grp Maybe HandshakeMode13
hm13 Bool
accepted
        (Maybe Version, Maybe Cipher)
_                -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing

contextSend :: Context -> ByteString -> IO ()
contextSend :: Context -> ByteString -> IO ()
contextSend Context
c ByteString
b = Context -> (Measurement -> Measurement) -> IO ()
updateMeasure Context
c (Int -> Measurement -> Measurement
addBytesSent forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Backend -> ByteString -> IO ()
backendSend forall a b. (a -> b) -> a -> b
$ Context -> Backend
ctxConnection Context
c) ByteString
b

contextRecv :: Context -> Int -> IO ByteString
contextRecv :: Context -> Int -> IO ByteString
contextRecv Context
c Int
sz = Context -> (Measurement -> Measurement) -> IO ()
updateMeasure Context
c (Int -> Measurement -> Measurement
addBytesReceived Int
sz) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (Backend -> Int -> IO ByteString
backendRecv forall a b. (a -> b) -> a -> b
$ Context -> Backend
ctxConnection Context
c) Int
sz

ctxEOF :: Context -> IO Bool
ctxEOF :: Context -> IO Bool
ctxEOF Context
ctx = forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ Context -> IORef Bool
ctxEOF_ Context
ctx

ctxHasSSLv2ClientHello :: Context -> IO Bool
ctxHasSSLv2ClientHello :: Context -> IO Bool
ctxHasSSLv2ClientHello Context
ctx = forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ Context -> IORef Bool
ctxSSLv2ClientHello Context
ctx

ctxDisableSSLv2ClientHello :: Context -> IO ()
ctxDisableSSLv2ClientHello :: Context -> IO ()
ctxDisableSSLv2ClientHello Context
ctx = forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Bool
ctxSSLv2ClientHello Context
ctx) Bool
False

setEOF :: Context -> IO ()
setEOF :: Context -> IO ()
setEOF Context
ctx = forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Bool
ctxEOF_ Context
ctx) Bool
True

ctxEstablished :: Context -> IO Established
ctxEstablished :: Context -> IO Established
ctxEstablished Context
ctx = forall a. IORef a -> IO a
readIORef forall a b. (a -> b) -> a -> b
$ Context -> IORef Established
ctxEstablished_ Context
ctx

ctxWithHooks :: Context -> (Hooks -> IO a) -> IO a
ctxWithHooks :: forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx Hooks -> IO a
f = forall a. IORef a -> IO a
readIORef (Context -> IORef Hooks
ctxHooks Context
ctx) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Hooks -> IO a
f

contextModifyHooks :: Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks :: Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
ctx = forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (Context -> IORef Hooks
ctxHooks Context
ctx)

setEstablished :: Context -> Established -> IO ()
setEstablished :: Context -> Established -> IO ()
setEstablished Context
ctx = forall a. IORef a -> a -> IO ()
writeIORef (Context -> IORef Established
ctxEstablished_ Context
ctx)

withLog :: Context -> (Logging -> IO ()) -> IO ()
withLog :: Context -> (Logging -> IO ()) -> IO ()
withLog Context
ctx Logging -> IO ()
f = forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx (Logging -> IO ()
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Hooks -> Logging
hookLogging)

throwCore :: MonadIO m => TLSError -> m a
throwCore :: forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSError -> TLSException
Uncontextualized

failOnEitherError :: MonadIO m => m (Either TLSError a) -> m a
failOnEitherError :: forall (m :: * -> *) a. MonadIO m => m (Either TLSError a) -> m a
failOnEitherError m (Either TLSError a)
f = do
    Either TLSError a
ret <- m (Either TLSError a)
f
    case Either TLSError a
ret of
        Left TLSError
err -> forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
err
        Right a
r  -> forall (m :: * -> *) a. Monad m => a -> m a
return a
r

usingState :: Context -> TLSSt a -> IO (Either TLSError a)
usingState :: forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx TLSSt a
f =
    forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Context -> MVar TLSState
ctxState Context
ctx) forall a b. (a -> b) -> a -> b
$ \TLSState
st ->
            let (Either TLSError a
a, TLSState
newst) = forall a. TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState TLSSt a
f TLSState
st
             in TLSState
newst seq :: forall a b. a -> b -> b
`seq` forall (m :: * -> *) a. Monad m => a -> m a
return (TLSState
newst, Either TLSError a
a)

usingState_ :: Context -> TLSSt a -> IO a
usingState_ :: forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt a
f = forall (m :: * -> *) a. MonadIO m => m (Either TLSError a) -> m a
failOnEitherError forall a b. (a -> b) -> a -> b
$ forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx TLSSt a
f

usingHState :: MonadIO m => Context -> HandshakeM a -> m a
usingHState :: forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM a
f = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx) forall a b. (a -> b) -> a -> b
$ \Maybe HandshakeState
mst ->
    case Maybe HandshakeState
mst of
        Maybe HandshakeState
Nothing -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall e a. Exception e => e -> IO a
throwIO forall a b. (a -> b) -> a -> b
$ TLSException
MissingHandshake
        Just HandshakeState
st -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> (b, a)
swap (forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. HandshakeState -> HandshakeM a -> (a, HandshakeState)
runHandshake HandshakeState
st HandshakeM a
f)

getHState :: MonadIO m => Context -> m (Maybe HandshakeState)
getHState :: forall (m :: * -> *).
MonadIO m =>
Context -> m (Maybe HandshakeState)
getHState Context
ctx = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
readMVar (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx)

saveHState :: Context -> IO (Saved (Maybe HandshakeState))
saveHState :: Context -> IO (Saved (Maybe HandshakeState))
saveHState Context
ctx = forall a. MVar a -> IO (Saved a)
saveMVar (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx)

restoreHState :: Context
              -> Saved (Maybe HandshakeState)
              -> IO (Saved (Maybe HandshakeState))
restoreHState :: Context
-> Saved (Maybe HandshakeState)
-> IO (Saved (Maybe HandshakeState))
restoreHState Context
ctx = forall a. MVar a -> Saved a -> IO (Saved a)
restoreMVar (Context -> MVar (Maybe HandshakeState)
ctxHandshake Context
ctx)

decideRecordVersion :: Context -> IO (Version, Bool)
decideRecordVersion :: Context -> IO (Version, Bool)
decideRecordVersion Context
ctx = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ do
    Version
ver <- Version -> TLSSt Version
getVersionWithDefault (forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ Supported -> [Version]
supportedVersions forall a b. (a -> b) -> a -> b
$ Context -> Supported
ctxSupported Context
ctx)
    Bool
hrr <- TLSSt Bool
getTLS13HRR
    -- For TLS 1.3, ver' is only used in ClientHello.
    -- The record version of the first ClientHello SHOULD be TLS 1.0.
    -- The record version of the second ClientHello MUST be TLS 1.2.
    let ver' :: Version
ver'
         | Version
ver forall a. Ord a => a -> a -> Bool
>= Version
TLS13 = if Bool
hrr then Version
TLS12 else Version
TLS10
         | Bool
otherwise    = Version
ver
    forall (m :: * -> *) a. Monad m => a -> m a
return (Version
ver', Version
ver forall a. Ord a => a -> a -> Bool
>= Version
TLS13)

runTxState :: Context -> RecordM a -> IO (Either TLSError a)
runTxState :: forall a. Context -> RecordM a -> IO (Either TLSError a)
runTxState Context
ctx RecordM a
f = do
    (Version
ver, Bool
tls13) <- Context -> IO (Version, Bool)
decideRecordVersion Context
ctx
    let opt :: RecordOptions
opt = RecordOptions { recordVersion :: Version
recordVersion = Version
ver
                            , recordTLS13 :: Bool
recordTLS13   = Bool
tls13
                            }
    forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Context -> MVar RecordState
ctxTxState Context
ctx) forall a b. (a -> b) -> a -> b
$ \RecordState
st ->
        case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
f RecordOptions
opt RecordState
st of
            Left TLSError
err         -> forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState
st, forall a b. a -> Either a b
Left TLSError
err)
            Right (a
a, RecordState
newSt) -> forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState
newSt, forall a b. b -> Either a b
Right a
a)

runRxState :: Context -> RecordM a -> IO (Either TLSError a)
runRxState :: forall a. Context -> RecordM a -> IO (Either TLSError a)
runRxState Context
ctx RecordM a
f = do
    Version
ver <- forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx TLSSt Version
getVersion
    -- For 1.3, ver is just ignored. So, it is not necessary to convert ver.
    let opt :: RecordOptions
opt = RecordOptions { recordVersion :: Version
recordVersion = Version
ver
                            , recordTLS13 :: Bool
recordTLS13   = Version
ver forall a. Ord a => a -> a -> Bool
>= Version
TLS13
                            }
    forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (Context -> MVar RecordState
ctxRxState Context
ctx) forall a b. (a -> b) -> a -> b
$ \RecordState
st ->
        case forall a.
RecordM a
-> RecordOptions -> RecordState -> Either TLSError (a, RecordState)
runRecordM RecordM a
f RecordOptions
opt RecordState
st of
            Left TLSError
err         -> forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState
st, forall a b. a -> Either a b
Left TLSError
err)
            Right (a
a, RecordState
newSt) -> forall (m :: * -> *) a. Monad m => a -> m a
return (RecordState
newSt, forall a b. b -> Either a b
Right a
a)

getStateRNG :: Context -> Int -> IO ByteString
getStateRNG :: Context -> Int -> IO ByteString
getStateRNG Context
ctx Int
n = forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx forall a b. (a -> b) -> a -> b
$ Int -> TLSSt ByteString
genRandom Int
n

withReadLock :: Context -> IO a -> IO a
withReadLock :: forall a. Context -> IO a -> IO a
withReadLock Context
ctx IO a
f = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (Context -> MVar ()
ctxLockRead Context
ctx) (forall a b. a -> b -> a
const IO a
f)

withWriteLock :: Context -> IO a -> IO a
withWriteLock :: forall a. Context -> IO a -> IO a
withWriteLock Context
ctx IO a
f = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (Context -> MVar ()
ctxLockWrite Context
ctx) (forall a b. a -> b -> a
const IO a
f)

withRWLock :: Context -> IO a -> IO a
withRWLock :: forall a. Context -> IO a -> IO a
withRWLock Context
ctx IO a
f = forall a. Context -> IO a -> IO a
withReadLock Context
ctx forall a b. (a -> b) -> a -> b
$ forall a. Context -> IO a -> IO a
withWriteLock Context
ctx IO a
f

withStateLock :: Context -> IO a -> IO a
withStateLock :: forall a. Context -> IO a -> IO a
withStateLock Context
ctx IO a
f = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (Context -> MVar ()
ctxLockState Context
ctx) (forall a b. a -> b -> a
const IO a
f)

tls13orLater :: MonadIO m => Context -> m Bool
tls13orLater :: forall (m :: * -> *). MonadIO m => Context -> m Bool
tls13orLater Context
ctx = do
    Either TLSError Version
ev <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. Context -> TLSSt a -> IO (Either TLSError a)
usingState Context
ctx forall a b. (a -> b) -> a -> b
$ Version -> TLSSt Version
getVersionWithDefault Version
TLS10 -- fixme
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ case Either TLSError Version
ev of
               Left  TLSError
_ -> Bool
False
               Right Version
v -> Version
v forall a. Ord a => a -> a -> Bool
>= Version
TLS13

addCertRequest13 :: Context -> Handshake13 -> IO ()
addCertRequest13 :: Context -> Handshake13 -> IO ()
addCertRequest13 Context
ctx Handshake13
certReq = forall a. IORef a -> (a -> a) -> IO ()
modifyIORef (Context -> IORef [Handshake13]
ctxCertRequests Context
ctx) (Handshake13
certReqforall a. a -> [a] -> [a]
:)

getCertRequest13 :: Context -> CertReqContext -> IO (Maybe Handshake13)
getCertRequest13 :: Context -> ByteString -> IO (Maybe Handshake13)
getCertRequest13 Context
ctx ByteString
context = do
    let ref :: IORef [Handshake13]
ref = Context -> IORef [Handshake13]
ctxCertRequests Context
ctx
    [Handshake13]
l <- forall a. IORef a -> IO a
readIORef IORef [Handshake13]
ref
    let ([Handshake13]
matched, [Handshake13]
others) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition (\(CertRequest13 ByteString
c [ExtensionRaw]
_) -> ByteString
context forall a. Eq a => a -> a -> Bool
== ByteString
c) [Handshake13]
l
    case [Handshake13]
matched of
        []          -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        (Handshake13
certReq:[Handshake13]
_) -> forall a. IORef a -> a -> IO ()
writeIORef IORef [Handshake13]
ref [Handshake13]
others forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just Handshake13
certReq)