-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.

module Network.HProx.DoH
  ( createResolver
  , dnsOverHTTPS
  ) where

import Data.ByteString.Base64.URL qualified as Base64
import Data.ByteString.Char8      qualified as BS8
import Data.ByteString.Lazy       qualified as LBS
import Network.DNS
    (DNSHeader(..), DNSMessage(..), Question(..), ResolvConf(..), Resolver)
import Network.DNS                qualified as DNS
import Network.HTTP.Types         qualified as HT

import Network.Wai

import Network.HProx.Util

createResolver :: String -> (Resolver -> IO a) -> IO a
createResolver :: forall a. String -> (Resolver -> IO a) -> IO a
createResolver String
remote Resolver -> IO a
handle = do
    ResolvSeed
seed <- ResolvConf -> IO ResolvSeed
DNS.makeResolvSeed ResolvConf
conf
    ResolvSeed -> (Resolver -> IO a) -> IO a
forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
DNS.withResolver ResolvSeed
seed Resolver -> IO a
handle
  where
    (Method
h, Int
p) = Int -> Method -> (Method, Int)
parseHostPortWithDefault Int
53 (String -> Method
BS8.pack String
remote)
    info :: FileOrNumericHost
info = String -> PortNumber -> FileOrNumericHost
DNS.RCHostPort (Method -> String
BS8.unpack Method
h) (Int -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p)

    conf :: ResolvConf
conf = ResolvConf
DNS.defaultResolvConf { resolvInfo = info }

dnsOverHTTPS :: Resolver -> Middleware
dnsOverHTTPS :: Resolver -> Middleware
dnsOverHTTPS Resolver
resolver Application
fallback Request
req Response -> IO ResponseReceived
respond
    | Request -> [Text]
pathInfo Request
req [Text] -> [Text] -> Bool
forall a. Eq a => a -> a -> Bool
== [Text
"dns-query"] Bool -> Bool -> Bool
&& Request -> Bool
isSecure Request
req = Resolver -> Application
handleDoH Resolver
resolver Request
req Response -> IO ResponseReceived
respond
    | Bool
otherwise = Application
fallback Request
req Response -> IO ResponseReceived
respond

handleDoH :: Resolver -> Application
handleDoH :: Resolver -> Application
handleDoH Resolver
resolver Request
req Response -> IO ResponseReceived
respond
    | Request -> Method
requestMethod Request
req Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== Method
"GET",
      [(Method
"dns", Just Method
dnsStr)] <- Request -> [QueryItem]
queryString Request
req,
      Right Method
dnsQuery <- Method -> Either String Method
Base64.decodeUnpadded Method
dnsStr,
      Right (DNSMessage { question :: DNSMessage -> [Question]
question = [Question
q], header :: DNSMessage -> DNSHeader
header = DNSHeader {Identifier
DNSFlags
identifier :: Identifier
flags :: DNSFlags
identifier :: DNSHeader -> Identifier
flags :: DNSHeader -> DNSFlags
..} }) <- Method -> Either DNSError DNSMessage
DNS.decode Method
dnsQuery =
        Identifier -> Question -> IO ResponseReceived
handleQuery Identifier
identifier Question
q
    | Request -> Method
requestMethod Request
req Method -> Method -> Bool
forall a. Eq a => a -> a -> Bool
== Method
"POST",
      KnownLength Word64
len <- Request -> RequestBodyLength
requestBodyLength Request
req,
      Word64
len Word64 -> Word64 -> Bool
forall a. Ord a => a -> a -> Bool
<= Word64
4096 = do
        Method
dnsQuery <- Request -> IO Method
getRequestBodyChunk Request
req
        case Method -> Either DNSError DNSMessage
DNS.decode Method
dnsQuery of
            Right (DNSMessage { question :: DNSMessage -> [Question]
question = [Question
q], header :: DNSMessage -> DNSHeader
header = DNSHeader {Identifier
DNSFlags
identifier :: DNSHeader -> Identifier
flags :: DNSHeader -> DNSFlags
identifier :: Identifier
flags :: DNSFlags
..} }) -> Identifier -> Question -> IO ResponseReceived
handleQuery Identifier
identifier Question
q
            Either DNSError DNSMessage
_otherwise                                                     -> Response -> IO ResponseReceived
respond Response
errorResp
    | Bool
otherwise = Response -> IO ResponseReceived
respond Response
errorResp
  where
    errorResp :: Response
errorResp = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
HT.status400 [(HeaderName
"Content-Type", Method
"text/plain")] ByteString
"invalid dns-over-https request"

    handleQuery :: Identifier -> Question -> IO ResponseReceived
handleQuery Identifier
ident Question{Method
TYPE
qname :: Method
qtype :: TYPE
qname :: Question -> Method
qtype :: Question -> TYPE
..} = do
        Either DNSError DNSMessage
resp <- Resolver -> Method -> TYPE -> IO (Either DNSError DNSMessage)
DNS.lookupRaw Resolver
resolver Method
qname TYPE
qtype
        Response -> IO ResponseReceived
respond (Response -> IO ResponseReceived)
-> Response -> IO ResponseReceived
forall a b. (a -> b) -> a -> b
$ case Either DNSError DNSMessage
resp of
            Left DNSError
_ -> Response
errorResp
            Right dnsResp :: DNSMessage
dnsResp@DNSMessage{header :: DNSMessage -> DNSHeader
header = DNSHeader
header} ->
                let encoded :: Method
encoded = DNSMessage -> Method
DNS.encode (DNSMessage
dnsResp {header = header {identifier = ident} }) in
                    Status -> ResponseHeaders -> ByteString -> Response
responseKnownLength Status
HT.status200
                        [(HeaderName
"Content-Type", Method
"application/dns-message")]
                        (Method -> ByteString
LBS.fromStrict Method
encoded)