-- | TLS bindings for [Rustls](https://github.com/rustls/rustls) via
-- [rustls-ffi](https://github.com/rustls/rustls-ffi).
--
-- See the [README on GitHub](https://github.com/amesgen/hs-rustls/tree/main/rustls)
-- for setup instructions.
--
-- Currently, most of the functionality exposed by rustls-ffi is available,
-- while rustls-ffi is still missing some more niche Rustls features.
--
-- Also see [http-client-rustls](https://hackage.haskell.org/package/http-client-rustls)
-- for making HTTPS requests using
-- [http-client](https://hackage.haskell.org/package/http-client) and Rustls.
--
-- == Client example
--
-- Suppose you have alread opened a 'Network.Socket.Socket' to @example.org@,
-- port 443 (see e.g. the examples at "Network.Socket"). This small example
-- showcases how to perform a simple HTTP GET request:
--
-- >>> :set -XOverloadedStrings
-- >>> import qualified Rustls
-- >>> import Network.Socket (Socket)
-- >>> import Data.Acquire (withAcquire)
-- >>> :{
-- example :: Socket -> IO ()
-- example socket = do
--   -- It is encouraged to share a single `clientConfig` when creating multiple
--   -- TLS connections.
--   clientConfig <-
--     Rustls.buildClientConfig $ Rustls.defaultClientConfigBuilder roots
--   let newConnection =
--         Rustls.newClientConnection socket clientConfig "example.org"
--   withAcquire newConnection $ \conn -> do
--     Rustls.writeBS conn "GET /"
--     recv <- Rustls.readBS conn 1000 -- max number of bytes to read
--     print recv
--   where
--     -- For now, rustls-ffi does not provide a built-in way to access
--     -- the OS certificate store.
--     roots = Rustls.ClientRootsFromFile "/etc/ssl/certs/ca-certificates.crt"
-- :}
--
-- == Using 'Acquire'
--
-- Some API functions (like 'newClientConnection' and 'newServerConnection')
-- return an 'Acquire' from
-- [resourcet](https://hackage.haskell.org/package/resourcet), as it is a
-- convenient abstraction for exposing a value that should be consumed in a
-- "bracketed" manner.
--
-- Usually, it can be used via 'Data.Acquire.with' or 'withAcquire', or via
-- 'allocateAcquire' when a 'Control.Monad.Trans.Resource.MonadResource'
-- constraint is available. If you really need the extra flexibility, you can
-- also access separate @open…@ and @close…@ functions by reaching for
-- "Data.Acquire.Internal".
module Rustls
  ( -- * Client

    -- ** Builder
    ClientConfigBuilder (..),
    defaultClientConfigBuilder,
    ClientRoots (..),
    PEMCertificates (..),

    -- ** Config
    ClientConfig,
    clientConfigLogCallback,
    buildClientConfig,

    -- ** Open a connection
    newClientConnection,

    -- * Server

    -- ** Builder
    ServerConfigBuilder (..),
    defaultServerConfigBuilder,
    ClientCertVerifier (..),

    -- ** Config
    ServerConfig,
    serverConfigLogCallback,
    buildServerConfig,

    -- ** Open a connection
    newServerConnection,

    -- * Connection
    Connection,
    Side (..),

    -- ** Read and write
    readBS,
    writeBS,

    -- ** Handshaking
    handshake,
    HandshakeQuery,
    getALPNProtocol,
    getTLSVersion,
    getCipherSuite,
    getSNIHostname,
    getPeerCertificate,

    -- ** Closing
    sendCloseNotify,

    -- ** Logging
    LogCallback,
    newLogCallback,
    LogLevel (..),

    -- ** Raw 'Ptr'-based API
    readPtr,
    writePtr,

    -- * Misc
    version,

    -- ** Backend
    Backend (..),
    ByteStringBackend (..),

    -- ** Types
    ALPNProtocol (..),
    CertifiedKey (..),
    DERCertificate (..),
    TLSVersion (TLS12, TLS13, unTLSVersion),
    defaultTLSVersions,
    allTLSVersions,
    CipherSuite,
    cipherSuiteID,
    showCipherSuite,
    defaultCipherSuites,
    allCipherSuites,

    -- ** Exceptions
    RustlsException,
    isCertError,
    RustlsLogException (..),
  )
where

import Control.Concurrent (forkFinally, killThread)
import Control.Concurrent.MVar
import qualified Control.Exception as E
import Control.Monad (forever, when, (<=<))
import Control.Monad.IO.Class
import Control.Monad.Trans.Cont
import Control.Monad.Trans.Reader
import Data.Acquire
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import Data.Coerce
import Data.Foldable (for_)
import Data.List.NonEmpty (NonEmpty)
import qualified Data.List.NonEmpty as NE
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Foreign as T
import Foreign
import Foreign.C
import GHC.Conc (reportError)
import GHC.Generics (Generic)
import Rustls.Internal
import Rustls.Internal.FFI (ConstPtr (..), TLSVersion (..))
import qualified Rustls.Internal.FFI as FFI
import System.IO.Unsafe (unsafePerformIO)

-- $setup
-- >>> import Control.Monad.IO.Class
-- >>> import Data.Acquire

-- | Combined version string of Rustls and rustls-ffi.
--
-- >>> version
-- "rustls-ffi/0.9.2/rustls/0.20.8"
version :: Text
version :: Text
version = IO Text -> Text
forall a. IO a -> a
unsafePerformIO (IO Text -> Text) -> IO Text -> Text
forall a b. (a -> b) -> a -> b
$ (Ptr Str -> IO Text) -> IO Text
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr Str
strPtr -> do
  Ptr Str -> IO ()
FFI.hsVersion Ptr Str
strPtr
  Str -> IO Text
strToText (Str -> IO Text) -> IO Str -> IO Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Str -> IO Str
forall a. Storable a => Ptr a -> IO a
peek Ptr Str
strPtr
{-# NOINLINE version #-}

peekNonEmpty :: (Storable a, Coercible a b) => ConstPtr a -> CSize -> NonEmpty b
peekNonEmpty :: forall a b.
(Storable a, Coercible a b) =>
ConstPtr a -> CSize -> NonEmpty b
peekNonEmpty (ConstPtr Ptr a
as) CSize
len =
  [b] -> NonEmpty b
forall a. HasCallStack => [a] -> NonEmpty a
NE.fromList ([b] -> NonEmpty b) -> (IO [a] -> [b]) -> IO [a] -> NonEmpty b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [b]
forall a b. Coercible a b => a -> b
coerce ([a] -> [b]) -> (IO [a] -> [a]) -> IO [a] -> [b]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO [a] -> [a]
forall a. IO a -> a
unsafePerformIO (IO [a] -> NonEmpty b) -> IO [a] -> NonEmpty b
forall a b. (a -> b) -> a -> b
$ Int -> Ptr a -> IO [a]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (CSize -> Int
cSizeToInt CSize
len) Ptr a
as

-- | All 'TLSVersion's supported by Rustls.
allTLSVersions :: NonEmpty TLSVersion
allTLSVersions :: NonEmpty TLSVersion
allTLSVersions = ConstPtr TLSVersion -> CSize -> NonEmpty TLSVersion
forall a b.
(Storable a, Coercible a b) =>
ConstPtr a -> CSize -> NonEmpty b
peekNonEmpty ConstPtr TLSVersion
FFI.allVersions CSize
FFI.allVersionsLen
{-# NOINLINE allTLSVersions #-}

-- | The default 'TLSVersion's used by Rustls. A subset of 'allTLSVersions'.
defaultTLSVersions :: NonEmpty TLSVersion
defaultTLSVersions :: NonEmpty TLSVersion
defaultTLSVersions = ConstPtr TLSVersion -> CSize -> NonEmpty TLSVersion
forall a b.
(Storable a, Coercible a b) =>
ConstPtr a -> CSize -> NonEmpty b
peekNonEmpty ConstPtr TLSVersion
FFI.defaultVersions CSize
FFI.defaultVersionsLen
{-# NOINLINE defaultTLSVersions #-}

-- | All 'CipherSuite's supported by Rustls.
allCipherSuites :: NonEmpty CipherSuite
allCipherSuites :: NonEmpty CipherSuite
allCipherSuites = ConstPtr (Ptr SupportedCipherSuite)
-> CSize -> NonEmpty CipherSuite
forall a b.
(Storable a, Coercible a b) =>
ConstPtr a -> CSize -> NonEmpty b
peekNonEmpty ConstPtr (Ptr SupportedCipherSuite)
FFI.allCipherSuites CSize
FFI.allCipherSuitesLen
{-# NOINLINE allCipherSuites #-}

-- | The default 'CipherSuite's used by Rustls. A subset of 'allCipherSuites'.
defaultCipherSuites :: NonEmpty CipherSuite
defaultCipherSuites :: NonEmpty CipherSuite
defaultCipherSuites = ConstPtr (ConstPtr SupportedCipherSuite)
-> CSize -> NonEmpty CipherSuite
forall a b.
(Storable a, Coercible a b) =>
ConstPtr a -> CSize -> NonEmpty b
peekNonEmpty ConstPtr (ConstPtr SupportedCipherSuite)
FFI.defaultCipherSuites CSize
FFI.defaultCipherSuitesLen
{-# NOINLINE defaultCipherSuites #-}

-- | A 'ClientConfigBuilder' with good defaults.
defaultClientConfigBuilder :: ClientRoots -> ClientConfigBuilder
defaultClientConfigBuilder :: ClientRoots -> ClientConfigBuilder
defaultClientConfigBuilder ClientRoots
roots =
  ClientConfigBuilder
    { clientConfigTLSVersions :: [TLSVersion]
clientConfigTLSVersions = [],
      clientConfigCipherSuites :: [CipherSuite]
clientConfigCipherSuites = [],
      clientConfigRoots :: ClientRoots
clientConfigRoots = ClientRoots
roots,
      clientConfigALPNProtocols :: [ALPNProtocol]
clientConfigALPNProtocols = [],
      clientConfigEnableSNI :: Bool
clientConfigEnableSNI = Bool
True,
      clientConfigCertifiedKeys :: [CertifiedKey]
clientConfigCertifiedKeys = []
    }

withCertifiedKeys :: [CertifiedKey] -> ((ConstPtr (ConstPtr FFI.CertifiedKey), CSize) -> IO a) -> IO a
withCertifiedKeys :: forall a.
[CertifiedKey]
-> ((ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a) -> IO a
withCertifiedKeys [CertifiedKey]
certifiedKeys (ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a
cb =
  (CertifiedKey -> (ConstPtr CertifiedKey -> IO a) -> IO a)
-> [CertifiedKey] -> ([ConstPtr CertifiedKey] -> IO a) -> IO a
forall a b res.
(a -> (b -> res) -> res) -> [a] -> ([b] -> res) -> res
withMany CertifiedKey -> (ConstPtr CertifiedKey -> IO a) -> IO a
forall {a}. CertifiedKey -> (ConstPtr CertifiedKey -> IO a) -> IO a
withCertifiedKey [CertifiedKey]
certifiedKeys \[ConstPtr CertifiedKey]
certKeys ->
    [ConstPtr CertifiedKey]
-> (Int -> Ptr (ConstPtr CertifiedKey) -> IO a) -> IO a
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [ConstPtr CertifiedKey]
certKeys \Int
len Ptr (ConstPtr CertifiedKey)
ptr -> (ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a
cb (Ptr (ConstPtr CertifiedKey) -> ConstPtr (ConstPtr CertifiedKey)
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr (ConstPtr CertifiedKey)
ptr, Int -> CSize
intToCSize Int
len)
  where
    withCertifiedKey :: CertifiedKey -> (ConstPtr CertifiedKey -> IO a) -> IO a
withCertifiedKey CertifiedKey {ByteString
certificateChain :: ByteString
privateKey :: ByteString
certificateChain :: CertifiedKey -> ByteString
privateKey :: CertifiedKey -> ByteString
..} ConstPtr CertifiedKey -> IO a
cb =
      ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
certificateChain \(Ptr CChar
certPtr, Int
certLen) ->
        ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
privateKey \(Ptr CChar
privPtr, Int
privLen) ->
          (Ptr (ConstPtr CertifiedKey) -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (ConstPtr CertifiedKey)
certKeyPtr -> do
            Result -> IO ()
rethrowR
              (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConstPtr Word8
-> CSize
-> ConstPtr Word8
-> CSize
-> Ptr (ConstPtr CertifiedKey)
-> IO Result
FFI.certifiedKeyBuild
                (Ptr Word8 -> ConstPtr Word8
forall a. Ptr a -> ConstPtr a
ConstPtr (Ptr Word8 -> ConstPtr Word8) -> Ptr Word8 -> ConstPtr Word8
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
certPtr)
                (Int -> CSize
intToCSize Int
certLen)
                (Ptr Word8 -> ConstPtr Word8
forall a. Ptr a -> ConstPtr a
ConstPtr (Ptr Word8 -> ConstPtr Word8) -> Ptr Word8 -> ConstPtr Word8
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
privPtr)
                (Int -> CSize
intToCSize Int
privLen)
                Ptr (ConstPtr CertifiedKey)
certKeyPtr
            ConstPtr CertifiedKey -> IO a
cb (ConstPtr CertifiedKey -> IO a)
-> IO (ConstPtr CertifiedKey) -> IO a
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr (ConstPtr CertifiedKey) -> IO (ConstPtr CertifiedKey)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr CertifiedKey)
certKeyPtr

withALPNProtocols :: [ALPNProtocol] -> ((ConstPtr FFI.SliceBytes, CSize) -> IO a) -> IO a
withALPNProtocols :: forall a.
[ALPNProtocol] -> ((ConstPtr SliceBytes, CSize) -> IO a) -> IO a
withALPNProtocols [ALPNProtocol]
bss (ConstPtr SliceBytes, CSize) -> IO a
cb = do
  (ByteString -> (SliceBytes -> IO a) -> IO a)
-> [ByteString] -> ([SliceBytes] -> IO a) -> IO a
forall a b res.
(a -> (b -> res) -> res) -> [a] -> ([b] -> res) -> res
withMany ByteString -> (SliceBytes -> IO a) -> IO a
forall {a}. ByteString -> (SliceBytes -> IO a) -> IO a
withSliceBytes ([ALPNProtocol] -> [ByteString]
forall a b. Coercible a b => a -> b
coerce [ALPNProtocol]
bss) \[SliceBytes]
bsPtrs ->
    [SliceBytes] -> (Int -> Ptr SliceBytes -> IO a) -> IO a
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [SliceBytes]
bsPtrs \Int
len Ptr SliceBytes
bsPtr -> (ConstPtr SliceBytes, CSize) -> IO a
cb (Ptr SliceBytes -> ConstPtr SliceBytes
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr SliceBytes
bsPtr, Int -> CSize
intToCSize Int
len)
  where
    withSliceBytes :: ByteString -> (SliceBytes -> IO a) -> IO a
withSliceBytes ByteString
bs SliceBytes -> IO a
cb =
      ByteString -> (CStringLen -> IO a) -> IO a
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs \(Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr -> Ptr Word8
buf, Int -> CSize
intToCSize -> CSize
len) ->
        SliceBytes -> IO a
cb (SliceBytes -> IO a) -> SliceBytes -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> SliceBytes
FFI.SliceBytes Ptr Word8
buf CSize
len

configBuilderNew ::
  ( ConstPtr (ConstPtr FFI.SupportedCipherSuite) ->
    CSize ->
    ConstPtr TLSVersion ->
    CSize ->
    Ptr (Ptr configBuilder) ->
    IO FFI.Result
  ) ->
  [CipherSuite] ->
  [TLSVersion] ->
  IO (Ptr configBuilder)
configBuilderNew :: forall configBuilder.
(ConstPtr (ConstPtr SupportedCipherSuite)
 -> CSize
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr configBuilder)
 -> IO Result)
-> [CipherSuite] -> [TLSVersion] -> IO (Ptr configBuilder)
configBuilderNew ConstPtr (ConstPtr SupportedCipherSuite)
-> CSize
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr configBuilder)
-> IO Result
configBuilderNewCustom [CipherSuite]
cipherSuites [TLSVersion]
tlsVersions = ContT (Ptr configBuilder) IO (Ptr configBuilder)
-> IO (Ptr configBuilder)
forall (m :: * -> *) r. Monad m => ContT r m r -> m r
evalContT do
  Ptr (Ptr configBuilder)
builderPtr <- ((Ptr (Ptr configBuilder) -> IO (Ptr configBuilder))
 -> IO (Ptr configBuilder))
-> ContT (Ptr configBuilder) IO (Ptr (Ptr configBuilder))
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT (Ptr (Ptr configBuilder) -> IO (Ptr configBuilder))
-> IO (Ptr configBuilder)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca
  (CSize
cipherSuitesLen, ConstPtr (ConstPtr SupportedCipherSuite)
cipherSuitesPtr) <-
    if [CipherSuite] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [CipherSuite]
cipherSuites
      then (CSize, ConstPtr (ConstPtr SupportedCipherSuite))
-> ContT
     (Ptr configBuilder)
     IO
     (CSize, ConstPtr (ConstPtr SupportedCipherSuite))
forall a. a -> ContT (Ptr configBuilder) IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CSize
FFI.defaultCipherSuitesLen, ConstPtr (ConstPtr SupportedCipherSuite)
FFI.defaultCipherSuites)
      else (((CSize, ConstPtr (ConstPtr SupportedCipherSuite))
  -> IO (Ptr configBuilder))
 -> IO (Ptr configBuilder))
-> ContT
     (Ptr configBuilder)
     IO
     (CSize, ConstPtr (ConstPtr SupportedCipherSuite))
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT \(CSize, ConstPtr (ConstPtr SupportedCipherSuite))
-> IO (Ptr configBuilder)
cb -> [ConstPtr SupportedCipherSuite]
-> (Int
    -> Ptr (ConstPtr SupportedCipherSuite) -> IO (Ptr configBuilder))
-> IO (Ptr configBuilder)
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen ([CipherSuite] -> [ConstPtr SupportedCipherSuite]
forall a b. Coercible a b => a -> b
coerce [CipherSuite]
cipherSuites) \Int
len Ptr (ConstPtr SupportedCipherSuite)
ptr ->
        (CSize, ConstPtr (ConstPtr SupportedCipherSuite))
-> IO (Ptr configBuilder)
cb (Int -> CSize
intToCSize Int
len, Ptr (ConstPtr SupportedCipherSuite)
-> ConstPtr (ConstPtr SupportedCipherSuite)
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr (ConstPtr SupportedCipherSuite)
ptr)
  (CSize
tlsVersionsLen, ConstPtr TLSVersion
tlsVersionsPtr) <-
    if [TLSVersion] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TLSVersion]
tlsVersions
      then (CSize, ConstPtr TLSVersion)
-> ContT (Ptr configBuilder) IO (CSize, ConstPtr TLSVersion)
forall a. a -> ContT (Ptr configBuilder) IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CSize
FFI.defaultVersionsLen, ConstPtr TLSVersion
FFI.defaultVersions)
      else (((CSize, ConstPtr TLSVersion) -> IO (Ptr configBuilder))
 -> IO (Ptr configBuilder))
-> ContT (Ptr configBuilder) IO (CSize, ConstPtr TLSVersion)
forall {k} (r :: k) (m :: k -> *) a.
((a -> m r) -> m r) -> ContT r m a
ContT \(CSize, ConstPtr TLSVersion) -> IO (Ptr configBuilder)
cb -> [TLSVersion]
-> (Int -> Ptr TLSVersion -> IO (Ptr configBuilder))
-> IO (Ptr configBuilder)
forall a b. Storable a => [a] -> (Int -> Ptr a -> IO b) -> IO b
withArrayLen [TLSVersion]
tlsVersions \Int
len Ptr TLSVersion
ptr ->
        (CSize, ConstPtr TLSVersion) -> IO (Ptr configBuilder)
cb (Int -> CSize
intToCSize Int
len, Ptr TLSVersion -> ConstPtr TLSVersion
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr TLSVersion
ptr)
  IO (Ptr configBuilder)
-> ContT (Ptr configBuilder) IO (Ptr configBuilder)
forall a. IO a -> ContT (Ptr configBuilder) IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
    Result -> IO ()
rethrowR
      (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConstPtr (ConstPtr SupportedCipherSuite)
-> CSize
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr configBuilder)
-> IO Result
configBuilderNewCustom
        ConstPtr (ConstPtr SupportedCipherSuite)
cipherSuitesPtr
        CSize
cipherSuitesLen
        ConstPtr TLSVersion
tlsVersionsPtr
        CSize
tlsVersionsLen
        Ptr (Ptr configBuilder)
builderPtr
    Ptr (Ptr configBuilder) -> IO (Ptr configBuilder)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr configBuilder)
builderPtr

withRootCertStore :: [PEMCertificates] -> (ConstPtr FFI.RootCertStore -> IO a) -> IO a
withRootCertStore :: forall a.
[PEMCertificates] -> (ConstPtr RootCertStore -> IO a) -> IO a
withRootCertStore [PEMCertificates]
certs ConstPtr RootCertStore -> IO a
action =
  IO (Ptr RootCertStore)
-> (Ptr RootCertStore -> IO ())
-> (Ptr RootCertStore -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket IO (Ptr RootCertStore)
FFI.rootCertStoreNew Ptr RootCertStore -> IO ()
FFI.rootCertStoreFree \Ptr RootCertStore
store -> do
    let addPEM :: ByteString -> Bool -> IO ()
addPEM ByteString
bs (forall a. Num a => Bool -> a
fromBool @CBool -> CBool
strict) =
          ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs \(Ptr CChar
buf, Int
len) ->
            Result -> IO ()
rethrowR
              (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr RootCertStore -> ConstPtr Word8 -> CSize -> CBool -> IO Result
FFI.rootCertStoreAddPEM
                Ptr RootCertStore
store
                (Ptr Word8 -> ConstPtr Word8
forall a. Ptr a -> ConstPtr a
ConstPtr (Ptr Word8 -> ConstPtr Word8) -> Ptr Word8 -> ConstPtr Word8
forall a b. (a -> b) -> a -> b
$ Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
buf)
                (Int -> CSize
intToCSize Int
len)
                CBool
strict
    [PEMCertificates] -> (PEMCertificates -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ [PEMCertificates]
certs \case
      PEMCertificatesStrict ByteString
bs -> ByteString -> Bool -> IO ()
addPEM ByteString
bs Bool
True
      PEMCertificatesLax ByteString
bs -> ByteString -> Bool -> IO ()
addPEM ByteString
bs Bool
False
    ConstPtr RootCertStore -> IO a
action (ConstPtr RootCertStore -> IO a) -> ConstPtr RootCertStore -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr RootCertStore -> ConstPtr RootCertStore
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr RootCertStore
store

-- | Build a 'ClientConfigBuilder' into a 'ClientConfig'.
--
-- This is a relatively expensive operation, so it is a good idea to share one
-- 'ClientConfig' when creating multiple 'Connection's.
buildClientConfig :: (MonadIO m) => ClientConfigBuilder -> m ClientConfig
buildClientConfig :: forall (m :: * -> *).
MonadIO m =>
ClientConfigBuilder -> m ClientConfig
buildClientConfig ClientConfigBuilder {Bool
[TLSVersion]
[CertifiedKey]
[CipherSuite]
[ALPNProtocol]
ClientRoots
clientConfigTLSVersions :: ClientConfigBuilder -> [TLSVersion]
clientConfigCipherSuites :: ClientConfigBuilder -> [CipherSuite]
clientConfigRoots :: ClientConfigBuilder -> ClientRoots
clientConfigALPNProtocols :: ClientConfigBuilder -> [ALPNProtocol]
clientConfigEnableSNI :: ClientConfigBuilder -> Bool
clientConfigCertifiedKeys :: ClientConfigBuilder -> [CertifiedKey]
clientConfigRoots :: ClientRoots
clientConfigTLSVersions :: [TLSVersion]
clientConfigCipherSuites :: [CipherSuite]
clientConfigALPNProtocols :: [ALPNProtocol]
clientConfigEnableSNI :: Bool
clientConfigCertifiedKeys :: [CertifiedKey]
..} = IO ClientConfig -> m ClientConfig
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ClientConfig -> m ClientConfig)
-> (IO ClientConfig -> IO ClientConfig)
-> IO ClientConfig
-> m ClientConfig
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO ClientConfig -> IO ClientConfig
forall a. IO a -> IO a
E.mask_ (IO ClientConfig -> m ClientConfig)
-> IO ClientConfig -> m ClientConfig
forall a b. (a -> b) -> a -> b
$
  IO (Ptr ClientConfigBuilder)
-> (Ptr ClientConfigBuilder -> IO ())
-> (Ptr ClientConfigBuilder -> IO ClientConfig)
-> IO ClientConfig
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError
    ( (ConstPtr (ConstPtr SupportedCipherSuite)
 -> CSize
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr ClientConfigBuilder)
 -> IO Result)
-> [CipherSuite] -> [TLSVersion] -> IO (Ptr ClientConfigBuilder)
forall configBuilder.
(ConstPtr (ConstPtr SupportedCipherSuite)
 -> CSize
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr configBuilder)
 -> IO Result)
-> [CipherSuite] -> [TLSVersion] -> IO (Ptr configBuilder)
configBuilderNew
        ConstPtr (ConstPtr SupportedCipherSuite)
-> CSize
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr ClientConfigBuilder)
-> IO Result
FFI.clientConfigBuilderNewCustom
        [CipherSuite]
clientConfigCipherSuites
        [TLSVersion]
clientConfigTLSVersions
    )
    Ptr ClientConfigBuilder -> IO ()
FFI.clientConfigBuilderFree
    \Ptr ClientConfigBuilder
builder -> do
      case ClientRoots
clientConfigRoots of
        ClientRootsFromFile FilePath
rootsPath ->
          FilePath -> (Ptr CChar -> IO ()) -> IO ()
forall a. FilePath -> (Ptr CChar -> IO a) -> IO a
withCString FilePath
rootsPath ((Ptr CChar -> IO ()) -> IO ()) -> (Ptr CChar -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
            Result -> IO ()
rethrowR (Result -> IO ()) -> (Ptr CChar -> IO Result) -> Ptr CChar -> IO ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Ptr ClientConfigBuilder -> ConstCString -> IO Result
FFI.clientConfigBuilderLoadRootsFromFile Ptr ClientConfigBuilder
builder (ConstCString -> IO Result)
-> (Ptr CChar -> ConstCString) -> Ptr CChar -> IO Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CChar -> ConstCString
forall a. Ptr a -> ConstPtr a
ConstPtr
        ClientRootsInMemory [PEMCertificates]
certs ->
          [PEMCertificates] -> (ConstPtr RootCertStore -> IO ()) -> IO ()
forall a.
[PEMCertificates] -> (ConstPtr RootCertStore -> IO a) -> IO a
withRootCertStore [PEMCertificates]
certs ((ConstPtr RootCertStore -> IO ()) -> IO ())
-> (ConstPtr RootCertStore -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Result -> IO ()
rethrowR (Result -> IO ())
-> (ConstPtr RootCertStore -> IO Result)
-> ConstPtr RootCertStore
-> IO ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Ptr ClientConfigBuilder -> ConstPtr RootCertStore -> IO Result
FFI.clientConfigBuilderUseRoots Ptr ClientConfigBuilder
builder
      [ALPNProtocol] -> ((ConstPtr SliceBytes, CSize) -> IO ()) -> IO ()
forall a.
[ALPNProtocol] -> ((ConstPtr SliceBytes, CSize) -> IO a) -> IO a
withALPNProtocols [ALPNProtocol]
clientConfigALPNProtocols \(ConstPtr SliceBytes
alpnPtr, CSize
len) ->
        Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ClientConfigBuilder
-> ConstPtr SliceBytes -> CSize -> IO Result
FFI.clientConfigBuilderSetALPNProtocols Ptr ClientConfigBuilder
builder ConstPtr SliceBytes
alpnPtr CSize
len
      Ptr ClientConfigBuilder -> CBool -> IO ()
FFI.clientConfigBuilderSetEnableSNI Ptr ClientConfigBuilder
builder (forall a. Num a => Bool -> a
fromBool @CBool Bool
clientConfigEnableSNI)
      [CertifiedKey]
-> ((ConstPtr (ConstPtr CertifiedKey), CSize) -> IO ()) -> IO ()
forall a.
[CertifiedKey]
-> ((ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a) -> IO a
withCertifiedKeys [CertifiedKey]
clientConfigCertifiedKeys \(ConstPtr (ConstPtr CertifiedKey)
ptr, CSize
len) ->
        Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ClientConfigBuilder
-> ConstPtr (ConstPtr CertifiedKey) -> CSize -> IO Result
FFI.clientConfigBuilderSetCertifiedKey Ptr ClientConfigBuilder
builder ConstPtr (ConstPtr CertifiedKey)
ptr CSize
len
      let clientConfigLogCallback :: Maybe a
clientConfigLogCallback = Maybe a
forall a. Maybe a
Nothing
      ForeignPtr ClientConfig
clientConfigPtr <-
        FinalizerPtr ClientConfig
-> Ptr ClientConfig -> IO (ForeignPtr ClientConfig)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr ClientConfig
FFI.clientConfigFree (Ptr ClientConfig -> IO (ForeignPtr ClientConfig))
-> (ConstPtr ClientConfig -> Ptr ClientConfig)
-> ConstPtr ClientConfig
-> IO (ForeignPtr ClientConfig)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstPtr ClientConfig -> Ptr ClientConfig
forall a. ConstPtr a -> Ptr a
unConstPtr
          (ConstPtr ClientConfig -> IO (ForeignPtr ClientConfig))
-> IO (ConstPtr ClientConfig) -> IO (ForeignPtr ClientConfig)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ClientConfigBuilder -> IO (ConstPtr ClientConfig)
FFI.clientConfigBuilderBuild Ptr ClientConfigBuilder
builder
      ClientConfig -> IO ClientConfig
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ClientConfig {Maybe LogCallback
ForeignPtr ClientConfig
forall a. Maybe a
clientConfigLogCallback :: Maybe LogCallback
clientConfigLogCallback :: forall a. Maybe a
clientConfigPtr :: ForeignPtr ClientConfig
clientConfigPtr :: ForeignPtr ClientConfig
..}

-- | Build a 'ServerConfigBuilder' into a 'ServerConfig'.
--
-- This is a relatively expensive operation, so it is a good idea to share one
-- 'ServerConfig' when creating multiple 'Connection's.
buildServerConfig :: (MonadIO m) => ServerConfigBuilder -> m ServerConfig
buildServerConfig :: forall (m :: * -> *).
MonadIO m =>
ServerConfigBuilder -> m ServerConfig
buildServerConfig ServerConfigBuilder {Bool
[TLSVersion]
[CipherSuite]
[ALPNProtocol]
Maybe ClientCertVerifier
NonEmpty CertifiedKey
serverConfigCertifiedKeys :: NonEmpty CertifiedKey
serverConfigTLSVersions :: [TLSVersion]
serverConfigCipherSuites :: [CipherSuite]
serverConfigALPNProtocols :: [ALPNProtocol]
serverConfigIgnoreClientOrder :: Bool
serverConfigClientCertVerifier :: Maybe ClientCertVerifier
serverConfigCertifiedKeys :: ServerConfigBuilder -> NonEmpty CertifiedKey
serverConfigTLSVersions :: ServerConfigBuilder -> [TLSVersion]
serverConfigCipherSuites :: ServerConfigBuilder -> [CipherSuite]
serverConfigALPNProtocols :: ServerConfigBuilder -> [ALPNProtocol]
serverConfigIgnoreClientOrder :: ServerConfigBuilder -> Bool
serverConfigClientCertVerifier :: ServerConfigBuilder -> Maybe ClientCertVerifier
..} = IO ServerConfig -> m ServerConfig
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ServerConfig -> m ServerConfig)
-> (IO ServerConfig -> IO ServerConfig)
-> IO ServerConfig
-> m ServerConfig
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO ServerConfig -> IO ServerConfig
forall a. IO a -> IO a
E.mask_ (IO ServerConfig -> m ServerConfig)
-> IO ServerConfig -> m ServerConfig
forall a b. (a -> b) -> a -> b
$
  IO (Ptr ServerConfigBuilder)
-> (Ptr ServerConfigBuilder -> IO ())
-> (Ptr ServerConfigBuilder -> IO ServerConfig)
-> IO ServerConfig
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracketOnError
    ( (ConstPtr (ConstPtr SupportedCipherSuite)
 -> CSize
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr ServerConfigBuilder)
 -> IO Result)
-> [CipherSuite] -> [TLSVersion] -> IO (Ptr ServerConfigBuilder)
forall configBuilder.
(ConstPtr (ConstPtr SupportedCipherSuite)
 -> CSize
 -> ConstPtr TLSVersion
 -> CSize
 -> Ptr (Ptr configBuilder)
 -> IO Result)
-> [CipherSuite] -> [TLSVersion] -> IO (Ptr configBuilder)
configBuilderNew
        ConstPtr (ConstPtr SupportedCipherSuite)
-> CSize
-> ConstPtr TLSVersion
-> CSize
-> Ptr (Ptr ServerConfigBuilder)
-> IO Result
FFI.serverConfigBuilderNewCustom
        [CipherSuite]
serverConfigCipherSuites
        [TLSVersion]
serverConfigTLSVersions
    )
    Ptr ServerConfigBuilder -> IO ()
FFI.serverConfigBuilderFree
    \Ptr ServerConfigBuilder
builder -> do
      [ALPNProtocol] -> ((ConstPtr SliceBytes, CSize) -> IO ()) -> IO ()
forall a.
[ALPNProtocol] -> ((ConstPtr SliceBytes, CSize) -> IO a) -> IO a
withALPNProtocols [ALPNProtocol]
serverConfigALPNProtocols \(ConstPtr SliceBytes
alpnPtr, CSize
len) ->
        Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ServerConfigBuilder
-> ConstPtr SliceBytes -> CSize -> IO Result
FFI.serverConfigBuilderSetALPNProtocols Ptr ServerConfigBuilder
builder ConstPtr SliceBytes
alpnPtr CSize
len
      Result -> IO ()
rethrowR
        (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ServerConfigBuilder -> CBool -> IO Result
FFI.serverConfigBuilderSetIgnoreClientOrder
          Ptr ServerConfigBuilder
builder
          (forall a. Num a => Bool -> a
fromBool @CBool Bool
serverConfigIgnoreClientOrder)
      [CertifiedKey]
-> ((ConstPtr (ConstPtr CertifiedKey), CSize) -> IO ()) -> IO ()
forall a.
[CertifiedKey]
-> ((ConstPtr (ConstPtr CertifiedKey), CSize) -> IO a) -> IO a
withCertifiedKeys (NonEmpty CertifiedKey -> [CertifiedKey]
forall a. NonEmpty a -> [a]
NE.toList NonEmpty CertifiedKey
serverConfigCertifiedKeys) \(ConstPtr (ConstPtr CertifiedKey)
ptr, CSize
len) ->
        Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ServerConfigBuilder
-> ConstPtr (ConstPtr CertifiedKey) -> CSize -> IO Result
FFI.serverConfigBuilderSetCertifiedKeys Ptr ServerConfigBuilder
builder ConstPtr (ConstPtr CertifiedKey)
ptr CSize
len
      let setBuilderCCV :: [PEMCertificates]
-> (ConstPtr RootCertStore -> IO a)
-> (a -> IO b)
-> (Ptr ServerConfigBuilder -> a -> IO a)
-> IO a
setBuilderCCV [PEMCertificates]
certs ConstPtr RootCertStore -> IO a
ccvNew a -> IO b
ccvFree Ptr ServerConfigBuilder -> a -> IO a
setCCV =
            [PEMCertificates] -> (ConstPtr RootCertStore -> IO a) -> IO a
forall a.
[PEMCertificates] -> (ConstPtr RootCertStore -> IO a) -> IO a
withRootCertStore [PEMCertificates]
certs \ConstPtr RootCertStore
roots ->
              IO a -> (a -> IO b) -> (a -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (ConstPtr RootCertStore -> IO a
ccvNew ConstPtr RootCertStore
roots) a -> IO b
ccvFree ((a -> IO a) -> IO a) -> (a -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ Ptr ServerConfigBuilder -> a -> IO a
setCCV Ptr ServerConfigBuilder
builder
      Maybe ClientCertVerifier -> (ClientCertVerifier -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Maybe ClientCertVerifier
serverConfigClientCertVerifier \case
        ClientCertVerifier [PEMCertificates]
certs -> do
          [PEMCertificates]
-> (ConstPtr RootCertStore -> IO (ConstPtr ClientCertVerifier))
-> (ConstPtr ClientCertVerifier -> IO ())
-> (Ptr ServerConfigBuilder
    -> ConstPtr ClientCertVerifier -> IO ())
-> IO ()
forall {a} {b} {a}.
[PEMCertificates]
-> (ConstPtr RootCertStore -> IO a)
-> (a -> IO b)
-> (Ptr ServerConfigBuilder -> a -> IO a)
-> IO a
setBuilderCCV
            [PEMCertificates]
certs
            ConstPtr RootCertStore -> IO (ConstPtr ClientCertVerifier)
FFI.clientCertVerifierNew
            ConstPtr ClientCertVerifier -> IO ()
FFI.clientCertVerifierFree
            Ptr ServerConfigBuilder -> ConstPtr ClientCertVerifier -> IO ()
FFI.serverConfigBuilderSetClientVerifier
        ClientCertVerifierOptional [PEMCertificates]
certs -> do
          [PEMCertificates]
-> (ConstPtr RootCertStore
    -> IO (ConstPtr ClientCertVerifierOptional))
-> (ConstPtr ClientCertVerifierOptional -> IO ())
-> (Ptr ServerConfigBuilder
    -> ConstPtr ClientCertVerifierOptional -> IO ())
-> IO ()
forall {a} {b} {a}.
[PEMCertificates]
-> (ConstPtr RootCertStore -> IO a)
-> (a -> IO b)
-> (Ptr ServerConfigBuilder -> a -> IO a)
-> IO a
setBuilderCCV
            [PEMCertificates]
certs
            ConstPtr RootCertStore -> IO (ConstPtr ClientCertVerifierOptional)
FFI.clientCertVerifierOptionalNew
            ConstPtr ClientCertVerifierOptional -> IO ()
FFI.clientCertVerifierOptionalFree
            Ptr ServerConfigBuilder
-> ConstPtr ClientCertVerifierOptional -> IO ()
FFI.serverConfigBuilderSetClientVerifierOptional
      ForeignPtr ServerConfig
serverConfigPtr <-
        FinalizerPtr ServerConfig
-> Ptr ServerConfig -> IO (ForeignPtr ServerConfig)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr ServerConfig
FFI.serverConfigFree (Ptr ServerConfig -> IO (ForeignPtr ServerConfig))
-> (ConstPtr ServerConfig -> Ptr ServerConfig)
-> ConstPtr ServerConfig
-> IO (ForeignPtr ServerConfig)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstPtr ServerConfig -> Ptr ServerConfig
forall a. ConstPtr a -> Ptr a
unConstPtr
          (ConstPtr ServerConfig -> IO (ForeignPtr ServerConfig))
-> IO (ConstPtr ServerConfig) -> IO (ForeignPtr ServerConfig)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr ServerConfigBuilder -> IO (ConstPtr ServerConfig)
FFI.serverConfigBuilderBuild Ptr ServerConfigBuilder
builder
      let serverConfigLogCallback :: Maybe a
serverConfigLogCallback = Maybe a
forall a. Maybe a
Nothing
      ServerConfig -> IO ServerConfig
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ServerConfig {Maybe LogCallback
ForeignPtr ServerConfig
forall a. Maybe a
serverConfigLogCallback :: Maybe LogCallback
serverConfigPtr :: ForeignPtr ServerConfig
serverConfigLogCallback :: forall a. Maybe a
serverConfigPtr :: ForeignPtr ServerConfig
..}

-- | A 'ServerConfigBuilder' with good defaults.
defaultServerConfigBuilder :: NonEmpty CertifiedKey -> ServerConfigBuilder
defaultServerConfigBuilder :: NonEmpty CertifiedKey -> ServerConfigBuilder
defaultServerConfigBuilder NonEmpty CertifiedKey
certifiedKeys =
  ServerConfigBuilder
    { serverConfigCertifiedKeys :: NonEmpty CertifiedKey
serverConfigCertifiedKeys = NonEmpty CertifiedKey
certifiedKeys,
      serverConfigTLSVersions :: [TLSVersion]
serverConfigTLSVersions = [],
      serverConfigCipherSuites :: [CipherSuite]
serverConfigCipherSuites = [],
      serverConfigALPNProtocols :: [ALPNProtocol]
serverConfigALPNProtocols = [],
      serverConfigIgnoreClientOrder :: Bool
serverConfigIgnoreClientOrder = Bool
False,
      serverConfigClientCertVerifier :: Maybe ClientCertVerifier
serverConfigClientCertVerifier = Maybe ClientCertVerifier
forall a. Maybe a
Nothing
    }

-- | Allocate a new logging callback, taking a 'LogLevel' and a message.
--
-- If it throws an exception, it will be wrapped in a 'RustlsLogException' and
-- passed to 'reportError'.
--
-- 🚫 Make sure that its lifetime encloses those of the 'Connection's which you
-- configured to use it.
newLogCallback :: (LogLevel -> Text -> IO ()) -> Acquire LogCallback
newLogCallback :: (LogLevel -> Text -> IO ()) -> Acquire LogCallback
newLogCallback LogLevel -> Text -> IO ()
cb = (FunPtr LogCallback -> LogCallback)
-> Acquire (FunPtr LogCallback) -> Acquire LogCallback
forall a b. (a -> b) -> Acquire a -> Acquire b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FunPtr LogCallback -> LogCallback
LogCallback (Acquire (FunPtr LogCallback) -> Acquire LogCallback)
-> (IO (FunPtr LogCallback) -> Acquire (FunPtr LogCallback))
-> IO (FunPtr LogCallback)
-> Acquire LogCallback
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IO (FunPtr LogCallback)
 -> (FunPtr LogCallback -> IO ()) -> Acquire (FunPtr LogCallback))
-> (FunPtr LogCallback -> IO ())
-> IO (FunPtr LogCallback)
-> Acquire (FunPtr LogCallback)
forall a b c. (a -> b -> c) -> b -> a -> c
flip IO (FunPtr LogCallback)
-> (FunPtr LogCallback -> IO ()) -> Acquire (FunPtr LogCallback)
forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire FunPtr LogCallback -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr (IO (FunPtr LogCallback) -> Acquire LogCallback)
-> IO (FunPtr LogCallback) -> Acquire LogCallback
forall a b. (a -> b) -> a -> b
$
  LogCallback -> IO (FunPtr LogCallback)
FFI.mkLogCallback \Ptr Userdata
_ (ConstPtr Ptr LogParams
logParamsPtr) -> IO () -> IO ()
ignoreExceptions do
    FFI.LogParams {LogLevel
Str
rustlsLogParamsLevel :: LogLevel
rustlsLogParamsMessage :: Str
rustlsLogParamsLevel :: LogParams -> LogLevel
rustlsLogParamsMessage :: LogParams -> Str
..} <- Ptr LogParams -> IO LogParams
forall a. Storable a => Ptr a -> IO a
peek Ptr LogParams
logParamsPtr
    let logLevel :: Either LogLevel LogLevel
logLevel = case LogLevel
rustlsLogParamsLevel of
          FFI.LogLevel CSize
1 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelError
          FFI.LogLevel CSize
2 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelWarn
          FFI.LogLevel CSize
3 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelInfo
          FFI.LogLevel CSize
4 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelDebug
          FFI.LogLevel CSize
5 -> LogLevel -> Either LogLevel LogLevel
forall a b. b -> Either a b
Right LogLevel
LogLevelTrace
          LogLevel
l -> LogLevel -> Either LogLevel LogLevel
forall a b. a -> Either a b
Left LogLevel
l
    case Either LogLevel LogLevel
logLevel of
      Left LogLevel
l -> SomeException -> IO ()
report (SomeException -> IO ()) -> SomeException -> IO ()
forall a b. (a -> b) -> a -> b
$ RustlsUnknownLogLevel -> SomeException
forall e. Exception e => e -> SomeException
E.SomeException (RustlsUnknownLogLevel -> SomeException)
-> RustlsUnknownLogLevel -> SomeException
forall a b. (a -> b) -> a -> b
$ LogLevel -> RustlsUnknownLogLevel
RustlsUnknownLogLevel LogLevel
l
      Right LogLevel
logLevel -> do
        Text
msg <- Str -> IO Text
strToText Str
rustlsLogParamsMessage
        LogLevel -> Text -> IO ()
cb LogLevel
logLevel Text
msg IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> IO ()
report
  where
    report :: SomeException -> IO ()
report = SomeException -> IO ()
reportError (SomeException -> IO ())
-> (SomeException -> SomeException) -> SomeException -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RustlsLogException -> SomeException
forall e. Exception e => e -> SomeException
E.SomeException (RustlsLogException -> SomeException)
-> (SomeException -> RustlsLogException)
-> SomeException
-> SomeException
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> RustlsLogException
RustlsLogException

newConnection ::
  (Backend b) =>
  b ->
  ForeignPtr config ->
  Maybe LogCallback ->
  (ConstPtr config -> Ptr (Ptr FFI.Connection) -> IO FFI.Result) ->
  Acquire (Connection side)
newConnection :: forall b config (side :: Side).
Backend b =>
b
-> ForeignPtr config
-> Maybe LogCallback
-> (ConstPtr config -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection side)
newConnection b
backend ForeignPtr config
configPtr Maybe LogCallback
logCallback ConstPtr config -> Ptr (Ptr Connection) -> IO Result
connectionNew =
  IO (Connection side)
-> (Connection side -> IO ()) -> Acquire (Connection side)
forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire IO (Connection side)
forall {side :: Side}. IO (Connection side)
acquire Connection side -> IO ()
forall {side :: Side}. Connection side -> IO ()
release
  where
    acquire :: IO (Connection side)
acquire = do
      Ptr Connection
conn <-
        (Ptr (Ptr Connection) -> IO (Ptr Connection))
-> IO (Ptr Connection)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (Ptr Connection)
connPtrPtr ->
          ForeignPtr config
-> (Ptr config -> IO (Ptr Connection)) -> IO (Ptr Connection)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr config
configPtr \Ptr config
cfgPtr -> IO (Ptr Connection) -> IO (Ptr Connection)
forall a. IO a -> IO a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
            Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConstPtr config -> Ptr (Ptr Connection) -> IO Result
connectionNew (Ptr config -> ConstPtr config
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr config
cfgPtr) Ptr (Ptr Connection)
connPtrPtr
            Ptr (Ptr Connection) -> IO (Ptr Connection)
forall a. Storable a => Ptr a -> IO a
peek Ptr (Ptr Connection)
connPtrPtr
      MVar IOMsgReq
ioMsgReq <- IO (MVar IOMsgReq)
forall a. IO (MVar a)
newEmptyMVar
      MVar IOMsgRes
ioMsgRes <- IO (MVar IOMsgRes)
forall a. IO (MVar a)
newEmptyMVar
      Ptr CSize
lenPtr <- IO (Ptr CSize)
forall a. Storable a => IO (Ptr a)
malloc
      let readWriteCallback :: (t -> Ptr Word8) -> p -> t -> CSize -> Ptr CSize -> IO IOResult
readWriteCallback t -> Ptr Word8
toBuf p
_ud t
buf CSize
len Ptr CSize
iPtr = do
            MVar IOMsgRes -> IOMsgRes -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar IOMsgRes
ioMsgRes (IOMsgRes -> IO ()) -> IOMsgRes -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Word8 -> CSize -> Ptr CSize -> IOMsgRes
UsingBuffer (t -> Ptr Word8
toBuf t
buf) CSize
len Ptr CSize
iPtr
            Done IOResult
ioResult <- MVar IOMsgReq -> IO IOMsgReq
forall a. MVar a -> IO a
takeMVar MVar IOMsgReq
ioMsgReq
            IOResult -> IO IOResult
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure IOResult
ioResult
      FunPtr ReadCallback
readCallback <- ReadCallback -> IO (FunPtr ReadCallback)
FFI.mkReadCallback (ReadCallback -> IO (FunPtr ReadCallback))
-> ReadCallback -> IO (FunPtr ReadCallback)
forall a b. (a -> b) -> a -> b
$ (Ptr Word8 -> Ptr Word8) -> ReadCallback
forall {t} {p}.
(t -> Ptr Word8) -> p -> t -> CSize -> Ptr CSize -> IO IOResult
readWriteCallback Ptr Word8 -> Ptr Word8
forall a. a -> a
id
      FunPtr WriteCallback
writeCallback <- WriteCallback -> IO (FunPtr WriteCallback)
FFI.mkWriteCallback (WriteCallback -> IO (FunPtr WriteCallback))
-> WriteCallback -> IO (FunPtr WriteCallback)
forall a b. (a -> b) -> a -> b
$ (ConstPtr Word8 -> Ptr Word8) -> WriteCallback
forall {t} {p}.
(t -> Ptr Word8) -> p -> t -> CSize -> Ptr CSize -> IO IOResult
readWriteCallback ConstPtr Word8 -> Ptr Word8
forall a. ConstPtr a -> Ptr a
unConstPtr
      let freeCallback :: IO ()
freeCallback = do
            FunPtr ReadCallback -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr ReadCallback
readCallback
            FunPtr WriteCallback -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr WriteCallback
writeCallback
          interact :: IO b
interact = IO () -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever do
            Request ReadOrWrite
readOrWrite <- MVar IOMsgReq -> IO IOMsgReq
forall a. MVar a -> IO a
takeMVar MVar IOMsgReq
ioMsgReq
            let readOrWriteTls :: Ptr Connection -> Ptr Userdata -> Ptr CSize -> IO IOResult
readOrWriteTls = case ReadOrWrite
readOrWrite of
                  ReadOrWrite
Read -> (Ptr Connection
 -> FunPtr ReadCallback -> Ptr Userdata -> Ptr CSize -> IO IOResult)
-> FunPtr ReadCallback
-> Ptr Connection
-> Ptr Userdata
-> Ptr CSize
-> IO IOResult
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr Connection
-> FunPtr ReadCallback -> Ptr Userdata -> Ptr CSize -> IO IOResult
FFI.connectionReadTls FunPtr ReadCallback
readCallback
                  ReadOrWrite
Write -> (Ptr Connection
 -> FunPtr WriteCallback
 -> Ptr Userdata
 -> Ptr CSize
 -> IO IOResult)
-> FunPtr WriteCallback
-> Ptr Connection
-> Ptr Userdata
-> Ptr CSize
-> IO IOResult
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr Connection
-> FunPtr WriteCallback -> Ptr Userdata -> Ptr CSize -> IO IOResult
FFI.connectionWriteTls FunPtr WriteCallback
writeCallback
            IOResult
_ <- Ptr Connection -> Ptr Userdata -> Ptr CSize -> IO IOResult
readOrWriteTls Ptr Connection
conn Ptr Userdata
forall a. Ptr a
nullPtr Ptr CSize
lenPtr
            MVar IOMsgRes -> IOMsgRes -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar IOMsgRes
ioMsgRes IOMsgRes
DoneFFI
      ThreadId
interactThread <- IO Any -> (Either SomeException Any -> IO ()) -> IO ThreadId
forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally IO Any
forall {b}. IO b
interact (IO () -> Either SomeException Any -> IO ()
forall a b. a -> b -> a
const IO ()
freeCallback)
      Maybe LogCallback -> (LogCallback -> IO ()) -> IO ()
forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ Maybe LogCallback
logCallback ((LogCallback -> IO ()) -> IO ())
-> (LogCallback -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Ptr Connection -> FunPtr LogCallback -> IO ()
FFI.connectionSetLogCallback Ptr Connection
conn (FunPtr LogCallback -> IO ())
-> (LogCallback -> FunPtr LogCallback) -> LogCallback -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. LogCallback -> FunPtr LogCallback
unLogCallback
      MVar Connection' -> Connection side
forall (side :: Side). MVar Connection' -> Connection side
Connection (MVar Connection' -> Connection side)
-> IO (MVar Connection') -> IO (Connection side)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection' -> IO (MVar Connection')
forall a. a -> IO (MVar a)
newMVar Connection' {b
Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
backend :: b
conn :: Ptr Connection
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
lenPtr :: Ptr CSize
interactThread :: ThreadId
conn :: Ptr Connection
backend :: b
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..}
    release :: Connection side -> IO ()
release (Connection MVar Connection'
c) = do
      Just Connection' {b
Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
conn :: Connection' -> Ptr Connection
backend :: ()
lenPtr :: Connection' -> Ptr CSize
ioMsgReq :: Connection' -> MVar IOMsgReq
ioMsgRes :: Connection' -> MVar IOMsgRes
interactThread :: Connection' -> ThreadId
conn :: Ptr Connection
backend :: b
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} <- MVar Connection' -> IO (Maybe Connection')
forall a. MVar a -> IO (Maybe a)
tryTakeMVar MVar Connection'
c
      Ptr Connection -> IO ()
FFI.connectionFree Ptr Connection
conn
      Ptr CSize -> IO ()
forall a. Ptr a -> IO ()
free Ptr CSize
lenPtr
      ThreadId -> IO ()
killThread ThreadId
interactThread

-- | Initialize a TLS connection as a client.
newClientConnection ::
  (Backend b) =>
  b ->
  ClientConfig ->
  -- | Hostname.
  Text ->
  Acquire (Connection Client)
newClientConnection :: forall b.
Backend b =>
b -> ClientConfig -> Text -> Acquire (Connection 'Client)
newClientConnection b
b ClientConfig {Maybe LogCallback
ForeignPtr ClientConfig
clientConfigLogCallback :: ClientConfig -> Maybe LogCallback
clientConfigPtr :: ClientConfig -> ForeignPtr ClientConfig
clientConfigPtr :: ForeignPtr ClientConfig
clientConfigLogCallback :: Maybe LogCallback
..} Text
hostname =
  b
-> ForeignPtr ClientConfig
-> Maybe LogCallback
-> (ConstPtr ClientConfig -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection 'Client)
forall b config (side :: Side).
Backend b =>
b
-> ForeignPtr config
-> Maybe LogCallback
-> (ConstPtr config -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection side)
newConnection b
b ForeignPtr ClientConfig
clientConfigPtr Maybe LogCallback
clientConfigLogCallback \ConstPtr ClientConfig
configPtr Ptr (Ptr Connection)
connPtrPtr ->
    Text -> (Ptr CChar -> IO Result) -> IO Result
forall a. Text -> (Ptr CChar -> IO a) -> IO a
T.withCString Text
hostname \Ptr CChar
hostnamePtr ->
      ConstPtr ClientConfig
-> ConstCString -> Ptr (Ptr Connection) -> IO Result
FFI.clientConnectionNew ConstPtr ClientConfig
configPtr (Ptr CChar -> ConstCString
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr CChar
hostnamePtr) Ptr (Ptr Connection)
connPtrPtr

-- | Initialize a TLS connection as a server.
newServerConnection ::
  (Backend b) =>
  b ->
  ServerConfig ->
  Acquire (Connection Server)
newServerConnection :: forall b.
Backend b =>
b -> ServerConfig -> Acquire (Connection 'Server)
newServerConnection b
b ServerConfig {Maybe LogCallback
ForeignPtr ServerConfig
serverConfigLogCallback :: ServerConfig -> Maybe LogCallback
serverConfigPtr :: ServerConfig -> ForeignPtr ServerConfig
serverConfigPtr :: ForeignPtr ServerConfig
serverConfigLogCallback :: Maybe LogCallback
..} =
  b
-> ForeignPtr ServerConfig
-> Maybe LogCallback
-> (ConstPtr ServerConfig -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection 'Server)
forall b config (side :: Side).
Backend b =>
b
-> ForeignPtr config
-> Maybe LogCallback
-> (ConstPtr config -> Ptr (Ptr Connection) -> IO Result)
-> Acquire (Connection side)
newConnection b
b ForeignPtr ServerConfig
serverConfigPtr Maybe LogCallback
serverConfigLogCallback ConstPtr ServerConfig -> Ptr (Ptr Connection) -> IO Result
FFI.serverConnectionNew

-- | Ensure that the connection is handshaked. It is only necessary to call this
-- if you want to obtain connection information. You can do so by providing a
-- 'HandshakeQuery'.
--
-- >>> :{
-- getALPNAndTLSVersion ::
--   MonadIO m =>
--   Connection side ->
--   m (Maybe ALPNProtocol, TLSVersion)
-- getALPNAndTLSVersion conn =
--   handshake conn $ (,) <$> getALPNProtocol <*> getTLSVersion
-- :}
handshake :: (MonadIO m) => Connection side -> HandshakeQuery side a -> m a
handshake :: forall (m :: * -> *) (side :: Side) a.
MonadIO m =>
Connection side -> HandshakeQuery side a -> m a
handshake Connection side
conn (HandshakeQuery ReaderT Connection' IO a
query) = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO do
  Connection side -> (Connection' -> IO a) -> IO a
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \Connection'
c -> do
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSHandshake
    ReaderT Connection' IO a -> Connection' -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Connection' IO a
query Connection'
c

-- | Get the negotiated ALPN protocol, if any.
getALPNProtocol :: HandshakeQuery side (Maybe ALPNProtocol)
getALPNProtocol :: forall (side :: Side). HandshakeQuery side (Maybe ALPNProtocol)
getALPNProtocol = (Connection' -> IO (Maybe ALPNProtocol))
-> HandshakeQuery side (Maybe ALPNProtocol)
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn, Ptr CSize
lenPtr :: Connection' -> Ptr CSize
lenPtr :: Ptr CSize
lenPtr} ->
  (Ptr (ConstPtr Word8) -> IO (Maybe ALPNProtocol))
-> IO (Maybe ALPNProtocol)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (ConstPtr Word8)
bufPtrPtr -> do
    ConstPtr Connection -> Ptr (ConstPtr Word8) -> Ptr CSize -> IO ()
FFI.connectionGetALPNProtocol (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) Ptr (ConstPtr Word8)
bufPtrPtr Ptr CSize
lenPtr
    ConstPtr Ptr Word8
bufPtr <- Ptr (ConstPtr Word8) -> IO (ConstPtr Word8)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr Word8)
bufPtrPtr
    CSize
len <- Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr
    !ByteString
alpn <- CStringLen -> IO ByteString
B.packCStringLen (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
bufPtr, CSize -> Int
cSizeToInt CSize
len)
    Maybe ALPNProtocol -> IO (Maybe ALPNProtocol)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe ALPNProtocol -> IO (Maybe ALPNProtocol))
-> Maybe ALPNProtocol -> IO (Maybe ALPNProtocol)
forall a b. (a -> b) -> a -> b
$ if ByteString -> Bool
B.null ByteString
alpn then Maybe ALPNProtocol
forall a. Maybe a
Nothing else ALPNProtocol -> Maybe ALPNProtocol
forall a. a -> Maybe a
Just (ALPNProtocol -> Maybe ALPNProtocol)
-> ALPNProtocol -> Maybe ALPNProtocol
forall a b. (a -> b) -> a -> b
$ ByteString -> ALPNProtocol
ALPNProtocol ByteString
alpn

-- | Get the negotiated TLS protocol version.
getTLSVersion :: HandshakeQuery side TLSVersion
getTLSVersion :: forall (side :: Side). HandshakeQuery side TLSVersion
getTLSVersion = (Connection' -> IO TLSVersion) -> HandshakeQuery side TLSVersion
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn} -> do
  !TLSVersion
ver <- ConstPtr Connection -> IO TLSVersion
FFI.connectionGetProtocolVersion (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn)
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TLSVersion -> Word16
unTLSVersion TLSVersion
ver Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
== Word16
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"internal rustls error: no protocol version negotiated"
  TLSVersion -> IO TLSVersion
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TLSVersion
ver

-- | Get the negotiated cipher suite.
getCipherSuite :: HandshakeQuery side CipherSuite
getCipherSuite :: forall (side :: Side). HandshakeQuery side CipherSuite
getCipherSuite = (Connection' -> IO CipherSuite) -> HandshakeQuery side CipherSuite
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn} -> do
  !ConstPtr SupportedCipherSuite
cipherSuite <- ConstPtr Connection -> IO (ConstPtr SupportedCipherSuite)
FFI.connectionGetNegotiatedCipherSuite (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn)
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ConstPtr SupportedCipherSuite
cipherSuite ConstPtr SupportedCipherSuite
-> ConstPtr SupportedCipherSuite -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr SupportedCipherSuite -> ConstPtr SupportedCipherSuite
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr SupportedCipherSuite
forall a. Ptr a
nullPtr) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    FilePath -> IO ()
forall a. FilePath -> IO a
forall (m :: * -> *) a. MonadFail m => FilePath -> m a
fail FilePath
"internal rustls error: no cipher suite negotiated"
  CipherSuite -> IO CipherSuite
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CipherSuite -> IO CipherSuite) -> CipherSuite -> IO CipherSuite
forall a b. (a -> b) -> a -> b
$ ConstPtr SupportedCipherSuite -> CipherSuite
CipherSuite ConstPtr SupportedCipherSuite
cipherSuite

-- | Get the SNI hostname set by the client, if any.
getSNIHostname :: HandshakeQuery Server (Maybe Text)
getSNIHostname :: HandshakeQuery 'Server (Maybe Text)
getSNIHostname = (Connection' -> IO (Maybe Text))
-> HandshakeQuery 'Server (Maybe Text)
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn, Ptr CSize
lenPtr :: Connection' -> Ptr CSize
lenPtr :: Ptr CSize
lenPtr} ->
  let go :: CSize -> IO (Maybe Text)
go CSize
n = Int -> (Ptr Word8 -> IO (Maybe Text)) -> IO (Maybe Text)
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes (CSize -> Int
cSizeToInt CSize
n) \Ptr Word8
bufPtr -> do
        Result
res <- ConstPtr Connection -> Ptr Word8 -> CSize -> Ptr CSize -> IO Result
FFI.serverConnectionGetSNIHostname (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) Ptr Word8
bufPtr CSize
n Ptr CSize
lenPtr
        if Result
res Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== Result
FFI.resultInsufficientSize
          then CSize -> IO (Maybe Text)
go (CSize
2 CSize -> CSize -> CSize
forall a. Num a => a -> a -> a
* CSize
n)
          else do
            Result -> IO ()
rethrowR Result
res
            CSize
len <- Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr
            !Text
sni <- CStringLen -> IO Text
T.peekCStringLen (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
bufPtr, CSize -> Int
cSizeToInt CSize
len)
            Maybe Text -> IO (Maybe Text)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Text -> IO (Maybe Text)) -> Maybe Text -> IO (Maybe Text)
forall a b. (a -> b) -> a -> b
$ if Text -> Bool
T.null Text
sni then Maybe Text
forall a. Maybe a
Nothing else Text -> Maybe Text
forall a. a -> Maybe a
Just Text
sni
   in CSize -> IO (Maybe Text)
go CSize
16

-- | A DER-encoded certificate.
newtype DERCertificate = DERCertificate {DERCertificate -> ByteString
unDERCertificate :: ByteString}
  deriving stock (Int -> DERCertificate -> ShowS
[DERCertificate] -> ShowS
DERCertificate -> FilePath
(Int -> DERCertificate -> ShowS)
-> (DERCertificate -> FilePath)
-> ([DERCertificate] -> ShowS)
-> Show DERCertificate
forall a.
(Int -> a -> ShowS) -> (a -> FilePath) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DERCertificate -> ShowS
showsPrec :: Int -> DERCertificate -> ShowS
$cshow :: DERCertificate -> FilePath
show :: DERCertificate -> FilePath
$cshowList :: [DERCertificate] -> ShowS
showList :: [DERCertificate] -> ShowS
Show, DERCertificate -> DERCertificate -> Bool
(DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool) -> Eq DERCertificate
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DERCertificate -> DERCertificate -> Bool
== :: DERCertificate -> DERCertificate -> Bool
$c/= :: DERCertificate -> DERCertificate -> Bool
/= :: DERCertificate -> DERCertificate -> Bool
Eq, Eq DERCertificate
Eq DERCertificate
-> (DERCertificate -> DERCertificate -> Ordering)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> Bool)
-> (DERCertificate -> DERCertificate -> DERCertificate)
-> (DERCertificate -> DERCertificate -> DERCertificate)
-> Ord DERCertificate
DERCertificate -> DERCertificate -> Bool
DERCertificate -> DERCertificate -> Ordering
DERCertificate -> DERCertificate -> DERCertificate
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DERCertificate -> DERCertificate -> Ordering
compare :: DERCertificate -> DERCertificate -> Ordering
$c< :: DERCertificate -> DERCertificate -> Bool
< :: DERCertificate -> DERCertificate -> Bool
$c<= :: DERCertificate -> DERCertificate -> Bool
<= :: DERCertificate -> DERCertificate -> Bool
$c> :: DERCertificate -> DERCertificate -> Bool
> :: DERCertificate -> DERCertificate -> Bool
$c>= :: DERCertificate -> DERCertificate -> Bool
>= :: DERCertificate -> DERCertificate -> Bool
$cmax :: DERCertificate -> DERCertificate -> DERCertificate
max :: DERCertificate -> DERCertificate -> DERCertificate
$cmin :: DERCertificate -> DERCertificate -> DERCertificate
min :: DERCertificate -> DERCertificate -> DERCertificate
Ord, (forall x. DERCertificate -> Rep DERCertificate x)
-> (forall x. Rep DERCertificate x -> DERCertificate)
-> Generic DERCertificate
forall x. Rep DERCertificate x -> DERCertificate
forall x. DERCertificate -> Rep DERCertificate x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. DERCertificate -> Rep DERCertificate x
from :: forall x. DERCertificate -> Rep DERCertificate x
$cto :: forall x. Rep DERCertificate x -> DERCertificate
to :: forall x. Rep DERCertificate x -> DERCertificate
Generic)

-- | Get the @i@-th certificate provided by the peer.
--
-- Index @0@ is the end entity certificate. Higher indices are certificates in
-- the chain. Requesting an index higher than what is available returns
-- 'Nothing'.
getPeerCertificate :: CSize -> HandshakeQuery side (Maybe DERCertificate)
getPeerCertificate :: forall (side :: Side).
CSize -> HandshakeQuery side (Maybe DERCertificate)
getPeerCertificate CSize
i = (Connection' -> IO (Maybe DERCertificate))
-> HandshakeQuery side (Maybe DERCertificate)
forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery \Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn, Ptr CSize
lenPtr :: Connection' -> Ptr CSize
lenPtr :: Ptr CSize
lenPtr} -> do
  ConstPtr Certificate
certPtr <- ConstPtr Connection -> CSize -> IO (ConstPtr Certificate)
FFI.connectionGetPeerCertificate (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) CSize
i
  if ConstPtr Certificate
certPtr ConstPtr Certificate -> ConstPtr Certificate -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr Certificate -> ConstPtr Certificate
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Certificate
forall a. Ptr a
nullPtr
    then Maybe DERCertificate -> IO (Maybe DERCertificate)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe DERCertificate
forall a. Maybe a
Nothing
    else (Ptr (ConstPtr Word8) -> IO (Maybe DERCertificate))
-> IO (Maybe DERCertificate)
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr (ConstPtr Word8)
bufPtrPtr -> do
      Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ConstPtr Certificate
-> Ptr (ConstPtr Word8) -> Ptr CSize -> IO Result
FFI.certificateGetDER ConstPtr Certificate
certPtr Ptr (ConstPtr Word8)
bufPtrPtr Ptr CSize
lenPtr
      ConstPtr Ptr Word8
bufPtr <- Ptr (ConstPtr Word8) -> IO (ConstPtr Word8)
forall a. Storable a => Ptr a -> IO a
peek Ptr (ConstPtr Word8)
bufPtrPtr
      Int
len <- CSize -> Int
cSizeToInt (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr
      !ByteString
bs <- CStringLen -> IO ByteString
B.packCStringLen (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
bufPtr, Int
len)
      Maybe DERCertificate -> IO (Maybe DERCertificate)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe DERCertificate -> IO (Maybe DERCertificate))
-> Maybe DERCertificate -> IO (Maybe DERCertificate)
forall a b. (a -> b) -> a -> b
$ DERCertificate -> Maybe DERCertificate
forall a. a -> Maybe a
Just (DERCertificate -> Maybe DERCertificate)
-> DERCertificate -> Maybe DERCertificate
forall a b. (a -> b) -> a -> b
$ ByteString -> DERCertificate
DERCertificate ByteString
bs

-- | Send a @close_notify@ warning alert. This informs the peer that the
-- connection is being closed.
sendCloseNotify :: (MonadIO m) => Connection side -> m ()
sendCloseNotify :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> m ()
sendCloseNotify Connection side
conn = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$
  Connection side -> (Connection' -> IO ()) -> IO ()
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \c :: Connection'
c@Connection' {Ptr Connection
conn :: Connection' -> Ptr Connection
conn :: Ptr Connection
conn} -> do
    Ptr Connection -> IO ()
FFI.connectionSendCloseNotify Ptr Connection
conn
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSWrite

-- | Read data from the Rustls 'Connection' into the given buffer.
readPtr :: (MonadIO m) => Connection side -> Ptr Word8 -> CSize -> m CSize
readPtr :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
readPtr Connection side
conn Ptr Word8
buf CSize
len = IO CSize -> m CSize
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CSize -> m CSize) -> IO CSize -> m CSize
forall a b. (a -> b) -> a -> b
$
  Connection side -> (Connection' -> IO CSize) -> IO CSize
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \c :: Connection'
c@Connection' {b
Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
conn :: Connection' -> Ptr Connection
backend :: ()
lenPtr :: Connection' -> Ptr CSize
ioMsgReq :: Connection' -> MVar IOMsgReq
ioMsgRes :: Connection' -> MVar IOMsgRes
interactThread :: Connection' -> ThreadId
conn :: Ptr Connection
backend :: b
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} -> do
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSWrite
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSRead
    Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Connection -> Ptr Word8 -> CSize -> Ptr CSize -> IO Result
FFI.connectionRead Ptr Connection
conn Ptr Word8
buf CSize
len Ptr CSize
lenPtr
    Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr

-- | Read data from the Rustls 'Connection' into a 'ByteString'. The result will
-- not be longer than the given length.
readBS ::
  (MonadIO m) =>
  Connection side ->
  -- | Maximum result length. Note that a buffer of this size will be allocated.
  Int ->
  m ByteString
readBS :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Int -> m ByteString
readBS Connection side
conn Int
maxLen = IO ByteString -> m ByteString
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ByteString -> m ByteString) -> IO ByteString -> m ByteString
forall a b. (a -> b) -> a -> b
$
  Int -> (Ptr Word8 -> IO Int) -> IO ByteString
BI.createAndTrim Int
maxLen \Ptr Word8
buf ->
    CSize -> Int
cSizeToInt (CSize -> Int) -> IO CSize -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection side -> Ptr Word8 -> CSize -> IO CSize
forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
readPtr Connection side
conn Ptr Word8
buf (Int -> CSize
intToCSize Int
maxLen)

-- | Write data to the Rustls 'Connection' from the given buffer.
writePtr :: (MonadIO m) => Connection side -> Ptr Word8 -> CSize -> m CSize
writePtr :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
writePtr Connection side
conn Ptr Word8
buf CSize
len = IO CSize -> m CSize
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO CSize -> m CSize) -> IO CSize -> m CSize
forall a b. (a -> b) -> a -> b
$
  Connection side -> (Connection' -> IO CSize) -> IO CSize
forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection Connection side
conn \c :: Connection'
c@Connection' {b
Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
conn :: Connection' -> Ptr Connection
backend :: ()
lenPtr :: Connection' -> Ptr CSize
ioMsgReq :: Connection' -> MVar IOMsgReq
ioMsgRes :: Connection' -> MVar IOMsgRes
interactThread :: Connection' -> ThreadId
conn :: Ptr Connection
backend :: b
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} -> do
    Result -> IO ()
rethrowR (Result -> IO ()) -> IO Result -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Connection -> Ptr Word8 -> CSize -> Ptr CSize -> IO Result
FFI.connectionWrite Ptr Connection
conn Ptr Word8
buf CSize
len Ptr CSize
lenPtr
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSWrite
    Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr

-- | Write a 'ByteString' to the Rustls 'Connection'.
writeBS :: (MonadIO m) => Connection side -> ByteString -> m ()
writeBS :: forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> ByteString -> m ()
writeBS Connection side
conn ByteString
bs = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ ByteString -> (CStringLen -> IO ()) -> IO ()
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs CStringLen -> IO ()
forall {m :: * -> *} {b}. MonadIO m => (Ptr b, Int) -> m ()
go
  where
    go :: (Ptr b, Int) -> m ()
go (Ptr b
buf, Int
len) = do
      Int
written <- CSize -> Int
cSizeToInt (CSize -> Int) -> m CSize -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Connection side -> Ptr Word8 -> CSize -> m CSize
forall (m :: * -> *) (side :: Side).
MonadIO m =>
Connection side -> Ptr Word8 -> CSize -> m CSize
writePtr Connection side
conn (Ptr b -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr b
buf) (Int -> CSize
intToCSize Int
len)
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
written Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
len) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
        (Ptr b, Int) -> m ()
go (Ptr b
buf Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
len, Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
written)