{-# LANGUAGE OverloadedStrings #-}

module Network.DNS.Decode.Internal (
    getResponse
  , getDNSFlags
  , getHeader
  , getResourceRecord
  , getResourceRecords
  , getDomain
  , getMailbox
  ) where

import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as BS
import Data.IP (IP(..), toIPv4, toIPv6b)
import qualified Safe

import Network.DNS.Imports
import Network.DNS.StateBinary
import Network.DNS.Types

----------------------------------------------------------------

getResponse :: SGet DNSMessage
getResponse = do
    hd <- getHeader
    qdCount <- getInt16
    anCount <- getInt16
    nsCount <- getInt16
    arCount <- getInt16
    DNSMessage hd <$> getQueries qdCount
                  <*> getResourceRecords anCount
                  <*> getResourceRecords nsCount
                  <*> getResourceRecords arCount

----------------------------------------------------------------

getDNSFlags :: SGet DNSFlags
getDNSFlags = do
    word <- get16
    maybe (fail $ "Unsupported flags: 0x" ++ showHex word "") pure (toFlags word)
  where
    toFlags :: Word16 -> Maybe DNSFlags
    toFlags flgs = do
      oc <- getOpcode flgs
      let rc = getRcode flgs
      return $ DNSFlags (getQorR flgs)
                        oc
                        (getAuthAnswer flgs)
                        (getTrunCation flgs)
                        (getRecDesired flgs)
                        (getRecAvailable flgs)
                        rc
                        (getAuthenData flgs)
    getQorR w = if testBit w 15 then QR_Response else QR_Query
    getOpcode w = Safe.toEnumMay (fromIntegral (shiftR w 11 .&. 0x0f))
    getAuthAnswer w = testBit w 10
    getTrunCation w = testBit w 9
    getRecDesired w = testBit w 8
    getRecAvailable w = testBit w 7
    getRcode w = toRCODEforHeader $ fromIntegral w
    getAuthenData w = testBit w 5

----------------------------------------------------------------

getHeader :: SGet DNSHeader
getHeader =
    DNSHeader <$> decodeIdentifier <*> getDNSFlags
  where
    decodeIdentifier = get16

----------------------------------------------------------------

getQueries :: Int -> SGet [Question]
getQueries n = replicateM n getQuery

getTYPE :: SGet TYPE
getTYPE = toTYPE <$> get16

getOptCode :: SGet OptCode
getOptCode = toOptCode <$> get16

getQuery :: SGet Question
getQuery = Question <$> getDomain
                    <*> getTYPE
                    <*  ignoreClass

getResourceRecords :: Int -> SGet [ResourceRecord]
getResourceRecords n = replicateM n getResourceRecord

getResourceRecord :: SGet ResourceRecord
getResourceRecord = do
    dom <- getDomain
    typ <- getTYPE
    cls <- decodeCLASS
    ttl <- decodeTTL
    len <- decodeRLen
    dat <- getRData typ len
    return $ ResourceRecord dom typ cls ttl dat
  where
    decodeCLASS = get16
    decodeTTL   = get32
    decodeRLen  = getInt16

getRData :: TYPE -> Int -> SGet RData
getRData NS _ = RD_NS <$> getDomain
getRData MX _ = RD_MX <$> decodePreference <*> getDomain
  where
    decodePreference = get16
getRData CNAME _ = RD_CNAME <$> getDomain
getRData DNAME _ = RD_DNAME <$> getDomain
getRData TXT len = (RD_TXT . ignoreLength) <$> getNByteString len
  where
    ignoreLength = BS.drop 1
getRData A len
  | len == 4  = (RD_A . toIPv4) <$> getNBytes len
  | otherwise = fail "IPv4 addresses must be 4 bytes long"
getRData AAAA len
  | len == 16 = (RD_AAAA . toIPv6b) <$> getNBytes len
  | otherwise = fail "IPv6 addresses must be 16 bytes long"
getRData SOA _ = RD_SOA    <$> getDomain
                           <*> getMailbox
                           <*> decodeSerial
                           <*> decodeRefesh
                           <*> decodeRetry
                           <*> decodeExpire
                           <*> decodeMinimum
  where
    decodeSerial  = get32
    decodeRefesh  = get32
    decodeRetry   = get32
    decodeExpire  = get32
    decodeMinimum = get32
getRData PTR _ = RD_PTR <$> getDomain
getRData SRV _ = RD_SRV <$> decodePriority
                           <*> decodeWeight
                           <*> decodePort
                           <*> getDomain
  where
    decodePriority = get16
    decodeWeight   = get16
    decodePort     = get16
getRData OPT ol = RD_OPT <$> decode' ol
  where
    decode' :: Int -> SGet [OData]
    decode' l
        | l  < 0 = fail $ "decodeOPTData: length inconsistency (" ++ show l ++ ")"
        | l == 0 = pure []
        | otherwise = do
            optCode <- getOptCode
            optLen <- getInt16
            dat <- getOData optCode optLen
            (dat:) <$> decode' (l - optLen - 4)
--
getRData TLSA len = RD_TLSA <$> decodeUsage
                               <*> decodeSelector
                               <*> decodeMType
                               <*> decodeADF
  where
    decodeUsage    = get8
    decodeSelector = get8
    decodeMType    = get8
    decodeADF      = getNByteString (len - 3)
--
getRData DS len = RD_DS <$> decodeTag
                           <*> decodeAlg
                           <*> decodeDtyp
                           <*> decodeDval
  where
    decodeTag  = get16
    decodeAlg  = get8
    decodeDtyp = get8
    decodeDval = getNByteString (len - 4)
--
getRData NULL len = const RD_NULL <$> getNByteString len
--
getRData DNSKEY len = RD_DNSKEY <$> decodeKeyFlags
                                <*> decodeKeyProto
                                <*> decodeKeyAlg
                                <*> decodeKeyBytes
  where
    decodeKeyFlags  = get16
    decodeKeyProto  = get8
    decodeKeyAlg    = get8
    decodeKeyBytes  = getNByteString (len - 4)
--
getRData NSEC3PARAM len = RD_NSEC3PARAM <$> decodeHashAlg
                                <*> decodeFlags
                                <*> decodeIterations
                                <*> decodeSalt
  where
    decodeHashAlg    = get8
    decodeFlags      = get8
    decodeIterations = get16
    decodeSalt       = do
        let n = len - 5
        slen <- get8
        guard $ fromIntegral slen == n
        if (n == 0)
        then return B.empty
        else getNByteString n
--
getRData _  len = UnknownRData <$> getNByteString len

getOData :: OptCode -> Int -> SGet OData
getOData ClientSubnet len = do
        fam <- getInt16
        srcMask <- get8
        scpMask <- get8
        rawip <- fmap fromIntegral . B.unpack <$> getNByteString (len - 4) -- 4 = 2 + 1 + 1
        ip <- case fam of
                    1 -> pure . IPv4 . toIPv4 $ take 4 (rawip ++ repeat 0)
                    2 -> pure . IPv6 . toIPv6b $ take 16 (rawip ++ repeat 0)
                    _ -> fail "Unsupported address family"
        pure $ OD_ClientSubnet srcMask scpMask ip
getOData opc len = UnknownOData opc <$> getNByteString len

----------------------------------------------------------------

getDomain :: SGet Domain
getDomain = do
    lim <- B.length <$> getInput
    getDomain' '.' lim 0

getMailbox :: SGet Mailbox
getMailbox = do
    lim <- B.length <$> getInput
    getDomain' '@' lim 0

-- | Get a domain name, using sep1 as the separate between the 1st and 2nd
-- label.  Subsequent labels (and always the trailing label) are terminated
-- with a ".".
getDomain' :: Char -> Int -> Int -> SGet ByteString
getDomain' sep1 lim loopcnt
  -- 127 is the logical limitation of pointers.
  | loopcnt >= 127 = fail "pointer recursion limit exceeded"
  | otherwise      = do
      pos <- getPosition
      c <- getInt8
      let n = getValue c
      getdomain pos c n
  where
    getdomain pos c n
      | c == 0 = return "." -- Perhaps the root domain?
      | isPointer c = do
          d <- getInt8
          let offset = n * 256 + d
          when (offset >= lim) $ fail "pointer is too large"
          mo <- pop offset
          case mo of
              Nothing -> do
                  target <- B.drop offset <$> getInput
                  case runSGet (getDomain' sep1 lim (loopcnt + 1)) target of
                        Left (DecodeError err) -> fail err
                        Left err               -> fail $ show err
                        Right o  -> push pos (fst o) >> return (fst o)
              Just o -> push pos o >> return o
      -- As for now, extended labels have no use.
      -- This may change some time in the future.
      | isExtLabel c = return ""
      | otherwise = do
          hs <- getNByteString n
          ds <- getDomain' '.' lim (loopcnt + 1)
          let dom = case ds of -- avoid trailing ".."
                  "." -> hs `BS.append` "."
                  _   -> hs `BS.append` BS.singleton sep1 `BS.append` ds
          push pos dom
          return dom
    getValue c = c .&. 0x3f
    isPointer c = testBit c 7 && testBit c 6
    isExtLabel c = not (testBit c 7) && testBit c 6

ignoreClass :: SGet ()
ignoreClass = () <$ get16