{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}

{- |
Copyright   : (c) 2018-2021 Tim Emiola
SPDX-License-Identifier: BSD3
Maintainer  : Tim Emiola <tim.emiola@gmail.com>

Provides a [WAI](https://hackage.haskell.com/packages/wai) middleware that
delegates handling of requests.

Provides 3 combinators that create middleware along with supporting data types.

* 'delegateTo': delegates handling of requests matching a predicate to a
  delegate Application

* 'delegateToProxy': delegates handling of requests matching a predicate to
  different host

* 'simpleProxy': is a simple reverse proxy, based on proxyApp of http-proxy by
  Erik de Castro Lopo/Michael Snoyman
-}
module Network.Wai.Middleware.Delegate
  ( -- * Middleware
    delegateTo
  , delegateToProxy
  , simpleProxy

    -- * Configuration
  , ProxySettings (..)

    -- * Aliases
  , RequestPredicate
  )
where

import Blaze.ByteString.Builder (fromByteString)
import Control.Concurrent.Async (race_)
import Control.Exception
  ( SomeException
  , handle
  , toException
  )
import Control.Monad (unless)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy.Char8 as LC8
import Data.CaseInsensitive (mk)
import Data.Conduit
  ( ConduitM
  , ConduitT
  , Flush (..)
  , Void
  , await
  , mapOutput
  , runConduit
  , yield
  , ($$+)
  , ($$++)
  , (.|)
  )
import Data.Conduit.Network (appSink, appSource)
import Data.Default (Default (..))
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.Int (Int64)
import Data.Monoid ((<>))
import Data.Streaming.Network
  ( ClientSettings
  , clientSettingsTCP
  , runTCPClient
  )
import Data.String (IsString)
import Network.HTTP.Client
  ( BodyReader
  , GivesPopper
  , Manager
  , Request (..)
  , RequestBody (..)
  , Response (..)
  , brRead
  , parseRequest
  , withResponse
  )
import Network.HTTP.Types
  ( Header
  , HeaderName
  , hContentType
  , internalServerError500
  , status304
  , status500
  )
import Network.HTTP.Types.Header (hHost)
import qualified Network.Wai as Wai
import Network.Wai.Conduit
  ( responseRawSource
  , responseSource
  , sourceRequestBody
  )


{- | Type alias for a function that determines if a request should be handled by
 a delegate.
-}
type RequestPredicate = Wai.Request -> Bool


{- | Create a middleware that handles all requests matching a predicate by
 delegating to an alternate Application.
-}
delegateTo :: Wai.Application -> RequestPredicate -> Wai.Middleware
delegateTo :: Application -> RequestPredicate -> Middleware
delegateTo Application
alt RequestPredicate
f Application
actual Request
req
  | RequestPredicate
f Request
req = Application
alt Request
req
  | Bool
otherwise = Application
actual Request
req


{- | Creates a middleware that handles all requests matching a predicate by
 proxing them to a host specified by ProxySettings.
-}
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Wai.Middleware
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Middleware
delegateToProxy ProxySettings
settings Manager
mgr = Application -> RequestPredicate -> Middleware
delegateTo (ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
mgr)


-- | Settings that configure the proxy endpoint.
data ProxySettings = ProxySettings
  { ProxySettings -> SomeException -> Response
proxyOnException :: SomeException -> Wai.Response
  -- ^ What to do with exceptions thrown by either the application or server.
  , ProxySettings -> Int
proxyTimeout :: Int
  -- ^ Timeout value in seconds. Default value: 30
  , ProxySettings -> ByteString
proxyHost :: BS.ByteString
  -- ^ The host being proxied
  , ProxySettings -> Int
proxyRedirectCount :: Int
  -- ^ The number of redirects to follow. 0 means none, which is the default.
  }


instance Default ProxySettings where
  def :: ProxySettings
def =
    ProxySettings
      { -- defaults to returning internal server error showing the error in the body
        proxyOnException :: SomeException -> Response
proxyOnException = SomeException -> Response
onException
      , -- default to 15 seconds
        proxyTimeout :: Int
proxyTimeout = Int
15
      , proxyHost :: ByteString
proxyHost = ByteString
"localhost"
      , proxyRedirectCount :: Int
proxyRedirectCount = Int
0
      }
    where
      onException :: SomeException -> Wai.Response
      onException :: SomeException -> Response
onException SomeException
e =
        Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS
          Status
internalServerError500
          [(HeaderName
hContentType, ByteString
"text/plain; charset=utf-8")]
          (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
LC8.fromChunks [String -> ByteString
C8.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e]


-- | A Wai Application that acts as a http/https proxy.
simpleProxy ::
  ProxySettings ->
  Manager ->
  Wai.Application
simpleProxy :: ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
manager Request
req Response -> IO ResponseReceived
respond
  -- we may connect requests to secure sites, when we do, we will not have
  -- seen their URI properly
  | Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT" =
      Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
        (Source IO ByteString -> Sink ByteString IO () -> IO ())
-> Response -> Response
forall (m :: * -> *) (n :: * -> *).
(MonadIO m, MonadIO n) =>
(Source m ByteString -> Sink ByteString n () -> IO ())
-> Response -> Response
responseRawSource
          (Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req)
          (Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
status500 [(HeaderName
"Content-Type", ByteString
"text/plain")] ByteString
"method CONNECT is not supported")
  | Bool
otherwise = do
      let scheme :: String
scheme
            | RequestPredicate
Wai.isSecure Request
req = String
"https"
            | Bool
otherwise = String
"http"
          rawUrl :: ByteString
rawUrl = Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req
          effectiveUrl :: String
effectiveUrl = String
scheme String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"://" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ProxySettings -> ByteString
proxyHost ProxySettings
settings) String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack ByteString
rawUrl
          newHost :: ByteString
newHost = ProxySettings -> ByteString
proxyHost ProxySettings
settings
          addHostHeader :: ResponseHeaders -> ResponseHeaders
addHostHeader = (:) (HeaderName
hHost, ByteString
newHost)

      Request
proxyReq' <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
effectiveUrl
      let onException :: SomeException -> Wai.Response
          onException :: SomeException -> Response
onException = ProxySettings -> SomeException -> Response
proxyOnException ProxySettings
settings (SomeException -> Response)
-> (SomeException -> SomeException) -> SomeException -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
forall e. Exception e => e -> SomeException
toException

          proxyReq :: Request
proxyReq =
            Request
proxyReq'
              { method = Wai.requestMethod req
              , requestHeaders = addHostHeader $ filter dropUpstreamHeaders $ Wai.requestHeaders req
              , -- always pass redirects back to the client.
                redirectCount = proxyRedirectCount settings
              , requestBody =
                  case Wai.requestBodyLength req of
                    RequestBodyLength
Wai.ChunkedBody ->
                      Source IO ByteString -> RequestBody
requestBodySourceChunked (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
                    Wai.KnownLength Word64
l ->
                      Int64 -> Source IO ByteString -> RequestBody
requestBodySource (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
l) (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
              , -- don't modify the response to ensure consistency with the response headers
                decompress = const False
              , host = newHost
              }

          respondUpstream :: IO ResponseReceived
respondUpstream = Request
-> Manager
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a.
Request -> Manager -> (Response BodyReader -> IO a) -> IO a
withResponse Request
proxyReq Manager
manager ((Response BodyReader -> IO ResponseReceived)
 -> IO ResponseReceived)
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response BodyReader
res -> do
            let body :: ConduitT i (Flush Builder) IO ()
body = (ByteString -> Flush Builder)
-> ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ()
forall (m :: * -> *) o1 o2 i r.
Monad m =>
(o1 -> o2) -> ConduitT i o1 m r -> ConduitT i o2 m r
mapOutput (Builder -> Flush Builder
forall a. a -> Flush a
Chunk (Builder -> Flush Builder)
-> (ByteString -> Builder) -> ByteString -> Flush Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Builder
fromByteString) (ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ())
-> (BodyReader -> ConduitT i ByteString IO ())
-> BodyReader
-> ConduitT i (Flush Builder) IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyReader -> ConduitT i ByteString IO ()
forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitT i ByteString m ()
bodyReaderSource (BodyReader -> ConduitT i (Flush Builder) IO ())
-> BodyReader -> ConduitT i (Flush Builder) IO ()
forall a b. (a -> b) -> a -> b
$ Response BodyReader -> BodyReader
forall body. Response body -> body
responseBody Response BodyReader
res
                headers :: ResponseHeaders
headers = (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
mk ByteString
"X-Via-Proxy", ByteString
"yes") (HeaderName, ByteString) -> ResponseHeaders -> ResponseHeaders
forall a. a -> [a] -> [a]
: Response BodyReader -> ResponseHeaders
forall body. Response body -> ResponseHeaders
responseHeaders Response BodyReader
res
            Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Source IO (Flush Builder) -> Response
responseSource (Response BodyReader -> Status
forall body. Response body -> Status
responseStatus Response BodyReader
res) ResponseHeaders
headers Source IO (Flush Builder)
forall {i}. ConduitT i (Flush Builder) IO ()
body

      (SomeException -> IO ResponseReceived)
-> IO ResponseReceived -> IO ResponseReceived
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (SomeException -> Response)
-> SomeException
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Response
onException) IO ResponseReceived
respondUpstream


handleConnect ::
  Wai.Request ->
  ConduitT () C8.ByteString IO () ->
  ConduitT C8.ByteString Void IO () ->
  IO ()
handleConnect :: Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req Source IO ByteString
fromClient Sink ByteString IO ()
toClient =
  ClientSettings -> (AppData -> IO ()) -> IO ()
forall a. ClientSettings -> (AppData -> IO a) -> IO a
runTCPClient (Request -> ClientSettings
toClientSettings Request
req) ((AppData -> IO ()) -> IO ()) -> (AppData -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AppData
ad -> do
    ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Source IO ByteString
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
"HTTP/1.1 200 OK\r\n\r\n" Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Sink ByteString IO ()
toClient
    IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
race_
      (ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Source IO ByteString
fromClient Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| AppData -> Sink ByteString IO ()
forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
appSink AppData
ad)
      (ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ AppData -> Source IO ByteString
forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
appSource AppData
ad Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Sink ByteString IO ()
toClient)


defaultClientPort :: Wai.Request -> Int
defaultClientPort :: Request -> Int
defaultClientPort Request
req
  | RequestPredicate
Wai.isSecure Request
req = Int
443
  | Bool
otherwise = Int
90


toClientSettings :: Wai.Request -> ClientSettings
toClientSettings :: Request -> ClientSettings
toClientSettings Request
req =
  case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
C8.break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
Wai.rawPathInfo Request
req of
    (ByteString
host, ByteString
"") -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
    (ByteString
host, ByteString
port') -> case ByteString -> Maybe (Int, ByteString)
C8.readInt (ByteString -> Maybe (Int, ByteString))
-> ByteString -> Maybe (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
C8.drop Int
1 ByteString
port' of
      Just (Int
port, ByteString
_) -> Int -> ByteString -> ClientSettings
clientSettingsTCP Int
port ByteString
host
      Maybe (Int, ByteString)
Nothing -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host


dropUpstreamHeaders :: (HeaderName, b) -> Bool
dropUpstreamHeaders :: forall b. (HeaderName, b) -> Bool
dropUpstreamHeaders (HeaderName
k, b
_) = HeaderName
k HeaderName -> [HeaderName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [HeaderName]
preservedHeaders


preservedHeaders :: [HeaderName]
preservedHeaders :: [HeaderName]
preservedHeaders = [HeaderName
"content-encoding", HeaderName
"content-length", HeaderName
"host"]


type Source' = ConduitT () ByteString IO ()


srcToPopperIO :: Source' -> GivesPopper ()
srcToPopperIO :: Source IO ByteString -> GivesPopper ()
srcToPopperIO Source IO ByteString
src NeedsPopper ()
f = do
  (SealedConduitT () ByteString IO ()
rsrc0, ()) <- Source IO ByteString
src Source IO ByteString
-> Sink ByteString IO ()
-> IO (SealedConduitT () ByteString IO (), ())
forall (m :: * -> *) a b.
Monad m =>
ConduitT () a m ()
-> ConduitT a Void m b -> m (SealedConduitT () a m (), b)
$$+ () -> Sink ByteString IO ()
forall a. a -> ConduitT ByteString Void IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  IORef (SealedConduitT () ByteString IO ())
irsrc <- SealedConduitT () ByteString IO ()
-> IO (IORef (SealedConduitT () ByteString IO ()))
forall a. a -> IO (IORef a)
newIORef SealedConduitT () ByteString IO ()
rsrc0
  let popper :: IO ByteString
      popper :: BodyReader
popper = do
        SealedConduitT () ByteString IO ()
rsrc <- IORef (SealedConduitT () ByteString IO ())
-> IO (SealedConduitT () ByteString IO ())
forall a. IORef a -> IO a
readIORef IORef (SealedConduitT () ByteString IO ())
irsrc
        (SealedConduitT () ByteString IO ()
rsrc', Maybe ByteString
mres) <- SealedConduitT () ByteString IO ()
rsrc SealedConduitT () ByteString IO ()
-> ConduitT ByteString Void IO (Maybe ByteString)
-> IO (SealedConduitT () ByteString IO (), Maybe ByteString)
forall (m :: * -> *) a b.
Monad m =>
SealedConduitT () a m ()
-> ConduitT a Void m b -> m (SealedConduitT () a m (), b)
$$++ ConduitT ByteString Void IO (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
        IORef (SealedConduitT () ByteString IO ())
-> SealedConduitT () ByteString IO () -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (SealedConduitT () ByteString IO ())
irsrc SealedConduitT () ByteString IO ()
rsrc'
        case Maybe ByteString
mres of
          Maybe ByteString
Nothing -> ByteString -> BodyReader
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BS.empty
          Just ByteString
bs
            | ByteString -> Bool
BS.null ByteString
bs -> BodyReader
popper
            | Bool
otherwise -> ByteString -> BodyReader
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
  NeedsPopper ()
f BodyReader
popper


requestBodySource :: Int64 -> Source' -> RequestBody
requestBodySource :: Int64 -> Source IO ByteString -> RequestBody
requestBodySource Int64
size = Int64 -> GivesPopper () -> RequestBody
RequestBodyStream Int64
size (GivesPopper () -> RequestBody)
-> (Source IO ByteString -> GivesPopper ())
-> Source IO ByteString
-> RequestBody
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Source IO ByteString -> GivesPopper ()
srcToPopperIO


requestBodySourceChunked :: Source' -> RequestBody
requestBodySourceChunked :: Source IO ByteString -> RequestBody
requestBodySourceChunked = GivesPopper () -> RequestBody
RequestBodyStreamChunked (GivesPopper () -> RequestBody)
-> (Source IO ByteString -> GivesPopper ())
-> Source IO ByteString
-> RequestBody
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Source IO ByteString -> GivesPopper ()
srcToPopperIO


bodyReaderSource ::
  (MonadIO m) =>
  BodyReader ->
  ConduitT i ByteString m ()
bodyReaderSource :: forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitT i ByteString m ()
bodyReaderSource BodyReader
br =
  ConduitT i ByteString m ()
forall {i}. ConduitT i ByteString m ()
loop
  where
    loop :: ConduitT i ByteString m ()
loop = do
      ByteString
bs <- BodyReader -> ConduitT i ByteString m ByteString
forall a. IO a -> ConduitT i ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (BodyReader -> ConduitT i ByteString m ByteString)
-> BodyReader -> ConduitT i ByteString m ByteString
forall a b. (a -> b) -> a -> b
$ BodyReader -> BodyReader
brRead BodyReader
br
      Bool -> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) (ConduitT i ByteString m () -> ConduitT i ByteString m ())
-> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall a b. (a -> b) -> a -> b
$ do
        ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs
        ConduitT i ByteString m ()
loop