{-# LANGUAGE CApiFFI #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE Trustworthy #-}
module Network.DNS
(
queryA
, queryAAAA
, queryCNAME
, querySRV
, queryTXT
, query
, DnsException(..)
, resIsReentrant
, queryRaw
, sendRaw
, mkQueryRaw
, decodeMessage
, encodeMessage
, mkQueryMsg
, Label
, Labels(..)
, IsLabels(..)
, Name(..)
, caseFoldName
, CharStr(..)
, IPv4(..)
, IPv6(..)
, TTL(..)
, Class(..)
, classIN
, Type(..)
, TypeSym(..)
, typeFromSym
, typeToSym
, Msg(..)
, MsgHeader(..)
, MsgHeaderFlags(..), QR(..)
, MsgQuestion(..)
, MsgRR(..)
, RData(..)
, rdType
, SRV(..)
)
where
import Control.Exception
import Data.Typeable (Typeable)
import Foreign.C
import Foreign.Marshal.Alloc
import Prelude
import qualified Data.ByteString as BS
import Compat
import Network.DNS.FFI
import Network.DNS.Message
data DnsException = DnsEncodeException
| DnsDecodeException
deriving (Show, Typeable)
instance Exception DnsException
query :: IsLabels n => Class -> n -> TypeSym -> IO (Msg n)
query cls name0 qtype
| Just name <- toName name0 = do
bs <- queryRaw cls name (typeFromSym qtype)
msg <- evaluate (decodeMessage bs)
maybe (throwIO DnsDecodeException) pure msg
| otherwise = throwIO DnsEncodeException
queryRaw :: Class -> Name -> Type -> IO BS.ByteString
queryRaw (Class cls) (Name name) qtype = withCResState $ \stptr -> do
allocaBytes max_msg_size $ \resptr -> do
_ <- c_memset resptr 0 max_msg_size
BS.useAsCString name $ \dn -> do
rc1 <- c_res_opt_set_use_dnssec stptr
unless (rc1 == 0) $
fail "res_init(3) failed"
resetErrno
reslen <- c_res_query stptr dn (fromIntegral cls) qtypeVal resptr max_msg_size
unless (reslen <= max_msg_size) $
fail "res_query(3) message size overflow"
errno <- getErrno
when (reslen < 0) $ do
unless (errno == eOK) $
throwErrno "res_query"
fail "res_query(3) failed"
BS.packCStringLen (resptr, fromIntegral reslen)
where
max_msg_size :: Num a => a
max_msg_size = 0x10000
qtypeVal :: CInt
qtypeVal = case qtype of Type w -> fromIntegral w
sendRaw :: BS.ByteString -> IO BS.ByteString
sendRaw req = withCResState $ \stptr -> do
allocaBytes max_msg_size $ \resptr -> do
_ <- c_memset resptr 0 max_msg_size
BS.useAsCStringLen req $ \(reqptr,reqlen) -> do
rc1 <- c_res_opt_set_use_dnssec stptr
unless (rc1 == 0) $
fail "res_init(3) failed"
resetErrno
reslen <- c_res_send stptr reqptr (fromIntegral reqlen) resptr max_msg_size
unless (reslen <= max_msg_size) $
fail "res_send(3) message size overflow"
errno <- getErrno
when (reslen < 0) $ do
unless (errno == eOK) $
throwErrno "res_send"
fail "res_send(3) failed"
BS.packCStringLen (resptr, fromIntegral reslen)
where
max_msg_size :: Num a => a
max_msg_size = 0x10000
mkQueryMsg :: IsLabels n => Class -> n -> Type -> Msg n
mkQueryMsg cls l qtype = Msg (MsgHeader{..})
[MsgQuestion l qtype cls]
[]
[]
[MsgRR {..}]
where
mhId = 31337
mhFlags = MsgHeaderFlags
{ mhQR = IsQuery
, mhOpcode = 0
, mhAA = False
, mhTC = False
, mhRD = True
, mhRA = False
, mhZ = False
, mhAD = True
, mhCD = False
, mhRCode = 0
}
mhQDCount = 1
mhANCount = 0
mhNSCount = 0
mhARCount = 1
rrName = fromLabels Root
rrClass = Class 512
rrTTL = TTL 0x8000
rrData = RDataOPT ""
mkQueryRaw :: Class -> Name -> Type -> IO BS.ByteString
mkQueryRaw (Class cls) (Name name) qtype = withCResState $ \stptr -> do
allocaBytes max_msg_size $ \resptr -> do
_ <- c_memset resptr 0 max_msg_size
BS.useAsCString name $ \dn -> do
rc1 <- c_res_opt_set_use_dnssec stptr
unless (rc1 == 0) $
fail "res_init(3) failed"
resetErrno
reslen <- c_res_mkquery stptr dn (fromIntegral cls) qtypeVal resptr max_msg_size
unless (reslen <= max_msg_size) $
fail "res_mkquery(3) message size overflow"
errno <- getErrno
when (reslen < 0) $ do
unless (errno == eOK) $
throwErrno "res_query"
fail "res_mkquery(3) failed"
BS.packCStringLen (resptr, fromIntegral reslen)
where
max_msg_size :: Num a => a
max_msg_size = 0x10000
qtypeVal :: CInt
qtypeVal = case qtype of Type w -> fromIntegral w
caseFoldName :: Name -> Name
caseFoldName (Name n) = (Name n'')
where
n' = BS.map cf n
n'' | BS.null n' = "."
| BS.last n' == 0x2e = n'
| otherwise = n' `mappend` "."
cf w | 0x61 <= w && w <= 0x7a = w - 0x20
| otherwise = w
queryA :: Name -> IO [(TTL,IPv4)]
queryA n = do
res <- query classIN n' TypeA
pure [ (ttl,ip4) | MsgRR { rrData = RDataA ip4, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryAAAA :: Name -> IO [(TTL,IPv6)]
queryAAAA n = do
res <- query classIN n' TypeAAAA
pure [ (ttl,ip6) | MsgRR { rrData = RDataAAAA ip6, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryCNAME :: Name -> IO [(TTL,Name)]
queryCNAME n = do
res <- query classIN n' TypeAAAA
pure [ (ttl,cname) | MsgRR { rrData = RDataCNAME cname, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
queryTXT :: Name -> IO [(TTL,[CharStr])]
queryTXT n = do
res <- query classIN n' TypeTXT
pure [ (ttl,txts) | MsgRR { rrData = RDataTXT txts, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n
querySRV :: Name -> IO [(TTL,SRV Name)]
querySRV n = do
res <- query classIN n' TypeSRV
pure [ (ttl,srv) | MsgRR { rrData = RDataSRV srv, rrTTL = ttl, rrName = n1, rrClass = Class 1 } <- msgAN res, caseFoldName n1 == n' ]
where
n' = caseFoldName n