{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}

module Network.Wai.Application.Classic.RevProxy (revProxyApp) where

#if __GLASGOW_HASKELL__ < 709
import Control.Applicative
#endif
import Control.Monad
import Control.Monad.IO.Class (liftIO)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (uncons)
import Data.ByteString.Builder (Builder)
import qualified Data.ByteString.Char8 as BS hiding (uncons)
import Data.Conduit
import qualified Network.HTTP.Client as H
import Network.HTTP.Types
import Network.Wai.Application.Classic.Conduit
import Network.Wai.Application.Classic.EventSource
import Network.Wai.Application.Classic.Field
import Network.Wai.Application.Classic.Header
import Network.Wai.Application.Classic.Path
import Network.Wai.Application.Classic.Types
import Network.Wai.Conduit

----------------------------------------------------------------

-- |  Relaying any requests as reverse proxy.

revProxyApp :: ClassicAppSpec -> RevProxyAppSpec -> RevProxyRoute -> Application
revProxyApp :: ClassicAppSpec -> RevProxyAppSpec -> RevProxyRoute -> Application
revProxyApp ClassicAppSpec
cspec RevProxyAppSpec
spec RevProxyRoute
route Request
req Response -> IO ResponseReceived
respond = Request
-> Manager
-> (Response BodyReader -> IO ResponseReceived)
-> IO ResponseReceived
forall a.
Request -> Manager -> (Response BodyReader -> IO a) -> IO a
H.withResponse Request
httpClientRequest Manager
mgr Response BodyReader -> IO ResponseReceived
proxy
  where
    proxy :: Response BodyReader -> IO ResponseReceived
proxy Response BodyReader
hrsp = do
        let status :: Status
status     = Response BodyReader -> Status
forall body. Response body -> Status
H.responseStatus Response BodyReader
hrsp
            hdr :: ResponseHeaders
hdr        = ResponseHeaders -> ResponseHeaders
fixHeader (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ Response BodyReader -> ResponseHeaders
forall body. Response body -> ResponseHeaders
H.responseHeaders Response BodyReader
hrsp
            clientBody :: BodyReader
clientBody = Response BodyReader -> BodyReader
forall body. Response body -> body
H.responseBody Response BodyReader
hrsp
            ct :: Maybe ByteString
ct         = HeaderName -> ResponseHeaders -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
hContentType ResponseHeaders
hdr
            src :: ConduitT () (Flush Builder) IO ()
src        = Maybe ByteString -> BodyReader -> ConduitT () (Flush Builder) IO ()
toSource Maybe ByteString
ct BodyReader
clientBody
        Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ Status
-> ResponseHeaders -> ConduitT () (Flush Builder) IO () -> Response
responseSource Status
status ResponseHeaders
hdr ConduitT () (Flush Builder) IO ()
src

    httpClientRequest :: Request
httpClientRequest = Request -> RevProxyRoute -> Request
reqToHReq Request
req RevProxyRoute
route
    mgr :: Manager
mgr = RevProxyAppSpec -> Manager
revProxyManager RevProxyAppSpec
spec
    fixHeader :: ResponseHeaders -> ResponseHeaders
fixHeader = ClassicAppSpec -> Request -> ResponseHeaders -> ResponseHeaders
addVia ClassicAppSpec
cspec Request
req (ResponseHeaders -> ResponseHeaders)
-> (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders
-> ResponseHeaders
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Header -> Bool) -> ResponseHeaders -> ResponseHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter Header -> Bool
headerToBeRelay

headerToBeRelay :: Header -> Bool
headerToBeRelay :: Header -> Bool
headerToBeRelay (HeaderName
k,ByteString
_)
      | HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
hTransferEncoding = Bool
False
      | HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
hAcceptEncoding   = Bool
False
      | HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
hContentLength    = Bool
False
      | HeaderName
k HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
== HeaderName
hContentEncoding  = Bool
False -- See H.decompress.
      | Bool
otherwise              = Bool
True

----------------------------------------------------------------

reqToHReq :: Request -> RevProxyRoute -> H.Request
reqToHReq :: Request -> RevProxyRoute -> Request
reqToHReq Request
req RevProxyRoute
route = Request
H.defaultRequest {
    host :: ByteString
H.host           = RevProxyRoute -> ByteString
revProxyDomain RevProxyRoute
route
  , port :: Int
H.port           = RevProxyRoute -> Int
revProxyPort RevProxyRoute
route
  , secure :: Bool
H.secure         = Bool
False -- FIXME: upstream is HTTP only
  , requestHeaders :: ResponseHeaders
H.requestHeaders = Request -> ResponseHeaders -> ResponseHeaders
addForwardedFor Request
req (ResponseHeaders -> ResponseHeaders)
-> ResponseHeaders -> ResponseHeaders
forall a b. (a -> b) -> a -> b
$ (Header -> Bool) -> ResponseHeaders -> ResponseHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter Header -> Bool
headerToBeRelay ResponseHeaders
hdr
  , path :: ByteString
H.path           = ByteString
path'
  , queryString :: ByteString
H.queryString    = ByteString -> ByteString
dropQuestion ByteString
query
  , requestBody :: RequestBody
H.requestBody    = RequestBodyLength -> BodyReader -> RequestBody
bodyToHBody RequestBodyLength
len BodyReader
body
  , method :: ByteString
H.method         = Request -> ByteString
requestMethod Request
req
  , proxy :: Maybe Proxy
H.proxy          = Maybe Proxy
forall a. Maybe a
Nothing
--  , H.rawBody        = False
  , decompress :: ByteString -> Bool
H.decompress     = Bool -> ByteString -> Bool
forall a b. a -> b -> a
const Bool
True
  , checkResponse :: Request -> Response BodyReader -> IO ()
H.checkResponse  = \Request
_ Response BodyReader
_ -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  , redirectCount :: Int
H.redirectCount  = Int
0
  }
  where
    path :: ByteString
path = Request -> ByteString
rawPathInfo Request
req
    src :: ByteString
src = RevProxyRoute -> ByteString
revProxySrc RevProxyRoute
route
    dst :: ByteString
dst = RevProxyRoute -> ByteString
revProxyDst RevProxyRoute
route
    hdr :: ResponseHeaders
hdr = Request -> ResponseHeaders
requestHeaders Request
req
    query :: ByteString
query = Request -> ByteString
rawQueryString Request
req
    len :: RequestBodyLength
len = Request -> RequestBodyLength
requestBodyLength Request
req
    body :: BodyReader
body = Request -> BodyReader
getRequestBodyChunk Request
req
    path' :: ByteString
path' = ByteString
dst ByteString -> ByteString -> ByteString
</> (ByteString
path ByteString -> ByteString -> ByteString
<\> ByteString
src)
    dropQuestion :: ByteString -> ByteString
dropQuestion ByteString
q = case ByteString -> Maybe (Word8, ByteString)
BS.uncons ByteString
q of
        Just (Word8
63, ByteString
q') -> ByteString
q' -- '?' is 63
        Maybe (Word8, ByteString)
_             -> ByteString
q

bodyToHBody :: RequestBodyLength -> IO ByteString -> H.RequestBody
bodyToHBody :: RequestBodyLength -> BodyReader -> RequestBody
bodyToHBody RequestBodyLength
ChunkedBody BodyReader
src       = GivesPopper () -> RequestBody
H.RequestBodyStreamChunked ((BodyReader -> IO ()) -> BodyReader -> IO ()
forall a b. (a -> b) -> a -> b
$ BodyReader
src)
bodyToHBody (KnownLength Word64
len) BodyReader
src = Int64 -> GivesPopper () -> RequestBody
H.RequestBodyStream (Word64 -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len) ((BodyReader -> IO ()) -> BodyReader -> IO ()
forall a b. (a -> b) -> a -> b
$ BodyReader
src)

----------------------------------------------------------------

#if MIN_VERSION_conduit(1,3,0)
toSource :: Maybe ByteString -> H.BodyReader -> ConduitT () (Flush Builder) IO ()
#else
toSource :: Maybe ByteString -> H.BodyReader -> Source IO (Flush Builder)
#endif
toSource :: Maybe ByteString -> BodyReader -> ConduitT () (Flush Builder) IO ()
toSource (Just ByteString
"text/event-stream") = BodyReader -> ConduitT () (Flush Builder) IO ()
bodyToEventSource
toSource Maybe ByteString
_                          = BodyReader -> ConduitT () (Flush Builder) IO ()
bodyToSource

#if MIN_VERSION_conduit(1,3,0)
bodyToSource :: H.BodyReader -> ConduitT () (Flush Builder) IO ()
#else
bodyToSource :: H.BodyReader -> Source IO (Flush Builder)
#endif
bodyToSource :: BodyReader -> ConduitT () (Flush Builder) IO ()
bodyToSource BodyReader
br = ConduitT () (Flush Builder) IO ()
forall i. ConduitT i (Flush Builder) IO ()
loop
  where
    loop :: ConduitT i (Flush Builder) IO ()
loop = do
        ByteString
bs <- BodyReader -> ConduitT i (Flush Builder) IO ByteString
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (BodyReader -> ConduitT i (Flush Builder) IO ByteString)
-> BodyReader -> ConduitT i (Flush Builder) IO ByteString
forall a b. (a -> b) -> a -> b
$ BodyReader -> BodyReader
H.brRead BodyReader
br
        Bool
-> ConduitT i (Flush Builder) IO ()
-> ConduitT i (Flush Builder) IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) (ConduitT i (Flush Builder) IO ()
 -> ConduitT i (Flush Builder) IO ())
-> ConduitT i (Flush Builder) IO ()
-> ConduitT i (Flush Builder) IO ()
forall a b. (a -> b) -> a -> b
$ do
            Flush Builder -> ConduitT i (Flush Builder) IO ()
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (Flush Builder -> ConduitT i (Flush Builder) IO ())
-> Flush Builder -> ConduitT i (Flush Builder) IO ()
forall a b. (a -> b) -> a -> b
$ Builder -> Flush Builder
forall a. a -> Flush a
Chunk (Builder -> Flush Builder) -> Builder -> Flush Builder
forall a b. (a -> b) -> a -> b
$ ByteString -> Builder
byteStringToBuilder ByteString
bs
            ConduitT i (Flush Builder) IO ()
loop
{-

FIXME:
badGateway :: ClassicAppSpec -> Request-> SomeException -> IO Response
badGateway cspec req _ =
    return $ responseBuilder st hdr bdy
  where
    hdr = addServer cspec textPlainHeader
    bdy = byteStringToBuilder "Bad Gateway\r\n"
    st = badGateway502
-}