{-# LANGUAGE OverloadedStrings #-}
module Network.DNS.Decode.Internal (
getResponse
, getDNSFlags
, getHeader
, getResourceRecord
, getResourceRecords
, getDomain
, getMailbox
) where
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BS
import Data.IP (IP(..), toIPv4, toIPv6b)
import qualified Safe
import Network.DNS.Imports
import Network.DNS.StateBinary
import Network.DNS.Types
getResponse :: SGet DNSMessage
getResponse = do
hd <- getHeader
qdCount <- getInt16
anCount <- getInt16
nsCount <- getInt16
arCount <- getInt16
DNSMessage hd <$> getQueries qdCount
<*> getResourceRecords anCount
<*> getResourceRecords nsCount
<*> getResourceRecords arCount
getDNSFlags :: SGet DNSFlags
getDNSFlags = do
word <- get16
maybe (fail $ "Unsupported flags: 0x" ++ showHex word "") pure (toFlags word)
where
toFlags :: Word16 -> Maybe DNSFlags
toFlags flgs = do
oc <- getOpcode flgs
let rc = getRcode flgs
return $ DNSFlags (getQorR flgs)
oc
(getAuthAnswer flgs)
(getTrunCation flgs)
(getRecDesired flgs)
(getRecAvailable flgs)
rc
(getAuthenData flgs)
getQorR w = if testBit w 15 then QR_Response else QR_Query
getOpcode w = Safe.toEnumMay (fromIntegral (shiftR w 11 .&. 0x0f))
getAuthAnswer w = testBit w 10
getTrunCation w = testBit w 9
getRecDesired w = testBit w 8
getRecAvailable w = testBit w 7
getRcode w = toRCODEforHeader $ fromIntegral w
getAuthenData w = testBit w 5
getHeader :: SGet DNSHeader
getHeader =
DNSHeader <$> decodeIdentifier <*> getDNSFlags
where
decodeIdentifier = get16
getQueries :: Int -> SGet [Question]
getQueries n = replicateM n getQuery
getTYPE :: SGet TYPE
getTYPE = toTYPE <$> get16
getOptCode :: SGet OptCode
getOptCode = toOptCode <$> get16
getQuery :: SGet Question
getQuery = Question <$> getDomain
<*> getTYPE
<* ignoreClass
getResourceRecords :: Int -> SGet [ResourceRecord]
getResourceRecords n = replicateM n getResourceRecord
getResourceRecord :: SGet ResourceRecord
getResourceRecord = do
dom <- getDomain
typ <- getTYPE
cls <- decodeCLASS
ttl <- decodeTTL
len <- decodeRLen
dat <- getRData typ len
return $ ResourceRecord dom typ cls ttl dat
where
decodeCLASS = get16
decodeTTL = get32
decodeRLen = getInt16
getRData :: TYPE -> Int -> SGet RData
getRData NS _ = RD_NS <$> getDomain
getRData MX _ = RD_MX <$> decodePreference <*> getDomain
where
decodePreference = get16
getRData CNAME _ = RD_CNAME <$> getDomain
getRData DNAME _ = RD_DNAME <$> getDomain
getRData TXT len = (RD_TXT . ignoreLength) <$> getNByteString len
where
ignoreLength = BS.drop 1
getRData A len
| len == 4 = (RD_A . toIPv4) <$> getNBytes len
| otherwise = fail "IPv4 addresses must be 4 bytes long"
getRData AAAA len
| len == 16 = (RD_AAAA . toIPv6b) <$> getNBytes len
| otherwise = fail "IPv6 addresses must be 16 bytes long"
getRData SOA _ = RD_SOA <$> getDomain
<*> getMailbox
<*> decodeSerial
<*> decodeRefesh
<*> decodeRetry
<*> decodeExpire
<*> decodeMinimum
where
decodeSerial = get32
decodeRefesh = get32
decodeRetry = get32
decodeExpire = get32
decodeMinimum = get32
getRData PTR _ = RD_PTR <$> getDomain
getRData SRV _ = RD_SRV <$> decodePriority
<*> decodeWeight
<*> decodePort
<*> getDomain
where
decodePriority = get16
decodeWeight = get16
decodePort = get16
getRData OPT ol = RD_OPT <$> decode' ol
where
decode' :: Int -> SGet [OData]
decode' l
| l < 0 = fail $ "decodeOPTData: length inconsistency (" ++ show l ++ ")"
| l == 0 = pure []
| otherwise = do
optCode <- getOptCode
optLen <- getInt16
dat <- getOData optCode optLen
(dat:) <$> decode' (l - optLen - 4)
getRData TLSA len = RD_TLSA <$> decodeUsage
<*> decodeSelector
<*> decodeMType
<*> decodeADF
where
decodeUsage = get8
decodeSelector = get8
decodeMType = get8
decodeADF = getNByteString (len - 3)
getRData DS len = RD_DS <$> decodeTag
<*> decodeAlg
<*> decodeDtyp
<*> decodeDval
where
decodeTag = get16
decodeAlg = get8
decodeDtyp = get8
decodeDval = getNByteString (len - 4)
getRData NULL len = const RD_NULL <$> getNByteString len
getRData DNSKEY len = RD_DNSKEY <$> decodeKeyFlags
<*> decodeKeyProto
<*> decodeKeyAlg
<*> decodeKeyBytes
where
decodeKeyFlags = get16
decodeKeyProto = get8
decodeKeyAlg = get8
decodeKeyBytes = getNByteString (len - 4)
getRData NSEC3PARAM len = RD_NSEC3PARAM <$> decodeHashAlg
<*> decodeFlags
<*> decodeIterations
<*> decodeSalt
where
decodeHashAlg = get8
decodeFlags = get8
decodeIterations = get16
decodeSalt = do
let n = len - 5
slen <- get8
guard $ fromIntegral slen == n
if (n == 0)
then return B.empty
else getNByteString n
getRData _ len = UnknownRData <$> getNByteString len
getOData :: OptCode -> Int -> SGet OData
getOData ClientSubnet len = do
fam <- getInt16
srcMask <- get8
scpMask <- get8
rawip <- fmap fromIntegral . B.unpack <$> getNByteString (len - 4)
ip <- case fam of
1 -> pure . IPv4 . toIPv4 $ take 4 (rawip ++ repeat 0)
2 -> pure . IPv6 . toIPv6b $ take 16 (rawip ++ repeat 0)
_ -> fail "Unsupported address family"
pure $ OD_ClientSubnet srcMask scpMask ip
getOData opc len = UnknownOData opc <$> getNByteString len
getDomain :: SGet Domain
getDomain = do
lim <- B.length <$> getInput
getDomain' '.' lim 0
getMailbox :: SGet Mailbox
getMailbox = do
lim <- B.length <$> getInput
getDomain' '@' lim 0
getDomain' :: Char -> Int -> Int -> SGet ByteString
getDomain' sep1 lim loopcnt
| loopcnt >= 127 = fail "pointer recursion limit exceeded"
| otherwise = do
pos <- getPosition
c <- getInt8
let n = getValue c
getdomain pos c n
where
getdomain pos c n
| c == 0 = return "."
| isPointer c = do
d <- getInt8
let offset = n * 256 + d
when (offset >= lim) $ fail "pointer is too large"
mo <- pop offset
case mo of
Nothing -> do
target <- B.drop offset <$> getInput
case runSGet (getDomain' sep1 lim (loopcnt + 1)) target of
Left (DecodeError err) -> fail err
Left err -> fail $ show err
Right o -> push pos (fst o) >> return (fst o)
Just o -> push pos o >> return o
| isExtLabel c = return ""
| otherwise = do
hs <- getNByteString n
ds <- getDomain' '.' lim (loopcnt + 1)
let dom = case ds of
"." -> hs `BS.append` "."
_ -> hs `BS.append` BS.singleton sep1 `BS.append` ds
push pos dom
return dom
getValue c = c .&. 0x3f
isPointer c = testBit c 7 && testBit c 6
isExtLabel c = not (testBit c 7) && testBit c 6
ignoreClass :: SGet ()
ignoreClass = () <$ get16