-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}

{-| Instead of running @hprox@ binary directly, you can use this library
    to run HProx in front of arbitrary WAI 'Application'.
-}

module Network.HProx
  ( Config(..)
  , CertFile(..)
  , defaultConfig
  , getConfig
  , run
  ) where

import qualified Data.ByteString.Char8               as BS8
import           Data.List                           (isSuffixOf)
import           Data.String                         (fromString)
import           Network.HTTP.Client.TLS             (newTlsManager)
import           Network.TLS                         as TLS
import           Network.Wai                         (Application,
                                                      modifyResponse)
import           Network.Wai.Handler.Warp            (HostPreference,
                                                      defaultSettings,
                                                      runSettings,
                                                      setBeforeMainLoop,
                                                      setHost, setNoParsePath,
                                                      setOnException, setPort,
                                                      setServerName)
import           Network.Wai.Handler.WarpTLS         (OnInsecure (..),
                                                      onInsecure, runTLS,
                                                      tlsServerHooks,
                                                      tlsSettings)
import           Network.Wai.Middleware.Gzip         (def, gzip)
import           Network.Wai.Middleware.StripHeaders (stripHeaders)
import           System.Posix.User                   (UserEntry (..),
                                                      getUserEntryForName,
                                                      setUserID)

import           Data.Maybe
import           Data.Version                        (showVersion)
import           Options.Applicative

import           Network.HProx.DoH
import           Network.HProx.Impl                  (ProxySettings (..),
                                                      forceSSL, httpProxy,
                                                      reverseProxy)
import           Paths_hprox                         (version)

-- | Configuration of HProx, see @hprox --help@ for details
data Config = Config
  { Config -> Maybe HostPreference
_bind :: Maybe HostPreference
  , Config -> Int
_port :: Int
  , Config -> [(String, CertFile)]
_ssl  :: [(String, CertFile)]
  , Config -> Maybe String
_user :: Maybe String
  , Config -> Maybe String
_auth :: Maybe FilePath
  , Config -> Maybe String
_ws   :: Maybe String
  , Config -> Maybe String
_rev  :: Maybe String
  , Config -> Maybe String
_doh  :: Maybe String
  }

-- | Default value of 'Config', same as running @hprox@ without arguments
defaultConfig :: Config
defaultConfig :: Config
defaultConfig = Maybe HostPreference
-> Int
-> [(String, CertFile)]
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Config
Config forall a. Maybe a
Nothing Int
3000 [] forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing forall a. Maybe a
Nothing

-- | Certificate file pairs
data CertFile = CertFile
  { CertFile -> String
certfile :: FilePath
  , CertFile -> String
keyfile  :: FilePath
  }

readCert :: CertFile -> IO TLS.Credential
readCert :: CertFile -> IO Credential
readCert (CertFile String
c String
k) = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> String -> IO (Either String Credential)
TLS.credentialLoadX509 String
c String
k

splitBy :: Eq a => a -> [a] -> [[a]]
splitBy :: forall a. Eq a => a -> [a] -> [[a]]
splitBy a
_ [] = [[]]
splitBy a
c (a
x:[a]
xs)
  | a
c forall a. Eq a => a -> a -> Bool
== a
x    = [] forall a. a -> [a] -> [a]
: forall a. Eq a => a -> [a] -> [[a]]
splitBy a
c [a]
xs
  | Bool
otherwise = let [a]
y:[[a]]
ys = forall a. Eq a => a -> [a] -> [[a]]
splitBy a
c [a]
xs in (a
xforall a. a -> [a] -> [a]
:[a]
y)forall a. a -> [a] -> [a]
:[[a]]
ys

parser :: ParserInfo Config
parser :: ParserInfo Config
parser = forall a. Parser a -> InfoMod a -> ParserInfo a
info (forall a. Parser (a -> a)
helper forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Parser (a -> a)
ver forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser Config
config) (forall a. InfoMod a
fullDesc forall a. Semigroup a => a -> a -> a
<> forall a. String -> InfoMod a
progDesc String
desc)
  where
    parseSSL :: String -> Either a (String, CertFile)
parseSSL String
s = case forall a. Eq a => a -> [a] -> [[a]]
splitBy Char
':' String
s of
        [String
host, String
cert, String
key] -> forall a b. b -> Either a b
Right (String
host, String -> String -> CertFile
CertFile String
cert String
key)
        [String]
_                 -> forall a b. a -> Either a b
Left a
"invalid format for ssl certificates"

    desc :: String
desc = String
"a lightweight HTTP proxy server, and more"
    ver :: Parser (a -> a)
ver = forall a. String -> Mod OptionFields (a -> a) -> Parser (a -> a)
infoOption (Version -> String
showVersion Version
version) (forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"version" forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"show version")

    config :: Parser Config
config = Maybe HostPreference
-> Int
-> [(String, CertFile)]
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Maybe String
-> Config
Config forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser (Maybe HostPreference)
bind
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. a -> Maybe a -> a
fromMaybe Int
3000 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Parser (Maybe Int)
port)
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser [(String, CertFile)]
ssl
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
user
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
auth
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
ws
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
rev
                    forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Parser (Maybe String)
doh

    bind :: Parser (Maybe HostPreference)
bind = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall a. IsString a => String -> a
fromString forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"bind"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'b'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"bind_ip"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"ip address to bind on (default: all interfaces)")

    port :: Parser (Maybe Int)
port = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall a. ReadM a -> Mod OptionFields a -> Parser a
option forall a. Read a => ReadM a
auto
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'p'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"port number (default 3000)")

    ssl :: Parser [(String, CertFile)]
ssl = forall (f :: * -> *) a. Alternative f => f a -> f [a]
many forall a b. (a -> b) -> a -> b
$ forall a. ReadM a -> Mod OptionFields a -> Parser a
option (forall a. (String -> Either String a) -> ReadM a
eitherReader forall {a}. IsString a => String -> Either a (String, CertFile)
parseSSL)
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"tls"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
's'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"hostname:cerfile:keyfile"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"enable TLS and specify a domain and associated TLS certificate (can be specified multiple times for multiple domains)")

    user :: Parser (Maybe String)
user = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"user"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'u'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"nobody"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"setuid after binding port")

    auth :: Parser (Maybe String)
auth = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"auth"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasName f => Char -> Mod f a
short Char
'a'
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"userpass.txt"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"password file for proxy authentication (plain text file with lines each containing a colon separated user/password pair)")

    ws :: Parser (Maybe String)
ws = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"ws"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"remote-host:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"remote host to handle websocket requests (port 443 indicates HTTPS remote server)")

    rev :: Parser (Maybe String)
rev = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"rev"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"remote-host:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"remote host for reverse proxy (port 443 indicates HTTPS remote server)")

    doh :: Parser (Maybe String)
doh = forall (f :: * -> *) a. Alternative f => f a -> f (Maybe a)
optional forall a b. (a -> b) -> a -> b
$ forall s. IsString s => Mod OptionFields s -> Parser s
strOption
        ( forall (f :: * -> *) a. HasName f => String -> Mod f a
long String
"doh"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. HasMetavar f => String -> Mod f a
metavar String
"dns-server:port"
       forall a. Semigroup a => a -> a -> a
<> forall (f :: * -> *) a. String -> Mod f a
help String
"enable DNS-over-HTTPS(DoH) support (53 will be used if port is not specified)")


setuid :: String -> IO ()
setuid :: String -> IO ()
setuid String
user = String -> IO UserEntry
getUserEntryForName String
user forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= UserID -> IO ()
setUserID forall b c a. (b -> c) -> (a -> b) -> a -> c
. UserEntry -> UserID
userID

-- | Read 'Config' from command line arguments
getConfig :: IO Config
getConfig :: IO Config
getConfig = forall a. ParserInfo a -> IO a
execParser ParserInfo Config
parser

-- | Run HProx in front of fallback 'Application', with specified 'Config'
run :: Application -- ^ fallback application
    -> Config      -- ^ configuration
    -> IO ()
run :: Application -> Config -> IO ()
run Application
fallback Config{Int
[(String, CertFile)]
Maybe String
Maybe HostPreference
_doh :: Maybe String
_rev :: Maybe String
_ws :: Maybe String
_auth :: Maybe String
_user :: Maybe String
_ssl :: [(String, CertFile)]
_port :: Int
_bind :: Maybe HostPreference
_doh :: Config -> Maybe String
_rev :: Config -> Maybe String
_ws :: Config -> Maybe String
_auth :: Config -> Maybe String
_user :: Config -> Maybe String
_ssl :: Config -> [(String, CertFile)]
_port :: Config -> Int
_bind :: Config -> Maybe HostPreference
..} = do

    let certfiles :: [(String, CertFile)]
certfiles = [(String, CertFile)]
_ssl
    [Credential]
certs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (CertFile -> IO Credential
readCertforall b c a. (b -> c) -> (a -> b) -> a -> c
.forall a b. (a, b) -> b
snd) [(String, CertFile)]
certfiles

    let isSSL :: Bool
isSSL = Bool -> Bool
not (forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(String, CertFile)]
certfiles)
        (String
primaryHost, CertFile
primaryCert) = forall a. [a] -> a
head [(String, CertFile)]
certfiles
        otherCerts :: [(String, Credential)]
otherCerts = forall a. [a] -> [a]
tail forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(String, CertFile)]
certfiles) [Credential]
certs

        settings :: Settings
settings = HostPreference -> Settings -> Settings
setHost (forall a. a -> Maybe a -> a
fromMaybe HostPreference
"*6" Maybe HostPreference
_bind) forall a b. (a -> b) -> a -> b
$
                   Int -> Settings -> Settings
setPort Int
_port forall a b. (a -> b) -> a -> b
$
                   (Maybe Request -> SomeException -> IO ()) -> Settings -> Settings
setOnException (\Maybe Request
_ SomeException
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()) forall a b. (a -> b) -> a -> b
$
                   Bool -> Settings -> Settings
setNoParsePath Bool
True forall a b. (a -> b) -> a -> b
$
                   ByteString -> Settings -> Settings
setServerName ByteString
"Apache" forall a b. (a -> b) -> a -> b
$
                   forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. a -> a
id (IO () -> Settings -> Settings
setBeforeMainLoop forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> IO ()
setuid) Maybe String
_user
                   Settings
defaultSettings

        tlsset' :: TLSSettings
tlsset' = String -> String -> TLSSettings
tlsSettings (CertFile -> String
certfile CertFile
primaryCert) (CertFile -> String
keyfile CertFile
primaryCert)
        hooks :: ServerHooks
hooks = (TLSSettings -> ServerHooks
tlsServerHooks TLSSettings
tlsset') { onServerNameIndication :: Maybe String -> IO Credentials
onServerNameIndication = forall {m :: * -> *}. MonadFail m => Maybe String -> m Credentials
onSNI }
        tlsset :: TLSSettings
tlsset = TLSSettings
tlsset' { tlsServerHooks :: ServerHooks
tlsServerHooks = ServerHooks
hooks, onInsecure :: OnInsecure
onInsecure = OnInsecure
AllowInsecure }

        onSNI :: Maybe String -> m Credentials
onSNI Maybe String
Nothing = forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SNI: unspecified"
        onSNI (Just String
host)
          | String -> String -> Bool
checkSNI String
host String
primaryHost = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Monoid a => a
mempty
          | Bool
otherwise                 = forall {m :: * -> *}.
MonadFail m =>
String -> [(String, Credential)] -> m Credentials
lookupSNI String
host [(String, Credential)]
otherCerts

        lookupSNI :: String -> [(String, Credential)] -> m Credentials
lookupSNI String
host [] = forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String
"SNI: unknown hostname (" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show String
host forall a. [a] -> [a] -> [a]
++ String
")")
        lookupSNI String
host ((String
p, Credential
cert) : [(String, Credential)]
cs)
          | String -> String -> Bool
checkSNI String
host String
p = forall (m :: * -> *) a. Monad m => a -> m a
return ([Credential] -> Credentials
TLS.Credentials [Credential
cert])
          | Bool
otherwise       = String -> [(String, Credential)] -> m Credentials
lookupSNI String
host [(String, Credential)]
cs

        checkSNI :: String -> String -> Bool
checkSNI String
host String
pat = case String
pat of
            Char
'*' : Char
'.' : String
p -> (Char
'.' forall a. a -> [a] -> [a]
: String
p) forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` String
host
            String
p             -> String
host forall a. Eq a => a -> a -> Bool
== String
p

        runner :: Settings -> Application -> IO ()
runner | Bool
isSSL     = TLSSettings -> Settings -> Application -> IO ()
runTLS TLSSettings
tlsset
               | Bool
otherwise = Settings -> Application -> IO ()
runSettings

    Maybe (ByteString -> Bool)
pauth <- case Maybe String
_auth of
        Maybe String
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
        Just String
f  -> forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> Maybe Int
BS8.elemIndex Char
':') forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BS8.lines forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> IO ByteString
BS8.readFile String
f
    Manager
manager <- forall (m :: * -> *). MonadIO m => m Manager
newTlsManager

    let pset :: ProxySettings
pset = Maybe (ByteString -> Bool)
-> Maybe ByteString
-> Maybe ByteString
-> Maybe ByteString
-> ProxySettings
ProxySettings Maybe (ByteString -> Bool)
pauth forall a. Maybe a
Nothing (String -> ByteString
BS8.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe String
_ws) (String -> ByteString
BS8.pack forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe String
_rev)
        proxy :: Application
proxy = (if Bool
isSSL then ProxySettings -> Middleware
forceSSL ProxySettings
pset else forall a. a -> a
id) forall a b. (a -> b) -> a -> b
$
                (Response -> Response) -> Middleware
modifyResponse ([ByteString] -> Response -> Response
stripHeaders [ByteString
"Server", ByteString
"Date"]) forall a b. (a -> b) -> a -> b
$
                GzipSettings -> Middleware
gzip forall a. Default a => a
def forall a b. (a -> b) -> a -> b
$
                ProxySettings -> Manager -> Middleware
httpProxy ProxySettings
pset Manager
manager forall a b. (a -> b) -> a -> b
$
                ProxySettings -> Manager -> Middleware
reverseProxy ProxySettings
pset Manager
manager Application
fallback

    case Maybe String
_doh of
        Maybe String
Nothing  -> Settings -> Application -> IO ()
runner Settings
settings Application
proxy
        Just String
doh -> forall a. String -> (Resolver -> IO a) -> IO a
createResolver String
doh (\Resolver
resolver -> Settings -> Application -> IO ()
runner Settings
settings (Resolver -> Middleware
dnsOverHTTPS Resolver
resolver Application
proxy))