{-# LANGUAGE CPP #-}
-- |
-- Module      : Network.TLS.Context
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
module Network.TLS.Context
    (
    -- * Context configuration
      TLSParams

    -- * Context object and accessor
    , Context(..)
    , Hooks(..)
    , Established(..)
    , ctxEOF
    , ctxHasSSLv2ClientHello
    , ctxDisableSSLv2ClientHello
    , ctxEstablished
    , withLog
    , ctxWithHooks
    , contextModifyHooks
    , setEOF
    , setEstablished
    , contextFlush
    , contextClose
    , contextSend
    , contextRecv
    , updateMeasure
    , withMeasure
    , withReadLock
    , withWriteLock
    , withStateLock
    , withRWLock

    -- * information
    , Information(..)
    , contextGetInformation

    -- * New contexts
    , contextNew
    -- * Deprecated new contexts methods
    , contextNewOnHandle
#ifdef INCLUDE_NETWORK
    , contextNewOnSocket
#endif

    -- * Context hooks
    , contextHookSetHandshakeRecv
    , contextHookSetHandshake13Recv
    , contextHookSetCertificateRecv
    , contextHookSetLogging

    -- * Using context states
    , throwCore
    , usingState
    , usingState_
    , runTxState
    , runRxState
    , usingHState
    , getHState
    , getStateRNG
    , tls13orLater
    , getFinished
    , getPeerFinished
    ) where

import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Record.Layer
import Network.TLS.Record.Reading
import Network.TLS.Record.Writing
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.PostHandshake (requestCertificateServer, postHandshakeAuthClientWith, postHandshakeAuthServerWith)
import Network.TLS.X509
import Network.TLS.RNG

import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.IORef

-- deprecated imports
#ifdef INCLUDE_NETWORK
import Network.Socket (Socket)
#endif
import System.IO (Handle)

class TLSParams a where
    getTLSCommonParams :: a -> CommonParams
    getTLSRole         :: a -> Role
    doHandshake        :: a -> Context -> IO ()
    doHandshakeWith    :: a -> Context -> Handshake -> IO ()
    doRequestCertificate :: a -> Context -> IO Bool
    doPostHandshakeAuthWith :: a -> Context -> Handshake13 -> IO ()

instance TLSParams ClientParams where
    getTLSCommonParams :: ClientParams -> CommonParams
getTLSCommonParams ClientParams
cparams = ( ClientParams -> Supported
clientSupported ClientParams
cparams
                                 , ClientParams -> Shared
clientShared ClientParams
cparams
                                 , ClientParams -> DebugParams
clientDebug ClientParams
cparams
                                 )
    getTLSRole :: ClientParams -> Role
getTLSRole ClientParams
_ = Role
ClientRole
    doHandshake :: ClientParams -> Context -> IO ()
doHandshake = ClientParams -> Context -> IO ()
handshakeClient
    doHandshakeWith :: ClientParams -> Context -> Handshake -> IO ()
doHandshakeWith = ClientParams -> Context -> Handshake -> IO ()
handshakeClientWith
    doRequestCertificate :: ClientParams -> Context -> IO Bool
doRequestCertificate ClientParams
_ Context
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
    doPostHandshakeAuthWith :: ClientParams -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith = ClientParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthClientWith

instance TLSParams ServerParams where
    getTLSCommonParams :: ServerParams -> CommonParams
getTLSCommonParams ServerParams
sparams = ( ServerParams -> Supported
serverSupported ServerParams
sparams
                                 , ServerParams -> Shared
serverShared ServerParams
sparams
                                 , ServerParams -> DebugParams
serverDebug ServerParams
sparams
                                 )
    getTLSRole :: ServerParams -> Role
getTLSRole ServerParams
_ = Role
ServerRole
    doHandshake :: ServerParams -> Context -> IO ()
doHandshake = ServerParams -> Context -> IO ()
handshakeServer
    doHandshakeWith :: ServerParams -> Context -> Handshake -> IO ()
doHandshakeWith = ServerParams -> Context -> Handshake -> IO ()
handshakeServerWith
    doRequestCertificate :: ServerParams -> Context -> IO Bool
doRequestCertificate = ServerParams -> Context -> IO Bool
requestCertificateServer
    doPostHandshakeAuthWith :: ServerParams -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith = ServerParams -> Context -> Handshake13 -> IO ()
postHandshakeAuthServerWith

-- | create a new context using the backend and parameters specified.
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
           => backend   -- ^ Backend abstraction with specific method to interact with the connection type.
           -> params    -- ^ Parameters of the context.
           -> m Context
contextNew :: forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew backend
backend params
params = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ do
    forall a. HasBackend a => a -> IO ()
initializeBackend backend
backend

    let (Supported
supported, Shared
shared, DebugParams
debug) = forall a. TLSParams a => a -> CommonParams
getTLSCommonParams params
params

    Seed
seed <- case DebugParams -> Maybe Seed
debugSeed DebugParams
debug of
                Maybe Seed
Nothing     -> do Seed
seed <- forall (randomly :: * -> *). MonadRandom randomly => randomly Seed
seedNew
                                  DebugParams -> Seed -> IO ()
debugPrintSeed DebugParams
debug Seed
seed
                                  forall (m :: * -> *) a. Monad m => a -> m a
return Seed
seed
                Just Seed
determ -> forall (m :: * -> *) a. Monad m => a -> m a
return Seed
determ
    let rng :: StateRNG
rng = Seed -> StateRNG
newStateRNG Seed
seed

    let role :: Role
role = forall a. TLSParams a => a -> Role
getTLSRole params
params
        st :: TLSState
st   = StateRNG -> Role -> TLSState
newTLSState StateRNG
rng Role
role

    MVar TLSState
stvar <- forall a. a -> IO (MVar a)
newMVar TLSState
st
    IORef Bool
eof   <- forall a. a -> IO (IORef a)
newIORef Bool
False
    IORef Established
established <- forall a. a -> IO (IORef a)
newIORef Established
NotEstablished
    IORef Measurement
stats <- forall a. a -> IO (IORef a)
newIORef Measurement
newMeasurement
    -- we enable the reception of SSLv2 ClientHello message only in the
    -- server context, where we might be dealing with an old/compat client.
    IORef Bool
sslv2Compat <- forall a. a -> IO (IORef a)
newIORef (Role
role forall a. Eq a => a -> a -> Bool
== Role
ServerRole)
    IORef Bool
needEmptyPacket <- forall a. a -> IO (IORef a)
newIORef Bool
False
    IORef Hooks
hooks <- forall a. a -> IO (IORef a)
newIORef Hooks
defaultHooks
    MVar RecordState
tx    <- forall a. a -> IO (MVar a)
newMVar RecordState
newRecordState
    MVar RecordState
rx    <- forall a. a -> IO (MVar a)
newMVar RecordState
newRecordState
    MVar (Maybe HandshakeState)
hs    <- forall a. a -> IO (MVar a)
newMVar forall a. Maybe a
Nothing
    IORef [PendingAction]
as    <- forall a. a -> IO (IORef a)
newIORef []
    IORef [Handshake13]
crs   <- forall a. a -> IO (IORef a)
newIORef []
    MVar ()
lockWrite <- forall a. a -> IO (MVar a)
newMVar ()
    MVar ()
lockRead  <- forall a. a -> IO (MVar a)
newMVar ()
    MVar ()
lockState <- forall a. a -> IO (MVar a)
newMVar ()
    IORef (Maybe FinishedData)
finished <- forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing
    IORef (Maybe FinishedData)
peerFinished <- forall a. a -> IO (IORef a)
newIORef forall a. Maybe a
Nothing

    let ctx :: Context
ctx = Context
            { ctxConnection :: Backend
ctxConnection   = forall a. HasBackend a => a -> Backend
getBackend backend
backend
            , ctxShared :: Shared
ctxShared       = Shared
shared
            , ctxSupported :: Supported
ctxSupported    = Supported
supported
            , ctxState :: MVar TLSState
ctxState        = MVar TLSState
stvar
            , ctxFragmentSize :: Maybe Int
ctxFragmentSize = forall a. a -> Maybe a
Just Int
16384
            , ctxTxState :: MVar RecordState
ctxTxState      = MVar RecordState
tx
            , ctxRxState :: MVar RecordState
ctxRxState      = MVar RecordState
rx
            , ctxHandshake :: MVar (Maybe HandshakeState)
ctxHandshake    = MVar (Maybe HandshakeState)
hs
            , ctxDoHandshake :: Context -> IO ()
ctxDoHandshake  = forall a. TLSParams a => a -> Context -> IO ()
doHandshake params
params
            , ctxDoHandshakeWith :: Context -> Handshake -> IO ()
ctxDoHandshakeWith  = forall a. TLSParams a => a -> Context -> Handshake -> IO ()
doHandshakeWith params
params
            , ctxDoRequestCertificate :: Context -> IO Bool
ctxDoRequestCertificate = forall a. TLSParams a => a -> Context -> IO Bool
doRequestCertificate params
params
            , ctxDoPostHandshakeAuthWith :: Context -> Handshake13 -> IO ()
ctxDoPostHandshakeAuthWith = forall a. TLSParams a => a -> Context -> Handshake13 -> IO ()
doPostHandshakeAuthWith params
params
            , ctxMeasurement :: IORef Measurement
ctxMeasurement  = IORef Measurement
stats
            , ctxEOF_ :: IORef Bool
ctxEOF_         = IORef Bool
eof
            , ctxEstablished_ :: IORef Established
ctxEstablished_ = IORef Established
established
            , ctxSSLv2ClientHello :: IORef Bool
ctxSSLv2ClientHello = IORef Bool
sslv2Compat
            , ctxNeedEmptyPacket :: IORef Bool
ctxNeedEmptyPacket  = IORef Bool
needEmptyPacket
            , ctxHooks :: IORef Hooks
ctxHooks            = IORef Hooks
hooks
            , ctxLockWrite :: MVar ()
ctxLockWrite        = MVar ()
lockWrite
            , ctxLockRead :: MVar ()
ctxLockRead         = MVar ()
lockRead
            , ctxLockState :: MVar ()
ctxLockState        = MVar ()
lockState
            , ctxPendingActions :: IORef [PendingAction]
ctxPendingActions   = IORef [PendingAction]
as
            , ctxCertRequests :: IORef [Handshake13]
ctxCertRequests     = IORef [Handshake13]
crs
            , ctxKeyLogger :: String -> IO ()
ctxKeyLogger        = DebugParams -> String -> IO ()
debugKeyLogger DebugParams
debug
            , ctxRecordLayer :: RecordLayer FinishedData
ctxRecordLayer      = RecordLayer FinishedData
recordLayer
            , ctxHandshakeSync :: HandshakeSync
ctxHandshakeSync    = (Context -> ClientState -> IO ())
-> (Context -> ServerState -> IO ()) -> HandshakeSync
HandshakeSync forall {m :: * -> *} {p} {p}. Monad m => p -> p -> m ()
syncNoOp forall {m :: * -> *} {p} {p}. Monad m => p -> p -> m ()
syncNoOp
            , ctxQUICMode :: Bool
ctxQUICMode         = Bool
False
            , ctxFinished :: IORef (Maybe FinishedData)
ctxFinished         = IORef (Maybe FinishedData)
finished
            , ctxPeerFinished :: IORef (Maybe FinishedData)
ctxPeerFinished     = IORef (Maybe FinishedData)
peerFinished
            }

        syncNoOp :: p -> p -> m ()
syncNoOp p
_ p
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()

        recordLayer :: RecordLayer FinishedData
recordLayer = RecordLayer
            { recordEncode :: Record Plaintext -> IO (Either TLSError FinishedData)
recordEncode    = Context -> Record Plaintext -> IO (Either TLSError FinishedData)
encodeRecord Context
ctx
            , recordEncode13 :: Record Plaintext -> IO (Either TLSError FinishedData)
recordEncode13  = Context -> Record Plaintext -> IO (Either TLSError FinishedData)
encodeRecord13 Context
ctx
            , recordSendBytes :: FinishedData -> IO ()
recordSendBytes = Context -> FinishedData -> IO ()
sendBytes Context
ctx
            , recordRecv :: Bool -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv      = Context -> Bool -> Int -> IO (Either TLSError (Record Plaintext))
recvRecord Context
ctx
            , recordRecv13 :: IO (Either TLSError (Record Plaintext))
recordRecv13    = Context -> IO (Either TLSError (Record Plaintext))
recvRecord13 Context
ctx
            }

    forall (m :: * -> *) a. Monad m => a -> m a
return Context
ctx

-- | create a new context on an handle.
contextNewOnHandle :: (MonadIO m, TLSParams params)
                   => Handle -- ^ Handle of the connection.
                   -> params -- ^ Parameters of the context.
                   -> m Context
contextNewOnHandle :: forall (m :: * -> *) params.
(MonadIO m, TLSParams params) =>
Handle -> params -> m Context
contextNewOnHandle = forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew
{-# DEPRECATED contextNewOnHandle "use contextNew" #-}

#ifdef INCLUDE_NETWORK
-- | create a new context on a socket.
contextNewOnSocket :: (MonadIO m, TLSParams params)
                   => Socket -- ^ Socket of the connection.
                   -> params -- ^ Parameters of the context.
                   -> m Context
contextNewOnSocket :: forall (m :: * -> *) params.
(MonadIO m, TLSParams params) =>
Socket -> params -> m Context
contextNewOnSocket Socket
sock params
params = forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Socket
sock params
params
{-# DEPRECATED contextNewOnSocket "use contextNew" #-}
#endif

contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv Context
context Handshake -> IO Handshake
f =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookRecvHandshake :: Handshake -> IO Handshake
hookRecvHandshake = Handshake -> IO Handshake
f })

contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv Context
context Handshake13 -> IO Handshake13
f =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookRecvHandshake13 :: Handshake13 -> IO Handshake13
hookRecvHandshake13 = Handshake13 -> IO Handshake13
f })

contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv Context
context CertificateChain -> IO ()
f =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookRecvCertificates :: CertificateChain -> IO ()
hookRecvCertificates = CertificateChain -> IO ()
f })

contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging Context
context Logging
loggingCallbacks =
    Context -> (Hooks -> Hooks) -> IO ()
contextModifyHooks Context
context (\Hooks
hooks -> Hooks
hooks { hookLogging :: Logging
hookLogging = Logging
loggingCallbacks })

-- | Get TLS Finished sent to peer
getFinished :: Context -> IO (Maybe FinishedData)
getFinished :: Context -> IO (Maybe FinishedData)
getFinished = forall a. IORef a -> IO a
readIORef forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> IORef (Maybe FinishedData)
ctxFinished

-- | Get TLS Finished received from peer
getPeerFinished :: Context -> IO (Maybe FinishedData)
getPeerFinished :: Context -> IO (Maybe FinishedData)
getPeerFinished = forall a. IORef a -> IO a
readIORef forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> IORef (Maybe FinishedData)
ctxPeerFinished