-- SPDX-License-Identifier: Apache-2.0
--
-- Copyright (C) 2023 Bin Jin. All Rights Reserved.
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards   #-}

module Network.HProx.DoH
  ( createResolver
  , dnsOverHTTPS
  ) where

import qualified Data.ByteString.Base64.URL as Base64
import qualified Data.ByteString.Char8      as BS8
import qualified Data.ByteString.Lazy       as LBS
import           Network.DNS                (DNSHeader (..), DNSMessage (..),
                                             Question (..), ResolvConf (..),
                                             Resolver)
import qualified Network.DNS                as DNS
import qualified Network.HTTP.Types         as HT

import           Network.Wai

import           Network.HProx.Util

createResolver :: String -> (Resolver -> IO a) -> IO a
createResolver :: forall a. String -> (Resolver -> IO a) -> IO a
createResolver String
remote Resolver -> IO a
handle = do
    ResolvSeed
seed <- ResolvConf -> IO ResolvSeed
DNS.makeResolvSeed ResolvConf
conf
    forall a. ResolvSeed -> (Resolver -> IO a) -> IO a
DNS.withResolver ResolvSeed
seed Resolver -> IO a
handle
  where
    (Method
h, Int
p) = Int -> Method -> (Method, Int)
parseHostPortWithDefault Int
53 (String -> Method
BS8.pack String
remote)
    info :: FileOrNumericHost
info = String -> PortNumber -> FileOrNumericHost
DNS.RCHostPort (Method -> String
BS8.unpack Method
h) (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
p)

    conf :: ResolvConf
conf = ResolvConf
DNS.defaultResolvConf { resolvInfo :: FileOrNumericHost
resolvInfo = FileOrNumericHost
info }

dnsOverHTTPS :: Resolver -> Middleware
dnsOverHTTPS :: Resolver -> Middleware
dnsOverHTTPS Resolver
resolver Application
fallback Request
req Response -> IO ResponseReceived
respond
    | Request -> [Text]
pathInfo Request
req forall a. Eq a => a -> a -> Bool
== [Text
"dns-query"] Bool -> Bool -> Bool
&& Request -> Bool
isSecure Request
req = Resolver -> Application
handleDoH Resolver
resolver Request
req Response -> IO ResponseReceived
respond
    | Bool
otherwise = Application
fallback Request
req Response -> IO ResponseReceived
respond

handleDoH :: Resolver -> Application
handleDoH :: Resolver -> Application
handleDoH Resolver
resolver Request
req Response -> IO ResponseReceived
respond
    | Request -> Method
requestMethod Request
req forall a. Eq a => a -> a -> Bool
== Method
"GET",
      [(Method
"dns", Just Method
dnsStr)] <- Request -> [QueryItem]
queryString Request
req,
      Right Method
dnsQuery <- Method -> Either String Method
Base64.decodeUnpadded Method
dnsStr,
      Right (DNSMessage { question :: DNSMessage -> [Question]
question = [Question
q], header :: DNSMessage -> DNSHeader
header = DNSHeader {Identifier
DNSFlags
flags :: DNSHeader -> DNSFlags
identifier :: DNSHeader -> Identifier
flags :: DNSFlags
identifier :: Identifier
..} }) <- Method -> Either DNSError DNSMessage
DNS.decode Method
dnsQuery =
        Identifier -> Question -> IO ResponseReceived
handleQuery Identifier
identifier Question
q
    | Request -> Method
requestMethod Request
req forall a. Eq a => a -> a -> Bool
== Method
"POST",
      KnownLength Word64
len <- Request -> RequestBodyLength
requestBodyLength Request
req,
      Word64
len forall a. Ord a => a -> a -> Bool
<= Word64
4096 = do
        Method
dnsQuery <- Request -> IO Method
getRequestBodyChunk Request
req
        case Method -> Either DNSError DNSMessage
DNS.decode Method
dnsQuery of
            Right (DNSMessage { question :: DNSMessage -> [Question]
question = [Question
q], header :: DNSMessage -> DNSHeader
header = DNSHeader {Identifier
DNSFlags
flags :: DNSFlags
identifier :: Identifier
flags :: DNSHeader -> DNSFlags
identifier :: DNSHeader -> Identifier
..} }) -> Identifier -> Question -> IO ResponseReceived
handleQuery Identifier
identifier Question
q
            Either DNSError DNSMessage
_ -> Response -> IO ResponseReceived
respond Response
errorResp
    | Bool
otherwise = Response -> IO ResponseReceived
respond Response
errorResp
  where
    errorResp :: Response
errorResp = Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
HT.status400 [(HeaderName
"Content-Type", Method
"text/plain")] ByteString
"invalid dns-over-https request"

    handleQuery :: Identifier -> Question -> IO ResponseReceived
handleQuery Identifier
ident Question{Method
TYPE
qtype :: Question -> TYPE
qname :: Question -> Method
qtype :: TYPE
qname :: Method
..} = do
        Either DNSError DNSMessage
resp <- Resolver -> Method -> TYPE -> IO (Either DNSError DNSMessage)
DNS.lookupRaw Resolver
resolver Method
qname TYPE
qtype
        Response -> IO ResponseReceived
respond forall a b. (a -> b) -> a -> b
$ case Either DNSError DNSMessage
resp of
            Left DNSError
_ -> Response
errorResp
            Right dnsResp :: DNSMessage
dnsResp@DNSMessage{header :: DNSMessage -> DNSHeader
header = DNSHeader
header} ->
                let encoded :: Method
encoded = DNSMessage -> Method
DNS.encode (DNSMessage
dnsResp {header :: DNSHeader
header = DNSHeader
header {identifier :: Identifier
identifier = Identifier
ident} }) in
                    Status -> ResponseHeaders -> ByteString -> Response
responseLBS Status
HT.status200
                        [(HeaderName
"Content-Type", Method
"application/dns-message"),
                         (HeaderName
"Content-Length", String -> Method
BS8.pack forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> String
show (Method -> Int
BS8.length Method
encoded))]
                        (Method -> ByteString
LBS.fromStrict Method
encoded)