{-# LANGUAGE OverloadedStrings #-}
module Network.Wai.Middleware.Delegate
( delegateTo
, delegateToProxy
, simpleProxy
, ProxySettings(..)
, RequestPredicate
)
where
import Control.Exception (SomeException, handle,
toException)
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy.Char8 as LC8
import Data.Monoid ((<>))
import Data.String (IsString)
import Blaze.ByteString.Builder (fromByteString)
import Control.Concurrent.Async (race_)
import Data.CaseInsensitive (mk)
import Data.Conduit (ConduitT, Flush (..), Void,
mapOutput, runConduit, yield,
(.|))
import Data.Conduit.Network (appSink, appSource)
import Data.Default (Default (..))
import Data.Streaming.Network (ClientSettings, clientSettingsTCP,
runTCPClient)
import Network.HTTP.Client (Manager, Request (..),
Response (..), parseRequest,
withResponse)
import Network.HTTP.Client.Conduit (bodyReaderSource)
import Network.HTTP.Conduit (requestBodySourceChunkedIO,
requestBodySourceIO)
import Network.HTTP.Types (hContentType,
internalServerError500, status304,
status500)
import Network.HTTP.Types.Header (hHost)
import qualified Network.Wai as Wai
import Network.Wai.Conduit (responseRawSource, responseSource,
sourceRequestBody)
type RequestPredicate = Wai.Request -> Bool
delegateTo :: Wai.Application -> RequestPredicate -> Wai.Middleware
delegateTo alt f actual req
| f req = alt req
| otherwise = actual req
delegateToProxy :: ProxySettings -> Manager -> RequestPredicate -> Wai.Middleware
delegateToProxy settings mgr = delegateTo (simpleProxy settings mgr)
data ProxySettings =
ProxySettings
{
proxyOnException :: SomeException -> Wai.Response
, proxyTimeout :: Int
, proxyHost :: BS.ByteString
}
instance Default ProxySettings where
def = ProxySettings
{
proxyOnException = onException
, proxyTimeout = 15
, proxyHost = "localhost"
}
where
onException :: SomeException -> Wai.Response
onException e =
Wai.responseLBS internalServerError500
[ (hContentType, "text/plain; charset=utf-8") ] $
LC8.fromChunks [C8.pack $ show e]
simpleProxy
:: ProxySettings
-> Manager
-> Wai.Application
simpleProxy settings manager req respond
| Wai.requestMethod req == "CONNECT" = do
putStrLn $ "Seen a CONNECT !!! to path " ++ (C8.unpack $ Wai.rawPathInfo req)
respond $ responseRawSource (handleConnect req)
(Wai.responseLBS status500 [("Content-Type", "text/plain")] "method CONNECT is not supported")
| otherwise = do
let scheme
| Wai.isSecure req = "https"
| otherwise = "http"
rawUrl = Wai.rawPathInfo req <> Wai.rawQueryString req
effectiveUrl = scheme ++ "://" ++ (C8.unpack $ proxyHost settings) ++ C8.unpack (rawUrl)
newHost = proxyHost settings
addHostHeader = (:) (hHost, newHost)
proxyReq' <- parseRequest effectiveUrl
let onException :: SomeException -> Wai.Response
onException = proxyOnException settings . toException
proxyReq = proxyReq'
{ method = Wai.requestMethod req
, requestHeaders = addHostHeader $ filter dropUpstreamHeaders $ Wai.requestHeaders req
, redirectCount = 0
, requestBody =
case Wai.requestBodyLength req of
Wai.ChunkedBody ->
requestBodySourceChunkedIO (sourceRequestBody req)
Wai.KnownLength l ->
requestBodySourceIO (fromIntegral l) (sourceRequestBody req)
, decompress = const False
, host = newHost
}
respondUpstream = withResponse proxyReq manager $ \res -> do
let body = mapOutput (Chunk . fromByteString) . bodyReaderSource $ responseBody res
headers = (mk "X-Via-Proxy", "yes") : (responseHeaders res)
respond $ responseSource (responseStatus res) headers body
handle (respond . onException) respondUpstream
handleConnect
:: Wai.Request
-> ConduitT () C8.ByteString IO ()
-> ConduitT C8.ByteString Void IO ()
-> IO ()
handleConnect req fromClient toClient =
runTCPClient (toClientSettings req) $ \ad -> do
runConduit $ yield "HTTP/1.1 200 OK\r\n\r\n" .| toClient
race_
(runConduit $ fromClient .| appSink ad)
(runConduit $ appSource ad .| toClient)
defaultClientPort :: Wai.Request -> Int
defaultClientPort req
| Wai.isSecure req = 443
| otherwise = 90
toClientSettings :: Wai.Request -> ClientSettings
toClientSettings req =
case C8.break (== ':') $ Wai.rawPathInfo req of
(host, "") -> clientSettingsTCP (defaultClientPort req) host
(host, port') -> case C8.readInt $ C8.drop 1 port' of
Just (port, _) -> clientSettingsTCP port host
Nothing -> clientSettingsTCP (defaultClientPort req) host
dropUpstreamHeaders :: (Eq a, IsString a) => (a, b) -> Bool
dropUpstreamHeaders (k, _) = k `notElem`
[ "content-encoding"
, "content-length"
, "host"
]