{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE CPP #-}
-- | Support for making connections via the OpenSSL library.
module Network.HTTP.Client.OpenSSL
    ( withOpenSSL
    , newOpenSSLManager
    , opensslManagerSettings
    , defaultMakeContext
    , OpenSSLSettings(..)
    , defaultOpenSSLSettings
    ) where

import Network.HTTP.Client
import Network.HTTP.Client.Internal
import Control.Exception
import Control.Monad.IO.Class
import Network.Socket.ByteString (sendAll, recv)
import OpenSSL
import qualified Data.ByteString as S
import qualified Network.Socket as N
import qualified OpenSSL.Session as SSL
import qualified OpenSSL.X509.SystemStore as SSL (contextLoadSystemCerts)
import Foreign.Storable (sizeOf)

-- | Create a new 'Manager' using 'opensslManagerSettings' and 'defaultMakeContext'
-- with 'defaultOpenSSLSettings'.
newOpenSSLManager :: MonadIO m => m Manager
newOpenSSLManager :: m Manager
newOpenSSLManager = IO Manager -> m Manager
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Manager -> m Manager) -> IO Manager -> m Manager
forall a b. (a -> b) -> a -> b
$ do
  -- sharing an SSL context between threads (without modifying it) is safe:
  -- https://github.com/openssl/openssl/issues/2165
  SSLContext
ctx <- OpenSSLSettings -> IO SSLContext
defaultMakeContext OpenSSLSettings
defaultOpenSSLSettings
  ManagerSettings -> IO Manager
newManager (ManagerSettings -> IO Manager) -> ManagerSettings -> IO Manager
forall a b. (a -> b) -> a -> b
$ IO SSLContext -> ManagerSettings
opensslManagerSettings (SSLContext -> IO SSLContext
forall (f :: * -> *) a. Applicative f => a -> f a
pure SSLContext
ctx)

-- | Note that it is the caller's responsibility to pass in an appropriate context.
opensslManagerSettings :: IO SSL.SSLContext -> ManagerSettings
opensslManagerSettings :: IO SSLContext -> ManagerSettings
opensslManagerSettings IO SSLContext
mkContext = ManagerSettings
defaultManagerSettings
    { managerTlsConnection :: IO (Maybe HostAddress -> String -> Int -> IO Connection)
managerTlsConnection = do
        SSLContext
ctx <- IO SSLContext
mkContext
        (Maybe HostAddress -> String -> Int -> IO Connection)
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
forall (m :: * -> *) a. Monad m => a -> m a
return ((Maybe HostAddress -> String -> Int -> IO Connection)
 -> IO (Maybe HostAddress -> String -> Int -> IO Connection))
-> (Maybe HostAddress -> String -> Int -> IO Connection)
-> IO (Maybe HostAddress -> String -> Int -> IO Connection)
forall a b. (a -> b) -> a -> b
$ \Maybe HostAddress
ha' String
host' Int
port' ->
            (Socket -> IO ())
-> Maybe HostAddress
-> String
-> Int
-> (Socket -> IO Connection)
-> IO Connection
forall a.
(Socket -> IO ())
-> Maybe HostAddress -> String -> Int -> (Socket -> IO a) -> IO a
withSocket (IO () -> Socket -> IO ()
forall a b. a -> b -> a
const (IO () -> Socket -> IO ()) -> IO () -> Socket -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) Maybe HostAddress
ha' String
host' Int
port' ((Socket -> IO Connection) -> IO Connection)
-> (Socket -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Socket
sock ->
                SSLContext -> Socket -> String -> IO Connection
makeSSLConnection SSLContext
ctx Socket
sock String
host'
    , managerTlsProxyConnection :: IO
  (ByteString
   -> (Connection -> IO ())
   -> String
   -> Maybe HostAddress
   -> String
   -> Int
   -> IO Connection)
managerTlsProxyConnection = do
        SSLContext
ctx <- IO SSLContext
mkContext
        (ByteString
 -> (Connection -> IO ())
 -> String
 -> Maybe HostAddress
 -> String
 -> Int
 -> IO Connection)
-> IO
     (ByteString
      -> (Connection -> IO ())
      -> String
      -> Maybe HostAddress
      -> String
      -> Int
      -> IO Connection)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ByteString
  -> (Connection -> IO ())
  -> String
  -> Maybe HostAddress
  -> String
  -> Int
  -> IO Connection)
 -> IO
      (ByteString
       -> (Connection -> IO ())
       -> String
       -> Maybe HostAddress
       -> String
       -> Int
       -> IO Connection))
-> (ByteString
    -> (Connection -> IO ())
    -> String
    -> Maybe HostAddress
    -> String
    -> Int
    -> IO Connection)
-> IO
     (ByteString
      -> (Connection -> IO ())
      -> String
      -> Maybe HostAddress
      -> String
      -> Int
      -> IO Connection)
forall a b. (a -> b) -> a -> b
$ \ByteString
connstr Connection -> IO ()
checkConn String
serverName Maybe HostAddress
_ha String
host' Int
port' ->
            (Socket -> IO ())
-> Maybe HostAddress
-> String
-> Int
-> (Socket -> IO Connection)
-> IO Connection
forall a.
(Socket -> IO ())
-> Maybe HostAddress -> String -> Int -> (Socket -> IO a) -> IO a
withSocket (IO () -> Socket -> IO ()
forall a b. a -> b -> a
const (IO () -> Socket -> IO ()) -> IO () -> Socket -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()) Maybe HostAddress
forall a. Maybe a
Nothing String
host' Int
port' ((Socket -> IO Connection) -> IO Connection)
-> (Socket -> IO Connection) -> IO Connection
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
                Connection
conn <- IO ByteString -> (ByteString -> IO ()) -> IO () -> IO Connection
makeConnection
                        (Socket -> Int -> IO ByteString
recv Socket
sock Int
bufSize)
                        (Socket -> ByteString -> IO ()
sendAll Socket
sock)
                        (() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ())
                Connection -> ByteString -> IO ()
connectionWrite Connection
conn ByteString
connstr
                Connection -> IO ()
checkConn Connection
conn
                SSLContext -> Socket -> String -> IO Connection
makeSSLConnection SSLContext
ctx Socket
sock String
serverName

    , managerRetryableException :: SomeException -> Bool
managerRetryableException = \SomeException
se ->
        case () of
          ()
            | Just (ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) <- SomeException -> Maybe ConnectionAbruptlyTerminated
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se -> Bool
True
            | Bool
otherwise -> ManagerSettings -> SomeException -> Bool
managerRetryableException ManagerSettings
defaultManagerSettings SomeException
se

    , managerWrapException :: forall a. Request -> IO a -> IO a
managerWrapException = \Request
req ->
        let
          wrap :: SomeException -> SomeException
wrap SomeException
se
            | Just (IOException
_ :: IOException)                      <- SomeException -> Maybe IOException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Just (SomeSSLException
_ :: SSL.SomeSSLException)             <- SomeException -> Maybe SomeSSLException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Just (ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) <- SomeException -> Maybe ConnectionAbruptlyTerminated
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Just (ProtocolError
_ :: SSL.ProtocolError)                <- SomeException -> Maybe ProtocolError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
se = SomeException
se'
            | Bool
otherwise                                                        = SomeException
se
            where
              se' :: SomeException
se' = HttpException -> SomeException
forall e. Exception e => e -> SomeException
toException (Request -> HttpExceptionContent -> HttpException
HttpExceptionRequest Request
req (SomeException -> HttpExceptionContent
InternalException SomeException
se))
        in
          (SomeException -> IO a) -> IO a -> IO a
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO (SomeException -> IO a)
-> (SomeException -> SomeException) -> SomeException -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
wrap)
    }
  where
    makeSSLConnection :: SSLContext -> Socket -> String -> IO Connection
makeSSLConnection SSLContext
ctx Socket
sock String
host = do
        SSL
ssl <- SSLContext -> Socket -> IO SSL
SSL.connection SSLContext
ctx Socket
sock
        SSL -> String -> IO ()
SSL.setTlsextHostName SSL
ssl String
host
        SSL -> String -> IO ()
SSL.enableHostnameValidation SSL
ssl String
host
        SSL -> IO ()
SSL.connect SSL
ssl
        IO ByteString -> (ByteString -> IO ()) -> IO () -> IO Connection
makeConnection
           (SSL -> Int -> IO ByteString
SSL.read SSL
ssl Int
bufSize IO ByteString
-> (ConnectionAbruptlyTerminated -> IO ByteString) -> IO ByteString
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \(ConnectionAbruptlyTerminated
_ :: SSL.ConnectionAbruptlyTerminated) -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
S.empty)
           -- Handling SSL.ConnectionAbruptlyTerminated as a stream end
           -- (some sites terminate SSL connection right after returning the data).
           (SSL -> ByteString -> IO ()
SSL.write SSL
ssl)
           (Socket -> IO ()
N.close Socket
sock)

-- same as Data.ByteString.Lazy.Internal.defaultChunkSize
bufSize :: Int
bufSize :: Int
bufSize = Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
1024 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
overhead
    where overhead :: Int
overhead = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int -> Int
forall a. Storable a => a -> Int
sizeOf (Int
forall a. HasCallStack => a
undefined :: Int)

defaultMakeContext :: OpenSSLSettings -> IO SSL.SSLContext
defaultMakeContext :: OpenSSLSettings -> IO SSLContext
defaultMakeContext OpenSSLSettings{String
[SSLOption]
VerificationMode
SSLContext -> IO ()
osslSettingsLoadCerts :: OpenSSLSettings -> SSLContext -> IO ()
osslSettingsCiphers :: OpenSSLSettings -> String
osslSettingsVerifyMode :: OpenSSLSettings -> VerificationMode
osslSettingsOptions :: OpenSSLSettings -> [SSLOption]
osslSettingsLoadCerts :: SSLContext -> IO ()
osslSettingsCiphers :: String
osslSettingsVerifyMode :: VerificationMode
osslSettingsOptions :: [SSLOption]
..} = do
    SSLContext
ctx <- IO SSLContext
SSL.context
    SSLContext -> VerificationMode -> IO ()
SSL.contextSetVerificationMode SSLContext
ctx VerificationMode
osslSettingsVerifyMode
    SSLContext -> String -> IO ()
SSL.contextSetCiphers SSLContext
ctx String
osslSettingsCiphers
    (SSLOption -> IO ()) -> [SSLOption] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SSLContext -> SSLOption -> IO ()
SSL.contextAddOption SSLContext
ctx) [SSLOption]
osslSettingsOptions
    SSLContext -> IO ()
osslSettingsLoadCerts SSLContext
ctx
    SSLContext -> IO SSLContext
forall (m :: * -> *) a. Monad m => a -> m a
return SSLContext
ctx

data OpenSSLSettings = OpenSSLSettings
    { OpenSSLSettings -> [SSLOption]
osslSettingsOptions :: [SSL.SSLOption]
    , OpenSSLSettings -> VerificationMode
osslSettingsVerifyMode :: SSL.VerificationMode
    , OpenSSLSettings -> String
osslSettingsCiphers :: String
    , OpenSSLSettings -> SSLContext -> IO ()
osslSettingsLoadCerts :: SSL.SSLContext -> IO ()
    }

-- | Default OpenSSL settings. In particular:
--
--  * SSLv2 and SSLv3 are disabled
--  * Hostname validation
--  * @DEFAULT@ cipher list
--  * Certificates loaded from OS-specific store
--
-- Note that these settings might change in the future.
defaultOpenSSLSettings :: OpenSSLSettings
defaultOpenSSLSettings :: OpenSSLSettings
defaultOpenSSLSettings = OpenSSLSettings :: [SSLOption]
-> VerificationMode
-> String
-> (SSLContext -> IO ())
-> OpenSSLSettings
OpenSSLSettings
    { osslSettingsOptions :: [SSLOption]
osslSettingsOptions =
        [ SSLOption
SSL.SSL_OP_ALL -- enable bug workarounds
        , SSLOption
SSL.SSL_OP_NO_SSLv2
        , SSLOption
SSL.SSL_OP_NO_SSLv3
        ]
    , osslSettingsVerifyMode :: VerificationMode
osslSettingsVerifyMode = VerifyPeer :: Bool
-> Bool
-> Maybe (Bool -> X509StoreCtx -> IO Bool)
-> VerificationMode
SSL.VerifyPeer
        -- vpFailIfNoPeerCert and vpClientOnce are only relevant for servers
        { vpFailIfNoPeerCert :: Bool
SSL.vpFailIfNoPeerCert = Bool
False
        , vpClientOnce :: Bool
SSL.vpClientOnce = Bool
False
        , vpCallback :: Maybe (Bool -> X509StoreCtx -> IO Bool)
SSL.vpCallback = Maybe (Bool -> X509StoreCtx -> IO Bool)
forall a. Maybe a
Nothing
        }
    , osslSettingsCiphers :: String
osslSettingsCiphers = String
"DEFAULT"
    , osslSettingsLoadCerts :: SSLContext -> IO ()
osslSettingsLoadCerts = SSLContext -> IO ()
SSL.contextLoadSystemCerts
    }