{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Middleware.Delegate
( delegateTo
, delegateToProxy
, simpleProxy
, ProxySettings(..)
, RequestPredicate
)
where
import Control.Exception (SomeException, handle,
toException)
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy.Char8 as LC8
import Data.Monoid ((<>))
import Data.String (IsString)
import Blaze.ByteString.Builder (fromByteString)
import Control.Concurrent.Async (race_)
import Data.CaseInsensitive (mk)
import Data.Conduit (ConduitT, Flush (..), Void,
mapOutput, runConduit, yield,
(.|))
import Data.Conduit.Network (appSink, appSource)
import Data.Default (Default (..))
import Data.Streaming.Network (ClientSettings, clientSettingsTCP,
runTCPClient)
import Network.HTTP.Client (Manager, Request (..),
Response (..), parseRequest,
withResponse)
import Network.HTTP.Client.Conduit (bodyReaderSource)
import Network.HTTP.Conduit (requestBodySourceChunkedIO,
requestBodySourceIO)
import Network.HTTP.Types (hContentType,
internalServerError500, status304,
status500)
import Network.HTTP.Types.Header (hHost)
import qualified Network.Wai as Wai
import Network.Wai.Conduit (responseRawSource, responseSource,
sourceRequestBody)
type RequestPredicate = Wai.Request -> Bool
delegateTo :: Wai.Application -> RequestPredicate -> Wai.Middleware
delegateTo :: Application -> RequestPredicate -> Middleware
delegateTo Application
alt RequestPredicate
f Application
actual Request
req
| RequestPredicate
f Request
req = Application
alt Request
req
| Bool
otherwise = Application
actual Request
req
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Wai.Middleware
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Middleware
delegateToProxy ProxySettings
settings Manager
mgr = Application -> RequestPredicate -> Middleware
delegateTo (ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
mgr)
data ProxySettings =
ProxySettings
{
ProxySettings -> SomeException -> Response
proxyOnException :: SomeException -> Wai.Response
, ProxySettings -> Int
proxyTimeout :: Int
, ProxySettings -> ByteString
proxyHost :: BS.ByteString
, ProxySettings -> Int
proxyRedirectCount :: Int
}
instance Default ProxySettings where
def :: ProxySettings
def = ProxySettings :: (SomeException -> Response)
-> Int -> ByteString -> Int -> ProxySettings
ProxySettings
{
proxyOnException :: SomeException -> Response
proxyOnException = SomeException -> Response
onException
, proxyTimeout :: Int
proxyTimeout = Int
15
, proxyHost :: ByteString
proxyHost = ByteString
"localhost"
, proxyRedirectCount :: Int
proxyRedirectCount = Int
0
}
where
onException :: SomeException -> Wai.Response
onException :: SomeException -> Response
onException SomeException
e =
Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
internalServerError500
[ (HeaderName
hContentType, ByteString
"text/plain; charset=utf-8") ] (ByteString -> Response) -> ByteString -> Response
forall a b. (a -> b) -> a -> b
$
[ByteString] -> ByteString
LC8.fromChunks [String -> ByteString
C8.pack (String -> ByteString) -> String -> ByteString
forall a b. (a -> b) -> a -> b
$ SomeException -> String
forall a. Show a => a -> String
show SomeException
e]
simpleProxy
:: ProxySettings
-> Manager
-> Wai.Application
simpleProxy :: ProxySettings -> Manager -> Application
simpleProxy ProxySettings
settings Manager
manager Request
req Response -> IO ResponseReceived
respond
| Request -> ByteString
Wai.requestMethod Request
req ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT" = do
String -> IO ()
putStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"Seen a CONNECT !!! to path " String -> String -> String
forall a. [a] -> [a] -> [a]
++ (ByteString -> String
C8.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
Wai.rawPathInfo Request
req)
Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ (Source IO ByteString -> Sink ByteString IO () -> IO ())
-> Response -> Response
forall (m :: * -> *) (n :: * -> *).
(MonadIO m, MonadIO n) =>
(Source m ByteString -> Sink ByteString n () -> IO ())
-> Response -> Response
responseRawSource (Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req)
(Status -> ResponseHeaders -> ByteString -> Response
Wai.responseLBS Status
status500 [(HeaderName
"Content-Type", ByteString
"text/plain")] ByteString
"method CONNECT is not supported")
| Bool
otherwise = do
let scheme :: String
scheme
| RequestPredicate
Wai.isSecure Request
req = String
"https"
| Bool
otherwise = String
"http"
rawUrl :: ByteString
rawUrl = Request -> ByteString
Wai.rawPathInfo Request
req ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Request -> ByteString
Wai.rawQueryString Request
req
effectiveUrl :: String
effectiveUrl = String
scheme String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"://" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (ByteString -> String
C8.unpack (ByteString -> String) -> ByteString -> String
forall a b. (a -> b) -> a -> b
$ ProxySettings -> ByteString
proxyHost ProxySettings
settings) String -> String -> String
forall a. [a] -> [a] -> [a]
++ ByteString -> String
C8.unpack (ByteString
rawUrl)
newHost :: ByteString
newHost = ProxySettings -> ByteString
proxyHost ProxySettings
settings
addHostHeader :: ResponseHeaders -> ResponseHeaders
addHostHeader = (:) (HeaderName
hHost, ByteString
newHost)
Request
proxyReq' <- String -> IO Request
forall (m :: * -> *). MonadThrow m => String -> m Request
parseRequest String
effectiveUrl
let onException :: SomeException -> Wai.Response
onException :: SomeException -> Response
onException = ProxySettings -> SomeException -> Response
proxyOnException ProxySettings
settings (SomeException -> Response)
-> (SomeException -> SomeException) -> SomeException -> Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> SomeException
forall e. Exception e => e -> SomeException
toException
proxyReq :: Request
proxyReq = Request
proxyReq'
{ method :: ByteString
method = Request -> ByteString
Wai.requestMethod Request
req
, requestHeaders :: ResponseHeaders
requestHeaders = ResponseHeaders -> ResponseHeaders
addHostHeader (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ ((HeaderName, ByteString) -> Bool)
-> ResponseHeaders -> ResponseHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter (HeaderName, ByteString) -> Bool
forall a b. (Eq a, IsString a) => (a, b) -> Bool
dropUpstreamHeaders (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ Request -> ResponseHeaders
Wai.requestHeaders Request
req
, redirectCount :: Int
redirectCount = ProxySettings -> Int
proxyRedirectCount ProxySettings
settings
, requestBody :: RequestBody
requestBody =
case Request -> RequestBodyLength
Wai.requestBodyLength Request
req of
RequestBodyLength
Wai.ChunkedBody ->
Source IO ByteString -> RequestBody
requestBodySourceChunkedIO (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
Wai.KnownLength Word64
l ->
Int64 -> Source IO ByteString -> RequestBody
requestBodySourceIO (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
l) (Request -> Source IO ByteString
forall (m :: * -> *). MonadIO m => Request -> Source m ByteString
sourceRequestBody Request
req)
, decompress :: ByteString -> Bool
decompress = Bool -> ByteString -> Bool
forall a b. a -> b -> a
const Bool
False
, host :: ByteString
host = ByteString
newHost
}
respondUpstream :: IO ResponseReceived
respondUpstream = Request
-> Manager
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a.
Request -> Manager -> (Response BodyReader -> IO a) -> IO a
withResponse Request
proxyReq Manager
manager ((Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived)
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ \Response BodyReader
res -> do
let body :: ConduitT i (Flush Builder) IO ()
body = (ByteString -> Flush Builder)
-> ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ()
forall (m :: * -> *) o1 o2 i r.
Monad m =>
(o1 -> o2) -> ConduitT i o1 m r -> ConduitT i o2 m r
mapOutput (Builder -> Flush Builder
forall a. a -> Flush a
Chunk (Builder -> Flush Builder)
-> (ByteString -> Builder) -> ByteString -> Flush Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Builder
fromByteString) (ConduitT i ByteString IO () -> ConduitT i (Flush Builder) IO ())
-> (BodyReader -> ConduitT i ByteString IO ())
-> BodyReader
-> ConduitT i (Flush Builder) IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyReader -> ConduitT i ByteString IO ()
forall (m :: * -> *) i.
MonadIO m =>
BodyReader -> ConduitM i ByteString m ()
bodyReaderSource (BodyReader -> ConduitT i (Flush Builder) IO ())
-> BodyReader -> ConduitT i (Flush Builder) IO ()
forall a b. (a -> b) -> a -> b
$ Response BodyReader -> BodyReader
forall body. Response body -> body
responseBody Response BodyReader
res
headers :: ResponseHeaders
headers = (ByteString -> HeaderName
forall s. FoldCase s => s -> CI s
mk ByteString
"X-Via-Proxy", ByteString
"yes") (HeaderName, ByteString) -> ResponseHeaders -> ResponseHeaders
forall a. a -> [a] -> [a]
: (Response BodyReader -> ResponseHeaders
forall body. Response body -> ResponseHeaders
responseHeaders Response BodyReader
res)
Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> Source IO (Flush Builder) -> Response
responseSource (Response BodyReader -> Status
forall body. Response body -> Status
responseStatus Response BodyReader
res) ResponseHeaders
headers Source IO (Flush Builder)
forall i. ConduitT i (Flush Builder) IO ()
body
(SomeException -> IO ResponseReceived)
-> IO ResponseReceived -> IO ResponseReceived
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> (SomeException -> Response)
-> SomeException
-> IO ResponseReceived
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Response
onException) IO ResponseReceived
respondUpstream
handleConnect
:: Wai.Request
-> ConduitT () C8.ByteString IO ()
-> ConduitT C8.ByteString Void IO ()
-> IO ()
handleConnect :: Request -> Source IO ByteString -> Sink ByteString IO () -> IO ()
handleConnect Request
req Source IO ByteString
fromClient Sink ByteString IO ()
toClient =
ClientSettings -> (AppData -> IO ()) -> IO ()
forall a. ClientSettings -> (AppData -> IO a) -> IO a
runTCPClient (Request -> ClientSettings
toClientSettings Request
req) ((AppData -> IO ()) -> IO ()) -> (AppData -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \AppData
ad -> do
ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Source IO ByteString
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
"HTTP/1.1 200 OK\r\n\r\n" Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| Sink ByteString IO ()
toClient
IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO ()
race_
(ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Source IO ByteString
fromClient Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| AppData -> Sink ByteString IO ()
forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
appSink AppData
ad)
(ConduitT () Void IO () -> IO ()
forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit (ConduitT () Void IO () -> IO ())
-> ConduitT () Void IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ AppData -> Source IO ByteString
forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
appSource AppData
ad Source IO ByteString
-> Sink ByteString IO () -> ConduitT () Void IO ()
forall (m :: * -> *) a b c r.
Monad m =>
ConduitM a b m () -> ConduitM b c m r -> ConduitM a c m r
.| Sink ByteString IO ()
toClient)
defaultClientPort :: Wai.Request -> Int
defaultClientPort :: Request -> Int
defaultClientPort Request
req
| RequestPredicate
Wai.isSecure Request
req = Int
443
| Bool
otherwise = Int
90
toClientSettings :: Wai.Request -> ClientSettings
toClientSettings :: Request -> ClientSettings
toClientSettings Request
req =
case (Char -> Bool) -> ByteString -> (ByteString, ByteString)
C8.break (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') (ByteString -> (ByteString, ByteString))
-> ByteString -> (ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ Request -> ByteString
Wai.rawPathInfo Request
req of
(ByteString
host, ByteString
"") -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
(ByteString
host, ByteString
port') -> case ByteString -> Maybe (Int, ByteString)
C8.readInt (ByteString -> Maybe (Int, ByteString))
-> ByteString -> Maybe (Int, ByteString)
forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
C8.drop Int
1 ByteString
port' of
Just (Int
port, ByteString
_) -> Int -> ByteString -> ClientSettings
clientSettingsTCP Int
port ByteString
host
Maybe (Int, ByteString)
Nothing -> Int -> ByteString -> ClientSettings
clientSettingsTCP (Request -> Int
defaultClientPort Request
req) ByteString
host
dropUpstreamHeaders :: (Eq a, IsString a) => (a, b) -> Bool
(a
k, b
_) = a
k a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem`
[ a
"content-encoding"
, a
"content-length"
, a
"host"
]