-- SPDX-License-Identifier: Apache-2.0

-- Copyright (C) 2023 Bin Jin. All Rights Reserved.
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns        #-}

{-| Instead of running @hprox@ binary directly, you can use this library
    to run HProx in front of arbitrary WAI 'Application'.
-}

module Network.HProx
  ( CertFile (..)
  , Config (..)
  , defaultConfig
  , getConfig
  , run
  ) where

import Data.ByteString.Char8       qualified as BS8
import Data.List                   (isSuffixOf, (\\))
import Data.String                 (fromString)
import Data.Version                (showVersion)
import Network.HTTP.Client.TLS     (newTlsManager)
import Network.HTTP.Types          qualified as HT
import Network.TLS                 qualified as TLS
import Network.TLS.Extra.Cipher    qualified as TLS
import Network.TLS.SessionManager  qualified as SM
import Network.Wai                 (Application, rawPathInfo)
import Network.Wai.Handler.Warp
    (InvalidRequest (..), defaultSettings, defaultShouldDisplayException,
    runSettings, setBeforeMainLoop, setHost, setLogger, setNoParsePath,
    setOnException, setPort, setServerName)
import Network.Wai.Handler.WarpTLS
    (OnInsecure (..), WarpTLSException, onInsecure, runTLS, tlsAllowedVersions,
    tlsCiphers, tlsServerHooks, tlsSessionManager, tlsSettings)
import System.Posix.User
    (UserEntry (..), getUserEntryForName, setUserID)

import Control.Exception    (Exception (..))
import GHC.IO.Exception     (IOErrorType (..))
import Network.HTTP2.Client qualified as H2
import System.IO.Error      (ioeGetErrorType)

#ifdef QUIC_ENABLED
import Control.Concurrent.Async     (mapConcurrently_)
import Data.List                    (find)
import Network.QUIC                 qualified as Q
import Network.QUIC.Internal        qualified as Q
import Network.Wai.Handler.Warp     (setAltSvc)
import Network.Wai.Handler.WarpQUIC (runQUIC)
#endif

import Control.Monad
import Data.Maybe
import Options.Applicative

import Network.HProx.DoH
import Network.HProx.Impl
import Network.HProx.Log
import Paths_hprox

-- | Configuration of HProx, see @hprox --help@ for details
data Config = Config
  { Config -> Maybe String
_bind     :: Maybe String
  , Config -> Int
_port     :: Int
  , Config -> [(String, CertFile)]
_ssl      :: [(String, CertFile)]
  , Config -> Maybe String
_user     :: Maybe String
  , Config -> Maybe String
_auth     :: Maybe FilePath
  , Config -> Maybe String
_ws       :: Maybe String
  , Config -> Maybe String
_rev      :: Maybe String
  , Config -> Maybe String
_doh      :: Maybe String
  , Config -> Bool
_naive    :: Bool
  , Config -> ByteString
_name     :: BS8.ByteString
  , Config -> LogLevel
_loglevel :: LogLevel
#ifdef QUIC_ENABLED
  , _quic     :: Maybe Int
#endif
  }

-- | Default value of 'Config', same as running @hprox@ without arguments
defaultConfig :: Config
defaultConfig :: Config
defaultConfig = Maybe String
-> Int
-> [(String, CertFile)]
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Bool
-> ByteString
-> LogLevel
-> Config
Config forall a. Maybe a
Nothing Int
3000 [] forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing Bool
False ByteString
"hprox" LogLevel
INFO
#ifdef QUIC_ENABLED
    Nothing
#endif

-- | Certificate file pairs
data CertFile = CertFile
  { CertFile -> String
certfile :: FilePath
  , CertFile -> String
keyfile  :: FilePath
  }

readCert :: CertFile -> IO TLS.Credential
readCert :: CertFile -> IO Credential
readCert (CertFile String
c String
k) = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO (Either String Credential)
TLS.credentialLoadX509 String
c String
k

splitBy :: Eq a => a -> [a] -> [[a]]
splitBy :: forall a. Eq a => a -> [a] -> [[a]]
splitBy a
_ [] = [[]]
splitBy a
c (a
x:[a]
xs)
  | a
c forall a. Eq a => a -> a -> Bool
== a
x    = [] forall a. a -> [a] -> [a]
: forall a. Eq a => a -> [a] -> [[a]]
splitBy a
c [a]
xs
  | Bool
otherwise = let [a]
y:[[a]]
ys = forall a. Eq a => a -> [a] -> [[a]]
splitBy a
c [a]
xs in (a
xforall a. a -> [a] -> [a]
:[a]
y)forall a. a -> [a] -> [a]
:[[a]]
ys

parser :: ParserInfo Config
parser :: ParserInfo Config
parser = forall a. Parser a -> InfoMod a -> ParserInfo a
info (forall a. Parser (a -> a)
helper forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Parser (a -> a)
ver forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Config
config) (forall a. InfoMod a
fullDesc forall a. Semigroup a => a -> a -> a
<> forall a. String -> InfoMod a
progDesc String
desc)
  where
    parseSSL :: String -> Either a (String, CertFile)
parseSSL String
s = case forall a. Eq a => a -> [a] -> [[a]]
splitBy Char
':' String
s of
        [String
host, String
cert, String
key] -> forall a b. b -> Either a b
Right (String
host, String -> String -> CertFile
CertFile String
cert String
key)
        [String]
_                 -> forall a b. a -> Either a b
Left a
"invalid format for ssl certificates"

    desc :: String
desc = String
"a lightweight HTTP proxy server, and more"
    ver :: Parser (a -> a)
ver = forall a. String -> Mod OptionFields (a -> a) -> Parser (a -> a)
infoOption (Version -> String
showVersion Version
version) (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"version" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"show version")

    config :: Parser Config
config = Maybe String
-> Int
-> [(String, CertFile)]
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Bool
-> ByteString
-> LogLevel
-> Config
Config forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser (Maybe String)
bind
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Int
port
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser [(String, CertFile)]
ssl
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
user
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
auth
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
ws
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
rev
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
doh
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Bool
naive
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString
name
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser LogLevel
loglevel
#ifdef QUIC_ENABLED
                    <*> quic
#endif

    bind :: Parser (Maybe String)
bind = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"bind"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'b'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"bind_ip"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"ip address to bind on (default: all interfaces)")

    port :: Parser Int
port = forall a. ReadM a -> Mod OptionFields a -> Parser a
option forall a. Read a => ReadM a
auto
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'p'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value Int
3000
       forall a. Semigroup a => a -> a -> a
<> forall a (f :: * -> *). Show a => Mod f a
showDefault
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"port number")

    ssl :: Parser [(String, CertFile)]
ssl = forall (f :: * -> *) a. Alternative f => f a -> f [a]
many forall a b. (a -> b) -> a -> b
$ forall a. ReadM a -> Mod OptionFields a -> Parser a
option (forall a. (String -> Either String a) -> ReadM a
eitherReader forall {a}. IsString a => String -> Either a (String, CertFile)
parseSSL)
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"tls"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
's'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"hostname:cerfile:keyfile"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"enable TLS and specify a domain and associated TLS certificate (can be specified multiple times for multiple domains)")

    user :: Parser (Maybe String)
user = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"user"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'u'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"nobody"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"setuid after binding port")

    auth :: Parser (Maybe String)
auth = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"auth"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'a'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"userpass.txt"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"password file for proxy authentication (plain text file with lines each containing a colon separated user/password pair)")

    ws :: Parser (Maybe String)
ws = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"ws"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"remote-host:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"remote host to handle websocket requests (port 443 indicates HTTPS remote server)")

    rev :: Parser (Maybe String)
rev = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"rev"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"remote-host:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"remote host for reverse proxy (port 443 indicates HTTPS remote server)")

    doh :: Parser (Maybe String)
doh = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"doh"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"dns-server:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"enable DNS-over-HTTPS(DoH) support (53 will be used if port is not specified)")

    naive :: Parser Bool
naive = Mod FlagFields Bool -> Parser Bool
switch
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"naive"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"add naiveproxy compatible padding (requires TLS)")

    name :: Parser ByteString
name = forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"name"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"server-name"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value ByteString
"hprox"
       forall a. Semigroup a => a -> a -> a
<> forall a (f :: * -> *). Show a => Mod f a
showDefault
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"specify the server name for the 'Server' header")

    loglevel :: Parser LogLevel
loglevel = forall a. ReadM a -> Mod OptionFields a -> Parser a
option (forall a. (String -> Maybe a) -> ReadM a
maybeReader String -> Maybe LogLevel
logLevelReader)
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"loglevel"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"<trace|debug|info|warn|error|none>"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value LogLevel
INFO
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"specify the logging level (default: info)")

#ifdef QUIC_ENABLED
    quic = optional $ option auto
        ( long "quic"
       <> short 'q'
       <> metavar "port"
       <> help "enable QUIC (HTTP/3) on UDP port")
#endif

setuid :: String -> IO ()
setuid :: String -> IO ()
setuid String
user = String -> IO UserEntry
getUserEntryForName String
user forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= UserID -> IO ()
setUserID forall b c a. (b -> c) -> (a -> b) -> a -> c
. UserEntry -> UserID
userID

-- | Read 'Config' from command line arguments
getConfig :: IO Config
getConfig :: IO Config
getConfig = forall a. ParserInfo a -> IO a
execParser ParserInfo Config
parser

-- | Run HProx in front of fallback 'Application', with specified 'Config'
run :: Application -- ^ fallback application
    -> Config      -- ^ configuration
    -> IO ()
run :: Application -> Config -> IO ()
run Application
fallback Config{Bool
Int
[(String, CertFile)]
Maybe String
ByteString
LogLevel
_loglevel :: LogLevel
_name :: ByteString
_naive :: Bool
_doh :: Maybe String
_rev :: Maybe String
_ws :: Maybe String
_auth :: Maybe String
_user :: Maybe String
_ssl :: [(String, CertFile)]
_port :: Int
_bind :: Maybe String
_loglevel :: Config -> LogLevel
_name :: Config -> ByteString
_naive :: Config -> Bool
_doh :: Config -> Maybe String
_rev :: Config -> Maybe String
_ws :: Config -> Maybe String
_auth :: Config -> Maybe String
_user :: Config -> Maybe String
_ssl :: Config -> [(String, CertFile)]
_port :: Config -> Int
_bind :: Config -> Maybe String
..} = LogType
-> LogLevel -> ((LogLevel -> LogStr -> IO ()) -> IO ()) -> IO ()
withLogger (Int -> LogType
LogStdout Int
4096) LogLevel
_loglevel forall a b. (a -> b) -> a -> b
$ \LogLevel -> LogStr -> IO ()
logger -> do
    LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"hprox " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Version -> String
showVersion Version
version) forall a. Semigroup a => a -> a -> a
<> LogStr
" started"
    LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"bind to TCP port " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. a -> Maybe a -> a
fromMaybe String
"[::]" Maybe String
_bind) forall a. Semigroup a => a -> a -> a
<> LogStr
":" forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr Int
_port

    let certfiles :: [(String, CertFile)]
certfiles = [(String, CertFile)]
_ssl

    [Credential]
certs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CertFile -> IO Credential
readCertforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> b
snd) [(String, CertFile)]
certfiles
    SessionManager
smgr <- Config -> IO SessionManager
SM.newSessionManager Config
SM.defaultConfig

    let isSSL :: Bool
isSSL = Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(String, CertFile)]
certfiles)
        (String
primaryHost, CertFile
primaryCert) = forall a. [a] -> a
head [(String, CertFile)]
certfiles
        otherCerts :: [(String, Credential)]
otherCerts = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(String, CertFile)]
certfiles) [Credential]
certs

    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
isSSL forall a b. (a -> b) -> a -> b
$ do
        LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"read " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Credential]
certs) forall a. Semigroup a => a -> a -> a
<> LogStr
" certificates"
        LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"primary domain: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr String
primaryHost
        LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"other domains: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr ([String] -> String
unwords forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(String, Credential)]
otherCerts)

    let settings :: Settings
settings = HostPreference -> Settings -> Settings
setHost (forall a. IsString a => String -> a
fromString (forall a. a -> Maybe a -> a
fromMaybe String
"*6" Maybe String
_bind)) forall a b. (a -> b) -> a -> b
$
                   Int -> Settings -> Settings
setPort Int
_port forall a b. (a -> b) -> a -> b
$
                   (Request -> Status -> Maybe Integer -> IO ())
-> Settings -> Settings
setLogger forall {p}. Request -> Status -> p -> IO ()
warpLogger forall a b. (a -> b) -> a -> b
$
                   (Maybe Request -> SomeException -> IO ()) -> Settings -> Settings
setOnException Maybe Request -> SomeException -> IO ()
exceptionHandler forall a b. (a -> b) -> a -> b
$
                   Bool -> Settings -> Settings
setNoParsePath Bool
True forall a b. (a -> b) -> a -> b
$
                   ByteString -> Settings -> Settings
setServerName ByteString
_name forall a b. (a -> b) -> a -> b
$
                   forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. a -> a
id (IO () -> Settings -> Settings
setBeforeMainLoop forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
setuid) Maybe String
_user
                   Settings
defaultSettings

        exceptionHandler :: Maybe Request -> SomeException -> IO ()
exceptionHandler Maybe Request
req SomeException
ex
            | LogLevel
_loglevel forall a. Ord a => a -> a -> Bool
> LogLevel
DEBUG                                 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool -> Bool
not (SomeException -> Bool
defaultShouldDisplayException SomeException
ex)            = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Just (IOError -> IOErrorType
ioeGetErrorType -> IOErrorType
EOF) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Just (H2.BadThingHappen SomeException
ex') <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex  = Maybe Request -> SomeException -> IO ()
exceptionHandler Maybe Request
req SomeException
ex'
            | Just (HTTP2Error
_ :: H2.HTTP2Error) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#ifdef QUIC_ENABLED
            | Just (Q.BadThingHappen ex') <- fromException ex   = exceptionHandler req ex'
            | Just (_ :: Q.QUICException) <- fromException ex   = return ()
#endif
            | Just (WarpTLSException
_ :: WarpTLSException) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Just InvalidRequest
ConnectionClosedByPeer <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex   = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise                                         =
                LogLevel -> LogStr -> IO ()
logger LogLevel
DEBUG forall a b. (a -> b) -> a -> b
$ LogStr
"exception: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall e. Exception e => e -> String
displayException SomeException
ex) forall a. Semigroup a => a -> a -> a
<>
                    (if (forall a. Maybe a -> Bool
isJust Maybe Request
req) then LogStr
" from: " forall a. Semigroup a => a -> a -> a
<> Request -> LogStr
logRequest (forall a. HasCallStack => Maybe a -> a
fromJust Maybe Request
req) else LogStr
"")

        warpLogger :: Request -> Status -> p -> IO ()
warpLogger Request
req Status
status p
_
            | Request -> ByteString
rawPathInfo Request
req forall a. Eq a => a -> a -> Bool
== ByteString
"/.hprox/health" = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise                           =
                LogLevel -> LogStr -> IO ()
logger LogLevel
TRACE forall a b. (a -> b) -> a -> b
$ LogStr
"(" forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Status -> Int
HT.statusCode Status
status) forall a. Semigroup a => a -> a -> a
<> LogStr
") " forall a. Semigroup a => a -> a -> a
<> Request -> LogStr
logRequest Request
req

        tlsset' :: TLSSettings
tlsset' = String -> String -> TLSSettings
tlsSettings (CertFile -> String
certfile CertFile
primaryCert) (CertFile -> String
keyfile CertFile
primaryCert)
        hooks :: ServerHooks
hooks = (TLSSettings -> ServerHooks
tlsServerHooks TLSSettings
tlsset') { onServerNameIndication :: Maybe String -> IO Credentials
TLS.onServerNameIndication = Maybe String -> IO Credentials
onSNI }

        -- https://www.ssllabs.com/ssltest
        weak_ciphers :: [Cipher]
weak_ciphers = [ Cipher
TLS.cipher_ECDHE_RSA_AES256CBC_SHA384
                       , Cipher
TLS.cipher_ECDHE_RSA_AES256CBC_SHA
                       , Cipher
TLS.cipher_AES256CCM_SHA256
                       , Cipher
TLS.cipher_AES256GCM_SHA384
                       , Cipher
TLS.cipher_AES256_SHA256
                       , Cipher
TLS.cipher_AES256_SHA1
                       ]

        tlsset :: TLSSettings
tlsset = TLSSettings
tlsset'
            { tlsServerHooks :: ServerHooks
tlsServerHooks     = ServerHooks
hooks
            , onInsecure :: OnInsecure
onInsecure         = OnInsecure
AllowInsecure
            , tlsAllowedVersions :: [Version]
tlsAllowedVersions = [Version
TLS.TLS13, Version
TLS.TLS12]
            , tlsCiphers :: [Cipher]
tlsCiphers         = [Cipher]
TLS.ciphersuite_strong forall a. Eq a => [a] -> [a] -> [a]
\\ [Cipher]
weak_ciphers
            , tlsSessionManager :: Maybe SessionManager
tlsSessionManager  = forall a. a -> Maybe a
Just SessionManager
smgr
            }

        logAndFail :: String -> IO b
logAndFail String
msg = LogLevel -> LogStr -> IO ()
logger LogLevel
WARN (forall msg. ToLogStr msg => msg -> LogStr
toLogStr String
msg) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
msg

        onSNI :: Maybe String -> IO Credentials
onSNI Maybe String
Nothing = forall {b}. String -> IO b
logAndFail String
"SNI: unspecified"
        onSNI (Just String
host)
          | String -> String -> Bool
checkSNI String
host String
primaryHost = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
          | Bool
otherwise                 = String -> [(String, Credential)] -> IO Credentials
lookupSNI String
host [(String, Credential)]
otherCerts

        lookupSNI :: String -> [(String, Credential)] -> IO Credentials
lookupSNI String
host [] = forall {b}. String -> IO b
logAndFail (String
"SNI: unknown hostname (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show String
host forall a. [a] -> [a] -> [a]
++ String
")")
        lookupSNI String
host ((String
p, Credential
cert) : [(String, Credential)]
cs)
          | String -> String -> Bool
checkSNI String
host String
p = forall (m :: * -> *) a. Monad m => a -> m a
return ([Credential] -> Credentials
TLS.Credentials [Credential
cert])
          | Bool
otherwise       = String -> [(String, Credential)] -> IO Credentials
lookupSNI String
host [(String, Credential)]
cs

        checkSNI :: String -> String -> Bool
checkSNI String
host String
pat = case String
pat of
            Char
'*' : Char
'.' : String
p -> (Char
'.' forall a. a -> [a] -> [a]
: String
p) forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` String
host
            String
p             -> String
host forall a. Eq a => a -> a -> Bool
== String
p

#ifdef QUIC_ENABLED
        alpn _ = return . fromMaybe "" . find (== "h3")
        altsvc qport = BS8.concat ["h3=\":", BS8.pack $ show qport ,"\""]

        quicset qport = Q.defaultServerConfig
            { Q.scAddresses      = [(fromString (fromMaybe "0.0.0.0" _bind), fromIntegral qport)]
            , Q.scVersions       = [Q.Version1, Q.Version2]
            , Q.scCredentials    = TLS.Credentials [head certs]
            , Q.scCiphers        = Q.scCiphers Q.defaultServerConfig \\ weak_ciphers
            , Q.scALPN           = Just alpn
            , Q.scUse0RTT        = True
            , Q.scSessionManager = smgr
            }

        runner | not isSSL           = runSettings settings
               | Just qport <- _quic = \app -> do
                    logger INFO $ "bind to UDP port " <> toLogStr (fromMaybe "0.0.0.0" _bind) <> ":" <> toLogStr qport
                    mapConcurrently_ ($ app)
                        [ runQUIC (quicset qport) settings
                        , runTLS tlsset (setAltSvc (altsvc qport) settings)
                        ]
               | otherwise           = runTLS tlsset settings
#else
        runner :: Application -> IO ()
runner | Bool
isSSL     = TLSSettings -> Settings -> Application -> IO ()
runTLS TLSSettings
tlsset Settings
settings
               | Bool
otherwise = Settings -> Application -> IO ()
runSettings Settings
settings
#endif

    Maybe (ByteString -> Bool)
pauth <- case Maybe String
_auth of
        Maybe String
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        Just String
f  -> do
            LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"read username and passwords from " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr String
f
            forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> Maybe Int
BS8.elemIndex Char
':') forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BS8.lines forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO ByteString
BS8.readFile String
f
    Manager
manager <- forall (m :: * -> *). MonadIO m => m Manager
newTlsManager

    let pset :: ProxySettings
pset = Maybe (ByteString -> Bool)
-> Maybe ByteString
-> Maybe ByteString
-> Maybe ByteString
-> Bool
-> (LogLevel -> LogStr -> IO ())
-> ProxySettings
ProxySettings Maybe (ByteString -> Bool)
pauth (forall a. a -> Maybe a
Just ByteString
_name) (String -> ByteString
BS8.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe String
_ws) (String -> ByteString
BS8.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe String
_rev) (Bool
_naive Bool -> Bool -> Bool
&& Bool
isSSL) LogLevel -> LogStr -> IO ()
logger
        proxy :: Application
proxy = Middleware
healthCheckProvider forall a b. (a -> b) -> a -> b
$
                (if Bool
isSSL then ProxySettings -> Middleware
forceSSL ProxySettings
pset else forall a. a -> a
id) forall a b. (a -> b) -> a -> b
$
                ProxySettings -> Manager -> Middleware
httpProxy ProxySettings
pset Manager
manager forall a b. (a -> b) -> a -> b
$
                ProxySettings -> Manager -> Middleware
reverseProxy ProxySettings
pset Manager
manager forall a b. (a -> b) -> a -> b
$
                Application
fallback

    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust Maybe String
_ws) forall a b. (a -> b) -> a -> b
$ LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"websocket redirect: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. HasCallStack => Maybe a -> a
fromJust Maybe String
_ws)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust Maybe String
_rev) forall a b. (a -> b) -> a -> b
$ LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"reverse proxy: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. HasCallStack => Maybe a -> a
fromJust Maybe String
_rev)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust Maybe String
_doh) forall a b. (a -> b) -> a -> b
$ LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"DNS-over-HTTPS redirect: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. HasCallStack => Maybe a -> a
fromJust Maybe String
_doh)

    case Maybe String
_doh of
        Maybe String
Nothing  -> Application -> IO ()
runner Application
proxy
        Just String
doh -> forall a. String -> (Resolver -> IO a) -> IO a
createResolver String
doh (\Resolver
resolver -> Application -> IO ()
runner (Resolver -> Middleware
dnsOverHTTPS Resolver
resolver Application
proxy))