-- SPDX-License-Identifier: Apache-2.0

-- Copyright (C) 2023 Bin Jin. All Rights Reserved.
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns        #-}

{-| Instead of running @hprox@ binary directly, you can use this library
    to run HProx in front of arbitrary WAI 'Application'.
-}

module Network.HProx
  ( CertFile (..)
  , Config (..)
  , LogLevel (..)
  , defaultConfig
  , getConfig
  , run
  ) where

import Data.ByteString.Char8       qualified as BS8
import Data.HashMap.Strict         qualified as HM
import Data.Ord                    (Down (..))
import Data.String                 (fromString)
import Data.Version                (showVersion)
import Network.HTTP.Client.TLS     (newTlsManager)
import Network.HTTP.Types          qualified as HT
import Network.TLS                 qualified as TLS
import Network.TLS.Extra.Cipher    qualified as TLS
import Network.TLS.SessionManager  qualified as SM
import Network.Wai                 (Application, rawPathInfo)
import Network.Wai.Handler.Warp
    (InvalidRequest (..), defaultSettings, defaultShouldDisplayException,
    runSettings, setHost, setLogger, setNoParsePath, setOnException, setPort,
    setServerName)
import Network.Wai.Handler.WarpTLS
    (OnInsecure (..), WarpTLSException, onInsecure, runTLS, tlsAllowedVersions,
    tlsCiphers, tlsServerHooks, tlsSessionManager, tlsSettings)

import Control.Exception    (Exception (..))
import GHC.IO.Exception     (IOErrorType (..))
import Network.HTTP2.Client qualified as H2
import System.IO.Error      (ioeGetErrorType)

#ifdef QUIC_ENABLED
import Control.Concurrent.Async     (mapConcurrently_)
import Network.QUIC                 qualified as Q
import Network.QUIC.Internal        qualified as Q
import Network.Wai.Handler.Warp     (setAltSvc)
import Network.Wai.Handler.WarpQUIC (runQUIC)
#endif

import Control.Monad
import Data.List
import Data.Maybe
import Options.Applicative

import Network.HProx.DoH
import Network.HProx.Impl
import Network.HProx.Log
import Network.HProx.Util
import Paths_hprox

-- | Configuration of HProx, see @hprox --help@ for details
data Config = Config
  { Config -> Maybe String
_bind     :: Maybe String
  , Config -> Int
_port     :: Int
  , Config -> [(String, CertFile)]
_ssl      :: [(String, CertFile)]
  , Config -> Maybe String
_auth     :: Maybe FilePath
  , Config -> Maybe String
_ws       :: Maybe String
  , Config -> [(Maybe ByteString, ByteString, ByteString)]
_rev      :: [(Maybe BS8.ByteString, BS8.ByteString, BS8.ByteString)]
  , Config -> Maybe String
_doh      :: Maybe String
  , Config -> Bool
_hide     :: Bool
  , Config -> Bool
_naive    :: Bool
  , Config -> ByteString
_name     :: BS8.ByteString
  , Config -> String
_log      :: String
  , Config -> LogLevel
_loglevel :: LogLevel
#ifdef QUIC_ENABLED
  , _quic     :: Maybe Int
#endif
  }

-- | Default value of 'Config', same as running @hprox@ without arguments
defaultConfig :: Config
defaultConfig :: Config
defaultConfig = Maybe String
-> Int
-> [(String, CertFile)]
-> Maybe String
-> Maybe String
-> [(Maybe ByteString, ByteString, ByteString)]
-> Maybe String
-> Bool
-> Bool
-> ByteString
-> String
-> LogLevel
-> Config
Config forall a. Maybe a
Nothing Int
3000 [] forall a. Maybe a
Nothing forall a. Maybe a
Nothing [] forall a. Maybe a
Nothing Bool
False Bool
False ByteString
"hprox" String
"stdout" LogLevel
INFO
#ifdef QUIC_ENABLED
    Nothing
#endif

-- | Certificate file pairs
data CertFile = CertFile
  { CertFile -> String
certfile :: FilePath
  , CertFile -> String
keyfile  :: FilePath
  }

readCert :: CertFile -> IO TLS.Credential
readCert :: CertFile -> IO Credential
readCert (CertFile String
c String
k) = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO (Either String Credential)
TLS.credentialLoadX509 String
c String
k

parser :: ParserInfo Config
parser :: ParserInfo Config
parser = forall a. Parser a -> InfoMod a -> ParserInfo a
info (forall a. Parser (a -> a)
helper forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Parser (a -> a)
ver forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Config
config) (forall a. InfoMod a
fullDesc forall a. Semigroup a => a -> a -> a
<> forall a. String -> InfoMod a
progDesc String
desc)
  where
    parseSSL :: String -> Either a (String, CertFile)
parseSSL String
s = case forall a. Eq a => a -> [a] -> [[a]]
splitBy Char
':' String
s of
        [String
host, String
cert, String
key] -> forall a b. b -> Either a b
Right (String
host, String -> String -> CertFile
CertFile String
cert String
key)
        [String]
_                 -> forall a b. a -> Either a b
Left a
"invalid format for ssl certificates"

    parseRev0 :: String -> Maybe (Maybe a, ByteString, ByteString)
parseRev0 s :: String
s@(Char
'/':String
_) = case forall a. Eq a => a -> [a] -> [Int]
elemIndices Char
'/' String
s of
        []      -> forall a. Maybe a
Nothing
        [Int]
indices -> let (String
prefix, String
remote) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall a. [a] -> a
last [Int]
indices forall a. Num a => a -> a -> a
+ Int
1) String
s
                   in forall a. a -> Maybe a
Just (forall a. Maybe a
Nothing, String -> ByteString
BS8.pack String
prefix, String -> ByteString
BS8.pack String
remote)
    parseRev0 String
remote = forall a. a -> Maybe a
Just (forall a. Maybe a
Nothing, ByteString
"/", String -> ByteString
BS8.pack String
remote)

    parseRev :: String -> Maybe (Maybe ByteString, ByteString, ByteString)
parseRev (Char
'/':Char
'/':String
s) = case forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Char
'/' String
s of
        Maybe Int
Nothing  -> forall a. Maybe a
Nothing
        Just Int
ind -> let (String
domain, String
other) = forall a. Int -> [a] -> ([a], [a])
splitAt Int
ind String
s
                    in do (Maybe Any
_, ByteString
prefix, ByteString
remote) <- forall {a}. String -> Maybe (Maybe a, ByteString, ByteString)
parseRev0 String
other
                          forall (m :: * -> *) a. Monad m => a -> m a
return (forall a. a -> Maybe a
Just (String -> ByteString
BS8.pack String
domain), ByteString
prefix, ByteString
remote)

    parseRev String
s = forall {a}. String -> Maybe (Maybe a, ByteString, ByteString)
parseRev0 String
s

    desc :: String
desc = String
"a lightweight HTTP proxy server, and more"
    ver :: Parser (a -> a)
ver = forall a. String -> Mod OptionFields (a -> a) -> Parser (a -> a)
infoOption (Version -> String
showVersion Version
version) (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"version" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Display the version information")

    config :: Parser Config
config = Maybe String
-> Int
-> [(String, CertFile)]
-> Maybe String
-> Maybe String
-> [(Maybe ByteString, ByteString, ByteString)]
-> Maybe String
-> Bool
-> Bool
-> ByteString
-> String
-> LogLevel
-> Config
Config forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser (Maybe String)
bind
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Int
port
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser [(String, CertFile)]
ssl
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
auth
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
ws
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser [(Maybe ByteString, ByteString, ByteString)]
rev
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
doh
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Bool
hide
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Bool
naive
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser ByteString
name
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser String
logging
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser LogLevel
loglevel
#ifdef QUIC_ENABLED
                    <*> quic
#endif

    bind :: Parser (Maybe String)
bind = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"bind"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'b'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"bind_ip"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the IP address to bind to (default: all interfaces)")

    port :: Parser Int
port = forall a. ReadM a -> Mod OptionFields a -> Parser a
option forall a. Read a => ReadM a
auto
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'p'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value Int
3000
       forall a. Semigroup a => a -> a -> a
<> forall a (f :: * -> *). Show a => Mod f a
showDefault
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the port number")

    ssl :: Parser [(String, CertFile)]
ssl = forall (f :: * -> *) a. Alternative f => f a -> f [a]
many forall a b. (a -> b) -> a -> b
$ forall a. ReadM a -> Mod OptionFields a -> Parser a
option (forall a. (String -> Either String a) -> ReadM a
eitherReader forall {a}. IsString a => String -> Either a (String, CertFile)
parseSSL)
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"tls"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
's'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"hostname:cerfile:keyfile"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Enable TLS and specify a domain with its associated TLS certificate (can be specified multiple times for multiple domains)")

    auth :: Parser (Maybe String)
auth = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"auth"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'a'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"userpass.txt"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the password file for proxy authentication. Plaintext passwords should be in the format 'user:pass' and will be automatically Argon2-hashed by hprox. Ensure that the password file with plaintext password is writable")

    ws :: Parser (Maybe String)
ws = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"ws"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"remote-host:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the remote host to handle WebSocket requests (port 443 indicates an HTTPS remote server)")

    rev :: Parser [(Maybe ByteString, ByteString, ByteString)]
rev = forall (f :: * -> *) a. Alternative f => f a -> f [a]
many forall a b. (a -> b) -> a -> b
$ forall a. ReadM a -> Mod OptionFields a -> Parser a
option (forall a. (String -> Maybe a) -> ReadM a
maybeReader String -> Maybe (Maybe ByteString, ByteString, ByteString)
parseRev)
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"rev"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"[//domain/][/prefix/]remote-host:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the remote host for reverse proxy (port 443 indicates an HTTPS remote server). An optional '//domain/' will only process requests with the 'Host: domain' header, and an optional '/prefix/' can be specified as a prefix to be matched (and stripped in proxied request)")

    doh :: Parser (Maybe String)
doh = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"doh"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"dns-server:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Enable DNS-over-HTTPS (DoH) support (port 53 will be used if not specified)")

    hide :: Parser Bool
hide = Mod FlagFields Bool -> Parser Bool
switch
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"hide"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Never send 'Proxy Authentication Required' response. Note that this might break the use of HTTPS proxy in browsers")

    naive :: Parser Bool
naive = Mod FlagFields Bool -> Parser Bool
switch
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"naive"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Add naiveproxy-compatible padding (requires TLS)")

    name :: Parser ByteString
name = forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"name"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"server-name"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value ByteString
"hprox"
       forall a. Semigroup a => a -> a -> a
<> forall a (f :: * -> *). Show a => Mod f a
showDefault
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the server name for the 'Server' header")

    logging :: Parser String
logging = forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"log"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"<none|stdout|stderr|file>"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value String
"stdout"
       forall a. Semigroup a => a -> a -> a
<> forall a (f :: * -> *). Show a => Mod f a
showDefault
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the logging type")

    loglevel :: Parser LogLevel
loglevel = forall a. ReadM a -> Mod OptionFields a -> Parser a
option (forall a. (String -> Maybe a) -> ReadM a
maybeReader String -> Maybe LogLevel
logLevelReader)
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"loglevel"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"<trace|debug|info|warn|error|none>"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasValue f => a -> Mod f a
value LogLevel
INFO
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"Specify the logging level (default: info)")

#ifdef QUIC_ENABLED
    quic = optional $ option auto
        ( long "quic"
       <> short 'q'
       <> metavar "port"
       <> help "Enable QUIC (HTTP/3) on UDP port")
#endif

getLoggerType :: String -> LogType' LogStr
getLoggerType :: String -> LogType' LogStr
getLoggerType String
"none"   = LogType' LogStr
LogNone
getLoggerType String
"stdout" = Int -> LogType' LogStr
LogStdout Int
4096
getLoggerType String
"stderr" = Int -> LogType' LogStr
LogStderr Int
4096
getLoggerType String
file     = String -> Int -> LogType' LogStr
LogFileNoRotate String
file Int
4096

-- | Read 'Config' from command line arguments
getConfig :: IO Config
getConfig :: IO Config
getConfig = forall a. ParserInfo a -> IO a
execParser ParserInfo Config
parser

-- | Run HProx in front of fallback 'Application', with specified 'Config'
run :: Application -- ^ fallback application
    -> Config      -- ^ configuration
    -> IO ()
run :: Application -> Config -> IO ()
run Application
fallback Config{Bool
Int
String
[(String, CertFile)]
[(Maybe ByteString, ByteString, ByteString)]
Maybe String
ByteString
LogLevel
_loglevel :: LogLevel
_log :: String
_name :: ByteString
_naive :: Bool
_hide :: Bool
_doh :: Maybe String
_rev :: [(Maybe ByteString, ByteString, ByteString)]
_ws :: Maybe String
_auth :: Maybe String
_ssl :: [(String, CertFile)]
_port :: Int
_bind :: Maybe String
_loglevel :: Config -> LogLevel
_log :: Config -> String
_name :: Config -> ByteString
_naive :: Config -> Bool
_hide :: Config -> Bool
_doh :: Config -> Maybe String
_rev :: Config -> [(Maybe ByteString, ByteString, ByteString)]
_ws :: Config -> Maybe String
_auth :: Config -> Maybe String
_ssl :: Config -> [(String, CertFile)]
_port :: Config -> Int
_bind :: Config -> Maybe String
..} = LogType' LogStr
-> LogLevel -> ((LogLevel -> LogStr -> IO ()) -> IO ()) -> IO ()
withLogger (String -> LogType' LogStr
getLoggerType String
_log) LogLevel
_loglevel forall a b. (a -> b) -> a -> b
$ \LogLevel -> LogStr -> IO ()
logger -> do
    LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"hprox " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Version -> String
showVersion Version
version) forall a. Semigroup a => a -> a -> a
<> LogStr
" started"
    LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"bind to TCP port " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. a -> Maybe a -> a
fromMaybe String
"[::]" Maybe String
_bind) forall a. Semigroup a => a -> a -> a
<> LogStr
":" forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr Int
_port

    let certfiles :: [(String, CertFile)]
certfiles = [(String, CertFile)]
_ssl

    [Credential]
certs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CertFile -> IO Credential
readCertforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> b
snd) [(String, CertFile)]
certfiles
    SessionManager
smgr <- Config -> IO SessionManager
SM.newSessionManager Config
SM.defaultConfig

    let isSSL :: Bool
isSSL = Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(String, CertFile)]
certfiles)
        (String
primaryHost, CertFile
primaryCert) = forall a. [a] -> a
head [(String, CertFile)]
certfiles
        otherCerts :: [(String, Credential)]
otherCerts = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(String, CertFile)]
certfiles) [Credential]
certs

    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
isSSL forall a b. (a -> b) -> a -> b
$ do
        LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"read " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Int
length [Credential]
certs) forall a. Semigroup a => a -> a -> a
<> LogStr
" certificates"
        LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"primary domain: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr String
primaryHost
        LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"other domains: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr ([String] -> String
unwords forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(String, Credential)]
otherCerts)

    let settings :: Settings
settings = HostPreference -> Settings -> Settings
setHost (forall a. IsString a => String -> a
fromString (forall a. a -> Maybe a -> a
fromMaybe String
"*6" Maybe String
_bind)) forall a b. (a -> b) -> a -> b
$
                   Int -> Settings -> Settings
setPort Int
_port forall a b. (a -> b) -> a -> b
$
                   (Request -> Status -> Maybe Integer -> IO ())
-> Settings -> Settings
setLogger forall {p}. Request -> Status -> p -> IO ()
warpLogger forall a b. (a -> b) -> a -> b
$
                   (Maybe Request -> SomeException -> IO ()) -> Settings -> Settings
setOnException Maybe Request -> SomeException -> IO ()
exceptionHandler forall a b. (a -> b) -> a -> b
$
                   Bool -> Settings -> Settings
setNoParsePath Bool
True forall a b. (a -> b) -> a -> b
$
                   ByteString -> Settings -> Settings
setServerName ByteString
_name Settings
defaultSettings

        exceptionHandler :: Maybe Request -> SomeException -> IO ()
exceptionHandler Maybe Request
req SomeException
ex
            | LogLevel
_loglevel forall a. Ord a => a -> a -> Bool
> LogLevel
DEBUG                                 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool -> Bool
not (SomeException -> Bool
defaultShouldDisplayException SomeException
ex)            = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Just (IOError -> IOErrorType
ioeGetErrorType -> IOErrorType
EOF) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Just (H2.BadThingHappen SomeException
ex') <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex  = Maybe Request -> SomeException -> IO ()
exceptionHandler Maybe Request
req SomeException
ex'
            | Just (HTTP2Error
_ :: H2.HTTP2Error) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex     = forall (m :: * -> *) a. Monad m => a -> m a
return ()
#ifdef QUIC_ENABLED
            | Just (Q.BadThingHappen ex') <- fromException ex   = exceptionHandler req ex'
            | Just (_ :: Q.QUICException) <- fromException ex   = return ()
#endif
            | Just (WarpTLSException
_ :: WarpTLSException) <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex  = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Just InvalidRequest
ConnectionClosedByPeer <- forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex   = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise                                         =
                LogLevel -> LogStr -> IO ()
logger LogLevel
DEBUG forall a b. (a -> b) -> a -> b
$ LogStr
"exception: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall e. Exception e => e -> String
displayException SomeException
ex) forall a. Semigroup a => a -> a -> a
<>
                    (if forall a. Maybe a -> Bool
isJust Maybe Request
req then LogStr
" from: " forall a. Semigroup a => a -> a -> a
<> Request -> LogStr
logRequest (forall a. HasCallStack => Maybe a -> a
fromJust Maybe Request
req) else LogStr
"")

        warpLogger :: Request -> Status -> p -> IO ()
warpLogger Request
req Status
status p
_
            | Request -> ByteString
rawPathInfo Request
req forall a. Eq a => a -> a -> Bool
== ByteString
"/.hprox/health" = forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise                           =
                LogLevel -> LogStr -> IO ()
logger LogLevel
TRACE forall a b. (a -> b) -> a -> b
$ LogStr
"(" forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (Status -> Int
HT.statusCode Status
status) forall a. Semigroup a => a -> a -> a
<> LogStr
") " forall a. Semigroup a => a -> a -> a
<> Request -> LogStr
logRequest Request
req

        tlsset' :: TLSSettings
tlsset' = String -> String -> TLSSettings
tlsSettings (CertFile -> String
certfile CertFile
primaryCert) (CertFile -> String
keyfile CertFile
primaryCert)
        hooks :: ServerHooks
hooks = (TLSSettings -> ServerHooks
tlsServerHooks TLSSettings
tlsset') { onServerNameIndication :: Maybe String -> IO Credentials
TLS.onServerNameIndication = forall {m :: * -> *}. MonadFail m => Maybe String -> m Credentials
onSNI }

        -- https://www.ssllabs.com/ssltest
        weak_ciphers :: [Cipher]
weak_ciphers = [ Cipher
TLS.cipher_ECDHE_RSA_AES256CBC_SHA384
                       , Cipher
TLS.cipher_ECDHE_RSA_AES256CBC_SHA
                       , Cipher
TLS.cipher_AES256CCM_SHA256
                       , Cipher
TLS.cipher_AES256GCM_SHA384
                       , Cipher
TLS.cipher_AES256_SHA256
                       , Cipher
TLS.cipher_AES256_SHA1
                       ]

        tlsset :: TLSSettings
tlsset = TLSSettings
tlsset'
            { tlsServerHooks :: ServerHooks
tlsServerHooks     = ServerHooks
hooks
            , onInsecure :: OnInsecure
onInsecure         = OnInsecure
AllowInsecure
            , tlsAllowedVersions :: [Version]
tlsAllowedVersions = [Version
TLS.TLS13, Version
TLS.TLS12]
            , tlsCiphers :: [Cipher]
tlsCiphers         = [Cipher]
TLS.ciphersuite_strong forall a. Eq a => [a] -> [a] -> [a]
\\ [Cipher]
weak_ciphers
            , tlsSessionManager :: Maybe SessionManager
tlsSessionManager  = forall a. a -> Maybe a
Just SessionManager
smgr
            }

        onSNI :: Maybe String -> m Credentials
onSNI Maybe String
Nothing = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SNI: unspecified"
        onSNI (Just String
host)
          | String -> String -> Bool
checkSNI String
host String
primaryHost = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
          | Bool
otherwise                 = forall {m :: * -> *}.
MonadFail m =>
String -> [(String, Credential)] -> m Credentials
lookupSNI String
host [(String, Credential)]
otherCerts

        lookupSNI :: String -> [(String, Credential)] -> m Credentials
lookupSNI String
host [] = forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"SNI: unknown hostname (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show String
host forall a. [a] -> [a] -> [a]
++ String
")")
        lookupSNI String
host ((String
p, Credential
cert) : [(String, Credential)]
cs)
          | String -> String -> Bool
checkSNI String
host String
p = forall (m :: * -> *) a. Monad m => a -> m a
return ([Credential] -> Credentials
TLS.Credentials [Credential
cert])
          | Bool
otherwise       = String -> [(String, Credential)] -> m Credentials
lookupSNI String
host [(String, Credential)]
cs

        checkSNI :: String -> String -> Bool
checkSNI String
host String
pat = case String
pat of
            Char
'*' : Char
'.' : String
p -> (Char
'.' forall a. a -> [a] -> [a]
: String
p) forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` String
host
            String
p             -> String
host forall a. Eq a => a -> a -> Bool
== String
p

#ifdef QUIC_ENABLED
        alpn _ = return . fromMaybe "" . find (== "h3")
        altsvc qport = BS8.concat ["h3=\":", BS8.pack $ show qport ,"\""]

        quicset qport = Q.defaultServerConfig
            { Q.scAddresses      = [(fromString (fromMaybe "0.0.0.0" _bind), fromIntegral qport)]
            , Q.scVersions       = [Q.Version1, Q.Version2]
            , Q.scCredentials    = TLS.Credentials [head certs]
            , Q.scCiphers        = Q.scCiphers Q.defaultServerConfig \\ weak_ciphers
            , Q.scALPN           = Just alpn
            , Q.scUse0RTT        = True
            , Q.scSessionManager = smgr
            }

        runner | not isSSL           = runSettings settings
               | Just qport <- _quic = \app -> do
                    logger INFO $ "bind to UDP port " <> toLogStr (fromMaybe "0.0.0.0" _bind) <> ":" <> toLogStr qport
                    mapConcurrently_ ($ app)
                        [ runQUIC (quicset qport) settings
                        , runTLS tlsset (setAltSvc (altsvc qport) settings)
                        ]
               | otherwise           = runTLS tlsset settings
#else
        runner :: Application -> IO ()
runner | Bool
isSSL     = TLSSettings -> Settings -> Application -> IO ()
runTLS TLSSettings
tlsset Settings
settings
               | Bool
otherwise = Settings -> Application -> IO ()
runSettings Settings
settings
#endif

    Maybe (ByteString -> Bool)
pauth <- case Maybe String
_auth of
        Maybe String
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        Just String
f  -> do
            LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"read username and passwords from " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr String
f
            [ByteString]
userList <- ByteString -> [ByteString]
BS8.lines forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO ByteString
BS8.readFile String
f
            let anyPlaintext :: Bool
anyPlaintext = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (\ByteString
line -> forall (t :: * -> *) a. Foldable t => t a -> Int
length (Char -> ByteString -> [Int]
BS8.elemIndices Char
':' ByteString
line) forall a. Eq a => a -> a -> Bool
/= Int
2) [ByteString]
userList
                processUser :: ByteString -> IO (Maybe (ByteString, PasswordSalted))
processUser ByteString
userpass = case ByteString -> Maybe (ByteString, Password)
passwordReader ByteString
userpass of
                    Maybe (ByteString, Password)
Nothing           -> do
                        LogLevel -> LogStr -> IO ()
logger LogLevel
WARN forall a b. (a -> b) -> a -> b
$ LogStr
"unable to parse line from password file: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr ByteString
userpass
                        forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
                    Just (ByteString
user, Password
pass) -> do
                        PasswordSalted
salted <- Password -> IO PasswordSalted
hashPasswordWithRandomSalt Password
pass
                        LogLevel -> LogStr -> IO ()
logger LogLevel
TRACE forall a b. (a -> b) -> a -> b
$ LogStr
"parsed user (with salted password) from password file: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (ByteString -> PasswordSalted -> ByteString
passwordWriter ByteString
user PasswordSalted
salted)
                        forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (ByteString
user, PasswordSalted
salted)
            HashMap ByteString PasswordSalted
passwordByUser <- forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k v
HM.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [Maybe a] -> [a]
catMaybes forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ByteString -> IO (Maybe (ByteString, PasswordSalted))
processUser [ByteString]
userList
            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
anyPlaintext forall a b. (a -> b) -> a -> b
$ do
                LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"writing back to password file " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr String
f
                String -> ByteString -> IO ()
BS8.writeFile String
f ([ByteString] -> ByteString
BS8.unlines [ ByteString -> PasswordSalted -> ByteString
passwordWriter ByteString
u PasswordSalted
p | (ByteString
u, PasswordSalted
p) <- forall k v. HashMap k v -> [(k, v)]
HM.toList HashMap ByteString PasswordSalted
passwordByUser])
            let verify :: ByteString -> Maybe Bool
verify ByteString
line = do
                    Int
idx <- Char -> ByteString -> Maybe Int
BS8.elemIndex Char
':' ByteString
line
                    let user :: ByteString
user = Int -> ByteString -> ByteString
BS8.take Int
idx ByteString
line
                        pass :: ByteString
pass = Int -> ByteString -> ByteString
BS8.drop (Int
idx forall a. Num a => a -> a -> a
+ Int
1) ByteString
line
                    PasswordSalted
targetPass <- forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HM.lookup ByteString
user HashMap ByteString PasswordSalted
passwordByUser
                    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ PasswordSalted -> ByteString -> Bool
verifyPassword PasswordSalted
targetPass ByteString
pass
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (\ByteString
line -> ByteString -> Maybe Bool
verify ByteString
line forall a. Eq a => a -> a -> Bool
== forall a. a -> Maybe a
Just Bool
True)

    Manager
manager <- forall (m :: * -> *). MonadIO m => m Manager
newTlsManager

    let revSorted :: [(Maybe ByteString, ByteString, ByteString)]
revSorted = forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (\(Maybe ByteString
a,ByteString
b,ByteString
_) -> forall a. a -> Down a
Down (forall a. Maybe a -> Bool
isJust Maybe ByteString
a, ByteString -> Int
BS8.length ByteString
b)) [(Maybe ByteString, ByteString, ByteString)]
_rev
        pset :: ProxySettings
pset = Maybe (ByteString -> Bool)
-> Maybe ByteString
-> Maybe ByteString
-> [(Maybe ByteString, ByteString, ByteString)]
-> Bool
-> Bool
-> (LogLevel -> LogStr -> IO ())
-> ProxySettings
ProxySettings Maybe (ByteString -> Bool)
pauth (forall a. a -> Maybe a
Just ByteString
_name) (String -> ByteString
BS8.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe String
_ws) [(Maybe ByteString, ByteString, ByteString)]
revSorted Bool
_hide (Bool
_naive Bool -> Bool -> Bool
&& Bool
isSSL) LogLevel -> LogStr -> IO ()
logger
        proxy :: Application
proxy = Middleware
healthCheckProvider forall a b. (a -> b) -> a -> b
$
                (if Bool
isSSL then ProxySettings -> Middleware
forceSSL ProxySettings
pset else forall a. a -> a
id) forall a b. (a -> b) -> a -> b
$
                ProxySettings -> Manager -> Middleware
httpProxy ProxySettings
pset Manager
manager forall a b. (a -> b) -> a -> b
$
                ProxySettings -> Manager -> Middleware
reverseProxy ProxySettings
pset Manager
manager Application
fallback

    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust Maybe String
_ws) forall a b. (a -> b) -> a -> b
$ LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"websocket redirect: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. HasCallStack => Maybe a -> a
fromJust Maybe String
_ws)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Maybe ByteString, ByteString, ByteString)]
revSorted) forall a b. (a -> b) -> a -> b
$ LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"reverse proxy: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. Show a => a -> String
show [(Maybe ByteString, ByteString, ByteString)]
revSorted)
    forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall a. Maybe a -> Bool
isJust Maybe String
_doh) forall a b. (a -> b) -> a -> b
$ LogLevel -> LogStr -> IO ()
logger LogLevel
INFO forall a b. (a -> b) -> a -> b
$ LogStr
"DNS-over-HTTPS redirect: " forall a. Semigroup a => a -> a -> a
<> forall msg. ToLogStr msg => msg -> LogStr
toLogStr (forall a. HasCallStack => Maybe a -> a
fromJust Maybe String
_doh)

    case Maybe String
_doh of
        Maybe String
Nothing  -> Application -> IO ()
runner Application
proxy
        Just String
doh -> forall a. String -> (Resolver -> IO a) -> IO a
createResolver String
doh (\Resolver
resolver -> Application -> IO ()
runner (Resolver -> Middleware
dnsOverHTTPS Resolver
resolver Application
proxy))