{-# LANGUAGE CPP #-}
module Network.TLS.Context
    (
    
      TLSParams
    
    , Context(..)
    , Hooks(..)
    , Established(..)
    , ctxEOF
    , ctxHasSSLv2ClientHello
    , ctxDisableSSLv2ClientHello
    , ctxEstablished
    , withLog
    , ctxWithHooks
    , contextModifyHooks
    , setEOF
    , setEstablished
    , contextFlush
    , contextClose
    , contextSend
    , contextRecv
    , updateMeasure
    , withMeasure
    , withReadLock
    , withWriteLock
    , withStateLock
    , withRWLock
    
    , Information(..)
    , contextGetInformation
    
    , contextNew
    
    , contextNewOnHandle
#ifdef INCLUDE_NETWORK
    , contextNewOnSocket
#endif
    
    , contextHookSetHandshakeRecv
    , contextHookSetHandshake13Recv
    , contextHookSetCertificateRecv
    , contextHookSetLogging
    
    , throwCore
    , usingState
    , usingState_
    , runTxState
    , runRxState
    , usingHState
    , getHState
    , getStateRNG
    , tls13orLater
    ) where
import Network.TLS.Backend
import Network.TLS.Context.Internal
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.State
import Network.TLS.Hooks
import Network.TLS.Record.State
import Network.TLS.Parameters
import Network.TLS.Measurement
import Network.TLS.Types (Role(..))
import Network.TLS.Handshake (handshakeClient, handshakeClientWith, handshakeServer, handshakeServerWith)
import Network.TLS.PostHandshake (requestCertificateServer, postHandshakeAuthClientWith, postHandshakeAuthServerWith)
import Network.TLS.X509
import Network.TLS.RNG
import Control.Concurrent.MVar
import Control.Monad.State.Strict
import Data.IORef
#ifdef INCLUDE_NETWORK
import Network.Socket (Socket)
#endif
import System.IO (Handle)
class TLSParams a where
    getTLSCommonParams :: a -> CommonParams
    getTLSRole         :: a -> Role
    doHandshake        :: a -> Context -> IO ()
    doHandshakeWith    :: a -> Context -> Handshake -> IO ()
    doRequestCertificate :: a -> Context -> IO Bool
    doPostHandshakeAuthWith :: a -> Context -> Handshake13 -> IO ()
instance TLSParams ClientParams where
    getTLSCommonParams cparams = ( clientSupported cparams
                                 , clientShared cparams
                                 , clientDebug cparams
                                 )
    getTLSRole _ = ClientRole
    doHandshake = handshakeClient
    doHandshakeWith = handshakeClientWith
    doRequestCertificate _ _ = return False
    doPostHandshakeAuthWith = postHandshakeAuthClientWith
instance TLSParams ServerParams where
    getTLSCommonParams sparams = ( serverSupported sparams
                                 , serverShared sparams
                                 , serverDebug sparams
                                 )
    getTLSRole _ = ServerRole
    doHandshake = handshakeServer
    doHandshakeWith = handshakeServerWith
    doRequestCertificate = requestCertificateServer
    doPostHandshakeAuthWith = postHandshakeAuthServerWith
contextNew :: (MonadIO m, HasBackend backend, TLSParams params)
           => backend   
           -> params    
           -> m Context
contextNew backend params = liftIO $ do
    initializeBackend backend
    let (supported, shared, debug) = getTLSCommonParams params
    seed <- case debugSeed debug of
                Nothing     -> do seed <- seedNew
                                  debugPrintSeed debug seed
                                  return seed
                Just determ -> return determ
    let rng = newStateRNG seed
    let role = getTLSRole params
        st   = newTLSState rng role
    stvar <- newMVar st
    eof   <- newIORef False
    established <- newIORef NotEstablished
    stats <- newIORef newMeasurement
    
    
    sslv2Compat <- newIORef (role == ServerRole)
    needEmptyPacket <- newIORef False
    hooks <- newIORef defaultHooks
    tx    <- newMVar newRecordState
    rx    <- newMVar newRecordState
    hs    <- newMVar Nothing
    as    <- newIORef []
    crs   <- newIORef []
    lockWrite <- newMVar ()
    lockRead  <- newMVar ()
    lockState <- newMVar ()
    return Context
            { ctxConnection   = getBackend backend
            , ctxShared       = shared
            , ctxSupported    = supported
            , ctxState        = stvar
            , ctxTxState      = tx
            , ctxRxState      = rx
            , ctxHandshake    = hs
            , ctxDoHandshake  = doHandshake params
            , ctxDoHandshakeWith  = doHandshakeWith params
            , ctxDoRequestCertificate = doRequestCertificate params
            , ctxDoPostHandshakeAuthWith = doPostHandshakeAuthWith params
            , ctxMeasurement  = stats
            , ctxEOF_         = eof
            , ctxEstablished_ = established
            , ctxSSLv2ClientHello = sslv2Compat
            , ctxNeedEmptyPacket  = needEmptyPacket
            , ctxHooks            = hooks
            , ctxLockWrite        = lockWrite
            , ctxLockRead         = lockRead
            , ctxLockState        = lockState
            , ctxPendingActions   = as
            , ctxCertRequests     = crs
            , ctxKeyLogger        = debugKeyLogger debug
            }
contextNewOnHandle :: (MonadIO m, TLSParams params)
                   => Handle 
                   -> params 
                   -> m Context
contextNewOnHandle = contextNew
{-# DEPRECATED contextNewOnHandle "use contextNew" #-}
#ifdef INCLUDE_NETWORK
contextNewOnSocket :: (MonadIO m, TLSParams params)
                   => Socket 
                   -> params 
                   -> m Context
contextNewOnSocket sock params = contextNew sock params
{-# DEPRECATED contextNewOnSocket "use contextNew" #-}
#endif
contextHookSetHandshakeRecv :: Context -> (Handshake -> IO Handshake) -> IO ()
contextHookSetHandshakeRecv context f =
    contextModifyHooks context (\hooks -> hooks { hookRecvHandshake = f })
contextHookSetHandshake13Recv :: Context -> (Handshake13 -> IO Handshake13) -> IO ()
contextHookSetHandshake13Recv context f =
    contextModifyHooks context (\hooks -> hooks { hookRecvHandshake13 = f })
contextHookSetCertificateRecv :: Context -> (CertificateChain -> IO ()) -> IO ()
contextHookSetCertificateRecv context f =
    contextModifyHooks context (\hooks -> hooks { hookRecvCertificates = f })
contextHookSetLogging :: Context -> Logging -> IO ()
contextHookSetLogging context loggingCallbacks =
    contextModifyHooks context (\hooks -> hooks { hookLogging = loggingCallbacks })