module Network.TLS.Context
(
Params(..)
, RoleParams(..)
, ClientParams(..)
, ServerParams(..)
, updateClientParams
, updateServerParams
, Logging(..)
, SessionID
, SessionData(..)
, MaxFragmentEnum(..)
, Measurement(..)
, CertificateUsage(..)
, CertificateRejectReason(..)
, defaultLogging
, defaultParamsClient
, defaultParamsServer
, withSessionManager
, setSessionManager
, Backend(..)
, Context
, ctxParams
, ctxConnection
, ctxEOF
, ctxHasSSLv2ClientHello
, ctxDisableSSLv2ClientHello
, ctxEstablished
, ctxLogging
, setEOF
, setEstablished
, contextFlush
, contextClose
, contextSend
, contextRecv
, updateMeasure
, withMeasure
, TLSParams
, TLSLogging
, TLSCertificateUsage
, TLSCertificateRejectReason
, TLSCtx
, defaultParams
, contextNew
, contextNewOnHandle
, throwCore
, usingState
, usingState_
, getStateRNG
) where
import Network.BSD (HostName)
import Network.TLS.Extension
import Network.TLS.Struct
import qualified Network.TLS.Struct as Struct
import Network.TLS.Session
import Network.TLS.Cipher
import Network.TLS.Compression
import Network.TLS.Crypto
import Network.TLS.State
import Network.TLS.Measurement
import Data.Certificate.X509
import Data.List (intercalate)
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import Crypto.Random.API
import Control.Concurrent.MVar
import Control.Monad.State
import Control.Exception (throwIO, Exception())
import Data.IORef
import System.IO (Handle, hSetBuffering, BufferMode(..), hFlush, hClose)
data Logging = Logging
{ loggingPacketSent :: String -> IO ()
, loggingPacketRecv :: String -> IO ()
, loggingIOSent :: B.ByteString -> IO ()
, loggingIORecv :: Header -> B.ByteString -> IO ()
}
data ClientParams = ClientParams
{ clientUseMaxFragmentLength :: Maybe MaxFragmentEnum
, clientUseServerName :: Maybe HostName
, clientWantSessionResume :: Maybe (SessionID, SessionData)
, onCertificateRequest :: ([CertificateType],
Maybe [HashAndSignatureAlgorithm],
[DistinguishedName]) -> IO [(X509, Maybe PrivateKey)]
}
data ServerParams = ServerParams
{ serverWantClientCert :: Bool
, serverCACertificates :: [X509]
, onClientCertificate :: [X509] -> IO CertificateUsage
, onUnverifiedClientCert :: IO Bool
, onCipherChoosing :: Version -> [Cipher] -> Cipher
}
data RoleParams = Client ClientParams | Server ServerParams
data Params = forall s . SessionManager s => Params
{ pConnectVersion :: Version
, pAllowedVersions :: [Version]
, pCiphers :: [Cipher]
, pCompressions :: [Compression]
, pHashSignatures :: [HashAndSignatureAlgorithm]
, pUseSecureRenegotiation :: Bool
, pUseSession :: Bool
, pCertificates :: [(X509, Maybe PrivateKey)]
, pLogging :: Logging
, onHandshake :: Measurement -> IO Bool
, onCertificatesRecv :: [X509] -> IO CertificateUsage
, pSessionManager :: s
, onSuggestNextProtocols :: IO (Maybe [B.ByteString])
, onNPNServerSuggest :: Maybe ([B.ByteString] -> IO B.ByteString)
, roleParams :: RoleParams
}
setSessionManager :: SessionManager s => s -> Params -> Params
setSessionManager manager (Params {..}) = Params { pSessionManager = manager, .. }
withSessionManager :: Params -> (forall s . SessionManager s => s -> a) -> a
withSessionManager (Params { pSessionManager = man }) f = f man
defaultLogging :: Logging
defaultLogging = Logging
{ loggingPacketSent = (\_ -> return ())
, loggingPacketRecv = (\_ -> return ())
, loggingIOSent = (\_ -> return ())
, loggingIORecv = (\_ _ -> return ())
}
defaultParamsClient :: Params
defaultParamsClient = Params
{ pConnectVersion = TLS10
, pAllowedVersions = [TLS10,TLS11,TLS12]
, pCiphers = []
, pCompressions = [nullCompression]
, pHashSignatures = [ (Struct.HashSHA512, SignatureRSA)
, (Struct.HashSHA384, SignatureRSA)
, (Struct.HashSHA256, SignatureRSA)
, (Struct.HashSHA224, SignatureRSA)
]
, pUseSecureRenegotiation = True
, pUseSession = True
, pCertificates = []
, pLogging = defaultLogging
, onHandshake = (\_ -> return True)
, onCertificatesRecv = (\_ -> return CertificateUsageAccept)
, pSessionManager = NoSessionManager
, onSuggestNextProtocols = return Nothing
, onNPNServerSuggest = Nothing
, roleParams = Client $ ClientParams
{ clientWantSessionResume = Nothing
, clientUseMaxFragmentLength = Nothing
, clientUseServerName = Nothing
, onCertificateRequest = \ _ -> return []
}
}
defaultParamsServer :: Params
defaultParamsServer = defaultParamsClient { roleParams = Server role }
where role = ServerParams
{ serverWantClientCert = False
, onCipherChoosing = \_ -> head
, serverCACertificates = []
, onClientCertificate = \ _ -> return $ CertificateUsageReject $ CertificateRejectOther "no client certificates expected"
, onUnverifiedClientCert = return False
}
updateRoleParams :: (ClientParams -> ClientParams) -> (ServerParams -> ServerParams) -> Params -> Params
updateRoleParams fc fs params = case roleParams params of
Client c -> params { roleParams = Client (fc c) }
Server s -> params { roleParams = Server (fs s) }
updateClientParams :: (ClientParams -> ClientParams) -> Params -> Params
updateClientParams f = updateRoleParams f id
updateServerParams :: (ServerParams -> ServerParams) -> Params -> Params
updateServerParams f = updateRoleParams id f
defaultParams :: Params
defaultParams = defaultParamsClient
instance Show Params where
show p = "Params { " ++ (intercalate "," $ map (\(k,v) -> k ++ "=" ++ v)
[ ("connectVersion", show $ pConnectVersion p)
, ("allowedVersions", show $ pAllowedVersions p)
, ("ciphers", show $ pCiphers p)
, ("compressions", show $ pCompressions p)
, ("certificates", show $ length $ pCertificates p)
]) ++ " }"
data CertificateRejectReason =
CertificateRejectExpired
| CertificateRejectRevoked
| CertificateRejectUnknownCA
| CertificateRejectOther String
deriving (Show,Eq)
data CertificateUsage =
CertificateUsageAccept
| CertificateUsageReject CertificateRejectReason
deriving (Show,Eq)
data Backend = Backend
{ backendFlush :: IO ()
, backendClose :: IO ()
, backendSend :: ByteString -> IO ()
, backendRecv :: Int -> IO ByteString
}
data Context = Context
{ ctxConnection :: Backend
, ctxParams :: Params
, ctxState :: MVar TLSState
, ctxMeasurement :: IORef Measurement
, ctxEOF_ :: IORef Bool
, ctxEstablished_ :: IORef Bool
, ctxSSLv2ClientHello :: IORef Bool
}
type TLSParams = Params
type TLSCtx = Context
type TLSLogging = Logging
type TLSCertificateUsage = CertificateUsage
type TLSCertificateRejectReason = CertificateRejectReason
updateMeasure :: MonadIO m => Context -> (Measurement -> Measurement) -> m ()
updateMeasure ctx f = liftIO $ do
x <- readIORef (ctxMeasurement ctx)
writeIORef (ctxMeasurement ctx) $! f x
withMeasure :: MonadIO m => Context -> (Measurement -> IO a) -> m a
withMeasure ctx f = liftIO (readIORef (ctxMeasurement ctx) >>= f)
contextFlush :: Context -> IO ()
contextFlush = backendFlush . ctxConnection
contextClose :: Context -> IO ()
contextClose = backendClose . ctxConnection
contextSend :: Context -> Bytes -> IO ()
contextSend c b = updateMeasure c (addBytesSent $ B.length b) >> (backendSend $ ctxConnection c) b
contextRecv :: Context -> Int -> IO Bytes
contextRecv c sz = updateMeasure c (addBytesReceived sz) >> (backendRecv $ ctxConnection c) sz
ctxEOF :: MonadIO m => Context -> m Bool
ctxEOF ctx = liftIO (readIORef $ ctxEOF_ ctx)
ctxHasSSLv2ClientHello :: MonadIO m => Context -> m Bool
ctxHasSSLv2ClientHello ctx = liftIO (readIORef $ ctxSSLv2ClientHello ctx)
ctxDisableSSLv2ClientHello :: MonadIO m => Context -> m ()
ctxDisableSSLv2ClientHello ctx = liftIO (writeIORef (ctxSSLv2ClientHello ctx) False)
setEOF :: MonadIO m => Context -> m ()
setEOF ctx = liftIO $ writeIORef (ctxEOF_ ctx) True
ctxEstablished :: MonadIO m => Context -> m Bool
ctxEstablished ctx = liftIO $ readIORef $ ctxEstablished_ ctx
setEstablished :: MonadIO m => Context -> Bool -> m ()
setEstablished ctx v = liftIO $ writeIORef (ctxEstablished_ ctx) v
ctxLogging :: Context -> Logging
ctxLogging = pLogging . ctxParams
contextNew :: (MonadIO m, CPRG rng)
=> Backend
-> Params
-> rng
-> m Context
contextNew backend params rng = liftIO $ do
let clientContext = case roleParams params of
Client {} -> True
Server {} -> False
let st = (newTLSState rng) { stClientContext = clientContext }
stvar <- newMVar st
eof <- newIORef False
established <- newIORef False
stats <- newIORef newMeasurement
sslv2Compat <- newIORef (not clientContext)
return $ Context
{ ctxConnection = backend
, ctxParams = params
, ctxState = stvar
, ctxMeasurement = stats
, ctxEOF_ = eof
, ctxEstablished_ = established
, ctxSSLv2ClientHello = sslv2Compat
}
contextNewOnHandle :: (MonadIO m, CPRG rng)
=> Handle
-> Params
-> rng
-> m Context
contextNewOnHandle handle params st =
liftIO (hSetBuffering handle NoBuffering) >> contextNew backend params st
where backend = Backend (hFlush handle) (hClose handle) (B.hPut handle) (B.hGet handle)
throwCore :: (MonadIO m, Exception e) => e -> m a
throwCore = liftIO . throwIO
usingState :: MonadIO m => Context -> TLSSt a -> m (Either TLSError a)
usingState ctx f =
liftIO $ modifyMVar (ctxState ctx) $ \st ->
let (a, newst) = runTLSState f st
in newst `seq` return (newst, a)
usingState_ :: MonadIO m => Context -> TLSSt a -> m a
usingState_ ctx f = do
ret <- usingState ctx f
case ret of
Left err -> throwCore err
Right r -> return r
getStateRNG :: MonadIO m => Context -> Int -> m Bytes
getStateRNG ctx n = usingState_ ctx (genTLSRandom n)