-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.Impl
  ( ProxySettings (..)
  , forceSSL
  , healthCheckProvider
  , httpConnectProxy
  , httpGetProxy
  , httpProxy
  , logRequest
  , pacProvider
  , reverseProxy
  ) where

import Control.Applicative        ((<|>))
import Control.Concurrent.Async   (concurrently)
import Control.Exception          (SomeException, try)
import Control.Monad              (unless, void, when)
import Control.Monad.IO.Class     (liftIO)
import Data.Binary.Builder        qualified as BB
import Data.ByteString            qualified as BS
import Data.ByteString.Base64     (decodeLenient)
import Data.ByteString.Char8      qualified as BS8
import Data.ByteString.Lazy       qualified as LBS
import Data.ByteString.Lazy.Char8 qualified as LBS8
import Data.CaseInsensitive       qualified as CI
import Data.Conduit.Binary        qualified as CB
import Data.Conduit.Network       qualified as CN
import Network.HTTP.Client        qualified as HC
import Network.HTTP.ReverseProxy
    (ProxyDest (..), SetIpHeader (..), WaiProxyResponse (..),
    defaultWaiProxySettings, waiProxyToSettings, wpsSetIpHeader,
    wpsUpgradeToRaw)
import Network.HTTP.Types         qualified as HT
import Network.HTTP.Types.Header  qualified as HT

import Data.Conduit
import Data.Maybe
import Network.Wai
import Network.Wai.Middleware.Gzip
import Network.Wai.Middleware.StripHeaders

import Network.HProx.Log
import Network.HProx.Util

data ProxySettings = ProxySettings
  { ProxySettings -> Maybe (ByteString -> Bool)
proxyAuth    :: Maybe (BS.ByteString -> Bool)
  , ProxySettings -> Maybe ByteString
passPrompt   :: Maybe BS.ByteString
  , ProxySettings -> Maybe ByteString
wsRemote     :: Maybe BS.ByteString
  , ProxySettings -> Maybe ByteString
revRemote    :: Maybe BS.ByteString
  , ProxySettings -> Bool
naivePadding :: Bool
  , ProxySettings -> Logger
logger       :: Logger
  }

logRequest :: Request -> LogStr
logRequest :: Request -> LogStr
logRequest Request
req = forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Request -> ByteString
requestMethod Request
req) forall a. Semigroup a => a -> a -> a
<>
    LogStr
" " forall a. Semigroup a => a -> a -> a
<> LogStr
hostname forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Request -> ByteString
rawPathInfo Request
req) forall a. Semigroup a => a -> a -> a
<>
    LogStr
" " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ Request -> HttpVersion
httpVersion Request
req) forall a. Semigroup a => a -> a -> a
<>
    LogStr
" " forall a. Semigroup a => a -> a -> a
<> (if Request -> Bool
isSecure Request
req then LogStr
"(tls) " else LogStr
"")
    forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ Request -> SockAddr
remoteHost Request
req)
  where
    isConnect :: Bool
isConnect = Request -> ByteString
requestMethod Request
req forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT"
    isGet :: Bool
isGet = ByteString
"http://" ByteString -> ByteString -> Bool
`BS.isPrefixOf` Request -> ByteString
rawPathInfo Request
req
    hostname :: LogStr
hostname | Bool
isConnect Bool -> Bool -> Bool
|| Bool
isGet = LogStr
""
             | Bool
otherwise          = forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. a -> Maybe a -> a
fromMaybe ByteString
"(no-host)" forall a b. (a -> b) -> a -> b
$ Request -> Maybe ByteString
requestHeaderHost Request
req)

httpProxy :: ProxySettings -> HC.Manager -> Middleware
httpProxy :: ProxySettings -> Manager -> Middleware
httpProxy ProxySettings
set Manager
mgr = Middleware
pacProvider forall b c a. (b -> c) -> (a -> b) -> a -> c
. ProxySettings -> Manager -> Middleware
httpGetProxy ProxySettings
set Manager
mgr forall b c a. (b -> c) -> (a -> b) -> a -> c
. ProxySettings -> Middleware
httpConnectProxy ProxySettings
set

forceSSL :: ProxySettings -> Middleware
forceSSL :: ProxySettings -> Middleware
forceSSL ProxySettings
pset Application
app Request
req Response -> IO ResponseReceived
respond
    | Request -> Bool
isSecure Request
req               = Application
app Request
req Response -> IO ResponseReceived
respond
    | ProxySettings -> Request -> Bool
redirectWebsocket ProxySettings
pset Request
req = Application
app Request
req Response -> IO ResponseReceived
respond
    | Bool
otherwise                  = Application
redirectToSSL Request
req Response -> IO ResponseReceived
respond

redirectToSSL :: Application
redirectToSSL :: Application
redirectToSSL Request
req Response -> IO ResponseReceived
respond
    | Just ByteString
host <- Request -> Maybe ByteString
requestHeaderHost Request
req = Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength
        Status
HT.status301
        [(HeaderName
"Location", ByteString
"https://" ByteString -> ByteString -> ByteString
`BS.append` ByteString
host)]
        ByteString
""
    | Bool
otherwise                          = Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength
        (Int -> ByteString -> Status
HT.mkStatus Int
426 ByteString
"Upgrade Required")
        [(HeaderName
"Upgrade", ByteString
"TLS/1.0, HTTP/1.1"), (HeaderName
"Connection", ByteString
"Upgrade")]
        ByteString
""

isProxyHeader :: HT.HeaderName -> Bool
isProxyHeader :: HeaderName -> Bool
isProxyHeader HeaderName
h = ByteString
"proxy" ByteString -> ByteString -> Bool
`BS.isPrefixOf` forall s. CI s -> s
CI.foldedCase HeaderName
h

isForwardedHeader :: HT.HeaderName -> Bool
isForwardedHeader :: HeaderName -> Bool
isForwardedHeader HeaderName
h = ByteString
"x-forwarded" ByteString -> ByteString -> Bool
`BS.isPrefixOf` forall s. CI s -> s
CI.foldedCase HeaderName
h

isCDNHeader :: HT.HeaderName -> Bool
isCDNHeader :: HeaderName -> Bool
isCDNHeader HeaderName
h = ByteString
"cf-" ByteString -> ByteString -> Bool
`BS.isPrefixOf` forall s. CI s -> s
CI.foldedCase HeaderName
h Bool -> Bool -> Bool
|| HeaderName
h forall a. Eq a => a -> a -> Bool
== HeaderName
"cdn-loop"

isToStripHeader :: HT.HeaderName -> Bool
isToStripHeader :: HeaderName -> Bool
isToStripHeader HeaderName
h = HeaderName -> Bool
isProxyHeader HeaderName
h Bool -> Bool -> Bool
|| HeaderName -> Bool
isForwardedHeader HeaderName
h Bool -> Bool -> Bool
|| HeaderName -> Bool
isCDNHeader HeaderName
h Bool -> Bool -> Bool
|| HeaderName
h forall a. Eq a => a -> a -> Bool
== HeaderName
"X-Real-IP" Bool -> Bool -> Bool
|| HeaderName
h forall a. Eq a => a -> a -> Bool
== HeaderName
"X-Scheme"

checkAuth :: ProxySettings -> Request -> Bool
checkAuth :: ProxySettings -> Request -> Bool
checkAuth ProxySettings{Bool
Maybe ByteString
Maybe (ByteString -> Bool)
Logger
logger :: Logger
naivePadding :: Bool
revRemote :: Maybe ByteString
wsRemote :: Maybe ByteString
passPrompt :: Maybe ByteString
proxyAuth :: Maybe (ByteString -> Bool)
logger :: ProxySettings -> Logger
naivePadding :: ProxySettings -> Bool
revRemote :: ProxySettings -> Maybe ByteString
wsRemote :: ProxySettings -> Maybe ByteString
passPrompt :: ProxySettings -> Maybe ByteString
proxyAuth :: ProxySettings -> Maybe (ByteString -> Bool)
..} Request
req
    | forall a. Maybe a -> Bool
isNothing Maybe (ByteString -> Bool)
proxyAuth = Bool
True
    | forall a. Maybe a -> Bool
isNothing Maybe ByteString
authRsp   = Bool
False
    | Bool
otherwise           = forall a. HasCallStack => Maybe a -> a
fromJust Maybe (ByteString -> Bool)
proxyAuth ByteString
decodedRsp
  where
    authRsp :: Maybe ByteString
authRsp = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
HT.hProxyAuthorization (Request -> ResponseHeaders
requestHeaders Request
req)

    decodedRsp :: ByteString
decodedRsp = ByteString -> ByteString
decodeLenient forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ (Char -> Bool) -> ByteString -> (ByteString, ByteString)
BS8.spanEnd (forall a. Eq a => a -> a -> Bool
/=Char
' ') forall a b. (a -> b) -> a -> b
$ forall a. HasCallStack => Maybe a -> a
fromJust Maybe ByteString
authRsp

redirectWebsocket :: ProxySettings -> Request -> Bool
redirectWebsocket :: ProxySettings -> Request -> Bool
redirectWebsocket ProxySettings{Bool
Maybe ByteString
Maybe (ByteString -> Bool)
Logger
logger :: Logger
naivePadding :: Bool
revRemote :: Maybe ByteString
wsRemote :: Maybe ByteString
passPrompt :: Maybe ByteString
proxyAuth :: Maybe (ByteString -> Bool)
logger :: ProxySettings -> Logger
naivePadding :: ProxySettings -> Bool
revRemote :: ProxySettings -> Maybe ByteString
wsRemote :: ProxySettings -> Maybe ByteString
passPrompt :: ProxySettings -> Maybe ByteString
proxyAuth :: ProxySettings -> Maybe (ByteString -> Bool)
..} Request
req = WaiProxySettings -> Request -> Bool
wpsUpgradeToRaw WaiProxySettings
defaultWaiProxySettings Request
req Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isJust Maybe ByteString
wsRemote

proxyAuthRequiredResponse :: ProxySettings -> Response
proxyAuthRequiredResponse :: ProxySettings -> Response
proxyAuthRequiredResponse ProxySettings{Bool
Maybe ByteString
Maybe (ByteString -> Bool)
Logger
logger :: Logger
naivePadding :: Bool
revRemote :: Maybe ByteString
wsRemote :: Maybe ByteString
passPrompt :: Maybe ByteString
proxyAuth :: Maybe (ByteString -> Bool)
logger :: ProxySettings -> Logger
naivePadding :: ProxySettings -> Bool
revRemote :: ProxySettings -> Maybe ByteString
wsRemote :: ProxySettings -> Maybe ByteString
passPrompt :: ProxySettings -> Maybe ByteString
proxyAuth :: ProxySettings -> Maybe (ByteString -> Bool)
..} = Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength
    Status
HT.status407
    [(HeaderName
HT.hProxyAuthenticate, ByteString
"Basic realm=\"" ByteString -> ByteString -> ByteString
`BS.append` ByteString
prompt ByteString -> ByteString -> ByteString
`BS.append` ByteString
"\"")]
    ByteString
""
  where
    prompt :: ByteString
prompt = forall a. a -> Maybe a -> a
fromMaybe ByteString
"hprox" Maybe ByteString
passPrompt

pacProvider :: Middleware
pacProvider :: Middleware
pacProvider Application
fallback Request
req Response -> IO ResponseReceived
respond
    | Request -> [Text]
pathInfo Request
req forall a. Eq a => a -> a -> Bool
== [Text
".hprox", Text
"config.pac"],
      Just ByteString
host' <- forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"x-forwarded-host" (Request -> ResponseHeaders
requestHeaders Request
req) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Request -> Maybe ByteString
requestHeaderHost Request
req =
        let issecure :: Bool
issecure = case forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"x-forwarded-proto" (Request -> ResponseHeaders
requestHeaders Request
req) of
                Just ByteString
proto -> ByteString
proto forall a. Eq a => a -> a -> Bool
== ByteString
"https"
                Maybe ByteString
Nothing    -> Request -> Bool
isSecure Request
req
            scheme :: ByteString
scheme = if Bool
issecure then ByteString
"HTTPS" else ByteString
"PROXY"
            defaultPort :: ByteString
defaultPort = if Bool
issecure then ByteString
":443" else ByteString
":80"
            host :: ByteString
host | Word8
58 Word8 -> ByteString -> Bool
`BS.elem` ByteString
host' = ByteString
host' -- ':'
                 | Bool
otherwise          = ByteString
host' ByteString -> ByteString -> ByteString
`BS.append` ByteString
defaultPort
        in Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength
               Status
HT.status200
               [(HeaderName
"Content-Type", ByteString
"application/x-ns-proxy-autoconfig")] forall a b. (a -> b) -> a -> b
$
               [ByteString] -> ByteString
LBS8.unlines [ ByteString
"function FindProxyForURL(url, host) {"
                            , [ByteString] -> ByteString
LBS8.fromChunks [ByteString
"  return \"", ByteString
scheme, ByteString
" ", ByteString
host, ByteString
"\";"]
                            , ByteString
"}"
                            ]
    | Bool
otherwise = Application
fallback Request
req Response -> IO ResponseReceived
respond

healthCheckProvider :: Middleware
healthCheckProvider :: Middleware
healthCheckProvider Application
fallback Request
req Response -> IO ResponseReceived
respond
    | Request -> [Text]
pathInfo Request
req forall a. Eq a => a -> a -> Bool
== [Text
".hprox", Text
"health"] =
        Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength
            Status
HT.status200
            [(HeaderName
"Content-Type", ByteString
"text/plain")]
            ByteString
"okay"
    | Bool
otherwise = Application
fallback Request
req Response -> IO ResponseReceived
respond

reverseProxy :: ProxySettings -> HC.Manager -> Middleware
reverseProxy :: ProxySettings -> Manager -> Middleware
reverseProxy ProxySettings{Bool
Maybe ByteString
Maybe (ByteString -> Bool)
Logger
logger :: Logger
naivePadding :: Bool
revRemote :: Maybe ByteString
wsRemote :: Maybe ByteString
passPrompt :: Maybe ByteString
proxyAuth :: Maybe (ByteString -> Bool)
logger :: ProxySettings -> Logger
naivePadding :: ProxySettings -> Bool
revRemote :: ProxySettings -> Maybe ByteString
wsRemote :: ProxySettings -> Maybe ByteString
passPrompt :: ProxySettings -> Maybe ByteString
proxyAuth :: ProxySettings -> Maybe (ByteString -> Bool)
..} Manager
mgr Application
fallback
    | Bool
isReverseProxy = Middleware
appWrapper forall a b. (a -> b) -> a -> b
$ (Request -> IO WaiProxyResponse)
-> WaiProxySettings -> Manager -> Application
waiProxyToSettings (forall (m :: * -> *) a. Monad m => a -> m a
returnforall b c a. (b -> c) -> (a -> b) -> a -> c
.Request -> WaiProxyResponse
proxyResponseFor) WaiProxySettings
settings Manager
mgr
    | Bool
otherwise      = Application
fallback
  where
    settings :: WaiProxySettings
settings = WaiProxySettings
defaultWaiProxySettings { wpsSetIpHeader :: SetIpHeader
wpsSetIpHeader = SetIpHeader
SIHNone }

    isReverseProxy :: Bool
isReverseProxy = forall a. Maybe a -> Bool
isJust Maybe ByteString
revRemote
    (ByteString
revHost, Int
revPort) = Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
80 (forall a. HasCallStack => Maybe a -> a
fromJust Maybe ByteString
revRemote)
    (Request -> ProxyDest -> WaiProxyResponse
revWrapper, Middleware
appWrapper)
        | Int
revPort forall a. Eq a => a -> a -> Bool
== Int
443 = (Request -> ProxyDest -> WaiProxyResponse
WPRModifiedRequestSecure, forall a. a -> a
id)
        | Bool
otherwise      = (Request -> ProxyDest -> WaiProxyResponse
WPRModifiedRequest, (Response -> Response) -> Middleware
modifyResponse ([ByteString] -> Response -> Response
stripHeaders [ByteString
"Server", ByteString
"Date"]))

    proxyResponseFor :: Request -> WaiProxyResponse
proxyResponseFor Request
req = Request -> ProxyDest -> WaiProxyResponse
revWrapper Request
nreq (ByteString -> Int -> ProxyDest
ProxyDest ByteString
revHost Int
revPort)
      where
        nreq :: Request
nreq = Request
req
          { requestHeaders :: ResponseHeaders
requestHeaders = ResponseHeaders
hdrs
          , requestHeaderHost :: Maybe ByteString
requestHeaderHost = forall a. a -> Maybe a
Just ByteString
revHost
          }

        hdrs :: ResponseHeaders
hdrs = (HeaderName
HT.hHost, ByteString
revHost) forall a. a -> [a] -> [a]
: [ (HeaderName
hdn, ByteString
hdv)
                                     | (HeaderName
hdn, ByteString
hdv) <- Request -> ResponseHeaders
requestHeaders Request
req
                                     , Bool -> Bool
not (HeaderName -> Bool
isToStripHeader HeaderName
hdn) Bool -> Bool -> Bool
&& HeaderName
hdn forall a. Eq a => a -> a -> Bool
/= HeaderName
HT.hHost
                                     ]

httpGetProxy :: ProxySettings -> HC.Manager -> Middleware
httpGetProxy :: ProxySettings -> Manager -> Middleware
httpGetProxy pset :: ProxySettings
pset@ProxySettings{Bool
Maybe ByteString
Maybe (ByteString -> Bool)
Logger
logger :: Logger
naivePadding :: Bool
revRemote :: Maybe ByteString
wsRemote :: Maybe ByteString
passPrompt :: Maybe ByteString
proxyAuth :: Maybe (ByteString -> Bool)
logger :: ProxySettings -> Logger
naivePadding :: ProxySettings -> Bool
revRemote :: ProxySettings -> Maybe ByteString
wsRemote :: ProxySettings -> Maybe ByteString
passPrompt :: ProxySettings -> Maybe ByteString
proxyAuth :: ProxySettings -> Maybe (ByteString -> Bool)
..} Manager
mgr Application
fallback = Middleware
appWrapper forall a b. (a -> b) -> a -> b
$ (Request -> IO WaiProxyResponse)
-> WaiProxySettings -> Manager -> Application
waiProxyToSettings (forall (m :: * -> *) a. Monad m => a -> m a
returnforall b c a. (b -> c) -> (a -> b) -> a -> c
.Request -> WaiProxyResponse
proxyResponseFor) WaiProxySettings
settings Manager
mgr
  where
    settings :: WaiProxySettings
settings = WaiProxySettings
defaultWaiProxySettings { wpsSetIpHeader :: SetIpHeader
wpsSetIpHeader = SetIpHeader
SIHNone }

    appWrapper :: Middleware
appWrapper = (Request -> Bool) -> Middleware -> Middleware
ifRequest Request -> Bool
isGetProxy (GzipSettings -> Middleware
gzip forall a. Default a => a
def)

    isGetProxy :: Request -> Bool
isGetProxy Request
req = case Request -> WaiProxyResponse
proxyResponseFor Request
req of
        WPRModifiedRequest Request
_ ProxyDest
_ -> Bool
True
        WaiProxyResponse
_                      -> Bool
False

    proxyResponseFor :: Request -> WaiProxyResponse
proxyResponseFor Request
req
        | ProxySettings -> Request -> Bool
redirectWebsocket ProxySettings
pset Request
req = ProxyDest -> WaiProxyResponse
wsWrapper (ByteString -> Int -> ProxyDest
ProxyDest ByteString
wsHost Int
wsPort)
        | Bool -> Bool
not Bool
isGETProxy             = Application -> WaiProxyResponse
WPRApplication Application
fallback
        | ProxySettings -> Request -> Bool
checkAuth ProxySettings
pset Request
req         = Request -> ProxyDest -> WaiProxyResponse
WPRModifiedRequest Request
nreq (ByteString -> Int -> ProxyDest
ProxyDest ByteString
host Int
port)
        | Bool
otherwise                  =
            forall a. Logger -> LogLevel -> LogStr -> a -> a
pureLogger Logger
logger LogLevel
WARN (LogStr
"unauthorized request: " forall a. Semigroup a => a -> a -> a
<> Request -> LogStr
logRequest Request
req) forall a b. (a -> b) -> a -> b
$
            Response -> WaiProxyResponse
WPRResponse (ProxySettings -> Response
proxyAuthRequiredResponse ProxySettings
pset)
      where
        (ByteString
wsHost, Int
wsPort) = Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
80 (forall a. HasCallStack => Maybe a -> a
fromJust Maybe ByteString
wsRemote)
        wsWrapper :: ProxyDest -> WaiProxyResponse
wsWrapper = if Int
wsPort forall a. Eq a => a -> a -> Bool
== Int
443 then ProxyDest -> WaiProxyResponse
WPRProxyDestSecure else ProxyDest -> WaiProxyResponse
WPRProxyDest

        notCONNECT :: Bool
notCONNECT = Request -> ByteString
requestMethod Request
req forall a. Eq a => a -> a -> Bool
/= ByteString
"CONNECT"
        rawPath :: ByteString
rawPath = Request -> ByteString
rawPathInfo Request
req
        rawPathPrefix :: ByteString
rawPathPrefix = ByteString
"http://"
        defaultPort :: Int
defaultPort = Int
80
        hostHeader :: Maybe (ByteString, Int)
hostHeader = Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
defaultPort forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Request -> Maybe ByteString
requestHeaderHost Request
req

        isRawPathProxy :: Bool
isRawPathProxy = ByteString
rawPathPrefix ByteString -> ByteString -> Bool
`BS.isPrefixOf` ByteString
rawPath
        hasProxyHeader :: Bool
hasProxyHeader = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (HeaderName -> Bool
isProxyHeaderforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) (Request -> ResponseHeaders
requestHeaders Request
req)
        scheme :: Maybe ByteString
scheme = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"X-Scheme" (Request -> ResponseHeaders
requestHeaders Request
req)
        isHTTP2Proxy :: Bool
isHTTP2Proxy = HttpVersion -> Int
HT.httpMajor (Request -> HttpVersion
httpVersion Request
req) forall a. Ord a => a -> a -> Bool
>= Int
2 Bool -> Bool -> Bool
&& Maybe ByteString
scheme forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just ByteString
"http" Bool -> Bool -> Bool
&& Request -> Bool
isSecure Request
req

        isGETProxy :: Bool
isGETProxy = Bool
notCONNECT Bool -> Bool -> Bool
&& (Bool
isRawPathProxy Bool -> Bool -> Bool
|| Bool
isHTTP2Proxy Bool -> Bool -> Bool
|| forall a. Maybe a -> Bool
isJust Maybe (ByteString, Int)
hostHeader Bool -> Bool -> Bool
&& Bool
hasProxyHeader)

        nreq :: Request
nreq = Request
req
          { rawPathInfo :: ByteString
rawPathInfo = ByteString
newRawPath
          , requestHeaders :: ResponseHeaders
requestHeaders = forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
notforall b c a. (b -> c) -> (a -> b) -> a -> c
.HeaderName -> Bool
isToStripHeaderforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ Request -> ResponseHeaders
requestHeaders Request
req
          }

        ((ByteString
host, Int
port), ByteString
newRawPath)
            | Bool
isRawPathProxy  = (Int -> ByteString -> (ByteString, Int)
parseHostPortWithDefault Int
defaultPort ByteString
hostPortP, ByteString
newRawPathP)
            | Bool
otherwise       = (forall a. HasCallStack => Maybe a -> a
fromJust Maybe (ByteString, Int)
hostHeader, ByteString
rawPath)
          where
            (ByteString
hostPortP, ByteString
newRawPathP) = (Char -> Bool) -> ByteString -> (ByteString, ByteString)
BS8.span (forall a. Eq a => a -> a -> Bool
/=Char
'/') forall a b. (a -> b) -> a -> b
$
                Int -> ByteString -> ByteString
BS.drop (ByteString -> Int
BS.length ByteString
rawPathPrefix) ByteString
rawPath

httpConnectProxy :: ProxySettings -> Middleware
httpConnectProxy :: ProxySettings -> Middleware
httpConnectProxy pset :: ProxySettings
pset@ProxySettings{Bool
Maybe ByteString
Maybe (ByteString -> Bool)
Logger
logger :: Logger
naivePadding :: Bool
revRemote :: Maybe ByteString
wsRemote :: Maybe ByteString
passPrompt :: Maybe ByteString
proxyAuth :: Maybe (ByteString -> Bool)
logger :: ProxySettings -> Logger
naivePadding :: ProxySettings -> Bool
revRemote :: ProxySettings -> Maybe ByteString
wsRemote :: ProxySettings -> Maybe ByteString
passPrompt :: ProxySettings -> Maybe ByteString
proxyAuth :: ProxySettings -> Maybe (ByteString -> Bool)
..} Application
fallback Request
req Response -> IO ResponseReceived
respond
    | Bool -> Bool
not Bool
isConnectProxy = Application
fallback Request
req Response -> IO ResponseReceived
respond
    | ProxySettings -> Request -> Bool
checkAuth ProxySettings
pset Request
req = IO ResponseReceived
respondResponse
    | Bool
otherwise          = do
        Logger
logger LogLevel
WARN forall a b. (a -> b) -> a -> b
$ LogStr
"unauthorized request: " forall a. Semigroup a => a -> a -> a
<> Request -> LogStr
logRequest Request
req
        Response -> IO ResponseReceived
respond (ProxySettings -> Response
proxyAuthRequiredResponse ProxySettings
pset)
  where
    hostPort' :: Maybe (ByteString, Int)
hostPort' = ByteString -> Maybe (ByteString, Int)
parseHostPort (Request -> ByteString
rawPathInfo Request
req) forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Request -> Maybe ByteString
requestHeaderHost Request
req forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ByteString -> Maybe (ByteString, Int)
parseHostPort)
    isConnectProxy :: Bool
isConnectProxy = Request -> ByteString
requestMethod Request
req forall a. Eq a => a -> a -> Bool
== ByteString
"CONNECT" Bool -> Bool -> Bool
&& forall a. Maybe a -> Bool
isJust Maybe (ByteString, Int)
hostPort'

    Just (ByteString
host, Int
port) = Maybe (ByteString, Int)
hostPort'
    settings :: ClientSettings
settings = Int -> ByteString -> ClientSettings
CN.clientSettings Int
port ByteString
host

    backup :: Response
backup = Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength Status
HT.status500 [(HeaderName
"Content-Type", ByteString
"text/plain")]
        ByteString
"HTTP CONNECT tunneling detected, but server does not support responseRaw"

    tryAndCatchAll :: IO a -> IO (Either SomeException a)
    tryAndCatchAll :: forall a. IO a -> IO (Either SomeException a)
tryAndCatchAll = forall e a. Exception e => IO a -> IO (Either e a)
try

    respondResponse :: IO ResponseReceived
respondResponse
        | HttpVersion -> Int
HT.httpMajor (Request -> HttpVersion
httpVersion Request
req) forall a. Ord a => a -> a -> Bool
< Int
2 = Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ (IO ByteString -> (ByteString -> IO ()) -> IO ())
-> Response -> Response
responseRaw (Bool -> IO ByteString -> (ByteString -> IO ()) -> IO ()
handleConnect Bool
True) Response
backup
        | Bool -> Bool
not Bool
naivePadding                   = Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> StreamingBody -> Response
responseStream Status
HT.status200 [] forall {a}. (Builder -> IO a) -> IO () -> IO ()
streaming
        | Bool
otherwise                          = do
            ByteString
padding <- IO ByteString
randomPadding
            Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ Status -> ResponseHeaders -> StreamingBody -> Response
responseStream Status
HT.status200 [(HeaderName
"Padding", ByteString
padding)] forall {a}. (Builder -> IO a) -> IO () -> IO ()
streaming
      where
        streaming :: (Builder -> IO a) -> IO () -> IO ()
streaming Builder -> IO a
write IO ()
flush = do
            IO ()
flush
            Bool -> IO ByteString -> (ByteString -> IO ()) -> IO ()
handleConnect Bool
False (Request -> IO ByteString
getRequestBodyChunk Request
req) (\ByteString
bs -> Builder -> IO a
write (ByteString -> Builder
BB.fromByteString ByteString
bs) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
flush)

    maximumLength :: Int
maximumLength = Int
65535 forall a. Num a => a -> a -> a
- Int
3 forall a. Num a => a -> a -> a
- Int
255
    countPaddings :: Int
countPaddings = Int
8

    addStreamPadding :: Bool
addStreamPadding = forall a. Maybe a -> Bool
isJust (forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Padding" (Request -> ResponseHeaders
requestHeaders Request
req)) Bool -> Bool -> Bool
&& Bool
naivePadding

    -- see: https://github.com/klzgrad/naiveproxy/#padding-protocol-an-informal-specification
    addPadding :: Int -> ConduitT BS.ByteString BS.ByteString IO ()
    addPadding :: Int -> ConduitT ByteString ByteString IO ()
addPadding Int
0 = forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield
    addPadding Int
n = do
        Maybe ByteString
mbs <- forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
        case Maybe ByteString
mbs of
            Maybe ByteString
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just ByteString
bs | ByteString -> Bool
BS.null ByteString
bs -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
            Just ByteString
bs -> do
                let (ByteString
bs0, ByteString
bs1) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
maximumLength ByteString
bs
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs1) forall a b. (a -> b) -> a -> b
$ forall i o (m :: * -> *). i -> ConduitT i o m ()
leftover ByteString
bs1
                let len :: Int
len = ByteString -> Int
BS.length ByteString
bs0
                Int
paddingLen <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO Int
randomPaddingLength
                let header :: Builder
header = forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map (Word8 -> Builder
BB.singletonforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (Integral a, Num b) => a -> b
fromIntegral) [Int
len forall a. Integral a => a -> a -> a
`div` Int
256, Int
len forall a. Integral a => a -> a -> a
`mod` Int
256, Int
paddingLen])
                    body :: Builder
body   = ByteString -> Builder
BB.fromByteString ByteString
bs0
                    tailer :: Builder
tailer = ByteString -> Builder
BB.fromByteString (Int -> Word8 -> ByteString
BS.replicate Int
paddingLen Word8
0)
                forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LBS.toStrict forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
BB.toLazyByteString (Builder
header forall a. Semigroup a => a -> a -> a
<> Builder
body forall a. Semigroup a => a -> a -> a
<> Builder
tailer)
                Int -> ConduitT ByteString ByteString IO ()
addPadding (Int
n forall a. Num a => a -> a -> a
- Int
1)

    removePadding :: Int -> ConduitT BS.ByteString BS.ByteString IO ()
    removePadding :: Int -> ConduitT ByteString ByteString IO ()
removePadding Int
0 = forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield
    removePadding Int
n = do
        ByteString
header <- forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take Int
3
        case ByteString -> [Word8]
LBS.unpack ByteString
header of
            [Word8
b0, Word8
b1, Word8
b2] -> do
                let len :: Int64
len = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b0 forall a. Num a => a -> a -> a
* Int64
256 forall a. Num a => a -> a -> a
+ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b1
                    paddingLen :: Int64
paddingLen = forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b2
                ByteString
bs <- forall (m :: * -> *) o.
Monad m =>
Int -> ConduitT ByteString o m ByteString
CB.take (forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
len forall a. Num a => a -> a -> a
+ Int64
paddingLen))
                if ByteString -> Int64
LBS.length ByteString
bs forall a. Eq a => a -> a -> Bool
/= Int64
len forall a. Num a => a -> a -> a
+ Int64
paddingLen
                    then forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    else forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield (ByteString -> ByteString
LBS.toStrict forall a b. (a -> b) -> a -> b
$ Int64 -> ByteString -> ByteString
LBS.take Int64
len ByteString
bs) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ConduitT ByteString ByteString IO ()
removePadding (Int
n forall a. Num a => a -> a -> a
- Int
1)
            [Word8]
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()

    yieldHttp1Response :: ConduitT i ByteString IO ()
yieldHttp1Response
        | Bool
naivePadding = do
            Builder
padding <- ByteString -> Builder
BB.fromByteString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ByteString
randomPadding
            forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString
LBS.toStrict forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
BB.toLazyByteString (Builder
"HTTP/1.1 200 OK\r\nPadding: " forall a. Semigroup a => a -> a -> a
<> Builder
padding forall a. Semigroup a => a -> a -> a
<> Builder
"\r\n\r\n")
        | Bool
otherwise    = forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
"HTTP/1.1 200 OK\r\n\r\n"

    handleConnect :: Bool -> IO BS.ByteString -> (BS.ByteString -> IO ()) -> IO ()
    handleConnect :: Bool -> IO ByteString -> (ByteString -> IO ()) -> IO ()
handleConnect Bool
http1 IO ByteString
fromClient' ByteString -> IO ()
toClient' = forall a. ClientSettings -> (AppData -> IO a) -> IO a
CN.runTCPClient ClientSettings
settings forall a b. (a -> b) -> a -> b
$ \AppData
server ->
        let toServer :: ConduitT ByteString o IO ()
toServer = forall ad (m :: * -> *) o.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT ByteString o m ()
CN.appSink AppData
server
            fromServer :: ConduitT i ByteString IO ()
fromServer = forall ad (m :: * -> *) i.
(HasReadWrite ad, MonadIO m) =>
ad -> ConduitT i ByteString m ()
CN.appSource AppData
server
            fromClient :: ConduitT i ByteString IO ()
fromClient = do
                ByteString
bs <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO ByteString
fromClient'
                forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ByteString -> Bool
BS.null ByteString
bs) (forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
bs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ConduitT i ByteString IO ()
fromClient)
            toClient :: ConduitT ByteString o IO ()
toClient = forall (m :: * -> *) i o r.
Monad m =>
(i -> ConduitT i o m r) -> ConduitT i o m ()
awaitForever (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ()
toClient')

            clientToServer :: ConduitT a c IO ()
clientToServer | Bool
addStreamPadding = forall {i}. ConduitT i ByteString IO ()
fromClient forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Int -> ConduitT ByteString ByteString IO ()
removePadding Int
countPaddings forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall {o}. ConduitT ByteString o IO ()
toServer
                           | Bool
otherwise        = forall {i}. ConduitT i ByteString IO ()
fromClient forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall {o}. ConduitT ByteString o IO ()
toServer

            serverToClient :: ConduitT a c IO ()
serverToClient | Bool
addStreamPadding = forall {i}. ConduitT i ByteString IO ()
fromServer forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| Int -> ConduitT ByteString ByteString IO ()
addPadding Int
countPaddings forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall {o}. ConduitT ByteString o IO ()
toClient
                           | Bool
otherwise        = forall {i}. ConduitT i ByteString IO ()
fromServer forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall {o}. ConduitT ByteString o IO ()
toClient
        in do
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
http1 forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit forall a b. (a -> b) -> a -> b
$ forall {i}. ConduitT i ByteString IO ()
yieldHttp1Response forall (m :: * -> *) a b c r.
Monad m =>
ConduitT a b m () -> ConduitT b c m r -> ConduitT a c m r
.| forall {o}. ConduitT ByteString o IO ()
toClient
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. IO a -> IO (Either SomeException a)
tryAndCatchAll forall a b. (a -> b) -> a -> b
$ forall a b. IO a -> IO b -> IO (a, b)
concurrently
                (forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit forall {a} {c}. ConduitT a c IO ()
clientToServer)
                (forall (m :: * -> *) r. Monad m => ConduitT () Void m r -> m r
runConduit forall {a} {c}. ConduitT a c IO ()
serverToClient)