-- | Helpers for setting up a tls connection with @tls@ package,
-- for further customization, please refer to @tls@ package.
--
-- Note, functions in this module will throw error if can't load certificates or CA store.
--
module Metro.TP.TLSSetting
  (
    -- * Make TLS settings
    makeClientParams
  , makeClientParams'
  , makeServerParams
  , makeServerParams'
  ) where

import qualified Data.ByteString            as B (empty, readFile)
import           Data.Default.Class         (def)
import qualified Data.PEM                   as X509 (pemContent, pemParseBS)
import qualified Data.X509                  as X509 (CertificateChain (..),
                                                     HashALG (..),
                                                     decodeSignedCertificate)
import qualified Data.X509.CertificateStore as X509 (CertificateStore,
                                                     makeCertificateStore)
import qualified Data.X509.Validation       as X509 (ServiceID, checkFQHN,
                                                     validate)
import qualified Network.TLS                as TLS
import qualified Network.TLS.Extra          as TLS (ciphersuite_strong)



makeCAStore :: FilePath -> IO X509.CertificateStore
makeCAStore :: FilePath -> IO CertificateStore
makeCAStore fp :: FilePath
fp = do
  ByteString
bs <- FilePath -> IO ByteString
B.readFile FilePath
fp
  let Right pems :: [PEM]
pems = ByteString -> Either FilePath [PEM]
X509.pemParseBS ByteString
bs
  case (PEM -> Either FilePath SignedCertificate)
-> [PEM] -> Either FilePath [SignedCertificate]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ByteString -> Either FilePath SignedCertificate
X509.decodeSignedCertificate (ByteString -> Either FilePath SignedCertificate)
-> (PEM -> ByteString) -> PEM -> Either FilePath SignedCertificate
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PEM -> ByteString
X509.pemContent) [PEM]
pems of
    Right cas :: [SignedCertificate]
cas -> CertificateStore -> IO CertificateStore
forall (m :: * -> *) a. Monad m => a -> m a
return ([SignedCertificate] -> CertificateStore
X509.makeCertificateStore [SignedCertificate]
cas)
    Left err :: FilePath
err  -> FilePath -> IO CertificateStore
forall a. HasCallStack => FilePath -> a
error FilePath
err

-- | make a simple tls 'TLS.ClientParams' that will validate server and use tls connection
-- without providing client's own certificate. suitable for connecting server which don't
-- validate clients.
--
-- we defer setting of 'TLS.clientServerIdentification' to connecting phase.
--
-- Note, tls's default validating method require server has v3 certificate.
-- you can use openssl's V3 extension to issue such a certificate. or change 'TLS.ClientParams'
-- before connecting.
--
makeClientParams :: FilePath          -- ^ trusted certificates.
                 -> X509.ServiceID
                 -> IO TLS.ClientParams
makeClientParams :: FilePath -> ServiceID -> IO ClientParams
makeClientParams tca :: FilePath
tca servid :: ServiceID
servid = do
  CertificateStore
caStore <- FilePath -> IO CertificateStore
makeCAStore FilePath
tca
  ClientParams -> IO ClientParams
forall (m :: * -> *) a. Monad m => a -> m a
return (FilePath -> ByteString -> ClientParams
TLS.defaultParamsClient "" ByteString
B.empty)
    { clientSupported :: Supported
TLS.clientSupported = Supported
forall a. Default a => a
def { supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_strong }
    , clientServerIdentification :: ServiceID
TLS.clientServerIdentification = ServiceID
servid
    , clientShared :: Shared
TLS.clientShared    = Shared
forall a. Default a => a
def
      { sharedCAStore :: CertificateStore
TLS.sharedCAStore         = CertificateStore
caStore
      , sharedValidationCache :: ValidationCache
TLS.sharedValidationCache = ValidationCache
forall a. Default a => a
def
      }
    }

-- | make a simple tls 'TLS.ClientParams' that will validate server and use tls connection
-- while providing client's own certificate as well. suitable for connecting server which
-- validate clients.
--
-- Also only accept v3 certificate.
--
makeClientParams' :: FilePath       -- ^ public certificate (X.509 format).
                  -> [FilePath]     -- ^ chain certificates (X.509 format).
                                    --   the root of your certificate chain should be
                                    --   already trusted by server, or tls will fail.
                  -> FilePath       -- ^ private key associated.
                  -> FilePath       -- ^ trusted certificates.
                  -> X509.ServiceID
                  -> IO TLS.ClientParams
makeClientParams' :: FilePath
-> [FilePath]
-> FilePath
-> FilePath
-> ServiceID
-> IO ClientParams
makeClientParams' pub :: FilePath
pub certs :: [FilePath]
certs priv :: FilePath
priv tca :: FilePath
tca servid :: ServiceID
servid = do
  ClientParams
p <- FilePath -> ServiceID -> IO ClientParams
makeClientParams FilePath
tca ServiceID
servid
  Either FilePath Credential
c <- FilePath
-> [FilePath] -> FilePath -> IO (Either FilePath Credential)
TLS.credentialLoadX509Chain FilePath
pub [FilePath]
certs FilePath
priv
  case Either FilePath Credential
c of
    Right c' :: Credential
c' ->
      ClientParams -> IO ClientParams
forall (m :: * -> *) a. Monad m => a -> m a
return ClientParams
p
        { clientShared :: Shared
TLS.clientShared = (ClientParams -> Shared
TLS.clientShared ClientParams
p)
          {
            sharedCredentials :: Credentials
TLS.sharedCredentials = [Credential] -> Credentials
TLS.Credentials [Credential
c']
          }
        , clientHooks :: ClientHooks
TLS.clientHooks = (ClientParams -> ClientHooks
TLS.clientHooks ClientParams
p)
          {
            onCertificateRequest :: OnCertificateRequest
TLS.onCertificateRequest = IO (Maybe Credential) -> OnCertificateRequest
forall a b. a -> b -> a
const (IO (Maybe Credential) -> OnCertificateRequest)
-> (Maybe Credential -> IO (Maybe Credential))
-> Maybe Credential
-> OnCertificateRequest
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Credential -> IO (Maybe Credential)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Credential -> OnCertificateRequest)
-> Maybe Credential -> OnCertificateRequest
forall a b. (a -> b) -> a -> b
$ Credential -> Maybe Credential
forall a. a -> Maybe a
Just Credential
c'
          }
        }
    Left err :: FilePath
err -> FilePath -> IO ClientParams
forall a. HasCallStack => FilePath -> a
error FilePath
err

-- | make a simple tls 'TLS.ServerParams' without validating client's certificate.
--
makeServerParams :: FilePath        -- ^ public certificate (X.509 format).
                 -> [FilePath]      -- ^ chain certificates (X.509 format).
                                    --   the root of your certificate chain should be
                                    --   already trusted by client, or tls will fail.
                 -> FilePath        -- ^ private key associated.
                 -> IO TLS.ServerParams
makeServerParams :: FilePath -> [FilePath] -> FilePath -> IO ServerParams
makeServerParams pub :: FilePath
pub certs :: [FilePath]
certs priv :: FilePath
priv = do
  Either FilePath Credential
c <- FilePath
-> [FilePath] -> FilePath -> IO (Either FilePath Credential)
TLS.credentialLoadX509Chain FilePath
pub [FilePath]
certs FilePath
priv
  case Either FilePath Credential
c of
    Right c' :: Credential
c'@(X509.CertificateChain c'' :: [SignedCertificate]
c'', _) ->
      ServerParams -> IO ServerParams
forall (m :: * -> *) a. Monad m => a -> m a
return ServerParams
forall a. Default a => a
def
        { serverCACertificates :: [SignedCertificate]
TLS.serverCACertificates =  [SignedCertificate]
c''
        , serverShared :: Shared
TLS.serverShared = Shared
forall a. Default a => a
def
          {
            sharedCredentials :: Credentials
TLS.sharedCredentials = [Credential] -> Credentials
TLS.Credentials [Credential
c']
          }
        , serverSupported :: Supported
TLS.serverSupported = Supported
forall a. Default a => a
def { supportedCiphers :: [Cipher]
TLS.supportedCiphers = [Cipher]
TLS.ciphersuite_strong }
        }
    Left err :: FilePath
err -> FilePath -> IO ServerParams
forall a. HasCallStack => FilePath -> a
error FilePath
err

-- | make a tls 'TLS.ServerParams' that also validating client's certificate.
--
makeServerParams' :: FilePath       -- ^ public certificate (X.509 format).
                  -> [FilePath]     -- ^ chain certificates (X.509 format).
                  -> FilePath       -- ^ private key associated.
                  -> FilePath       -- ^ server will use these certificates to validate clients.
                  -> IO TLS.ServerParams
makeServerParams' :: FilePath -> [FilePath] -> FilePath -> FilePath -> IO ServerParams
makeServerParams' pub :: FilePath
pub certs :: [FilePath]
certs priv :: FilePath
priv tca :: FilePath
tca = do
  CertificateStore
caStore <- FilePath -> IO CertificateStore
makeCAStore FilePath
tca
  ServerParams
p <- FilePath -> [FilePath] -> FilePath -> IO ServerParams
makeServerParams FilePath
pub [FilePath]
certs FilePath
priv
  ServerParams -> IO ServerParams
forall (m :: * -> *) a. Monad m => a -> m a
return ServerParams
p
    { serverWantClientCert :: Bool
TLS.serverWantClientCert = Bool
True
    , serverShared :: Shared
TLS.serverShared = (ServerParams -> Shared
TLS.serverShared ServerParams
p)
      {   sharedCAStore :: CertificateStore
TLS.sharedCAStore = CertificateStore
caStore
      }
    , serverHooks :: ServerHooks
TLS.serverHooks = ServerHooks
forall a. Default a => a
def
      { onClientCertificate :: CertificateChain -> IO CertificateUsage
TLS.onClientCertificate = \chain :: CertificateChain
chain -> do
        [FailedReason]
errs <- HashALG
-> ValidationHooks
-> ValidationChecks
-> CertificateStore
-> ValidationCache
-> ServiceID
-> CertificateChain
-> IO [FailedReason]
X509.validate HashALG
X509.HashSHA256 ValidationHooks
forall a. Default a => a
def (ValidationChecks
forall a. Default a => a
def { checkFQHN :: Bool
X509.checkFQHN = Bool
False }) CertificateStore
caStore ValidationCache
forall a. Default a => a
def ("", ByteString
B.empty) CertificateChain
chain
        case [FailedReason]
errs of
          [] -> CertificateUsage -> IO CertificateUsage
forall (m :: * -> *) a. Monad m => a -> m a
return CertificateUsage
TLS.CertificateUsageAccept
          xs :: [FailedReason]
xs -> CertificateUsage -> IO CertificateUsage
forall (m :: * -> *) a. Monad m => a -> m a
return (CertificateUsage -> IO CertificateUsage)
-> (FilePath -> CertificateUsage)
-> FilePath
-> IO CertificateUsage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CertificateRejectReason -> CertificateUsage
TLS.CertificateUsageReject (CertificateRejectReason -> CertificateUsage)
-> (FilePath -> CertificateRejectReason)
-> FilePath
-> CertificateUsage
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FilePath -> CertificateRejectReason
TLS.CertificateRejectOther (FilePath -> IO CertificateUsage)
-> FilePath -> IO CertificateUsage
forall a b. (a -> b) -> a -> b
$ [FailedReason] -> FilePath
forall a. Show a => a -> FilePath
show [FailedReason]
xs
      }
    }