{-# LANGUAGE CPP #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Network.TLS.State
( TLSState(..)
, TLSSt
, runTLSState
, newTLSState
, withTLSRNG
, updateVerifiedData
, finishHandshakeTypeMaterial
, finishHandshakeMaterial
, certVerifyHandshakeTypeMaterial
, certVerifyHandshakeMaterial
, setVersion
, setVersionIfUnset
, getVersion
, getVersionWithDefault
, setSecureRenegotiation
, getSecureRenegotiation
, setExtensionALPN
, getExtensionALPN
, setNegotiatedProtocol
, getNegotiatedProtocol
, setClientALPNSuggest
, getClientALPNSuggest
, setClientEcPointFormatSuggest
, getClientEcPointFormatSuggest
, getClientCertificateChain
, setClientCertificateChain
, setClientSNI
, getClientSNI
, getVerifiedData
, setSession
, getSession
, isSessionResuming
, isClientContext
, setExporterMasterSecret
, getExporterMasterSecret
, setTLS13KeyShare
, getTLS13KeyShare
, setTLS13PreSharedKey
, getTLS13PreSharedKey
, setTLS13HRR
, getTLS13HRR
, setTLS13Cookie
, getTLS13Cookie
, setClientSupportsPHA
, getClientSupportsPHA
, genRandom
, withRNG
) where
import Network.TLS.Imports
import Network.TLS.Struct
import Network.TLS.Struct13
import Network.TLS.RNG
import Network.TLS.Types (Role(..), HostName)
import Network.TLS.Wire (GetContinuation)
import Network.TLS.Extension
import qualified Data.ByteString as B
import Control.Monad.State.Strict
import Network.TLS.ErrT
import Crypto.Random
import Data.X509 (CertificateChain)
data TLSState = TLSState
{ stSession :: Session
, stSessionResuming :: Bool
, stSecureRenegotiation :: Bool
, stClientVerifiedData :: ByteString
, stServerVerifiedData :: ByteString
, stExtensionALPN :: Bool
, stHandshakeRecordCont :: Maybe (GetContinuation (HandshakeType, ByteString))
, stNegotiatedProtocol :: Maybe B.ByteString
, stHandshakeRecordCont13 :: Maybe (GetContinuation (HandshakeType13, ByteString))
, stClientALPNSuggest :: Maybe [B.ByteString]
, stClientGroupSuggest :: Maybe [Group]
, stClientEcPointFormatSuggest :: Maybe [EcPointFormat]
, stClientCertificateChain :: Maybe CertificateChain
, stClientSNI :: Maybe HostName
, stRandomGen :: StateRNG
, stVersion :: Maybe Version
, stClientContext :: Role
, stTLS13KeyShare :: Maybe KeyShare
, stTLS13PreSharedKey :: Maybe PreSharedKey
, stTLS13HRR :: !Bool
, stTLS13Cookie :: Maybe Cookie
, stExporterMasterSecret :: Maybe ByteString
, stClientSupportsPHA :: !Bool
}
newtype TLSSt a = TLSSt { runTLSSt :: ErrT TLSError (State TLSState) a }
deriving (Monad, MonadError TLSError, Functor, Applicative)
instance MonadState TLSState TLSSt where
put x = TLSSt (lift $ put x)
get = TLSSt (lift get)
#if MIN_VERSION_mtl(2,1,0)
state f = TLSSt (lift $ state f)
#endif
runTLSState :: TLSSt a -> TLSState -> (Either TLSError a, TLSState)
runTLSState f st = runState (runErrT (runTLSSt f)) st
newTLSState :: StateRNG -> Role -> TLSState
newTLSState rng clientContext = TLSState
{ stSession = Session Nothing
, stSessionResuming = False
, stSecureRenegotiation = False
, stClientVerifiedData = B.empty
, stServerVerifiedData = B.empty
, stExtensionALPN = False
, stHandshakeRecordCont = Nothing
, stHandshakeRecordCont13 = Nothing
, stNegotiatedProtocol = Nothing
, stClientALPNSuggest = Nothing
, stClientGroupSuggest = Nothing
, stClientEcPointFormatSuggest = Nothing
, stClientCertificateChain = Nothing
, stClientSNI = Nothing
, stRandomGen = rng
, stVersion = Nothing
, stClientContext = clientContext
, stTLS13KeyShare = Nothing
, stTLS13PreSharedKey = Nothing
, stTLS13HRR = False
, stTLS13Cookie = Nothing
, stExporterMasterSecret = Nothing
, stClientSupportsPHA = False
}
updateVerifiedData :: Role -> ByteString -> TLSSt ()
updateVerifiedData sending bs = do
cc <- isClientContext
if cc /= sending
then modify (\st -> st { stServerVerifiedData = bs })
else modify (\st -> st { stClientVerifiedData = bs })
finishHandshakeTypeMaterial :: HandshakeType -> Bool
finishHandshakeTypeMaterial HandshakeType_ClientHello = True
finishHandshakeTypeMaterial HandshakeType_ServerHello = True
finishHandshakeTypeMaterial HandshakeType_Certificate = True
finishHandshakeTypeMaterial HandshakeType_HelloRequest = False
finishHandshakeTypeMaterial HandshakeType_ServerHelloDone = True
finishHandshakeTypeMaterial HandshakeType_ClientKeyXchg = True
finishHandshakeTypeMaterial HandshakeType_ServerKeyXchg = True
finishHandshakeTypeMaterial HandshakeType_CertRequest = True
finishHandshakeTypeMaterial HandshakeType_CertVerify = True
finishHandshakeTypeMaterial HandshakeType_Finished = True
finishHandshakeMaterial :: Handshake -> Bool
finishHandshakeMaterial = finishHandshakeTypeMaterial . typeOfHandshake
certVerifyHandshakeTypeMaterial :: HandshakeType -> Bool
certVerifyHandshakeTypeMaterial HandshakeType_ClientHello = True
certVerifyHandshakeTypeMaterial HandshakeType_ServerHello = True
certVerifyHandshakeTypeMaterial HandshakeType_Certificate = True
certVerifyHandshakeTypeMaterial HandshakeType_HelloRequest = False
certVerifyHandshakeTypeMaterial HandshakeType_ServerHelloDone = True
certVerifyHandshakeTypeMaterial HandshakeType_ClientKeyXchg = True
certVerifyHandshakeTypeMaterial HandshakeType_ServerKeyXchg = True
certVerifyHandshakeTypeMaterial HandshakeType_CertRequest = True
certVerifyHandshakeTypeMaterial HandshakeType_CertVerify = False
certVerifyHandshakeTypeMaterial HandshakeType_Finished = False
certVerifyHandshakeMaterial :: Handshake -> Bool
certVerifyHandshakeMaterial = certVerifyHandshakeTypeMaterial . typeOfHandshake
setSession :: Session -> Bool -> TLSSt ()
setSession session resuming = modify (\st -> st { stSession = session, stSessionResuming = resuming })
getSession :: TLSSt Session
getSession = gets stSession
isSessionResuming :: TLSSt Bool
isSessionResuming = gets stSessionResuming
setVersion :: Version -> TLSSt ()
setVersion ver = modify (\st -> st { stVersion = Just ver })
setVersionIfUnset :: Version -> TLSSt ()
setVersionIfUnset ver = modify maybeSet
where maybeSet st = case stVersion st of
Nothing -> st { stVersion = Just ver }
Just _ -> st
getVersion :: TLSSt Version
getVersion = fromMaybe (error "internal error: version hasn't been set yet") <$> gets stVersion
getVersionWithDefault :: Version -> TLSSt Version
getVersionWithDefault defaultVer = fromMaybe defaultVer <$> gets stVersion
setSecureRenegotiation :: Bool -> TLSSt ()
setSecureRenegotiation b = modify (\st -> st { stSecureRenegotiation = b })
getSecureRenegotiation :: TLSSt Bool
getSecureRenegotiation = gets stSecureRenegotiation
setExtensionALPN :: Bool -> TLSSt ()
setExtensionALPN b = modify (\st -> st { stExtensionALPN = b })
getExtensionALPN :: TLSSt Bool
getExtensionALPN = gets stExtensionALPN
setNegotiatedProtocol :: B.ByteString -> TLSSt ()
setNegotiatedProtocol s = modify (\st -> st { stNegotiatedProtocol = Just s })
getNegotiatedProtocol :: TLSSt (Maybe B.ByteString)
getNegotiatedProtocol = gets stNegotiatedProtocol
setClientALPNSuggest :: [B.ByteString] -> TLSSt ()
setClientALPNSuggest ps = modify (\st -> st { stClientALPNSuggest = Just ps})
getClientALPNSuggest :: TLSSt (Maybe [B.ByteString])
getClientALPNSuggest = gets stClientALPNSuggest
setClientEcPointFormatSuggest :: [EcPointFormat] -> TLSSt ()
setClientEcPointFormatSuggest epf = modify (\st -> st { stClientEcPointFormatSuggest = Just epf})
getClientEcPointFormatSuggest :: TLSSt (Maybe [EcPointFormat])
getClientEcPointFormatSuggest = gets stClientEcPointFormatSuggest
setClientCertificateChain :: CertificateChain -> TLSSt ()
setClientCertificateChain s = modify (\st -> st { stClientCertificateChain = Just s })
getClientCertificateChain :: TLSSt (Maybe CertificateChain)
getClientCertificateChain = gets stClientCertificateChain
setClientSNI :: HostName -> TLSSt ()
setClientSNI hn = modify (\st -> st { stClientSNI = Just hn })
getClientSNI :: TLSSt (Maybe HostName)
getClientSNI = gets stClientSNI
getVerifiedData :: Role -> TLSSt ByteString
getVerifiedData client = gets (if client == ClientRole then stClientVerifiedData else stServerVerifiedData)
isClientContext :: TLSSt Role
isClientContext = gets stClientContext
genRandom :: Int -> TLSSt ByteString
genRandom n = do
withRNG (getRandomBytes n)
withRNG :: MonadPseudoRandom StateRNG a -> TLSSt a
withRNG f = do
st <- get
let (a,rng') = withTLSRNG (stRandomGen st) f
put (st { stRandomGen = rng' })
return a
setExporterMasterSecret :: ByteString -> TLSSt ()
setExporterMasterSecret key = modify (\st -> st { stExporterMasterSecret = Just key })
getExporterMasterSecret :: TLSSt (Maybe ByteString)
getExporterMasterSecret = gets stExporterMasterSecret
setTLS13KeyShare :: Maybe KeyShare -> TLSSt ()
setTLS13KeyShare mks = modify (\st -> st { stTLS13KeyShare = mks })
getTLS13KeyShare :: TLSSt (Maybe KeyShare)
getTLS13KeyShare = gets stTLS13KeyShare
setTLS13PreSharedKey :: Maybe PreSharedKey -> TLSSt ()
setTLS13PreSharedKey mpsk = modify (\st -> st { stTLS13PreSharedKey = mpsk })
getTLS13PreSharedKey :: TLSSt (Maybe PreSharedKey)
getTLS13PreSharedKey = gets stTLS13PreSharedKey
setTLS13HRR :: Bool -> TLSSt ()
setTLS13HRR b = modify (\st -> st { stTLS13HRR = b })
getTLS13HRR :: TLSSt Bool
getTLS13HRR = gets stTLS13HRR
setTLS13Cookie :: Maybe Cookie -> TLSSt ()
setTLS13Cookie mcookie = modify (\st -> st { stTLS13Cookie = mcookie })
getTLS13Cookie :: TLSSt (Maybe Cookie)
getTLS13Cookie = gets stTLS13Cookie
setClientSupportsPHA :: Bool -> TLSSt ()
setClientSupportsPHA b = modify (\st -> st { stClientSupportsPHA = b })
getClientSupportsPHA :: TLSSt Bool
getClientSupportsPHA = gets stClientSupportsPHA