{-# LANGUAGE OverloadedStrings #-}
module Network.TLS.Handshake.Server.Common (
applicationProtocol,
checkValidClientCertChain,
clientCertificate,
credentialDigitalSignatureKey,
filterCredentials,
filterCredentialsWithHashSignatures,
isCredentialAllowed,
storePrivInfoServer,
) where
import Control.Monad.State.Strict
import Data.X509 (ExtKeyUsageFlag (..))
import Network.TLS.Context.Internal
import Network.TLS.Credentials
import Network.TLS.Crypto
import Network.TLS.Extension
import Network.TLS.Handshake.Certificate
import Network.TLS.Handshake.Common
import Network.TLS.Handshake.Key
import Network.TLS.Handshake.State
import Network.TLS.Imports
import Network.TLS.Parameters
import Network.TLS.State
import Network.TLS.Struct
import Network.TLS.Util (catchException)
import Network.TLS.X509
checkValidClientCertChain
:: MonadIO m => Context -> String -> m CertificateChain
checkValidClientCertChain :: forall (m :: * -> *).
MonadIO m =>
Context -> String -> m CertificateChain
checkValidClientCertChain Context
ctx String
errmsg = do
Maybe CertificateChain
chain <- Context
-> HandshakeM (Maybe CertificateChain)
-> m (Maybe CertificateChain)
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx HandshakeM (Maybe CertificateChain)
getClientCertChain
let throwerror :: TLSError
throwerror = String -> AlertDescription -> TLSError
Error_Protocol String
errmsg AlertDescription
UnexpectedMessage
case Maybe CertificateChain
chain of
Maybe CertificateChain
Nothing -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
Just CertificateChain
cc
| CertificateChain -> Bool
isNullCertificateChain CertificateChain
cc -> TLSError -> m CertificateChain
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore TLSError
throwerror
| Bool
otherwise -> CertificateChain -> m CertificateChain
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return CertificateChain
cc
credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey :: Credential -> Maybe PubKey
credentialDigitalSignatureKey Credential
cred
| (PubKey, PrivKey) -> Bool
isDigitalSignaturePair (PubKey, PrivKey)
keys = PubKey -> Maybe PubKey
forall a. a -> Maybe a
Just PubKey
pubkey
| Bool
otherwise = Maybe PubKey
forall a. Maybe a
Nothing
where
keys :: (PubKey, PrivKey)
keys@(PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred
filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials :: (Credential -> Bool) -> Credentials -> Credentials
filterCredentials Credential -> Bool
p (Credentials [Credential]
l) = [Credential] -> Credentials
Credentials ((Credential -> Bool) -> [Credential] -> [Credential]
forall a. (a -> Bool) -> [a] -> [a]
filter Credential -> Bool
p [Credential]
l)
isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed :: Version -> [ExtensionRaw] -> Credential -> Bool
isCredentialAllowed Version
ver [ExtensionRaw]
exts Credential
cred =
PubKey
pubkey PubKey -> Version -> Bool
`versionCompatible` Version
ver Bool -> Bool -> Bool
&& (Group -> Bool) -> PubKey -> Bool
satisfiesEcPredicate Group -> Bool
p PubKey
pubkey
where
(PubKey
pubkey, PrivKey
_) = Credential -> (PubKey, PrivKey)
credentialPublicPrivateKeys Credential
cred
p :: Group -> Bool
p
| Version
ver Version -> Version -> Bool
forall a. Ord a => a -> a -> Bool
< Version
TLS13 = case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_SupportedGroups [ExtensionRaw]
exts
Maybe ByteString
-> (ByteString -> Maybe SupportedGroups) -> Maybe SupportedGroups
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe SupportedGroups
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
Maybe SupportedGroups
Nothing -> Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True
Just (SupportedGroups [Group]
sg) -> (Group -> [Group] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Group]
sg)
| Bool
otherwise = Bool -> Group -> Bool
forall a b. a -> b -> a
const Bool
True
filterCredentialsWithHashSignatures
:: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures :: [ExtensionRaw] -> Credentials -> Credentials
filterCredentialsWithHashSignatures [ExtensionRaw]
exts =
case ExtensionID -> Maybe SignatureAlgorithmsCert
forall {b}. Extension b => ExtensionID -> Maybe b
withExt ExtensionID
EID_SignatureAlgorithmsCert of
Just (SignatureAlgorithmsCert [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas
Maybe SignatureAlgorithmsCert
Nothing ->
case ExtensionID -> Maybe SignatureAlgorithms
forall {b}. Extension b => ExtensionID -> Maybe b
withExt ExtensionID
EID_SignatureAlgorithms of
Maybe SignatureAlgorithms
Nothing -> Credentials -> Credentials
forall a. a -> a
id
Just (SignatureAlgorithms [HashAndSignatureAlgorithm]
sas) -> [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas
where
withExt :: ExtensionID -> Maybe b
withExt ExtensionID
extId = ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
extId [ExtensionRaw]
exts Maybe ByteString -> (ByteString -> Maybe b) -> Maybe b
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType -> ByteString -> Maybe b
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello
withAlgs :: [HashAndSignatureAlgorithm] -> Credentials -> Credentials
withAlgs [HashAndSignatureAlgorithm]
sas = (Credential -> Bool) -> Credentials -> Credentials
filterCredentials ([HashAndSignatureAlgorithm] -> Credential -> Bool
credentialMatchesHashSignatures [HashAndSignatureAlgorithm]
sas)
storePrivInfoServer :: MonadIO m => Context -> Credential -> m ()
storePrivInfoServer :: forall (m :: * -> *). MonadIO m => Context -> Credential -> m ()
storePrivInfoServer Context
ctx (CertificateChain
cc, PrivKey
privkey) = m PubKey -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Context -> CertificateChain -> PrivKey -> m PubKey
forall (m :: * -> *).
MonadIO m =>
Context -> CertificateChain -> PrivKey -> m PubKey
storePrivInfo Context
ctx CertificateChain
cc PrivKey
privkey)
applicationProtocol
:: Context -> [ExtensionRaw] -> ServerParams -> IO [ExtensionRaw]
applicationProtocol :: Context -> [ExtensionRaw] -> ServerParams -> IO [ExtensionRaw]
applicationProtocol Context
ctx [ExtensionRaw]
exts ServerParams
sparams = do
case ExtensionID -> [ExtensionRaw] -> Maybe ByteString
extensionLookup ExtensionID
EID_ApplicationLayerProtocolNegotiation [ExtensionRaw]
exts
Maybe ByteString
-> (ByteString -> Maybe ApplicationLayerProtocolNegotiation)
-> Maybe ApplicationLayerProtocolNegotiation
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MessageType
-> ByteString -> Maybe ApplicationLayerProtocolNegotiation
forall a. Extension a => MessageType -> ByteString -> Maybe a
extensionDecode MessageType
MsgTClientHello of
Maybe ApplicationLayerProtocolNegotiation
Nothing -> [ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
Just (ApplicationLayerProtocolNegotiation [ByteString]
protos) -> do
case ServerHooks -> Maybe ([ByteString] -> IO ByteString)
onALPNClientSuggest (ServerHooks -> Maybe ([ByteString] -> IO ByteString))
-> ServerHooks -> Maybe ([ByteString] -> IO ByteString)
forall a b. (a -> b) -> a -> b
$ ServerParams -> ServerHooks
serverHooks ServerParams
sparams of
Just [ByteString] -> IO ByteString
io -> do
ByteString
proto <- [ByteString] -> IO ByteString
io [ByteString]
protos
Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString
proto ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"") (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
TLSError -> IO ()
forall (m :: * -> *) a. MonadIO m => TLSError -> m a
throwCore (TLSError -> IO ()) -> TLSError -> IO ()
forall a b. (a -> b) -> a -> b
$
String -> AlertDescription -> TLSError
Error_Protocol String
"no supported application protocols" AlertDescription
NoApplicationProtocol
Context -> TLSSt () -> IO ()
forall a. Context -> TLSSt a -> IO a
usingState_ Context
ctx (TLSSt () -> IO ()) -> TLSSt () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
Bool -> TLSSt ()
setExtensionALPN Bool
True
ByteString -> TLSSt ()
setNegotiatedProtocol ByteString
proto
[ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
[ ExtensionID -> ByteString -> ExtensionRaw
ExtensionRaw
ExtensionID
EID_ApplicationLayerProtocolNegotiation
(ApplicationLayerProtocolNegotiation -> ByteString
forall a. Extension a => a -> ByteString
extensionEncode (ApplicationLayerProtocolNegotiation -> ByteString)
-> ApplicationLayerProtocolNegotiation -> ByteString
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ApplicationLayerProtocolNegotiation
ApplicationLayerProtocolNegotiation [ByteString
proto])
]
Maybe ([ByteString] -> IO ByteString)
_ -> [ExtensionRaw] -> IO [ExtensionRaw]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return []
clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate :: ServerParams -> Context -> CertificateChain -> IO ()
clientCertificate ServerParams
sparams Context
ctx CertificateChain
certs = do
Context -> (Hooks -> IO ()) -> IO ()
forall a. Context -> (Hooks -> IO a) -> IO a
ctxWithHooks Context
ctx (Hooks -> CertificateChain -> IO ()
`hookRecvCertificates` CertificateChain
certs)
CertificateUsage
usage <-
IO CertificateUsage -> IO CertificateUsage
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CertificateUsage -> IO CertificateUsage)
-> IO CertificateUsage -> IO CertificateUsage
forall a b. (a -> b) -> a -> b
$
IO CertificateUsage
-> (SomeException -> IO CertificateUsage) -> IO CertificateUsage
forall a. IO a -> (SomeException -> IO a) -> IO a
catchException
(ServerHooks -> CertificateChain -> IO CertificateUsage
onClientCertificate (ServerParams -> ServerHooks
serverHooks ServerParams
sparams) CertificateChain
certs)
SomeException -> IO CertificateUsage
rejectOnException
case CertificateUsage
usage of
CertificateUsage
CertificateUsageAccept -> [ExtKeyUsageFlag] -> CertificateChain -> IO ()
forall (m :: * -> *).
MonadIO m =>
[ExtKeyUsageFlag] -> CertificateChain -> m ()
verifyLeafKeyUsage [ExtKeyUsageFlag
KeyUsage_digitalSignature] CertificateChain
certs
CertificateUsageReject CertificateRejectReason
reason -> CertificateRejectReason -> IO ()
forall (m :: * -> *) a. MonadIO m => CertificateRejectReason -> m a
certificateRejected CertificateRejectReason
reason
Context -> HandshakeM () -> IO ()
forall (m :: * -> *) a. MonadIO m => Context -> HandshakeM a -> m a
usingHState Context
ctx (HandshakeM () -> IO ()) -> HandshakeM () -> IO ()
forall a b. (a -> b) -> a -> b
$ CertificateChain -> HandshakeM ()
setClientCertChain CertificateChain
certs