{-# Language BlockArguments, LambdaCase #-}
{-|
Module      : Hookup
Description : Network connections generalized over TLS and SOCKS
Copyright   : (c) Eric Mertens, 2016
License     : ISC
Maintainer  : emertens@gmail.com

This module provides a uniform interface to network connections
with optional support for TLS and SOCKS.

This library is careful to support both IPv4 and IPv6. It will attempt to
all of the addresses that a domain name resolves to until one the first
successful connection.

Use 'connect' and 'close' to establish and close network connections.

Use 'recv', 'recvLine', and 'send' to receive and transmit data on an
open network connection.

TLS and SOCKS parameters can be provided. When both are provided a connection
will first be established to the SOCKS server and then the TLS connection will
be established through that proxy server. This is most useful when connecting
through a dynamic port forward of an SSH client via the @-D@ flag.

-}
module Hookup
  (
  -- * Connections
  Connection,
  connect,
  connectWithSocket,
  close,
  upgradeTls,

  -- * Reading and writing data
  recv,
  recvLine,
  send,
  putBuf,

  -- * Configuration
  ConnectionParams(..),
  SocksParams(..),
  SocksAuthentication(..),
  TlsParams(..),
  TlsVerify(..),
  PEM.PemPasswordSupply(..),
  defaultTlsParams,

  -- * Errors
  ConnectionFailure(..),
  CommandReply(..)

  -- * SSL Information
  , getClientCertificate
  , getPeerCertificate
  , getPeerCertFingerprintSha1
  , getPeerCertFingerprintSha256
  , getPeerCertFingerprintSha512
  , getPeerPubkeyFingerprintSha1
  , getPeerPubkeyFingerprintSha256
  , getPeerPubkeyFingerprintSha512
  ) where

import           Control.Concurrent
import           Control.Exception
import           Control.Monad (when, unless)
import           System.IO.Error (isDoesNotExistError, ioeGetErrorString)
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import           Data.Foldable (for_, traverse_)
import           Data.List (intercalate, partition)
import           Data.Maybe (fromMaybe, mapMaybe)
import           Foreign.C.String (withCStringLen)
import           Foreign.Ptr (nullPtr)
import           Network.Socket (AddrInfo, HostName, PortNumber, SockAddr, Socket, Family)
import qualified Network.Socket as Socket
import qualified Network.Socket.ByteString as SocketB
import           OpenSSL.Session (SSL, SSLContext)
import qualified OpenSSL as SSL
import qualified OpenSSL.Session as SSL
import           OpenSSL.X509.SystemStore (contextLoadSystemCerts)
import           OpenSSL.X509 (X509)
import qualified OpenSSL.X509 as X509
import qualified OpenSSL.PEM as PEM
import qualified OpenSSL.EVP.Digest as Digest
import           Data.Attoparsec.ByteString (Parser)
import qualified Data.Attoparsec.ByteString as Parser

import           Hookup.Concurrent (concurrentAttempts)
import           Hookup.OpenSSL
import           Hookup.Socks5

-- | Parameters for 'connect'.
--
-- Common defaults for fields: 'defaultFamily', 'defaultTlsParams'
--
-- When a 'SocksParams' is provided the connection will be established
-- using a SOCKS (version 5) proxy.
--
-- When a 'TlsParams' is provided the connection negotiate TLS at connect
-- time in order to protect the stream.
--
-- The binding hostname can be used to force the connect to use a particular
-- interface or IP protocol version.
data ConnectionParams = ConnectionParams
  { ConnectionParams -> HostName
cpHost  :: HostName          -- ^ Destination host
  , ConnectionParams -> PortNumber
cpPort  :: PortNumber        -- ^ Destination TCP port
  , ConnectionParams -> Maybe SocksParams
cpSocks :: Maybe SocksParams -- ^ Optional SOCKS parameters
  , ConnectionParams -> Maybe TlsParams
cpTls   :: Maybe TlsParams   -- ^ Optional TLS parameters
  , ConnectionParams -> Maybe HostName
cpBind  :: Maybe HostName    -- ^ Source address to bind
  }
  deriving Int -> ConnectionParams -> ShowS
[ConnectionParams] -> ShowS
ConnectionParams -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionParams] -> ShowS
$cshowList :: [ConnectionParams] -> ShowS
show :: ConnectionParams -> HostName
$cshow :: ConnectionParams -> HostName
showsPrec :: Int -> ConnectionParams -> ShowS
$cshowsPrec :: Int -> ConnectionParams -> ShowS
Show

-- | SOCKS connection parameters
data SocksParams = SocksParams
  { SocksParams -> HostName
spHost :: HostName   -- ^ SOCKS server host
  , SocksParams -> PortNumber
spPort :: PortNumber -- ^ SOCKS server port
  , SocksParams -> SocksAuthentication
spAuth :: SocksAuthentication -- ^ SOCKS authentication method
  }
  deriving Int -> SocksParams -> ShowS
[SocksParams] -> ShowS
SocksParams -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [SocksParams] -> ShowS
$cshowList :: [SocksParams] -> ShowS
show :: SocksParams -> HostName
$cshow :: SocksParams -> HostName
showsPrec :: Int -> SocksParams -> ShowS
$cshowsPrec :: Int -> SocksParams -> ShowS
Show

data SocksAuthentication
  = NoSocksAuthentication -- ^ no credentials
  | UsernamePasswordSocksAuthentication ByteString ByteString -- ^ RFC 1929 username and password
  deriving Int -> SocksAuthentication -> ShowS
[SocksAuthentication] -> ShowS
SocksAuthentication -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [SocksAuthentication] -> ShowS
$cshowList :: [SocksAuthentication] -> ShowS
show :: SocksAuthentication -> HostName
$cshow :: SocksAuthentication -> HostName
showsPrec :: Int -> SocksAuthentication -> ShowS
$cshowsPrec :: Int -> SocksAuthentication -> ShowS
Show

-- | TLS connection parameters. These parameters are passed to
-- OpenSSL when making a secure connection.
data TlsParams = TlsParams
  { TlsParams -> Maybe HostName
tpClientCertificate  :: Maybe FilePath -- ^ Path to client certificate
  , TlsParams -> Maybe HostName
tpClientPrivateKey   :: Maybe FilePath -- ^ Path to client private key
  , TlsParams -> Maybe ByteString
tpClientPrivateKeyPassword :: Maybe ByteString -- ^ Private key decryption password
  , TlsParams -> Maybe HostName
tpServerCertificate  :: Maybe FilePath -- ^ Path to CA certificate bundle
  , TlsParams -> HostName
tpCipherSuite        :: String -- ^ OpenSSL cipher suite name (e.g. @\"HIGH\"@)
  , TlsParams -> Maybe HostName
tpCipherSuiteTls13   :: Maybe String -- ^ OpenSSL cipher suites for TLS 1.3
  , TlsParams -> TlsVerify
tpVerify             :: TlsVerify -- ^ Hostname to use when checking certificate validity
  }
  deriving Int -> TlsParams -> ShowS
[TlsParams] -> ShowS
TlsParams -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [TlsParams] -> ShowS
$cshowList :: [TlsParams] -> ShowS
show :: TlsParams -> HostName
$cshow :: TlsParams -> HostName
showsPrec :: Int -> TlsParams -> ShowS
$cshowsPrec :: Int -> TlsParams -> ShowS
Show

data TlsVerify
  = VerifyDefault -- ^ Use the connection hostname to verify
  | VerifyNone -- ^ No verification
  | VerifyHostname String -- ^ Use the given hostname to verify
  deriving Int -> TlsVerify -> ShowS
[TlsVerify] -> ShowS
TlsVerify -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [TlsVerify] -> ShowS
$cshowList :: [TlsVerify] -> ShowS
show :: TlsVerify -> HostName
$cshow :: TlsVerify -> HostName
showsPrec :: Int -> TlsVerify -> ShowS
$cshowsPrec :: Int -> TlsVerify -> ShowS
Show

-- | Type for errors that can be thrown by this package.
data ConnectionFailure
  -- | Failure during 'getAddrInfo' resolving remote host
  = HostnameResolutionFailure HostName String
  -- | Failure during 'connect' to remote host
  | ConnectionFailure [ConnectError]
  -- | Failure during 'recvLine'
  | LineTooLong
  -- | Incomplete line during 'recvLine'
  | LineTruncated
  -- | Socks command rejected by server by given reply code
  | SocksError CommandReply
  -- | Socks authentication method was not accepted
  | SocksAuthenticationMethodRejected
  -- | Socks authentication method was not accepted
  | SocksAuthenticationCredentialsRejected
  -- | Username or password were too long
  | SocksBadAuthenticationCredentials
  -- | Socks server sent an invalid message or no message.
  | SocksProtocolError
  -- | Domain name was too long for SOCKS protocol
  | SocksBadDomainName
  deriving Int -> ConnectionFailure -> ShowS
[ConnectionFailure] -> ShowS
ConnectionFailure -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [ConnectionFailure] -> ShowS
$cshowList :: [ConnectionFailure] -> ShowS
show :: ConnectionFailure -> HostName
$cshow :: ConnectionFailure -> HostName
showsPrec :: Int -> ConnectionFailure -> ShowS
$cshowsPrec :: Int -> ConnectionFailure -> ShowS
Show

-- | 'displayException' implemented for prettier messages
instance Exception ConnectionFailure where
  displayException :: ConnectionFailure -> HostName
displayException ConnectionFailure
LineTruncated = HostName
"connection closed while reading line"
  displayException ConnectionFailure
LineTooLong   = HostName
"line length exceeded maximum"
  displayException (ConnectionFailure [ConnectError]
xs) =
    HostName
"connection attempt failed due to: " forall a. [a] -> [a] -> [a]
++
      forall a. [a] -> [[a]] -> [a]
intercalate HostName
", " (forall a b. (a -> b) -> [a] -> [b]
map forall e. Exception e => e -> HostName
displayException [ConnectError]
xs)
  displayException (HostnameResolutionFailure HostName
h HostName
s) =
    HostName
"hostname resolution failed (" forall a. [a] -> [a] -> [a]
++ HostName
h forall a. [a] -> [a] -> [a]
++ HostName
"): "  forall a. [a] -> [a] -> [a]
++ HostName
s
  displayException ConnectionFailure
SocksAuthenticationMethodRejected =
    HostName
"SOCKS authentication method rejected"
  displayException ConnectionFailure
SocksAuthenticationCredentialsRejected =
    HostName
"SOCKS authentication credentials rejected"
  displayException ConnectionFailure
SocksBadAuthenticationCredentials =
    HostName
"SOCKS authentication credentials too long"
  displayException ConnectionFailure
SocksProtocolError =
    HostName
"SOCKS server protocol error"
  displayException ConnectionFailure
SocksBadDomainName =
    HostName
"SOCKS domain name length limit exceeded"
  displayException (SocksError CommandReply
reply) =
    HostName
"SOCKS command rejected: " forall a. [a] -> [a] -> [a]
++
    case CommandReply
reply of
      CommandReply
Succeeded         -> HostName
"succeeded"
      CommandReply
GeneralFailure    -> HostName
"general SOCKS server failure"
      CommandReply
NotAllowed        -> HostName
"connection not allowed by ruleset"
      CommandReply
NetUnreachable    -> HostName
"network unreachable"
      CommandReply
HostUnreachable   -> HostName
"host unreachable"
      CommandReply
ConnectionRefused -> HostName
"connection refused"
      CommandReply
TTLExpired        -> HostName
"TTL expired"
      CommandReply
CmdNotSupported   -> HostName
"command not supported"
      CommandReply
AddrNotSupported  -> HostName
"address type not supported"
      CommandReply Word8
n    -> HostName
"unknown reply " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> HostName
show Word8
n

data ConnectError = ConnectError SockAddr IOError
  deriving Int -> ConnectError -> ShowS
[ConnectError] -> ShowS
ConnectError -> HostName
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
showList :: [ConnectError] -> ShowS
$cshowList :: [ConnectError] -> ShowS
show :: ConnectError -> HostName
$cshow :: ConnectError -> HostName
showsPrec :: Int -> ConnectError -> ShowS
$cshowsPrec :: Int -> ConnectError -> ShowS
Show

instance Exception ConnectError where
  displayException :: ConnectError -> HostName
displayException (ConnectError SockAddr
addr IOError
e) = forall a. Show a => a -> HostName
show SockAddr
addr forall a. [a] -> [a] -> [a]
++ HostName
": " forall a. [a] -> [a] -> [a]
++ forall e. Exception e => e -> HostName
displayException IOError
e

-- | Default values for TLS that use no client certificates, use
-- system CA root, @\"HIGH\"@ cipher suite, and which validate hostnames.
defaultTlsParams :: TlsParams
defaultTlsParams :: TlsParams
defaultTlsParams = TlsParams
  { tpClientCertificate :: Maybe HostName
tpClientCertificate  = forall a. Maybe a
Nothing
  , tpClientPrivateKey :: Maybe HostName
tpClientPrivateKey   = forall a. Maybe a
Nothing
  , tpClientPrivateKeyPassword :: Maybe ByteString
tpClientPrivateKeyPassword = forall a. Maybe a
Nothing
  , tpServerCertificate :: Maybe HostName
tpServerCertificate  = forall a. Maybe a
Nothing -- use system provided CAs
  , tpCipherSuite :: HostName
tpCipherSuite        = HostName
"HIGH"
  , tpCipherSuiteTls13 :: Maybe HostName
tpCipherSuiteTls13   = forall a. Maybe a
Nothing
  , tpVerify :: TlsVerify
tpVerify             = TlsVerify
VerifyDefault
  }

------------------------------------------------------------------------
-- Opening sockets
------------------------------------------------------------------------

-- | Open a socket using the given parameters either directly or
-- via a SOCKS server.
openSocket :: ConnectionParams -> IO Socket
openSocket :: ConnectionParams -> IO Socket
openSocket ConnectionParams
params =
  case ConnectionParams -> Maybe SocksParams
cpSocks ConnectionParams
params of
    Maybe SocksParams
Nothing -> HostName -> PortNumber -> Maybe HostName -> IO Socket
openSocket' (ConnectionParams -> HostName
cpHost ConnectionParams
params) (ConnectionParams -> PortNumber
cpPort ConnectionParams
params) (ConnectionParams -> Maybe HostName
cpBind ConnectionParams
params)
    Just SocksParams
sp ->
      forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (HostName -> PortNumber -> Maybe HostName -> IO Socket
openSocket' (SocksParams -> HostName
spHost SocksParams
sp) (SocksParams -> PortNumber
spPort SocksParams
sp) (ConnectionParams -> Maybe HostName
cpBind ConnectionParams
params))
        Socket -> IO ()
Socket.close
        \Socket
sock ->
          Socket
sock forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ Socket -> HostName -> PortNumber -> SocksAuthentication -> IO ()
socksConnect Socket
sock (ConnectionParams -> HostName
cpHost ConnectionParams
params) (ConnectionParams -> PortNumber
cpPort ConnectionParams
params) (SocksParams -> SocksAuthentication
spAuth SocksParams
sp)

netParse :: Show a => Socket -> Parser a -> IO a
netParse :: forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser a
parser =
  do -- receiving 1 byte at a time is not efficient, but these messages
     -- are very short and we don't want to read any more from the socket
     -- than is necessary
     Result a
result <- forall (m :: * -> *) a.
Monad m =>
m ByteString -> Parser a -> ByteString -> m (Result a)
Parser.parseWith
                 (Socket -> Int -> IO ByteString
SocketB.recv Socket
sock Int
1)
                 Parser a
parser
                 ByteString
B.empty
     case Result a
result of
       Parser.Done ByteString
i a
x | ByteString -> Bool
B.null ByteString
i -> forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
       Result a
_ -> forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksProtocolError

socksConnect :: Socket -> HostName -> PortNumber -> SocksAuthentication -> IO ()
socksConnect :: Socket -> HostName -> PortNumber -> SocksAuthentication -> IO ()
socksConnect Socket
sock HostName
host PortNumber
port SocksAuthentication
auth =
 do case SocksAuthentication
auth of
      SocksAuthentication
NoSocksAuthentication ->
       do Socket -> ByteString -> IO ()
SocketB.sendAll Socket
sock forall a b. (a -> b) -> a -> b
$
            ClientHello -> ByteString
buildClientHello ClientHello
              { cHelloMethods :: [AuthMethod]
cHelloMethods = [AuthMethod
AuthNoAuthenticationRequired] }
          ServerHello
hello <- forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser ServerHello
parseServerHello
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ServerHello -> AuthMethod
sHelloMethod ServerHello
hello forall a. Eq a => a -> a -> Bool
== AuthMethod
AuthNoAuthenticationRequired)
            (forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksAuthenticationMethodRejected)

      UsernamePasswordSocksAuthentication ByteString
u ByteString
p ->
       do forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Int
B.length ByteString
u forall a. Ord a => a -> a -> Bool
< Int
256 Bool -> Bool -> Bool
&& ByteString -> Int
B.length ByteString
p forall a. Ord a => a -> a -> Bool
< Int
256)
            (forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksBadAuthenticationCredentials)

          Socket -> ByteString -> IO ()
SocketB.sendAll Socket
sock forall a b. (a -> b) -> a -> b
$
            ClientHello -> ByteString
buildClientHello ClientHello
              { cHelloMethods :: [AuthMethod]
cHelloMethods = [AuthMethod
AuthUsernamePassword] }
          ServerHello
hello <- forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser ServerHello
parseServerHello
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ServerHello -> AuthMethod
sHelloMethod ServerHello
hello forall a. Eq a => a -> a -> Bool
== AuthMethod
AuthUsernamePassword)
            (forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksAuthenticationMethodRejected)

          Socket -> ByteString -> IO ()
SocketB.sendAll Socket
sock forall a b. (a -> b) -> a -> b
$
            PlainAuthentication -> ByteString
buildPlainAuthentication PlainAuthentication
              { plainUsername :: ByteString
plainUsername = ByteString
u, plainPassword :: ByteString
plainPassword = ByteString
p }
          PlainAuthenticationReply
status <- forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser PlainAuthenticationReply
parsePlainAuthenticationReply
          forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Word8
0 forall a. Eq a => a -> a -> Bool
== PlainAuthenticationReply -> Word8
plainStatus PlainAuthenticationReply
status)
            (forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksAuthenticationCredentialsRejected)

    let dnBytes :: ByteString
dnBytes = HostName -> ByteString
B8.pack HostName
host
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Int
B.length ByteString
dnBytes forall a. Ord a => a -> a -> Bool
< Int
256)
      (forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
SocksBadDomainName)

    Socket -> ByteString -> IO ()
SocketB.sendAll Socket
sock forall a b. (a -> b) -> a -> b
$
      Request -> ByteString
buildRequest Request
        { reqCommand :: Command
reqCommand  = Command
Connect
        , reqAddress :: Address
reqAddress  = Host -> PortNumber -> Address
Address (ByteString -> Host
DomainName ByteString
dnBytes) PortNumber
port
        }

    Response
response <- forall a. Show a => Socket -> Parser a -> IO a
netParse Socket
sock Parser Response
parseResponse
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Response -> CommandReply
rspReply Response
response forall a. Eq a => a -> a -> Bool
== CommandReply
Succeeded )
      (forall e a. Exception e => e -> IO a
throwIO (CommandReply -> ConnectionFailure
SocksError (Response -> CommandReply
rspReply Response
response)))

openSocket' ::
  HostName       {- ^ destination      -} ->
  PortNumber     {- ^ destination port -} ->
  Maybe HostName {- ^ source           -} ->
  IO Socket      {- ^ connected socket -}
openSocket' :: HostName -> PortNumber -> Maybe HostName -> IO Socket
openSocket' HostName
h PortNumber
p Maybe HostName
mbBind =
  do Maybe [AddrInfo]
mbSrc <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Maybe PortNumber -> HostName -> IO [AddrInfo]
resolve forall a. Maybe a
Nothing) Maybe HostName
mbBind
     [AddrInfo]
dst   <- Maybe PortNumber -> HostName -> IO [AddrInfo]
resolve (forall a. a -> Maybe a
Just PortNumber
p) HostName
h
     let pairs :: [(Maybe SockAddr, AddrInfo)]
pairs = [(Maybe SockAddr, AddrInfo)] -> [(Maybe SockAddr, AddrInfo)]
interleaveAddressFamilies (Maybe [AddrInfo] -> [AddrInfo] -> [(Maybe SockAddr, AddrInfo)]
matchBindAddrs Maybe [AddrInfo]
mbSrc [AddrInfo]
dst)
     forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Maybe SockAddr, AddrInfo)]
pairs)
       (forall e a. Exception e => e -> IO a
throwIO (HostName -> HostName -> ConnectionFailure
HostnameResolutionFailure HostName
h HostName
"No source/destination address family match"))
     Either [SomeException] Socket
res <- forall a.
Int -> (a -> IO ()) -> [IO a] -> IO (Either [SomeException] a)
concurrentAttempts Int
connAttemptDelay Socket -> IO ()
Socket.close (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Maybe SockAddr -> AddrInfo -> IO Socket
connectToAddrInfo forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Maybe SockAddr, AddrInfo)]
pairs)
     case Either [SomeException] Socket
res of
       Left [SomeException]
es -> forall e a. Exception e => e -> IO a
throwIO ([ConnectError] -> ConnectionFailure
ConnectionFailure (forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall e. Exception e => SomeException -> Maybe e
fromException [SomeException]
es))
       Right Socket
s -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
s

hints :: AddrInfo
hints :: AddrInfo
hints = AddrInfo
Socket.defaultHints
  { addrSocketType :: SocketType
Socket.addrSocketType = SocketType
Socket.Stream
  , addrFlags :: [AddrInfoFlag]
Socket.addrFlags      = [AddrInfoFlag
Socket.AI_NUMERICSERV]
  }

resolve :: Maybe PortNumber -> HostName -> IO [AddrInfo]
resolve :: Maybe PortNumber -> HostName -> IO [AddrInfo]
resolve Maybe PortNumber
mbPort HostName
host =
  do Either IOError [AddrInfo]
res <- forall e a. Exception e => IO a -> IO (Either e a)
try (Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
Socket.getAddrInfo (forall a. a -> Maybe a
Just AddrInfo
hints) (forall a. a -> Maybe a
Just HostName
host) (forall a. Show a => a -> HostName
showforall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>Maybe PortNumber
mbPort))
     case Either IOError [AddrInfo]
res of
       Right [AddrInfo]
ais -> forall (f :: * -> *) a. Applicative f => a -> f a
pure [AddrInfo]
ais
       Left IOError
ioe
         | IOError -> Bool
isDoesNotExistError IOError
ioe ->
             forall e a. Exception e => e -> IO a
throwIO (HostName -> HostName -> ConnectionFailure
HostnameResolutionFailure HostName
host (IOError -> HostName
ioeGetErrorString IOError
ioe))
         | Bool
otherwise -> forall e a. Exception e => e -> IO a
throwIO IOError
ioe -- unexpected

-- | When no bind address is specified return the full list of destination
-- addresses with no bind address specified.
--
-- When bind addresses are specified return a subset of the destination list
-- matched up with the first address from the bind list that has the
-- correct address family.
matchBindAddrs :: Maybe [AddrInfo] -> [AddrInfo] -> [(Maybe SockAddr, AddrInfo)]
matchBindAddrs :: Maybe [AddrInfo] -> [AddrInfo] -> [(Maybe SockAddr, AddrInfo)]
matchBindAddrs Maybe [AddrInfo]
Nothing    [AddrInfo]
dst = [ (forall a. Maybe a
Nothing, AddrInfo
x) | AddrInfo
x <- [AddrInfo]
dst ]
matchBindAddrs (Just [AddrInfo]
src) [AddrInfo]
dst =
  [ (forall a. a -> Maybe a
Just (AddrInfo -> SockAddr
Socket.addrAddress AddrInfo
s), AddrInfo
d)
  | AddrInfo
d <- [AddrInfo]
dst
  , let ss :: [AddrInfo]
ss = [AddrInfo
s | AddrInfo
s <- [AddrInfo]
src, AddrInfo -> Family
Socket.addrFamily AddrInfo
d forall a. Eq a => a -> a -> Bool
== AddrInfo -> Family
Socket.addrFamily AddrInfo
s]
  , AddrInfo
s <- forall a. Int -> [a] -> [a]
take Int
1 [AddrInfo]
ss ]

connAttemptDelay :: Int
connAttemptDelay :: Int
connAttemptDelay = Int
150 forall a. Num a => a -> a -> a
* Int
1000 -- 150ms

-- | Alternate list of addresses between IPv6 and other (IPv4) addresses.
interleaveAddressFamilies :: [(Maybe SockAddr, AddrInfo)] -> [(Maybe SockAddr, AddrInfo)]
interleaveAddressFamilies :: [(Maybe SockAddr, AddrInfo)] -> [(Maybe SockAddr, AddrInfo)]
interleaveAddressFamilies [(Maybe SockAddr, AddrInfo)]
xs = forall a. [a] -> [a] -> [a]
interleave [(Maybe SockAddr, AddrInfo)]
sixes [(Maybe SockAddr, AddrInfo)]
others
  where
    ([(Maybe SockAddr, AddrInfo)]
sixes, [(Maybe SockAddr, AddrInfo)]
others) = forall a. (a -> Bool) -> [a] -> ([a], [a])
partition forall {a}. (a, AddrInfo) -> Bool
is6 [(Maybe SockAddr, AddrInfo)]
xs
    is6 :: (a, AddrInfo) -> Bool
is6 (a, AddrInfo)
x = Family
Socket.AF_INET6 forall a. Eq a => a -> a -> Bool
== AddrInfo -> Family
Socket.addrFamily (forall a b. (a, b) -> b
snd (a, AddrInfo)
x)

    interleave :: [a] -> [a] -> [a]
interleave (a
x:[a]
xs) (a
y:[a]
ys) = a
x forall a. a -> [a] -> [a]
: a
y forall a. a -> [a] -> [a]
: [a] -> [a] -> [a]
interleave [a]
xs [a]
ys
    interleave []     [a]
ys     = [a]
ys
    interleave [a]
xs     []     = [a]
xs

-- | Create a socket and connect to the service identified
-- by the given 'AddrInfo' and return the connected socket.
connectToAddrInfo :: Maybe SockAddr -> AddrInfo -> IO Socket
connectToAddrInfo :: Maybe SockAddr -> AddrInfo -> IO Socket
connectToAddrInfo Maybe SockAddr
mbSrc AddrInfo
info
  = let addr :: SockAddr
addr = AddrInfo -> SockAddr
Socket.addrAddress AddrInfo
info in
    forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (AddrInfo -> IO Socket
socket' AddrInfo
info) Socket -> IO ()
Socket.close \Socket
s ->
    do forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (Socket -> SockAddr -> IO ()
bind' Socket
s) Maybe SockAddr
mbSrc
       Socket -> SockAddr -> IO ()
Socket.connect Socket
s SockAddr
addr
       forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
s
    forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (forall e a. Exception e => e -> IO a
throwIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. SockAddr -> IOError -> ConnectError
ConnectError SockAddr
addr)

-- | A version of 'Socket.bind' that doesn't bother binding on the wildcard
-- address. The effect of binding on a wildcard address in this library
-- is to pick an address family. Because of the matching done earlier this
-- is unnecessary for client connections and causes a local port to be
-- unnecessarily fixed early.
bind' :: Socket -> SockAddr -> IO ()
bind' :: Socket -> SockAddr -> IO ()
bind' Socket
_ (Socket.SockAddrInet PortNumber
_ HostAddress
0) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
bind' Socket
_ (Socket.SockAddrInet6 PortNumber
_ HostAddress
_ (HostAddress
0,HostAddress
0,HostAddress
0,HostAddress
0) HostAddress
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
bind' Socket
s SockAddr
a = Socket -> SockAddr -> IO ()
Socket.bind Socket
s SockAddr
a

-- | Open a 'Socket' using the parameters from an 'AddrInfo'
socket' :: AddrInfo -> IO Socket
socket' :: AddrInfo -> IO Socket
socket' AddrInfo
ai =
  Family -> SocketType -> ProtocolNumber -> IO Socket
Socket.socket
    (AddrInfo -> Family
Socket.addrFamily     AddrInfo
ai)
    (AddrInfo -> SocketType
Socket.addrSocketType AddrInfo
ai)
    (AddrInfo -> ProtocolNumber
Socket.addrProtocol   AddrInfo
ai)


------------------------------------------------------------------------
-- Generalization of Socket
------------------------------------------------------------------------

data NetworkHandle = SSL (Maybe X509) SSL | Socket Socket

openNetworkHandle ::
  ConnectionParams {- ^ parameters             -} ->
  IO Socket        {- ^ socket creation action -} ->
  IO NetworkHandle {- ^ open network handle    -}
openNetworkHandle :: ConnectionParams -> IO Socket -> IO NetworkHandle
openNetworkHandle ConnectionParams
params IO Socket
mkSocket =
  case ConnectionParams -> Maybe TlsParams
cpTls ConnectionParams
params of
    Maybe TlsParams
Nothing  -> Socket -> NetworkHandle
Socket forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Socket
mkSocket
    Just TlsParams
tls ->
        do (Maybe X509
clientCert, SSL
ssl) <- TlsParams -> HostName -> IO Socket -> IO (Maybe X509, SSL)
startTls TlsParams
tls (ConnectionParams -> HostName
cpHost ConnectionParams
params) IO Socket
mkSocket
           forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe X509 -> SSL -> NetworkHandle
SSL Maybe X509
clientCert SSL
ssl)

closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle :: NetworkHandle -> IO ()
closeNetworkHandle (Socket Socket
s) = Socket -> IO ()
Socket.close Socket
s
closeNetworkHandle (SSL Maybe X509
_ SSL
s) =
  do SSL -> ShutdownType -> IO ()
SSL.shutdown SSL
s ShutdownType
SSL.Unidirectional
     forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ Socket -> IO ()
Socket.close (SSL -> Maybe Socket
SSL.sslSocket SSL
s)

networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend :: NetworkHandle -> ByteString -> IO ()
networkSend (Socket Socket
s) = Socket -> ByteString -> IO ()
SocketB.sendAll Socket
s
networkSend (SSL  Maybe X509
_ SSL
s) = SSL -> ByteString -> IO ()
SSL.write       SSL
s

networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv :: NetworkHandle -> Int -> IO ByteString
networkRecv (Socket Socket
s) = Socket -> Int -> IO ByteString
SocketB.recv Socket
s
networkRecv (SSL  Maybe X509
_ SSL
s) = SSL -> Int -> IO ByteString
SSL.read     SSL
s

------------------------------------------------------------------------
-- Sockets with a receive buffer
------------------------------------------------------------------------

-- | A connection to a network service along with its read buffer
-- used for line-oriented protocols. The connection could be a plain
-- network connection, SOCKS connected, or TLS.
data Connection =
  Connection
  {-# UNPACK #-} !(MVar ByteString)
  {-# UNPACK #-} !(MVar NetworkHandle)

-- | Open network connection to TCP service specified by
-- the given parameters.
--
-- The resulting connection MUST be closed with 'close' to avoid leaking
-- resources.
--
-- Throws 'IOError', 'SocksError', 'SSL.ProtocolError', 'ConnectionFailure'
connect ::
  ConnectionParams {- ^ parameters      -} ->
  IO Connection    {- ^ open connection -}
connect :: ConnectionParams -> IO Connection
connect ConnectionParams
params =
 do NetworkHandle
h <- ConnectionParams -> IO Socket -> IO NetworkHandle
openNetworkHandle ConnectionParams
params (ConnectionParams -> IO Socket
openSocket ConnectionParams
params)
    MVar ByteString -> MVar NetworkHandle -> Connection
Connection forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
newMVar ByteString
B.empty forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (MVar a)
newMVar NetworkHandle
h

-- | Create a new 'Connection' using an already connected socket.
-- This will attempt to start TLS if configured but will ignore
-- any SOCKS server settings as it is assumed that the socket
-- is already actively connected to the intended service.
--
-- Throws 'SSL.ProtocolError'
connectWithSocket ::
  ConnectionParams {- ^ parameters       -} ->
  Socket           {- ^ connected socket -} ->
  IO Connection    {- ^ open connection  -}
connectWithSocket :: ConnectionParams -> Socket -> IO Connection
connectWithSocket ConnectionParams
params Socket
sock =
 do NetworkHandle
h <- ConnectionParams -> IO Socket -> IO NetworkHandle
openNetworkHandle ConnectionParams
params (forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
sock)
    MVar ByteString -> MVar NetworkHandle -> Connection
Connection forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (MVar a)
newMVar ByteString
B.empty forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> IO (MVar a)
newMVar NetworkHandle
h

-- | Close network connection.
close ::
  Connection {- ^ open connection -} ->
  IO ()
close :: Connection -> IO ()
close (Connection MVar ByteString
_ MVar NetworkHandle
m) = forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar NetworkHandle
m \NetworkHandle
h -> NetworkHandle -> IO ()
closeNetworkHandle NetworkHandle
h

-- | Receive the next chunk from the stream. This operation will first
-- return the buffer if it contains a non-empty chunk. Otherwise it will
-- request up to the requested number of bytes from the stream.
--
-- Throws: 'IOError', 'SSL.ConnectionAbruptlyTerminated', 'SSL.ProtocolError'
recv ::
  Connection    {- ^ open connection              -} ->
  Int           {- ^ maximum underlying recv size -} ->
  IO ByteString {- ^ next chunk from stream       -}
recv :: Connection -> Int -> IO ByteString
recv (Connection MVar ByteString
bufVar MVar NetworkHandle
hVar) Int
n =
  forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar ByteString
bufVar \ByteString
bufChunk ->
  if ByteString -> Bool
B.null ByteString
bufChunk then
   do NetworkHandle
h <- forall a. MVar a -> IO a
readMVar MVar NetworkHandle
hVar
      ByteString
bs <- NetworkHandle -> Int -> IO ByteString
networkRecv NetworkHandle
h Int
n
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
B.empty, ByteString
bs)
  else
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
B.empty, ByteString
bufChunk)

-- | Receive a line from the network connection. Both
-- @"\\r\\n"@ and @"\\n"@ are recognized.
--
-- Returning 'Nothing' means that the peer has closed its half of
-- the connection.
--
-- Unterminated lines will raise a 'LineTruncated' exception. This
-- can happen if the peer transmits some data and closes its end
-- without transmitting a line terminator.
--
-- Throws: 'SSL.ConnectionAbruptlyTerminated', 'SSL.ProtocolError', 'ConnectionFailure', 'IOError'
recvLine ::
  Connection            {- ^ open connection            -} ->
  Int                   {- ^ maximum line length        -} ->
  IO (Maybe ByteString) {- ^ next line or end-of-stream -}
recvLine :: Connection -> Int -> IO (Maybe ByteString)
recvLine (Connection MVar ByteString
bufVar MVar NetworkHandle
hVar) Int
n =
  forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar ByteString
bufVar \ByteString
bs ->
   do NetworkHandle
h <- forall a. MVar a -> IO a
readMVar MVar NetworkHandle
hVar
      NetworkHandle
-> Int
-> ByteString
-> [ByteString]
-> IO (ByteString, Maybe ByteString)
go NetworkHandle
h (ByteString -> Int
B.length ByteString
bs) ByteString
bs []
  where
    -- bsn: cached length of concatenation of (bs:bss)
    -- bs : most recent chunk
    -- bss: other chunks ordered from most to least recent
    go :: NetworkHandle
-> Int
-> ByteString
-> [ByteString]
-> IO (ByteString, Maybe ByteString)
go NetworkHandle
h Int
bsn ByteString
bs [ByteString]
bss =
      case Char -> ByteString -> Maybe Int
B8.elemIndex Char
'\n' ByteString
bs of
        Just Int
i -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (HasCallStack => ByteString -> ByteString
B.tail ByteString
b, -- tail drops newline
                        forall a. a -> Maybe a
Just (ByteString -> ByteString
cleanEnd ([ByteString] -> ByteString
B.concat (forall a. [a] -> [a]
reverse (ByteString
aforall a. a -> [a] -> [a]
:[ByteString]
bss)))))
          where
            (ByteString
a,ByteString
b) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
i ByteString
bs
        Maybe Int
Nothing ->
          do forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
bsn forall a. Ord a => a -> a -> Bool
>= Int
n) (forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
LineTooLong)
             ByteString
more <- NetworkHandle -> Int -> IO ByteString
networkRecv NetworkHandle
h Int
n
             if ByteString -> Bool
B.null ByteString
more -- connection closed
               then if Int
bsn forall a. Eq a => a -> a -> Bool
== Int
0 then forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
B.empty, forall a. Maybe a
Nothing)
                                else forall e a. Exception e => e -> IO a
throwIO ConnectionFailure
LineTruncated
               else NetworkHandle
-> Int
-> ByteString
-> [ByteString]
-> IO (ByteString, Maybe ByteString)
go NetworkHandle
h (Int
bsn forall a. Num a => a -> a -> a
+ ByteString -> Int
B.length ByteString
more) ByteString
more (ByteString
bsforall a. a -> [a] -> [a]
:[ByteString]
bss)

-- | Push a 'ByteString' onto the buffer so that it will be the first
-- bytes to be read on the next receive operation. This could perhaps
-- be useful for putting the unused portion of a 'recv' back into the
-- buffer for future 'recvLine' or 'recv' operations.
putBuf ::
  Connection {- ^ connection         -} ->
  ByteString {- ^ new head of buffer -} ->
  IO ()
putBuf :: Connection -> ByteString -> IO ()
putBuf (Connection MVar ByteString
bufVar MVar NetworkHandle
_) ByteString
bs =
  forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ByteString
bufVar (\ByteString
old -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! ByteString -> ByteString -> ByteString
B.append ByteString
bs ByteString
old)

-- | Remove the trailing @'\\r'@ if one is found.
cleanEnd :: ByteString -> ByteString
cleanEnd :: ByteString -> ByteString
cleanEnd ByteString
bs
  | ByteString -> Bool
B.null ByteString
bs Bool -> Bool -> Bool
|| ByteString -> Char
B8.last ByteString
bs forall a. Eq a => a -> a -> Bool
/= Char
'\r' = ByteString
bs
  | Bool
otherwise                       = HasCallStack => ByteString -> ByteString
B.init ByteString
bs

-- | Send bytes on the network connection. This ensures the whole chunk is
-- transmitted, which might take multiple underlying sends.
--
-- Throws: 'IOError', 'SSL.ProtocolError'
send ::
  Connection {- ^ open connection -} ->
  ByteString {- ^ chunk           -} ->
  IO ()
send :: Connection -> ByteString -> IO ()
send (Connection MVar ByteString
_ MVar NetworkHandle
hVar) ByteString
bs =
  do NetworkHandle
h <- forall a. MVar a -> IO a
readMVar MVar NetworkHandle
hVar
     NetworkHandle -> ByteString -> IO ()
networkSend NetworkHandle
h ByteString
bs

upgradeTls ::
  TlsParams {- ^ connection params -} ->
  String {- ^ hostname -} ->
  Connection ->
  IO ()
upgradeTls :: TlsParams -> HostName -> Connection -> IO ()
upgradeTls TlsParams
tp HostName
hostname (Connection MVar ByteString
bufVar MVar NetworkHandle
hVar) =
  forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar ByteString
bufVar \ByteString
buf ->
  forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar  MVar NetworkHandle
hVar   \NetworkHandle
h ->
  case NetworkHandle
h of
    SSL{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (NetworkHandle
h, ByteString
buf)
    Socket Socket
s ->
      do (Maybe X509
cert, SSL
ssl) <- TlsParams -> HostName -> IO Socket -> IO (Maybe X509, SSL)
startTls TlsParams
tp HostName
hostname (forall (f :: * -> *) a. Applicative f => a -> f a
pure Socket
s)
         forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe X509 -> SSL -> NetworkHandle
SSL Maybe X509
cert SSL
ssl, ByteString
B.empty)

------------------------------------------------------------------------

-- | Initiate a TLS session on the given socket destined for
-- the given hostname. When successful an active TLS connection
-- is returned with certificate verification successful when
-- requested. This function requires that the TLSParams component
-- of 'ConnectionParams' is set.
startTls ::
  TlsParams {- ^ connection params      -} ->
  String    {- ^ hostname               -} ->
  IO Socket {- ^ socket creation action -} ->
  IO (Maybe X509, SSL) {- ^ (client certificate, connected TLS) -}
startTls :: TlsParams -> HostName -> IO Socket -> IO (Maybe X509, SSL)
startTls TlsParams
tp HostName
hostname IO Socket
mkSocket = forall a. IO a -> IO a
SSL.withOpenSSL forall a b. (a -> b) -> a -> b
$
 do SSLContext
ctx <- IO SSLContext
SSL.context

    -- configure context
    SSLContext -> HostName -> IO ()
SSL.contextSetCiphers          SSLContext
ctx (TlsParams -> HostName
tpCipherSuite TlsParams
tp)
    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
(a -> f b) -> t a -> f ()
traverse_ (SSLContext -> HostName -> IO ()
contextSetTls13Ciphers SSLContext
ctx) (TlsParams -> Maybe HostName
tpCipherSuiteTls13 TlsParams
tp)
    
    case TlsParams -> TlsVerify
tpVerify TlsParams
tp of
      TlsVerify
VerifyNone -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      TlsVerify
VerifyDefault ->
       do SSLContext -> HostName -> IO ()
installVerification SSLContext
ctx HostName
hostname
          SSLContext -> VerificationMode -> IO ()
SSL.contextSetVerificationMode SSLContext
ctx VerificationMode
verifyPeer
      VerifyHostname HostName
h ->
       do SSLContext -> HostName -> IO ()
installVerification SSLContext
ctx HostName
h
          SSLContext -> VerificationMode -> IO ()
SSL.contextSetVerificationMode SSLContext
ctx VerificationMode
verifyPeer
    
    SSLContext -> SSLOption -> IO ()
SSL.contextAddOption           SSLContext
ctx SSLOption
SSL.SSL_OP_ALL
    SSLContext -> SSLOption -> IO ()
SSL.contextRemoveOption        SSLContext
ctx SSLOption
SSL.SSL_OP_DONT_INSERT_EMPTY_FRAGMENTS

    -- configure certificates
    SSLContext -> Maybe HostName -> IO ()
setupCaCertificates SSLContext
ctx (TlsParams -> Maybe HostName
tpServerCertificate TlsParams
tp)
    Maybe X509
clientCert <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SSLContext -> HostName -> IO X509
setupCertificate SSLContext
ctx) (TlsParams -> Maybe HostName
tpClientCertificate TlsParams
tp)

    forall (t :: * -> *) (f :: * -> *) a b.
(Foldable t, Applicative f) =>
t a -> (a -> f b) -> f ()
for_ (TlsParams -> Maybe HostName
tpClientPrivateKey TlsParams
tp) \HostName
path ->
      forall a. SSLContext -> Maybe ByteString -> IO a -> IO a
withDefaultPassword SSLContext
ctx (TlsParams -> Maybe ByteString
tpClientPrivateKeyPassword TlsParams
tp) forall a b. (a -> b) -> a -> b
$
        SSLContext -> HostName -> IO ()
SSL.contextSetPrivateKeyFile SSLContext
ctx HostName
path

    -- add socket to context
    -- creation of the socket is delayed until this point to avoid
    -- leaking the file descriptor in the cases of exceptions above.
    SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Socket
mkSocket

    -- configure hostname used for SNI
    Bool
isip <- HostName -> IO Bool
isIpAddress HostName
hostname
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
isip (SSL -> HostName -> IO ()
SSL.setTlsextHostName SSL
ssl HostName
hostname)

    SSL -> IO ()
SSL.connect SSL
ssl

    forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe X509
clientCert, SSL
ssl)

isIpAddress :: HostName -> IO Bool
isIpAddress :: HostName -> IO Bool
isIpAddress HostName
host =
 do Either IOError [AddrInfo]
res <- forall e a. Exception e => IO a -> IO (Either e a)
try (Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
Socket.getAddrInfo
                  (forall a. a -> Maybe a
Just AddrInfo
Socket.defaultHints{addrFlags :: [AddrInfoFlag]
Socket.addrFlags=[AddrInfoFlag
Socket.AI_NUMERICHOST]})
                  (forall a. a -> Maybe a
Just HostName
host) forall a. Maybe a
Nothing)
    case Either IOError [AddrInfo]
res :: Either IOError [AddrInfo] of
      Right{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
True
      Left {} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

setupCaCertificates :: SSLContext -> Maybe FilePath -> IO ()
setupCaCertificates :: SSLContext -> Maybe HostName -> IO ()
setupCaCertificates SSLContext
ctx Maybe HostName
mbPath =
  case Maybe HostName
mbPath of
    Maybe HostName
Nothing   -> SSLContext -> IO ()
contextLoadSystemCerts SSLContext
ctx
    Just HostName
path -> forall a. SSLContext -> Maybe ByteString -> IO a -> IO a
withDefaultPassword SSLContext
ctx forall a. Maybe a
Nothing (SSLContext -> HostName -> IO ()
SSL.contextSetCAFile SSLContext
ctx HostName
path)

setupCertificate :: SSLContext -> FilePath -> IO X509
setupCertificate :: SSLContext -> HostName -> IO X509
setupCertificate SSLContext
ctx HostName
path =
  do X509
x509 <- HostName -> IO X509
PEM.readX509 forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< HostName -> IO HostName
readFile HostName
path -- EX
     SSLContext -> X509 -> IO ()
SSL.contextSetCertificate SSLContext
ctx X509
x509
     forall (f :: * -> *) a. Applicative f => a -> f a
pure X509
x509

verifyPeer :: SSL.VerificationMode
verifyPeer :: VerificationMode
verifyPeer = SSL.VerifyPeer
  { vpFailIfNoPeerCert :: Bool
SSL.vpFailIfNoPeerCert = Bool
True
  , vpClientOnce :: Bool
SSL.vpClientOnce       = Bool
True
  , vpCallback :: Maybe (Bool -> X509StoreCtx -> IO Bool)
SSL.vpCallback         = forall a. Maybe a
Nothing
  }

-- | Get peer certificate if one exists.
getPeerCertificate :: Connection -> IO (Maybe X509.X509)
getPeerCertificate :: Connection -> IO (Maybe X509)
getPeerCertificate (Connection MVar ByteString
_ MVar NetworkHandle
hVar) =
  forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar NetworkHandle
hVar \case
    Socket{} -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    SSL Maybe X509
_ SSL
ssl -> SSL -> IO (Maybe X509)
SSL.getPeerCertificate SSL
ssl

-- | Get peer certificate if one exists.
getClientCertificate :: Connection -> IO (Maybe X509.X509)
getClientCertificate :: Connection -> IO (Maybe X509)
getClientCertificate (Connection MVar ByteString
_ MVar NetworkHandle
hVar) =
 do NetworkHandle
h <- forall a. MVar a -> IO a
readMVar MVar NetworkHandle
hVar
    forall (f :: * -> *) a. Applicative f => a -> f a
pure case NetworkHandle
h of
      Socket{} -> forall a. Maybe a
Nothing
      SSL Maybe X509
c SSL
_  -> Maybe X509
c

getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha1 = HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint HostName
"sha1"

getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha256 = HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint HostName
"sha256"

getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerCertFingerprintSha512 = HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint HostName
"sha512"

getPeerCertFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint :: HostName -> Connection -> IO (Maybe ByteString)
getPeerCertFingerprint HostName
name Connection
h =
   do Maybe X509
mb <- Connection -> IO (Maybe X509)
getPeerCertificate Connection
h
      case Maybe X509
mb of
        Maybe X509
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
        Just X509
x509 ->
         do ByteString
der <- X509 -> IO ByteString
X509.writeDerX509 X509
x509
            Maybe Digest
mbdigest <- HostName -> IO (Maybe Digest)
Digest.getDigestByName HostName
name
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! case Maybe Digest
mbdigest of
              Maybe Digest
Nothing -> forall a. Maybe a
Nothing
              Just Digest
digest -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$! Digest -> ByteString -> ByteString
Digest.digestLBS Digest
digest ByteString
der

getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha1 = HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint HostName
"sha1"

getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha256 = HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint HostName
"sha256"

getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 :: Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprintSha512 = HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint HostName
"sha512"

getPeerPubkeyFingerprint :: String -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint :: HostName -> Connection -> IO (Maybe ByteString)
getPeerPubkeyFingerprint HostName
name Connection
h =
 do Maybe X509
mb <- Connection -> IO (Maybe X509)
getPeerCertificate Connection
h
    case Maybe X509
mb of
      Maybe X509
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
      Just X509
x509 ->
       do ByteString
der <- X509 -> IO ByteString
getPubKeyDer X509
x509
          Maybe Digest
mbdigest <- HostName -> IO (Maybe Digest)
Digest.getDigestByName HostName
name
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$! case Maybe Digest
mbdigest of
            Maybe Digest
Nothing -> forall a. Maybe a
Nothing
            Just Digest
digest -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$! Digest -> ByteString -> ByteString
Digest.digestBS Digest
digest ByteString
der