{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Network.Wai.Middleware.Delegate
(
delegateTo
, delegateToProxy
, simpleProxy
, ProxySettings (..)
, RequestPredicate
)
where
import Blaze.ByteString.Builder (fromByteString)
import Control.Concurrent.Async (race_)
import Control.Exception
( SomeException
, handle
, toException
)
import Control.Monad (unless)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy.Char8 as LC8
import Data.CaseInsensitive (mk)
import Data.Conduit
( ConduitM
, ConduitT
, Flush (..)
, Void
, await
, mapOutput
, runConduit
, yield
, ($$+)
, ($$++)
, (.|)
)
import Data.Conduit.Network (appSink, appSource)
import Data.Default (Default (..))
import Data.IORef (newIORef, readIORef, writeIORef)
import Data.Int (Int64)
import Data.Monoid ((<>))
import Data.Streaming.Network
( ClientSettings
, clientSettingsTCP
, runTCPClient
)
import Data.String (IsString)
import Network.HTTP.Client
( BodyReader
, GivesPopper
, Manager
, Request (..)
, RequestBody (..)
, Response (..)
, brRead
, parseRequest
, withResponse
)
import Network.HTTP.Types
( Header
, HeaderName
, hContentType
, internalServerError500
, status304
, status500
)
import Network.HTTP.Types.Header (hHost)
import qualified Network.Wai as Wai
import Network.Wai.Conduit
( responseRawSource
, responseSource
, sourceRequestBody
)
type RequestPredicate = Wai.Request -> Bool
delegateTo :: Wai.Application -> RequestPredicate -> Wai.Middleware
delegateTo :: Application -> RequestPredicate -> Middleware
delegateTo Application
alt RequestPredicate
f Application
actual Request
req
| RequestPredicate
f Request
req = Application
alt Request
req
| Bool
otherwise = Application
actual Request
req
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Wai.Middleware
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Middleware
delegateToProxy ProxySettings
settings Manager
mgr = Application -> RequestPredicate -> Middleware
delegateTo (ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
mgr)
data ProxySettings = ProxySettings
{ ProxySettings -> SomeException -> Response
proxyOnException :: SomeException -> Wai.Response
, ProxySettings -> Int
proxyTimeout :: Int
, ProxySettings -> ByteString
proxyHost :: BS.ByteString
, ProxySettings -> Int
proxyRedirectCount :: Int
}
instance Default ProxySettings where
def :: ProxySettings
def =
ProxySettings
{
proxyOnException :: SomeException -> Response
proxyOnException = SomeException -> Response
onException
,
proxyTimeout :: Int
proxyTimeout = Int
15
, proxyHost :: ByteString
proxyHost = ByteString
"localhost"
, proxyRedirectCount :: Int
proxyRedirectCount = Int
0
}
where
onException :: SomeException -> Wai.Response
onException :: SomeException -> Response
onException SomeException
e =
Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS
Status
internalServerError500
[(HeaderName
hContentType, ByteString
"text/plain; charset=utf-8")]
forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
LC8.fromChunks [String -> ByteString
C8.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show SomeException
e]
simpleProxy ::
ProxySettings ->
Manager ->
Wai.Application
simpleProxy :: ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
manager Request
req Response -> IO ResponseReceived
respond
| Request -> ByteString
Wai.requestMethod Request
req forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT" =
Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) (n :: * -> *).
(MonadIO m, MonadIO n) =>
(Source m ByteString -> Sink ByteString n () -> IO ())
-> Response -> Response
responseRawSource
(Request
-> ConduitT () ByteString IO ()
-> ConduitT ByteString Void IO ()
-> IO ()
handleConnect Request
req)
(Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
status500 [(HeaderName
"Content-Type", ByteString
"text/plain")] ByteString
"method CONNECT is not supported")
| Bool
otherwise = do
let scheme :: String
scheme
| RequestPredicate
Wai.isSecure Request
req = String
"https"
| Bool
otherwise = String
"http"
rawUrl :: ByteString
rawUrl = Request -> ByteString
Wai.rawPathInfo Request
req forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req
effectiveUrl :: String
effectiveUrl = String
scheme forall a. [a] -> [a] -> [a]
++ String
"://" forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ProxySettings -> ByteString
proxyHost ProxySettings
settings) forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack ByteString
rawUrl
newHost :: ByteString
newHost = ProxySettings -> ByteString
proxyHost ProxySettings
settings
addHostHeader :: ResponseHeaders -> ResponseHeaders
addHostHeader = (:) (HeaderName
hHost, ByteString
newHost)
Request
proxyReq' <- forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
effectiveUrl
let onException :: SomeException -> Wai.Response
onException :: SomeException -> Response
onException = ProxySettings -> SomeException -> Response
proxyOnException ProxySettings
settings forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e. Exception e => e -> SomeException
toException
proxyReq :: Request
proxyReq =
Request
proxyReq'
{ method :: ByteString
method = Request -> ByteString
Wai.requestMethod Request
req
, requestHeaders :: ResponseHeaders
requestHeaders = ResponseHeaders -> ResponseHeaders
addHostHeader forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall b. (HeaderName, b) -> Bool
dropUpstreamHeaders forall a b. (a -> b) -> a -> b
$ Request -> ResponseHeaders
Wai.requestHeaders Request
req
,
redirectCount :: Int
redirectCount = ProxySettings -> Int
proxyRedirectCount ProxySettings
settings
, requestBody :: RequestBody
requestBody =
case Request -> RequestBodyLength
Wai.requestBodyLength Request
req of
RequestBodyLength
Wai.ChunkedBody ->
ConduitT () ByteString IO () -> RequestBody
requestBodySourceChunked (forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
Wai.KnownLength Word64
l ->
Int64 -> ConduitT () ByteString IO () -> RequestBody
requestBodySource (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
l) (forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
,
decompress :: ByteString -> Bool
decompress = forall a b. a -> b -> a
const Bool
False
, host :: ByteString
host = ByteString
newHost
}
respondUpstream :: IO ResponseReceived
respondUpstream = forall a.
Request -> Manager -> (Response BodyReader -> IO a) -> IO a
withResponse Request
proxyReq Manager
manager forall a b. (a -> b) -> a -> b
$ \Response BodyReader
res -> do
let body :: ConduitT i (Flush Builder) IO ()
body = forall (m :: * -> *) o1 o2 i r.
Monad m =>
(o1 -> o2) -> ConduitT i o1 m r -> ConduitT i o2 m r
mapOutput (forall a. a -> Flush a
Chunk forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Builder
fromByteString) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitT i ByteString m ()
bodyReaderSource forall a b. (a -> b) -> a -> b
$ forall body. Response body -> body
responseBody Response BodyReader
res
headers :: ResponseHeaders
headers = (forall s. FoldCase s => s -> CI s
mk ByteString
"X-Via-Proxy", ByteString
"yes") forall a. a -> [a] -> [a]
: forall body. Response body -> ResponseHeaders
responseHeaders Response BodyReader
res
Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Source IO (Flush Builder) -> Response
responseSource (forall body. Response body -> Status
responseStatus Response BodyReader
res) ResponseHeaders
headers forall {i}. ConduitT i (Flush Builder) IO ()
body
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Response -> IO ResponseReceived
respond forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Response
onException) IO ResponseReceived
respondUpstream
handleConnect ::
Wai.Request ->
ConduitT () C8.ByteString IO () ->
ConduitT C8.ByteString Void IO () ->
IO ()
handleConnect :: Request
-> ConduitT () ByteString IO ()
-> ConduitT ByteString Void IO ()
-> IO ()
handleConnect Request
req ConduitT () ByteString IO ()
fromClient ConduitT ByteString Void IO ()
toClient =
forall a. ClientSettings -> (AppData -> IO a) -> IO a
runTCPClient (Request -> ClientSettings
toClientSettings Request
req) forall a b. (a -> b) -> a -> b
$ \AppData
ad -> do
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
"HTTP/1.1 200 OK\r\n\r\n" forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| ConduitT ByteString Void IO ()
toClient
forall a b. IO a -> IO b -> IO ()
race_
(forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit forall a b. (a -> b) -> a -> b
$ ConduitT () ByteString IO ()
fromClient forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
appSink AppData
ad)
(forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit forall a b. (a -> b) -> a -> b
$ forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
appSource AppData
ad forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| ConduitT ByteString Void IO ()
toClient)
defaultClientPort :: Wai.Request -> Int
defaultClientPort :: Request -> Int
defaultClientPort Request
req
| RequestPredicate
Wai.isSecure Request
req = Int
443
| Bool
otherwise = Int
90
toClientSettings :: Wai.Request -> ClientSettings
toClientSettings :: Request -> ClientSettings
toClientSettings Request
req =
case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
C8.break (forall a. Eq a => a -> a -> Bool
== Char
':') forall a b. (a -> b) -> a -> b
$ Request -> ByteString
Wai.rawPathInfo Request
req of
(ByteString
host, ByteString
"") -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
(ByteString
host, ByteString
port') -> case ByteString -> Maybe (Int, ByteString)
C8.readInt forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
C8.drop Int
1 ByteString
port' of
Just (Int
port, ByteString
_) -> Int -> ByteString -> ClientSettings
clientSettingsTCP Int
port ByteString
host
Maybe (Int, ByteString)
Nothing -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
dropUpstreamHeaders :: (HeaderName, b) -> Bool
(HeaderName
k, b
_) = HeaderName
k forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [HeaderName]
preservedHeaders
preservedHeaders :: [HeaderName]
= [HeaderName
"content-encoding", HeaderName
"content-length", HeaderName
"host"]
type Source' = ConduitT () ByteString IO ()
srcToPopperIO :: Source' -> GivesPopper ()
srcToPopperIO :: ConduitT () ByteString IO () -> GivesPopper ()
srcToPopperIO ConduitT () ByteString IO ()
src NeedsPopper ()
f = do
(SealedConduitT () ByteString IO ()
rsrc0, ()) <- ConduitT () ByteString IO ()
src forall (m :: * -> *) a b.
Monad m =>
ConduitT () a m ()
-> ConduitT a Void m b -> m (SealedConduitT () a m (), b)
$$+ forall (m :: * -> *) a. Monad m => a -> m a
return ()
IORef (SealedConduitT () ByteString IO ())
irsrc <- forall a. a -> IO (IORef a)
newIORef SealedConduitT () ByteString IO ()
rsrc0
let popper :: IO ByteString
popper :: BodyReader
popper = do
SealedConduitT () ByteString IO ()
rsrc <- forall a. IORef a -> IO a
readIORef IORef (SealedConduitT () ByteString IO ())
irsrc
(SealedConduitT () ByteString IO ()
rsrc', Maybe ByteString
mres) <- SealedConduitT () ByteString IO ()
rsrc forall (m :: * -> *) a b.
Monad m =>
SealedConduitT () a m ()
-> ConduitT a Void m b -> m (SealedConduitT () a m (), b)
$$++ forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
forall a. IORef a -> a -> IO ()
writeIORef IORef (SealedConduitT () ByteString IO ())
irsrc SealedConduitT () ByteString IO ()
rsrc'
case Maybe ByteString
mres of
Maybe ByteString
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BS.empty
Just ByteString
bs
| ByteString -> Bool
BS.null ByteString
bs -> BodyReader
popper
| Bool
otherwise -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bs
NeedsPopper ()
f BodyReader
popper
requestBodySource :: Int64 -> Source' -> RequestBody
requestBodySource :: Int64 -> ConduitT () ByteString IO () -> RequestBody
requestBodySource Int64
size = Int64 -> GivesPopper () -> RequestBody
RequestBodyStream Int64
size forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT () ByteString IO () -> GivesPopper ()
srcToPopperIO
requestBodySourceChunked :: Source' -> RequestBody
requestBodySourceChunked :: ConduitT () ByteString IO () -> RequestBody
requestBodySourceChunked = GivesPopper () -> RequestBody
RequestBodyStreamChunked forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT () ByteString IO () -> GivesPopper ()
srcToPopperIO
bodyReaderSource ::
(MonadIO m) =>
BodyReader ->
ConduitT i ByteString m ()
bodyReaderSource :: forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitT i ByteString m ()
bodyReaderSource BodyReader
br =
forall {i}. ConduitT i ByteString m ()
loop
where
loop :: ConduitT i ByteString m ()
loop = do
ByteString
bs <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ BodyReader -> BodyReader
brRead BodyReader
br
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs
ConduitT i ByteString m ()
loop