{-# OPTIONS_GHC -Wno-missing-export-lists #-}

-- | Internal module, not subject to PVP.
module Rustls.Internal where

import Control.Concurrent (ThreadId)
import Control.Concurrent.MVar
import qualified Control.Exception as E
import Control.Monad (when)
import Control.Monad.Trans.Reader
import Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Unsafe as BU
import Data.Coerce (coerce)
import Data.Function (on)
import Data.Functor (void)
import Data.List.NonEmpty (NonEmpty)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Foreign as T
import Foreign hiding (void)
import Foreign.C
import GHC.Generics (Generic)
import qualified Network.Socket as NS
import Rustls.Internal.FFI (ConstPtr (..))
import qualified Rustls.Internal.FFI as FFI
import System.IO.Unsafe (unsafePerformIO)

-- | An ALPN protocol ID. See
-- <https://www.iana.org/assignments/tls-extensiontype-values/tls-extensiontype-values.xhtml#alpn-protocol-ids>
-- for a list of registered IDs.
newtype ALPNProtocol = ALPNProtocol {ALPNProtocol -> ByteString
unALPNProtocol :: ByteString}
  deriving stock (Int -> ALPNProtocol -> ShowS
[ALPNProtocol] -> ShowS
ALPNProtocol -> String
(Int -> ALPNProtocol -> ShowS)
-> (ALPNProtocol -> String)
-> ([ALPNProtocol] -> ShowS)
-> Show ALPNProtocol
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ALPNProtocol -> ShowS
showsPrec :: Int -> ALPNProtocol -> ShowS
$cshow :: ALPNProtocol -> String
show :: ALPNProtocol -> String
$cshowList :: [ALPNProtocol] -> ShowS
showList :: [ALPNProtocol] -> ShowS
Show, ALPNProtocol -> ALPNProtocol -> Bool
(ALPNProtocol -> ALPNProtocol -> Bool)
-> (ALPNProtocol -> ALPNProtocol -> Bool) -> Eq ALPNProtocol
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ALPNProtocol -> ALPNProtocol -> Bool
== :: ALPNProtocol -> ALPNProtocol -> Bool
$c/= :: ALPNProtocol -> ALPNProtocol -> Bool
/= :: ALPNProtocol -> ALPNProtocol -> Bool
Eq, Eq ALPNProtocol
Eq ALPNProtocol
-> (ALPNProtocol -> ALPNProtocol -> Ordering)
-> (ALPNProtocol -> ALPNProtocol -> Bool)
-> (ALPNProtocol -> ALPNProtocol -> Bool)
-> (ALPNProtocol -> ALPNProtocol -> Bool)
-> (ALPNProtocol -> ALPNProtocol -> Bool)
-> (ALPNProtocol -> ALPNProtocol -> ALPNProtocol)
-> (ALPNProtocol -> ALPNProtocol -> ALPNProtocol)
-> Ord ALPNProtocol
ALPNProtocol -> ALPNProtocol -> Bool
ALPNProtocol -> ALPNProtocol -> Ordering
ALPNProtocol -> ALPNProtocol -> ALPNProtocol
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ALPNProtocol -> ALPNProtocol -> Ordering
compare :: ALPNProtocol -> ALPNProtocol -> Ordering
$c< :: ALPNProtocol -> ALPNProtocol -> Bool
< :: ALPNProtocol -> ALPNProtocol -> Bool
$c<= :: ALPNProtocol -> ALPNProtocol -> Bool
<= :: ALPNProtocol -> ALPNProtocol -> Bool
$c> :: ALPNProtocol -> ALPNProtocol -> Bool
> :: ALPNProtocol -> ALPNProtocol -> Bool
$c>= :: ALPNProtocol -> ALPNProtocol -> Bool
>= :: ALPNProtocol -> ALPNProtocol -> Bool
$cmax :: ALPNProtocol -> ALPNProtocol -> ALPNProtocol
max :: ALPNProtocol -> ALPNProtocol -> ALPNProtocol
$cmin :: ALPNProtocol -> ALPNProtocol -> ALPNProtocol
min :: ALPNProtocol -> ALPNProtocol -> ALPNProtocol
Ord, (forall x. ALPNProtocol -> Rep ALPNProtocol x)
-> (forall x. Rep ALPNProtocol x -> ALPNProtocol)
-> Generic ALPNProtocol
forall x. Rep ALPNProtocol x -> ALPNProtocol
forall x. ALPNProtocol -> Rep ALPNProtocol x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ALPNProtocol -> Rep ALPNProtocol x
from :: forall x. ALPNProtocol -> Rep ALPNProtocol x
$cto :: forall x. Rep ALPNProtocol x -> ALPNProtocol
to :: forall x. Rep ALPNProtocol x -> ALPNProtocol
Generic)

-- | A TLS cipher suite supported by Rustls.
newtype CipherSuite = CipherSuite (ConstPtr FFI.SupportedCipherSuite)

-- | Get the IANA value from a cipher suite. The bytes are interpreted in network order.
--
-- See <https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-4> for a list.
cipherSuiteID :: CipherSuite -> Word16
cipherSuiteID :: CipherSuite -> Word16
cipherSuiteID (CipherSuite ConstPtr SupportedCipherSuite
cipherSuitePtr) =
  ConstPtr SupportedCipherSuite -> Word16
FFI.supportedCipherSuiteGetSuite ConstPtr SupportedCipherSuite
cipherSuitePtr

instance Eq CipherSuite where
  == :: CipherSuite -> CipherSuite -> Bool
(==) = Word16 -> Word16 -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Word16 -> Word16 -> Bool)
-> (CipherSuite -> Word16) -> CipherSuite -> CipherSuite -> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` CipherSuite -> Word16
cipherSuiteID

instance Ord CipherSuite where
  compare :: CipherSuite -> CipherSuite -> Ordering
compare = Word16 -> Word16 -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Word16 -> Word16 -> Ordering)
-> (CipherSuite -> Word16)
-> CipherSuite
-> CipherSuite
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` CipherSuite -> Word16
cipherSuiteID

-- | Get the text representation of a cipher suite.
showCipherSuite :: CipherSuite -> Text
showCipherSuite :: CipherSuite -> Text
showCipherSuite (CipherSuite ConstPtr SupportedCipherSuite
cipherSuitePtr) = IO Text -> Text
forall a. IO a -> a
unsafePerformIO (IO Text -> Text) -> IO Text -> Text
forall a b. (a -> b) -> a -> b
$
  (Ptr Str -> IO Text) -> IO Text
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr Str
strPtr -> do
    ConstPtr SupportedCipherSuite -> Ptr Str -> IO ()
FFI.hsSupportedCipherSuiteGetName ConstPtr SupportedCipherSuite
cipherSuitePtr Ptr Str
strPtr
    Str -> IO Text
strToText (Str -> IO Text) -> IO Str -> IO Text
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Ptr Str -> IO Str
forall a. Storable a => Ptr a -> IO a
peek Ptr Str
strPtr

instance Show CipherSuite where
  show :: CipherSuite -> String
show = Text -> String
T.unpack (Text -> String) -> (CipherSuite -> Text) -> CipherSuite -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CipherSuite -> Text
showCipherSuite

-- | Rustls client config builder.
data ClientConfigBuilder = ClientConfigBuilder
  { -- | Client root certificates.
    ClientConfigBuilder -> ClientRoots
clientConfigRoots :: ClientRoots,
    -- | Supported 'FFI.TLSVersion's. When empty, good defaults are used.
    ClientConfigBuilder -> [TLSVersion]
clientConfigTLSVersions :: [FFI.TLSVersion],
    -- | Supported 'CipherSuite's in order of preference. When empty, good
    -- defaults are used.
    ClientConfigBuilder -> [CipherSuite]
clientConfigCipherSuites :: [CipherSuite],
    -- | ALPN protocols.
    ClientConfigBuilder -> [ALPNProtocol]
clientConfigALPNProtocols :: [ALPNProtocol],
    -- | Whether to enable Server Name Indication. Defaults to 'True'.
    ClientConfigBuilder -> Bool
clientConfigEnableSNI :: Bool,
    -- | List of 'CertifiedKey's for client authentication.
    --
    -- Clients that want to support both ECDSA and RSA certificates will want
    -- the ECDSA to go first in the list.
    ClientConfigBuilder -> [CertifiedKey]
clientConfigCertifiedKeys :: [CertifiedKey]
  }
  deriving stock (Int -> ClientConfigBuilder -> ShowS
[ClientConfigBuilder] -> ShowS
ClientConfigBuilder -> String
(Int -> ClientConfigBuilder -> ShowS)
-> (ClientConfigBuilder -> String)
-> ([ClientConfigBuilder] -> ShowS)
-> Show ClientConfigBuilder
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ClientConfigBuilder -> ShowS
showsPrec :: Int -> ClientConfigBuilder -> ShowS
$cshow :: ClientConfigBuilder -> String
show :: ClientConfigBuilder -> String
$cshowList :: [ClientConfigBuilder] -> ShowS
showList :: [ClientConfigBuilder] -> ShowS
Show, (forall x. ClientConfigBuilder -> Rep ClientConfigBuilder x)
-> (forall x. Rep ClientConfigBuilder x -> ClientConfigBuilder)
-> Generic ClientConfigBuilder
forall x. Rep ClientConfigBuilder x -> ClientConfigBuilder
forall x. ClientConfigBuilder -> Rep ClientConfigBuilder x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ClientConfigBuilder -> Rep ClientConfigBuilder x
from :: forall x. ClientConfigBuilder -> Rep ClientConfigBuilder x
$cto :: forall x. Rep ClientConfigBuilder x -> ClientConfigBuilder
to :: forall x. Rep ClientConfigBuilder x -> ClientConfigBuilder
Generic)

-- | How to look up root certificates.
data ClientRoots
  = -- | Fetch PEM-encoded root certificates from a file.
    ClientRootsFromFile FilePath
  | -- | Use in-memory PEM-encoded certificates.
    ClientRootsInMemory [PEMCertificates]
  deriving stock ((forall x. ClientRoots -> Rep ClientRoots x)
-> (forall x. Rep ClientRoots x -> ClientRoots)
-> Generic ClientRoots
forall x. Rep ClientRoots x -> ClientRoots
forall x. ClientRoots -> Rep ClientRoots x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ClientRoots -> Rep ClientRoots x
from :: forall x. ClientRoots -> Rep ClientRoots x
$cto :: forall x. Rep ClientRoots x -> ClientRoots
to :: forall x. Rep ClientRoots x -> ClientRoots
Generic)

instance Show ClientRoots where
  show :: ClientRoots -> String
show ClientRoots
_ = String
"ClientRoots"

-- | In-memory PEM-encoded certificates.
data PEMCertificates
  = -- | Syntactically valid PEM-encoded certificates.
    PEMCertificatesStrict ByteString
  | -- | PEM-encoded certificates, ignored if syntactically invalid.
    --
    -- This may be useful on systems that have syntactically invalid root certificates.
    PEMCertificatesLax ByteString
  deriving stock (Int -> PEMCertificates -> ShowS
[PEMCertificates] -> ShowS
PEMCertificates -> String
(Int -> PEMCertificates -> ShowS)
-> (PEMCertificates -> String)
-> ([PEMCertificates] -> ShowS)
-> Show PEMCertificates
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PEMCertificates -> ShowS
showsPrec :: Int -> PEMCertificates -> ShowS
$cshow :: PEMCertificates -> String
show :: PEMCertificates -> String
$cshowList :: [PEMCertificates] -> ShowS
showList :: [PEMCertificates] -> ShowS
Show, (forall x. PEMCertificates -> Rep PEMCertificates x)
-> (forall x. Rep PEMCertificates x -> PEMCertificates)
-> Generic PEMCertificates
forall x. Rep PEMCertificates x -> PEMCertificates
forall x. PEMCertificates -> Rep PEMCertificates x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. PEMCertificates -> Rep PEMCertificates x
from :: forall x. PEMCertificates -> Rep PEMCertificates x
$cto :: forall x. Rep PEMCertificates x -> PEMCertificates
to :: forall x. Rep PEMCertificates x -> PEMCertificates
Generic)

-- | A complete chain of certificates plus a private key for the leaf certificate.
data CertifiedKey = CertifiedKey
  { -- | PEM-encoded certificate chain.
    CertifiedKey -> ByteString
certificateChain :: ByteString,
    -- | PEM-encoded private key.
    CertifiedKey -> ByteString
privateKey :: ByteString
  }
  deriving stock ((forall x. CertifiedKey -> Rep CertifiedKey x)
-> (forall x. Rep CertifiedKey x -> CertifiedKey)
-> Generic CertifiedKey
forall x. Rep CertifiedKey x -> CertifiedKey
forall x. CertifiedKey -> Rep CertifiedKey x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. CertifiedKey -> Rep CertifiedKey x
from :: forall x. CertifiedKey -> Rep CertifiedKey x
$cto :: forall x. Rep CertifiedKey x -> CertifiedKey
to :: forall x. Rep CertifiedKey x -> CertifiedKey
Generic)

instance Show CertifiedKey where
  show :: CertifiedKey -> String
show CertifiedKey
_ = String
"CertifiedKey"

-- | Assembled configuration for a Rustls client connection.
data ClientConfig = ClientConfig
  { ClientConfig -> ForeignPtr ClientConfig
clientConfigPtr :: ForeignPtr FFI.ClientConfig,
    -- | A logging callback.
    --
    -- Note that this is a record selector, so you can use it as a setter:
    --
    -- >>> :{
    -- setLogCallback :: LogCallback -> ClientConfig -> ClientConfig
    -- setLogCallback logCallback clientConfig =
    --   clientConfig { clientConfigLogCallback = Just logCallback }
    -- :}
    ClientConfig -> Maybe LogCallback
clientConfigLogCallback :: Maybe LogCallback
  }

-- | How to verify TLS client certificates.
data ClientCertVerifier
  = -- | Root certificates used to verify TLS client certificates.
    ClientCertVerifier [PEMCertificates]
  | -- | Root certificates used to verify TLS client certificates if present,
    -- but does not reject clients which provide no certificate.
    ClientCertVerifierOptional [PEMCertificates]
  deriving stock (Int -> ClientCertVerifier -> ShowS
[ClientCertVerifier] -> ShowS
ClientCertVerifier -> String
(Int -> ClientCertVerifier -> ShowS)
-> (ClientCertVerifier -> String)
-> ([ClientCertVerifier] -> ShowS)
-> Show ClientCertVerifier
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ClientCertVerifier -> ShowS
showsPrec :: Int -> ClientCertVerifier -> ShowS
$cshow :: ClientCertVerifier -> String
show :: ClientCertVerifier -> String
$cshowList :: [ClientCertVerifier] -> ShowS
showList :: [ClientCertVerifier] -> ShowS
Show, (forall x. ClientCertVerifier -> Rep ClientCertVerifier x)
-> (forall x. Rep ClientCertVerifier x -> ClientCertVerifier)
-> Generic ClientCertVerifier
forall x. Rep ClientCertVerifier x -> ClientCertVerifier
forall x. ClientCertVerifier -> Rep ClientCertVerifier x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ClientCertVerifier -> Rep ClientCertVerifier x
from :: forall x. ClientCertVerifier -> Rep ClientCertVerifier x
$cto :: forall x. Rep ClientCertVerifier x -> ClientCertVerifier
to :: forall x. Rep ClientCertVerifier x -> ClientCertVerifier
Generic)

-- | Rustls client config builder.
data ServerConfigBuilder = ServerConfigBuilder
  { -- | List of 'CertifiedKey's.
    ServerConfigBuilder -> NonEmpty CertifiedKey
serverConfigCertifiedKeys :: NonEmpty CertifiedKey,
    -- | Supported 'FFI.TLSVersion's. When empty, good defaults are
    -- used.
    ServerConfigBuilder -> [TLSVersion]
serverConfigTLSVersions :: [FFI.TLSVersion],
    -- | Supported 'CipherSuite's in order of preference. When empty, good
    -- defaults are used.
    ServerConfigBuilder -> [CipherSuite]
serverConfigCipherSuites :: [CipherSuite],
    -- | ALPN protocols.
    ServerConfigBuilder -> [ALPNProtocol]
serverConfigALPNProtocols :: [ALPNProtocol],
    -- | Ignore the client's ciphersuite order. Defaults to 'False'.
    ServerConfigBuilder -> Bool
serverConfigIgnoreClientOrder :: Bool,
    -- | Optionally, a client cert verifier.
    ServerConfigBuilder -> Maybe ClientCertVerifier
serverConfigClientCertVerifier :: Maybe ClientCertVerifier
  }
  deriving stock (Int -> ServerConfigBuilder -> ShowS
[ServerConfigBuilder] -> ShowS
ServerConfigBuilder -> String
(Int -> ServerConfigBuilder -> ShowS)
-> (ServerConfigBuilder -> String)
-> ([ServerConfigBuilder] -> ShowS)
-> Show ServerConfigBuilder
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ServerConfigBuilder -> ShowS
showsPrec :: Int -> ServerConfigBuilder -> ShowS
$cshow :: ServerConfigBuilder -> String
show :: ServerConfigBuilder -> String
$cshowList :: [ServerConfigBuilder] -> ShowS
showList :: [ServerConfigBuilder] -> ShowS
Show, (forall x. ServerConfigBuilder -> Rep ServerConfigBuilder x)
-> (forall x. Rep ServerConfigBuilder x -> ServerConfigBuilder)
-> Generic ServerConfigBuilder
forall x. Rep ServerConfigBuilder x -> ServerConfigBuilder
forall x. ServerConfigBuilder -> Rep ServerConfigBuilder x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ServerConfigBuilder -> Rep ServerConfigBuilder x
from :: forall x. ServerConfigBuilder -> Rep ServerConfigBuilder x
$cto :: forall x. Rep ServerConfigBuilder x -> ServerConfigBuilder
to :: forall x. Rep ServerConfigBuilder x -> ServerConfigBuilder
Generic)

-- | Assembled configuration for a Rustls server connection.
data ServerConfig = ServerConfig
  { ServerConfig -> ForeignPtr ServerConfig
serverConfigPtr :: ForeignPtr FFI.ServerConfig,
    -- | A logging callback.
    --
    -- Note that this is a record selector, so you can use it as a setter:
    --
    -- >>> :{
    -- setLogCallback :: LogCallback -> ServerConfig -> ServerConfig
    -- setLogCallback logCallback serverConfig =
    --   serverConfig { serverConfigLogCallback = Just logCallback }
    -- :}
    ServerConfig -> Maybe LogCallback
serverConfigLogCallback :: Maybe LogCallback
  }

-- | Rustls log level.
data LogLevel
  = LogLevelError
  | LogLevelWarn
  | LogLevelInfo
  | LogLevelDebug
  | LogLevelTrace
  deriving stock (Int -> LogLevel -> ShowS
[LogLevel] -> ShowS
LogLevel -> String
(Int -> LogLevel -> ShowS)
-> (LogLevel -> String) -> ([LogLevel] -> ShowS) -> Show LogLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> LogLevel -> ShowS
showsPrec :: Int -> LogLevel -> ShowS
$cshow :: LogLevel -> String
show :: LogLevel -> String
$cshowList :: [LogLevel] -> ShowS
showList :: [LogLevel] -> ShowS
Show, LogLevel -> LogLevel -> Bool
(LogLevel -> LogLevel -> Bool)
-> (LogLevel -> LogLevel -> Bool) -> Eq LogLevel
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: LogLevel -> LogLevel -> Bool
== :: LogLevel -> LogLevel -> Bool
$c/= :: LogLevel -> LogLevel -> Bool
/= :: LogLevel -> LogLevel -> Bool
Eq, Eq LogLevel
Eq LogLevel
-> (LogLevel -> LogLevel -> Ordering)
-> (LogLevel -> LogLevel -> Bool)
-> (LogLevel -> LogLevel -> Bool)
-> (LogLevel -> LogLevel -> Bool)
-> (LogLevel -> LogLevel -> Bool)
-> (LogLevel -> LogLevel -> LogLevel)
-> (LogLevel -> LogLevel -> LogLevel)
-> Ord LogLevel
LogLevel -> LogLevel -> Bool
LogLevel -> LogLevel -> Ordering
LogLevel -> LogLevel -> LogLevel
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: LogLevel -> LogLevel -> Ordering
compare :: LogLevel -> LogLevel -> Ordering
$c< :: LogLevel -> LogLevel -> Bool
< :: LogLevel -> LogLevel -> Bool
$c<= :: LogLevel -> LogLevel -> Bool
<= :: LogLevel -> LogLevel -> Bool
$c> :: LogLevel -> LogLevel -> Bool
> :: LogLevel -> LogLevel -> Bool
$c>= :: LogLevel -> LogLevel -> Bool
>= :: LogLevel -> LogLevel -> Bool
$cmax :: LogLevel -> LogLevel -> LogLevel
max :: LogLevel -> LogLevel -> LogLevel
$cmin :: LogLevel -> LogLevel -> LogLevel
min :: LogLevel -> LogLevel -> LogLevel
Ord, Int -> LogLevel
LogLevel -> Int
LogLevel -> [LogLevel]
LogLevel -> LogLevel
LogLevel -> LogLevel -> [LogLevel]
LogLevel -> LogLevel -> LogLevel -> [LogLevel]
(LogLevel -> LogLevel)
-> (LogLevel -> LogLevel)
-> (Int -> LogLevel)
-> (LogLevel -> Int)
-> (LogLevel -> [LogLevel])
-> (LogLevel -> LogLevel -> [LogLevel])
-> (LogLevel -> LogLevel -> [LogLevel])
-> (LogLevel -> LogLevel -> LogLevel -> [LogLevel])
-> Enum LogLevel
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: LogLevel -> LogLevel
succ :: LogLevel -> LogLevel
$cpred :: LogLevel -> LogLevel
pred :: LogLevel -> LogLevel
$ctoEnum :: Int -> LogLevel
toEnum :: Int -> LogLevel
$cfromEnum :: LogLevel -> Int
fromEnum :: LogLevel -> Int
$cenumFrom :: LogLevel -> [LogLevel]
enumFrom :: LogLevel -> [LogLevel]
$cenumFromThen :: LogLevel -> LogLevel -> [LogLevel]
enumFromThen :: LogLevel -> LogLevel -> [LogLevel]
$cenumFromTo :: LogLevel -> LogLevel -> [LogLevel]
enumFromTo :: LogLevel -> LogLevel -> [LogLevel]
$cenumFromThenTo :: LogLevel -> LogLevel -> LogLevel -> [LogLevel]
enumFromThenTo :: LogLevel -> LogLevel -> LogLevel -> [LogLevel]
Enum, LogLevel
LogLevel -> LogLevel -> Bounded LogLevel
forall a. a -> a -> Bounded a
$cminBound :: LogLevel
minBound :: LogLevel
$cmaxBound :: LogLevel
maxBound :: LogLevel
Bounded, (forall x. LogLevel -> Rep LogLevel x)
-> (forall x. Rep LogLevel x -> LogLevel) -> Generic LogLevel
forall x. Rep LogLevel x -> LogLevel
forall x. LogLevel -> Rep LogLevel x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. LogLevel -> Rep LogLevel x
from :: forall x. LogLevel -> Rep LogLevel x
$cto :: forall x. Rep LogLevel x -> LogLevel
to :: forall x. Rep LogLevel x -> LogLevel
Generic)

-- | A Rustls connection logging callback.
newtype LogCallback = LogCallback {LogCallback -> FunPtr LogCallback
unLogCallback :: FunPtr FFI.LogCallback}

-- | A 'Monad' to get TLS connection information via 'Rustls.handshake'.
newtype HandshakeQuery (side :: Side) a = HandshakeQuery (ReaderT Connection' IO a)
  deriving newtype ((forall a b.
 (a -> b) -> HandshakeQuery side a -> HandshakeQuery side b)
-> (forall a b.
    a -> HandshakeQuery side b -> HandshakeQuery side a)
-> Functor (HandshakeQuery side)
forall a b. a -> HandshakeQuery side b -> HandshakeQuery side a
forall a b.
(a -> b) -> HandshakeQuery side a -> HandshakeQuery side b
forall (side :: Side) a b.
a -> HandshakeQuery side b -> HandshakeQuery side a
forall (side :: Side) a b.
(a -> b) -> HandshakeQuery side a -> HandshakeQuery side b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (side :: Side) a b.
(a -> b) -> HandshakeQuery side a -> HandshakeQuery side b
fmap :: forall a b.
(a -> b) -> HandshakeQuery side a -> HandshakeQuery side b
$c<$ :: forall (side :: Side) a b.
a -> HandshakeQuery side b -> HandshakeQuery side a
<$ :: forall a b. a -> HandshakeQuery side b -> HandshakeQuery side a
Functor, Functor (HandshakeQuery side)
Functor (HandshakeQuery side)
-> (forall a. a -> HandshakeQuery side a)
-> (forall a b.
    HandshakeQuery side (a -> b)
    -> HandshakeQuery side a -> HandshakeQuery side b)
-> (forall a b c.
    (a -> b -> c)
    -> HandshakeQuery side a
    -> HandshakeQuery side b
    -> HandshakeQuery side c)
-> (forall a b.
    HandshakeQuery side a
    -> HandshakeQuery side b -> HandshakeQuery side b)
-> (forall a b.
    HandshakeQuery side a
    -> HandshakeQuery side b -> HandshakeQuery side a)
-> Applicative (HandshakeQuery side)
forall a. a -> HandshakeQuery side a
forall a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side a
forall a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
forall a b.
HandshakeQuery side (a -> b)
-> HandshakeQuery side a -> HandshakeQuery side b
forall a b c.
(a -> b -> c)
-> HandshakeQuery side a
-> HandshakeQuery side b
-> HandshakeQuery side c
forall (side :: Side). Functor (HandshakeQuery side)
forall (side :: Side) a. a -> HandshakeQuery side a
forall (side :: Side) a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side a
forall (side :: Side) a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
forall (side :: Side) a b.
HandshakeQuery side (a -> b)
-> HandshakeQuery side a -> HandshakeQuery side b
forall (side :: Side) a b c.
(a -> b -> c)
-> HandshakeQuery side a
-> HandshakeQuery side b
-> HandshakeQuery side c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (side :: Side) a. a -> HandshakeQuery side a
pure :: forall a. a -> HandshakeQuery side a
$c<*> :: forall (side :: Side) a b.
HandshakeQuery side (a -> b)
-> HandshakeQuery side a -> HandshakeQuery side b
<*> :: forall a b.
HandshakeQuery side (a -> b)
-> HandshakeQuery side a -> HandshakeQuery side b
$cliftA2 :: forall (side :: Side) a b c.
(a -> b -> c)
-> HandshakeQuery side a
-> HandshakeQuery side b
-> HandshakeQuery side c
liftA2 :: forall a b c.
(a -> b -> c)
-> HandshakeQuery side a
-> HandshakeQuery side b
-> HandshakeQuery side c
$c*> :: forall (side :: Side) a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
*> :: forall a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
$c<* :: forall (side :: Side) a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side a
<* :: forall a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side a
Applicative, Applicative (HandshakeQuery side)
Applicative (HandshakeQuery side)
-> (forall a b.
    HandshakeQuery side a
    -> (a -> HandshakeQuery side b) -> HandshakeQuery side b)
-> (forall a b.
    HandshakeQuery side a
    -> HandshakeQuery side b -> HandshakeQuery side b)
-> (forall a. a -> HandshakeQuery side a)
-> Monad (HandshakeQuery side)
forall a. a -> HandshakeQuery side a
forall a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
forall a b.
HandshakeQuery side a
-> (a -> HandshakeQuery side b) -> HandshakeQuery side b
forall (side :: Side). Applicative (HandshakeQuery side)
forall (side :: Side) a. a -> HandshakeQuery side a
forall (side :: Side) a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
forall (side :: Side) a b.
HandshakeQuery side a
-> (a -> HandshakeQuery side b) -> HandshakeQuery side b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (side :: Side) a b.
HandshakeQuery side a
-> (a -> HandshakeQuery side b) -> HandshakeQuery side b
>>= :: forall a b.
HandshakeQuery side a
-> (a -> HandshakeQuery side b) -> HandshakeQuery side b
$c>> :: forall (side :: Side) a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
>> :: forall a b.
HandshakeQuery side a
-> HandshakeQuery side b -> HandshakeQuery side b
$creturn :: forall (side :: Side) a. a -> HandshakeQuery side a
return :: forall a. a -> HandshakeQuery side a
Monad)

type role HandshakeQuery nominal _

handshakeQuery :: (Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery :: forall a (side :: Side).
(Connection' -> IO a) -> HandshakeQuery side a
handshakeQuery = (Connection' -> IO a) -> HandshakeQuery side a
forall a b. Coercible a b => a -> b
coerce

-- | TLS exception thrown by Rustls.
--
-- Use 'E.displayException' for a human-friendly representation.
newtype RustlsException = RustlsException {RustlsException -> Word32
rustlsErrorCode :: Word32}
  deriving stock (Int -> RustlsException -> ShowS
[RustlsException] -> ShowS
RustlsException -> String
(Int -> RustlsException -> ShowS)
-> (RustlsException -> String)
-> ([RustlsException] -> ShowS)
-> Show RustlsException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RustlsException -> ShowS
showsPrec :: Int -> RustlsException -> ShowS
$cshow :: RustlsException -> String
show :: RustlsException -> String
$cshowList :: [RustlsException] -> ShowS
showList :: [RustlsException] -> ShowS
Show)

instance E.Exception RustlsException where
  displayException :: RustlsException -> String
displayException RustlsException {Word32
rustlsErrorCode :: RustlsException -> Word32
rustlsErrorCode :: Word32
rustlsErrorCode} =
    [String] -> String
unwords
      [ String
"Rustls error:",
        Text -> String
T.unpack (Result -> Text
resultMsg (Word32 -> Result
FFI.Result Word32
rustlsErrorCode)),
        String
"(" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word32 -> String
forall a. Show a => a -> String
show Word32
rustlsErrorCode String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
")"
      ]

resultMsg :: FFI.Result -> Text
resultMsg :: Result -> Text
resultMsg Result
r = IO Text -> Text
forall a. IO a -> a
unsafePerformIO (IO Text -> Text) -> IO Text -> Text
forall a b. (a -> b) -> a -> b
$
  (Ptr CSize -> IO Text) -> IO Text
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca \Ptr CSize
lenPtr -> Int -> (Ptr CChar -> IO Text) -> IO Text
forall a b. Int -> (Ptr a -> IO b) -> IO b
allocaBytes (CSize -> Int
cSizeToInt CSize
msgLen) \Ptr CChar
buf -> do
    Result -> Ptr CChar -> CSize -> Ptr CSize -> IO ()
FFI.errorMsg Result
r Ptr CChar
buf CSize
msgLen Ptr CSize
lenPtr
    CSize
len <- Ptr CSize -> IO CSize
forall a. Storable a => Ptr a -> IO a
peek Ptr CSize
lenPtr
    CStringLen -> IO Text
T.peekCStringLen (Ptr CChar
buf, CSize -> Int
cSizeToInt CSize
len)
  where
    msgLen :: CSize
msgLen = CSize
1024 -- a bit pessimistic?

-- | Checks if the given 'RustlsException' represents a certificate error.
isCertError :: RustlsException -> Bool
isCertError :: RustlsException -> Bool
isCertError RustlsException {Word32
rustlsErrorCode :: RustlsException -> Word32
rustlsErrorCode :: Word32
rustlsErrorCode} =
  forall a. (Eq a, Num a) => a -> Bool
toBool @CBool (CBool -> Bool) -> CBool -> Bool
forall a b. (a -> b) -> a -> b
$ Result -> CBool
FFI.resultIsCertError (Word32 -> Result
FFI.Result Word32
rustlsErrorCode)

rethrowR :: FFI.Result -> IO ()
rethrowR :: Result -> IO ()
rethrowR = \case
  Result
r | Result
r Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
== Result
FFI.resultOk -> IO ()
forall a. Monoid a => a
mempty
  FFI.Result Word32
rustlsErrorCode ->
    RustlsException -> IO ()
forall e a. Exception e => e -> IO a
E.throwIO (RustlsException -> IO ()) -> RustlsException -> IO ()
forall a b. (a -> b) -> a -> b
$ Word32 -> RustlsException
RustlsException Word32
rustlsErrorCode

-- | Wrapper for exceptions thrown in a 'LogCallback'.
newtype RustlsLogException = RustlsLogException E.SomeException
  deriving stock (Int -> RustlsLogException -> ShowS
[RustlsLogException] -> ShowS
RustlsLogException -> String
(Int -> RustlsLogException -> ShowS)
-> (RustlsLogException -> String)
-> ([RustlsLogException] -> ShowS)
-> Show RustlsLogException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RustlsLogException -> ShowS
showsPrec :: Int -> RustlsLogException -> ShowS
$cshow :: RustlsLogException -> String
show :: RustlsLogException -> String
$cshowList :: [RustlsLogException] -> ShowS
showList :: [RustlsLogException] -> ShowS
Show)
  deriving anyclass (Show RustlsLogException
Typeable RustlsLogException
Typeable RustlsLogException
-> Show RustlsLogException
-> (RustlsLogException -> SomeException)
-> (SomeException -> Maybe RustlsLogException)
-> (RustlsLogException -> String)
-> Exception RustlsLogException
SomeException -> Maybe RustlsLogException
RustlsLogException -> String
RustlsLogException -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
$ctoException :: RustlsLogException -> SomeException
toException :: RustlsLogException -> SomeException
$cfromException :: SomeException -> Maybe RustlsLogException
fromException :: SomeException -> Maybe RustlsLogException
$cdisplayException :: RustlsLogException -> String
displayException :: RustlsLogException -> String
E.Exception)

data RustlsUnknownLogLevel = RustlsUnknownLogLevel FFI.LogLevel
  deriving stock (Int -> RustlsUnknownLogLevel -> ShowS
[RustlsUnknownLogLevel] -> ShowS
RustlsUnknownLogLevel -> String
(Int -> RustlsUnknownLogLevel -> ShowS)
-> (RustlsUnknownLogLevel -> String)
-> ([RustlsUnknownLogLevel] -> ShowS)
-> Show RustlsUnknownLogLevel
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RustlsUnknownLogLevel -> ShowS
showsPrec :: Int -> RustlsUnknownLogLevel -> ShowS
$cshow :: RustlsUnknownLogLevel -> String
show :: RustlsUnknownLogLevel -> String
$cshowList :: [RustlsUnknownLogLevel] -> ShowS
showList :: [RustlsUnknownLogLevel] -> ShowS
Show)
  deriving anyclass (Show RustlsUnknownLogLevel
Typeable RustlsUnknownLogLevel
Typeable RustlsUnknownLogLevel
-> Show RustlsUnknownLogLevel
-> (RustlsUnknownLogLevel -> SomeException)
-> (SomeException -> Maybe RustlsUnknownLogLevel)
-> (RustlsUnknownLogLevel -> String)
-> Exception RustlsUnknownLogLevel
SomeException -> Maybe RustlsUnknownLogLevel
RustlsUnknownLogLevel -> String
RustlsUnknownLogLevel -> SomeException
forall e.
Typeable e
-> Show e
-> (e -> SomeException)
-> (SomeException -> Maybe e)
-> (e -> String)
-> Exception e
$ctoException :: RustlsUnknownLogLevel -> SomeException
toException :: RustlsUnknownLogLevel -> SomeException
$cfromException :: SomeException -> Maybe RustlsUnknownLogLevel
fromException :: SomeException -> Maybe RustlsUnknownLogLevel
$cdisplayException :: RustlsUnknownLogLevel -> String
displayException :: RustlsUnknownLogLevel -> String
E.Exception)

-- | Underlying data sources for Rustls.
class Backend b where
  -- | Read data from the backend into the given buffer.
  backendRead ::
    b ->
    -- | Target buffer pointer.
    Ptr Word8 ->
    -- | Target buffer length.
    CSize ->
    -- | Amount of bytes read.
    IO CSize

  -- | Write data from the given buffer to the backend.
  backendWrite ::
    b ->
    -- | Source buffer pointer.
    Ptr Word8 ->
    -- | Source buffer length.
    CSize ->
    -- | Amount of bytes written.
    IO CSize

instance Backend NS.Socket where
  backendRead :: Socket -> Ptr Word8 -> CSize -> IO CSize
backendRead Socket
s Ptr Word8
buf CSize
len =
    Int -> CSize
intToCSize (Int -> CSize) -> IO Int -> IO CSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Ptr Word8 -> Int -> IO Int
NS.recvBuf Socket
s Ptr Word8
buf (CSize -> Int
cSizeToInt CSize
len)
  backendWrite :: Socket -> Ptr Word8 -> CSize -> IO CSize
backendWrite Socket
s Ptr Word8
buf CSize
len =
    Int -> CSize
intToCSize (Int -> CSize) -> IO Int -> IO CSize
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Socket -> Ptr Word8 -> Int -> IO Int
NS.sendBuf Socket
s Ptr Word8
buf (CSize -> Int
cSizeToInt CSize
len)

-- | An in-memory 'Backend'.
data ByteStringBackend = ByteStringBackend
  { -- | Read a 'ByteString' with the given max length.
    ByteStringBackend -> Int -> IO ByteString
bsbRead :: Int -> IO ByteString,
    -- | Write a 'ByteString'.
    ByteStringBackend -> ByteString -> IO ()
bsbWrite :: ByteString -> IO ()
  }
  deriving stock ((forall x. ByteStringBackend -> Rep ByteStringBackend x)
-> (forall x. Rep ByteStringBackend x -> ByteStringBackend)
-> Generic ByteStringBackend
forall x. Rep ByteStringBackend x -> ByteStringBackend
forall x. ByteStringBackend -> Rep ByteStringBackend x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. ByteStringBackend -> Rep ByteStringBackend x
from :: forall x. ByteStringBackend -> Rep ByteStringBackend x
$cto :: forall x. Rep ByteStringBackend x -> ByteStringBackend
to :: forall x. Rep ByteStringBackend x -> ByteStringBackend
Generic)

-- | This instance will silently truncate 'ByteString's which are too long.
instance Backend ByteStringBackend where
  backendRead :: ByteStringBackend -> Ptr Word8 -> CSize -> IO CSize
backendRead ByteStringBackend {Int -> IO ByteString
bsbRead :: ByteStringBackend -> Int -> IO ByteString
bsbRead :: Int -> IO ByteString
bsbRead} Ptr Word8
buf CSize
len = do
    ByteString
bs <- Int -> IO ByteString
bsbRead (CSize -> Int
cSizeToInt CSize
len)
    ByteString -> (CStringLen -> IO CSize) -> IO CSize
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BU.unsafeUseAsCStringLen ByteString
bs \(Ptr CChar
bsPtr, Int
bsLen) -> do
      let copyLen :: Int
copyLen = Int
bsLen Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` CSize -> Int
cSizeToInt CSize
len
      Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall a. Ptr a -> Ptr a -> Int -> IO ()
copyBytes Ptr Word8
buf (Ptr CChar -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr Ptr CChar
bsPtr) Int
copyLen
      CSize -> IO CSize
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CSize -> IO CSize) -> CSize -> IO CSize
forall a b. (a -> b) -> a -> b
$ Int -> CSize
intToCSize Int
copyLen
  backendWrite :: ByteStringBackend -> Ptr Word8 -> CSize -> IO CSize
backendWrite ByteStringBackend {ByteString -> IO ()
bsbWrite :: ByteStringBackend -> ByteString -> IO ()
bsbWrite :: ByteString -> IO ()
bsbWrite} Ptr Word8
buf CSize
len = do
    ByteString -> IO ()
bsbWrite (ByteString -> IO ()) -> IO ByteString -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CStringLen -> IO ByteString
B.packCStringLen (Ptr Word8 -> Ptr CChar
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
buf, CSize -> Int
cSizeToInt CSize
len)
    CSize -> IO CSize
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure CSize
len

-- | Type-level indicator whether a 'Connection' is client- or server-side.
data Side = Client | Server

-- | A Rustls connection.
newtype Connection (side :: Side) = Connection (MVar Connection')

type role Connection nominal

data Connection' = forall b.
  (Backend b) =>
  Connection'
  { Connection' -> Ptr Connection
conn :: Ptr FFI.Connection,
    ()
backend :: b,
    Connection' -> Ptr CSize
lenPtr :: Ptr CSize,
    Connection' -> MVar IOMsgReq
ioMsgReq :: MVar IOMsgReq,
    Connection' -> MVar IOMsgRes
ioMsgRes :: MVar IOMsgRes,
    Connection' -> ThreadId
interactThread :: ThreadId
  }

withConnection :: Connection side -> (Connection' -> IO a) -> IO a
withConnection :: forall (side :: Side) a.
Connection side -> (Connection' -> IO a) -> IO a
withConnection (Connection MVar Connection'
c) = MVar Connection' -> (Connection' -> IO a) -> IO a
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar Connection'
c

data ReadOrWrite = Read | Write

-- GHC will delay async exceptions to (non-interruptible) FFI calls until they
-- finish. In particular, this means that when a (safe) FFI call invokes a
-- Haskell callback, it is uncancelable. As usages of this library will most
-- likely involve actual I/O (which really should be able to be cancelled), we
-- invoke the respective FFI functions (which will themselves then call back
-- into Haskell) in a separate thread, and interact with it via message passing
-- (see the 'IOMsgReq' and 'IOMsgRes' types).

-- | Messages sent to the background thread.
data IOMsgReq
  = -- | Request to start a read or a write FFI call from the background thread.
    -- It should respond with 'UsingBuffer'.
    Request ReadOrWrite
  | -- | Notify the background thread that we are done interacting with the
    -- buffer.
    Done FFI.IOResult

-- | Messages sent from the background thread.
data IOMsgRes
  = -- | Reply with a buffer, either containing the read data, or awaiting a
    -- write to this buffer.
    UsingBuffer (Ptr Word8) CSize (Ptr CSize)
  | -- | Notify that the FFI call finished.
    DoneFFI

interactTLS :: Connection' -> ReadOrWrite -> IO ()
interactTLS :: Connection' -> ReadOrWrite -> IO ()
interactTLS Connection' {b
Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
conn :: Connection' -> Ptr Connection
backend :: ()
lenPtr :: Connection' -> Ptr CSize
ioMsgReq :: Connection' -> MVar IOMsgReq
ioMsgRes :: Connection' -> MVar IOMsgRes
interactThread :: Connection' -> ThreadId
conn :: Ptr Connection
backend :: b
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} ReadOrWrite
readOrWrite = ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
E.uninterruptibleMask \forall a. IO a -> IO a
restore -> do
  MVar IOMsgReq -> IOMsgReq -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar IOMsgReq
ioMsgReq (IOMsgReq -> IO ()) -> IOMsgReq -> IO ()
forall a b. (a -> b) -> a -> b
$ ReadOrWrite -> IOMsgReq
Request ReadOrWrite
readOrWrite
  UsingBuffer Ptr Word8
buf CSize
len Ptr CSize
readPtr <- MVar IOMsgRes -> IO IOMsgRes
forall a. MVar a -> IO a
takeMVar MVar IOMsgRes
ioMsgRes
  Ptr CSize -> CSize -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CSize
readPtr
    (CSize -> IO ()) -> IO CSize -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO CSize -> IO CSize
forall a. IO a -> IO a
restore (Ptr Word8 -> CSize -> IO CSize
readOrWriteBackend Ptr Word8
buf CSize
len)
      IO CSize -> IO () -> IO CSize
forall a b. IO a -> IO b -> IO a
`E.onException` IOResult -> IO ()
done IOResult
FFI.ioResultErr
  IOResult -> IO ()
done IOResult
FFI.ioResultOk
  where
    readOrWriteBackend :: Ptr Word8 -> CSize -> IO CSize
readOrWriteBackend = case ReadOrWrite
readOrWrite of
      ReadOrWrite
Read -> b -> Ptr Word8 -> CSize -> IO CSize
forall b. Backend b => b -> Ptr Word8 -> CSize -> IO CSize
backendRead b
backend
      ReadOrWrite
Write -> b -> Ptr Word8 -> CSize -> IO CSize
forall b. Backend b => b -> Ptr Word8 -> CSize -> IO CSize
backendWrite b
backend
    done :: IOResult -> IO ()
done IOResult
ioResult = do
      MVar IOMsgReq -> IOMsgReq -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar IOMsgReq
ioMsgReq (IOMsgReq -> IO ()) -> IOMsgReq -> IO ()
forall a b. (a -> b) -> a -> b
$ IOResult -> IOMsgReq
Done IOResult
ioResult
      IOMsgRes
DoneFFI <- MVar IOMsgRes -> IO IOMsgRes
forall a. MVar a -> IO a
takeMVar MVar IOMsgRes
ioMsgRes
      () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

data RunTLSMode = TLSHandshake | TLSRead | TLSWrite
  deriving (RunTLSMode -> RunTLSMode -> Bool
(RunTLSMode -> RunTLSMode -> Bool)
-> (RunTLSMode -> RunTLSMode -> Bool) -> Eq RunTLSMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RunTLSMode -> RunTLSMode -> Bool
== :: RunTLSMode -> RunTLSMode -> Bool
$c/= :: RunTLSMode -> RunTLSMode -> Bool
/= :: RunTLSMode -> RunTLSMode -> Bool
Eq)

runTLS :: Connection' -> RunTLSMode -> IO ()
runTLS :: Connection' -> RunTLSMode -> IO ()
runTLS c :: Connection'
c@Connection' {b
Ptr CSize
Ptr Connection
ThreadId
MVar IOMsgRes
MVar IOMsgReq
conn :: Connection' -> Ptr Connection
backend :: ()
lenPtr :: Connection' -> Ptr CSize
ioMsgReq :: Connection' -> MVar IOMsgReq
ioMsgRes :: Connection' -> MVar IOMsgRes
interactThread :: Connection' -> ThreadId
conn :: Ptr Connection
backend :: b
lenPtr :: Ptr CSize
ioMsgReq :: MVar IOMsgReq
ioMsgRes :: MVar IOMsgRes
interactThread :: ThreadId
..} = \case
  RunTLSMode
TLSHandshake -> IO Bool -> IO ()
forall {m :: * -> *}. Monad m => m Bool -> m ()
loopWhileTrue do
    forall a. (Eq a, Num a) => a -> Bool
toBool @CBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConstPtr Connection -> IO CBool
FFI.connectionIsHandshaking (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn) IO Bool -> (Bool -> IO Bool) -> IO Bool
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Bool
True -> Bool -> Bool -> Bool
(||) (Bool -> Bool -> Bool) -> IO Bool -> IO (Bool -> Bool)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Bool
runWrite IO (Bool -> Bool) -> IO Bool -> IO Bool
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Bool
runRead
      Bool
False -> Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
  RunTLSMode
TLSRead -> do
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSHandshake
    IO Bool -> IO ()
forall {m :: * -> *}. Monad m => m Bool -> m ()
loopWhileTrue IO Bool
runRead
  RunTLSMode
TLSWrite -> do
    Connection' -> RunTLSMode -> IO ()
runTLS Connection'
c RunTLSMode
TLSHandshake
    IO Bool -> IO ()
forall {m :: * -> *}. Monad m => m Bool -> m ()
loopWhileTrue IO Bool
runWrite
  where
    runRead :: IO Bool
runRead = do
      Bool
wantsRead <- forall a. (Eq a, Num a) => a -> Bool
toBool @CBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConstPtr Connection -> IO CBool
FFI.connectionWantsRead (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn)
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
wantsRead do
        Connection' -> ReadOrWrite -> IO ()
interactTLS Connection'
c ReadOrWrite
Read
        Result
r <- Ptr Connection -> IO Result
FFI.connectionProcessNewPackets Ptr Connection
conn
        -- try to notify our peer that we encountered a TLS error
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Result
r Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= Result
FFI.resultOk) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
ignoreSyncExceptions (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void IO Bool
runWrite
        Result -> IO ()
rethrowR Result
r
      Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
wantsRead

    runWrite :: IO Bool
runWrite = do
      Bool
wantsWrite <- forall a. (Eq a, Num a) => a -> Bool
toBool @CBool (CBool -> Bool) -> IO CBool -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConstPtr Connection -> IO CBool
FFI.connectionWantsWrite (Ptr Connection -> ConstPtr Connection
forall a. Ptr a -> ConstPtr a
ConstPtr Ptr Connection
conn)
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
wantsWrite (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        Connection' -> ReadOrWrite -> IO ()
interactTLS Connection'
c ReadOrWrite
Write
      Bool -> IO Bool
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
wantsWrite

    loopWhileTrue :: m Bool -> m ()
loopWhileTrue m Bool
action = do
      Bool
continue <- m Bool
action
      Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
continue (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ m Bool -> m ()
loopWhileTrue m Bool
action

cSizeToInt :: CSize -> Int
cSizeToInt :: CSize -> Int
cSizeToInt = CSize -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE cSizeToInt #-}

intToCSize :: Int -> CSize
intToCSize :: Int -> CSize
intToCSize = Int -> CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral
{-# INLINE intToCSize #-}

strToText :: FFI.Str -> IO Text
strToText :: Str -> IO Text
strToText (FFI.Str Ptr CChar
buf CSize
len) = CStringLen -> IO Text
T.peekCStringLen (Ptr CChar
buf, CSize -> Int
cSizeToInt CSize
len)

ignoreExceptions :: IO () -> IO ()
ignoreExceptions :: IO () -> IO ()
ignoreExceptions = IO (Either SomeException ()) -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO (Either SomeException ()) -> IO ())
-> (IO () -> IO (Either SomeException ())) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e a. Exception e => IO a -> IO (Either e a)
E.try @E.SomeException

ignoreSyncExceptions :: IO () -> IO ()
ignoreSyncExceptions :: IO () -> IO ()
ignoreSyncExceptions = (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle \case
  (SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
E.fromException -> Just e :: SomeAsyncException
e@(E.SomeAsyncException e
_)) -> SomeAsyncException -> IO ()
forall e a. Exception e => e -> IO a
E.throwIO SomeAsyncException
e
  SomeException
_ -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()