module Network.TLS.Context
(
TLSParams
, Context(..)
, Hooks(..)
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, withLog
, ctxWithHooks
, contextModifyHooks
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, withReadLock
, withWriteLock
, withStateLock
, withRWLock
, Information(..)
, contextGetInformation
, contextNew
, contextNewOnHandle
#ifdef INCLUDE_NETWORK
, contextNewOnSocket
#endif
, contextHookSetHandshakeRecv
, contextHookSetCertificateRecv
, contextHookSetLogging
, throwCore
, usingState
, usingState_
, runTxState
, runRxState
, usingHState
, getHState
, getStateRNG
) where
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.X509
import Network.TLS.RNG
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.IORef
#ifdef INCLUDE_NETWORK
import Network.Socket (Socket)
#endif
import System.IO (Handle)
class TLSParams a where
getTLSCommonParams :: a -> CommonParams
getTLSRole :: a -> Role
doHandshake :: a -> Context -> IO ()
doHandshakeWith :: a -> Context -> Handshake -> IO ()
instance TLSParams ClientParams where
getTLSCommonParams cparams = ( clientSupported cparams
, clientShared cparams
, clientDebug cparams
)
getTLSRole _ = ClientRole
doHandshake = handshakeClient
doHandshakeWith = handshakeClientWith
instance TLSParams ServerParams where
getTLSCommonParams sparams = ( serverSupported sparams
, serverShared sparams
, serverDebug sparams
)
getTLSRole _ = ServerRole
doHandshake = handshakeServer
doHandshakeWith = handshakeServerWith
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
=> backend
-> params
-> m Context
contextNew backend params = liftIO $ do
initializeBackend backend
let (supported, shared, debug) = getTLSCommonParams params
seed <- case debugSeed debug of
Nothing -> do seed <- seedNew
debugPrintSeed debug $ seed
return seed
Just determ -> return determ
let rng = newStateRNG seed
let role = getTLSRole params
st = newTLSState rng role
stvar <- newMVar st
eof <- newIORef False
established <- newIORef False
stats <- newIORef newMeasurement
sslv2Compat <- newIORef (role == ServerRole)
needEmptyPacket <- newIORef False
hooks <- newIORef defaultHooks
tx <- newMVar newRecordState
rx <- newMVar newRecordState
hs <- newMVar Nothing
lockWrite <- newMVar ()
lockRead <- newMVar ()
lockState <- newMVar ()
return $ Context
{ ctxConnection = getBackend backend
, ctxShared = shared
, ctxSupported = supported
, ctxState = stvar
, ctxTxState = tx
, ctxRxState = rx
, ctxHandshake = hs
, ctxDoHandshake = doHandshake params
, ctxDoHandshakeWith = doHandshakeWith params
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
, ctxNeedEmptyPacket = needEmptyPacket
, ctxHooks = hooks
, ctxLockWrite = lockWrite
, ctxLockRead = lockRead
, ctxLockState = lockState
}
contextNewOnHandle :: (MonadIO m, TLSParams params)
=> Handle
-> params
-> m Context
contextNewOnHandle handle params = contextNew handle params
#ifdef INCLUDE_NETWORK
contextNewOnSocket :: (MonadIO m, TLSParams params)
=> Socket
-> params
-> m Context
contextNewOnSocket sock params = contextNew sock params
#endif
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvHandshake = f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv context f =
contextModifyHooks context (\hooks -> hooks { hookRecvCertificates = f })
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging context loggingCallbacks =
contextModifyHooks context (\hooks -> hooks { hookLogging = loggingCallbacks })