{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Network.Wai.Middleware.Delegate
(
delegateTo
, delegateToProxy
, simpleProxy
, ProxySettings (..)
, RequestPredicate
)
where
import Blaze.ByteString.Builder (fromByteString)
import Control.Concurrent.Async (race_)
import Control.Exception
( SomeException
, handle
, toException
)
import Control.Monad (unless)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy.Char8 as LC8
import Data.CaseInsensitive (mk)
import Data.Conduit
( ConduitM
, ConduitT
, Flush (..)
, Void
, await
, mapOutput
, runConduit
, yield
, ($$+)
, ($$++)
, (.|)
)
import Data.Conduit.Network (appSink, appSource)
import Data.Default (Default (..))
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.Int (Int64)
import Data.Monoid ((<>))
import Data.Streaming.Network
( ClientSettings
, clientSettingsTCP
, runTCPClient
)
import Data.String (IsString)
import Network.HTTP.Client
( BodyReader
, GivesPopper
, Manager
, Request (..)
, RequestBody (..)
, Response (..)
, brRead
, parseRequest
, withResponse
)
import Network.HTTP.Types
( Header
, HeaderName
, hContentType
, internalServerError500
, status304
, status500
)
import Network.HTTP.Types.Header (hHost)
import qualified Network.Wai as Wai
import Network.Wai.Conduit
( responseRawSource
, responseSource
, sourceRequestBody
)
type RequestPredicate = Wai.Request -> Bool
delegateTo :: Wai.Application -> RequestPredicate -> Wai.Middleware
delegateTo :: Application -> RequestPredicate -> Middleware
delegateTo Application
alt RequestPredicate
f Application
actual Request
req
| RequestPredicate
f Request
req = Application
alt Request
req
| Bool
otherwise = Application
actual Request
req
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Wai.Middleware
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Middleware
delegateToProxy ProxySettings
settings Manager
mgr = Application -> RequestPredicate -> Middleware
delegateTo (ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
mgr)
data ProxySettings = ProxySettings
{ ProxySettings -> SomeException -> Response
proxyOnException :: SomeException -> Wai.Response
, ProxySettings -> Int
proxyTimeout :: Int
, ProxySettings -> ByteString
proxyHost :: BS.ByteString
, ProxySettings -> Int
proxyRedirectCount :: Int
}
instance Default ProxySettings where
def :: ProxySettings
def =
ProxySettings
{
proxyOnException :: SomeException -> Response
proxyOnException = SomeException -> Response
onException
,
proxyTimeout :: Int
proxyTimeout = Int
15
, proxyHost :: ByteString
proxyHost = ByteString
"localhost"
, proxyRedirectCount :: Int
proxyRedirectCount = Int
0
}
where
onException :: SomeException -> Wai.Response
onException :: SomeException -> Response
onException SomeException
e =
Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS
Status
internalServerError500
[(HeaderName
hContentType, ByteString
"text/plain; charset=utf-8")]
(ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
LC8.fromChunks [String -> ByteString
C8.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e]
simpleProxy ::
ProxySettings ->
Manager ->
Wai.Application
simpleProxy :: ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
manager Request
req Response -> IO ResponseReceived
respond
| Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT" =
Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$
(Source IO ByteString -> Sink ByteString IO () -> IO ())
-> Response -> Response
forall (m :: * -> *) (n :: * -> *).
(MonadIO m, MonadIO n) =>
(Source m ByteString -> Sink ByteString n () -> IO ())
-> Response -> Response
responseRawSource
(Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req)
(Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
status500 [(HeaderName
"Content-Type", ByteString
"text/plain")] ByteString
"method CONNECT is not supported")
| Bool
otherwise = do
let scheme :: String
scheme
| RequestPredicate
Wai.isSecure Request
req = String
"https"
| Bool
otherwise = String
"http"
rawUrl :: ByteString
rawUrl = Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req
effectiveUrl :: String
effectiveUrl = String
scheme String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"://" String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ProxySettings -> ByteString
proxyHost ProxySettings
settings) String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack ByteString
rawUrl
newHost :: ByteString
newHost = ProxySettings -> ByteString
proxyHost ProxySettings
settings
addHostHeader :: ResponseHeaders -> ResponseHeaders
addHostHeader = (:) (HeaderName
hHost, ByteString
newHost)
Request
proxyReq' <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
effectiveUrl
let onException :: SomeException -> Wai.Response
onException :: SomeException -> Response
onException = ProxySettings -> SomeException -> Response
proxyOnException ProxySettings
settings (SomeException -> Response)
-> (SomeException -> SomeException) -> SomeException -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
forall e. Exception e => e -> SomeException
toException
proxyReq :: Request
proxyReq =
Request
proxyReq'
{ method = Wai.requestMethod req
, requestHeaders = addHostHeader $ filter dropUpstreamHeaders $ Wai.requestHeaders req
,
redirectCount = proxyRedirectCount settings
, requestBody =
case Wai.requestBodyLength req of
RequestBodyLength
Wai.ChunkedBody ->
Source IO ByteString -> RequestBody
requestBodySourceChunked (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
Wai.KnownLength Word64
l ->
Int64 -> Source IO ByteString -> RequestBody
requestBodySource (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
l) (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
,
decompress = const False
, host = newHost
}
respondUpstream :: IO ResponseReceived
respondUpstream = Request
-> Manager
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a.
Request -> Manager -> (Response BodyReader -> IO a) -> IO a
withResponse Request
proxyReq Manager
manager ((Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived)
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response BodyReader
res -> do
let body :: ConduitT i (Flush Builder) IO ()
body = (ByteString -> Flush Builder)
-> ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ()
forall (m :: * -> *) o1 o2 i r.
Monad m =>
(o1 -> o2) -> ConduitT i o1 m r -> ConduitT i o2 m r
mapOutput (Builder -> Flush Builder
forall a. a -> Flush a
Chunk (Builder -> Flush Builder)
-> (ByteString -> Builder) -> ByteString -> Flush Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Builder
fromByteString) (ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ())
-> (BodyReader -> ConduitT i ByteString IO ())
-> BodyReader
-> ConduitT i (Flush Builder) IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyReader -> ConduitT i ByteString IO ()
forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitT i ByteString m ()
bodyReaderSource (BodyReader -> ConduitT i (Flush Builder) IO ())
-> BodyReader -> ConduitT i (Flush Builder) IO ()
forall a b. (a -> b) -> a -> b
$ Response BodyReader -> BodyReader
forall body. Response body -> body
responseBody Response BodyReader
res
headers :: ResponseHeaders
headers = (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
mk ByteString
"X-Via-Proxy", ByteString
"yes") (HeaderName, ByteString) -> ResponseHeaders -> ResponseHeaders
forall a. a -> [a] -> [a]
: Response BodyReader -> ResponseHeaders
forall body. Response body -> ResponseHeaders
responseHeaders Response BodyReader
res
Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Source IO (Flush Builder) -> Response
responseSource (Response BodyReader -> Status
forall body. Response body -> Status
responseStatus Response BodyReader
res) ResponseHeaders
headers Source IO (Flush Builder)
forall {i}. ConduitT i (Flush Builder) IO ()
body
(SomeException -> IO ResponseReceived)
-> IO ResponseReceived -> IO ResponseReceived
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (SomeException -> Response)
-> SomeException
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Response
onException) IO ResponseReceived
respondUpstream
handleConnect ::
Wai.Request ->
ConduitT () C8.ByteString IO () ->
ConduitT C8.ByteString Void IO () ->
IO ()
handleConnect :: Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req Source IO ByteString
fromClient Sink ByteString IO ()
toClient =
ClientSettings -> (AppData -> IO ()) -> IO ()
forall a. ClientSettings -> (AppData -> IO a) -> IO a
runTCPClient (Request -> ClientSettings
toClientSettings Request
req) ((AppData -> IO ()) -> IO ()) -> (AppData -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AppData
ad -> do
ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Source IO ByteString
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
"HTTP/1.1 200 OK\r\n\r\n" Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Sink ByteString IO ()
toClient
IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
race_
(ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Source IO ByteString
fromClient Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| AppData -> Sink ByteString IO ()
forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
appSink AppData
ad)
(ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ AppData -> Source IO ByteString
forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
appSource AppData
ad Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Sink ByteString IO ()
toClient)
defaultClientPort :: Wai.Request -> Int
defaultClientPort :: Request -> Int
defaultClientPort Request
req
| RequestPredicate
Wai.isSecure Request
req = Int
443
| Bool
otherwise = Int
90
toClientSettings :: Wai.Request -> ClientSettings
toClientSettings :: Request -> ClientSettings
toClientSettings Request
req =
case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
C8.break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
Wai.rawPathInfo Request
req of
(ByteString
host, ByteString
"") -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
(ByteString
host, ByteString
port') -> case ByteString -> Maybe (Int, ByteString)
C8.readInt (ByteString -> Maybe (Int, ByteString))
-> ByteString -> Maybe (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
C8.drop Int
1 ByteString
port' of
Just (Int
port, ByteString
_) -> Int -> ByteString -> ClientSettings
clientSettingsTCP Int
port ByteString
host
Maybe (Int, ByteString)
Nothing -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
dropUpstreamHeaders :: (HeaderName, b) -> Bool
(HeaderName
k, b
_) = HeaderName
k HeaderName -> [HeaderName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [HeaderName]
preservedHeaders
preservedHeaders :: [HeaderName]
= [HeaderName
"content-encoding", HeaderName
"content-length", HeaderName
"host"]
type Source' = ConduitT () ByteString IO ()
srcToPopperIO :: Source' -> GivesPopper ()
srcToPopperIO :: Source IO ByteString -> GivesPopper ()
srcToPopperIO Source IO ByteString
src NeedsPopper ()
f = do
(SealedConduitT () ByteString IO ()
rsrc0, ()) <- Source IO ByteString
src Source IO ByteString
-> Sink ByteString IO ()
-> IO (SealedConduitT () ByteString IO (), ())
forall (m :: * -> *) a b.
Monad m =>
ConduitT () a m ()
-> ConduitT a Void m b -> m (SealedConduitT () a m (), b)
$$+ () -> Sink ByteString IO ()
forall a. a -> ConduitT ByteString Void IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
IORef (SealedConduitT () ByteString IO ())
irsrc <- SealedConduitT () ByteString IO ()
-> IO (IORef (SealedConduitT () ByteString IO ()))
forall a. a -> IO (IORef a)
newIORef SealedConduitT () ByteString IO ()
rsrc0
let popper :: IO ByteString
popper :: BodyReader
popper = do
SealedConduitT () ByteString IO ()
rsrc <- IORef (SealedConduitT () ByteString IO ())
-> IO (SealedConduitT () ByteString IO ())
forall a. IORef a -> IO a
readIORef IORef (SealedConduitT () ByteString IO ())
irsrc
(SealedConduitT () ByteString IO ()
rsrc', Maybe ByteString
mres) <- SealedConduitT () ByteString IO ()
rsrc SealedConduitT () ByteString IO ()
-> ConduitT ByteString Void IO (Maybe ByteString)
-> IO (SealedConduitT () ByteString IO (), Maybe ByteString)
forall (m :: * -> *) a b.
Monad m =>
SealedConduitT () a m ()
-> ConduitT a Void m b -> m (SealedConduitT () a m (), b)
$$++ ConduitT ByteString Void IO (Maybe ByteString)
forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
IORef (SealedConduitT () ByteString IO ())
-> SealedConduitT () ByteString IO () -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (SealedConduitT () ByteString IO ())
irsrc SealedConduitT () ByteString IO ()
rsrc'
case Maybe ByteString
mres of
Maybe ByteString
Nothing -> ByteString -> BodyReader
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BS.empty
Just ByteString
bs
| ByteString -> Bool
BS.null ByteString
bs -> BodyReader
popper
| Bool
otherwise -> ByteString -> BodyReader
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
NeedsPopper ()
f BodyReader
popper
requestBodySource :: Int64 -> Source' -> RequestBody
requestBodySource :: Int64 -> Source IO ByteString -> RequestBody
requestBodySource Int64
size = Int64 -> GivesPopper () -> RequestBody
RequestBodyStream Int64
size (GivesPopper () -> RequestBody)
-> (Source IO ByteString -> GivesPopper ())
-> Source IO ByteString
-> RequestBody
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Source IO ByteString -> GivesPopper ()
srcToPopperIO
requestBodySourceChunked :: Source' -> RequestBody
requestBodySourceChunked :: Source IO ByteString -> RequestBody
requestBodySourceChunked = GivesPopper () -> RequestBody
RequestBodyStreamChunked (GivesPopper () -> RequestBody)
-> (Source IO ByteString -> GivesPopper ())
-> Source IO ByteString
-> RequestBody
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Source IO ByteString -> GivesPopper ()
srcToPopperIO
bodyReaderSource ::
(MonadIO m) =>
BodyReader ->
ConduitT i ByteString m ()
bodyReaderSource :: forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitT i ByteString m ()
bodyReaderSource BodyReader
br =
ConduitT i ByteString m ()
forall {i}. ConduitT i ByteString m ()
loop
where
loop :: ConduitT i ByteString m ()
loop = do
ByteString
bs <- BodyReader -> ConduitT i ByteString m ByteString
forall a. IO a -> ConduitT i ByteString m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (BodyReader -> ConduitT i ByteString m ByteString)
-> BodyReader -> ConduitT i ByteString m ByteString
forall a b. (a -> b) -> a -> b
$ BodyReader -> BodyReader
brRead BodyReader
br
Bool -> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) (ConduitT i ByteString m () -> ConduitT i ByteString m ())
-> ConduitT i ByteString m () -> ConduitT i ByteString m ()
forall a b. (a -> b) -> a -> b
$ do
ByteString -> ConduitT i ByteString m ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs
ConduitT i ByteString m ()
loop