module Network.DNS.Resolver (
FileOrNumericHost(..), ResolvConf(..), defaultResolvConf
, ResolvSeed, makeResolvSeed
, Resolver(..), withResolver, withResolvers
, lookup, lookupAuth, lookupRaw
) where
import Control.Applicative
import Control.Exception
import Data.Char
import Data.Int
import Data.List hiding (find, lookup)
import Network.BSD
import Network.DNS.Decode
import Network.DNS.Encode
import Network.DNS.Internal
import Network.Socket hiding (send, sendTo, recv, recvFrom)
import Network.Socket.ByteString.Lazy
import Prelude hiding (lookup)
import System.Random
import System.Timeout
#if mingw32_HOST_OS == 1
import Network.Socket (send)
import qualified Data.ByteString.Lazy.Char8 as LB
import Control.Monad (when)
#endif
data FileOrNumericHost = RCFilePath FilePath | RCHostName HostName
data ResolvConf = ResolvConf {
resolvInfo :: FileOrNumericHost
, resolvTimeout :: Int
, resolvRetry :: Int
, resolvBufsize :: Integer
}
defaultResolvConf :: ResolvConf
defaultResolvConf = ResolvConf {
resolvInfo = RCFilePath "/etc/resolv.conf"
, resolvTimeout = 3 * 1000 * 1000
, resolvRetry = 5
, resolvBufsize = 512
}
data ResolvSeed = ResolvSeed {
addrInfo :: AddrInfo
, rsTimeout :: Int
, rsRetry :: Int
, rsBufsize :: Integer
}
data Resolver = Resolver {
genId :: IO Int
, dnsSock :: Socket
, dnsTimeout :: Int
, dnsRetry :: Int
, dnsBufsize :: Integer
}
makeResolvSeed :: ResolvConf -> IO ResolvSeed
makeResolvSeed conf = ResolvSeed <$> addr
<*> pure (resolvTimeout conf)
<*> pure (resolvRetry conf)
<*> pure (resolvBufsize conf)
where
addr = case resolvInfo conf of
RCHostName numhost -> makeAddrInfo numhost
RCFilePath file -> toAddr <$> readFile file >>= makeAddrInfo
toAddr cs = let l:_ = filter ("nameserver" `isPrefixOf`) $ lines cs
in extract l
extract = reverse . dropWhile isSpace . reverse . dropWhile isSpace . drop 11
makeAddrInfo :: HostName -> IO AddrInfo
makeAddrInfo addr = do
proto <- getProtocolNumber "udp"
let hints = defaultHints {
addrFlags = [AI_ADDRCONFIG, AI_NUMERICHOST, AI_PASSIVE]
, addrSocketType = Datagram
, addrProtocol = proto
}
a:_ <- getAddrInfo (Just hints) (Just addr) (Just "domain")
return a
withResolver :: ResolvSeed -> (Resolver -> IO a) -> IO a
withResolver seed func = bracket (openSocket seed) sClose $ \sock -> do
connectSocket sock seed
func $ makeResolver seed sock
withResolvers :: [ResolvSeed] -> ([Resolver] -> IO a) -> IO a
withResolvers seeds func = bracket openSockets closeSockets $ \socks -> do
mapM_ (uncurry connectSocket) $ zip socks seeds
let resolvs = map (uncurry makeResolver) $ zip seeds socks
func resolvs
where
openSockets = mapM openSocket seeds
closeSockets = mapM sClose
openSocket :: ResolvSeed -> IO Socket
openSocket seed = socket (addrFamily ai) (addrSocketType ai) (addrProtocol ai)
where
ai = addrInfo seed
connectSocket :: Socket -> ResolvSeed -> IO ()
connectSocket sock seed = connect sock (addrAddress ai)
where
ai = addrInfo seed
makeResolver :: ResolvSeed -> Socket -> Resolver
makeResolver seed sock = Resolver {
genId = getRandom
, dnsSock = sock
, dnsTimeout = rsTimeout seed
, dnsRetry = rsRetry seed
, dnsBufsize = rsBufsize seed
}
getRandom :: IO Int
getRandom = getStdRandom (randomR (0,65535))
lookupSection :: (DNSFormat -> [ResourceRecord])
-> Resolver
-> Domain
-> TYPE
-> IO (Either DNSError [RDATA])
lookupSection section rlv dom typ = (>>= toRDATA) <$> lookupRaw rlv dom typ
where
correct r = rrtype r == typ
toRDATA = Right . map rdata . filter correct . section
lookup :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RDATA])
lookup = lookupSection answer
lookupAuth :: Resolver -> Domain -> TYPE -> IO (Either DNSError [RDATA])
lookupAuth = lookupSection authority
lookupRaw :: Resolver -> Domain -> TYPE -> IO (Either DNSError DNSFormat)
lookupRaw rlv dom typ = do
seqno <- genId rlv
let query = composeQuery seqno [q]
checkSeqno = check seqno
loop query checkSeqno 0 False
where
loop query checkSeqno cnt mismatch
| cnt == retry = do
let ret | mismatch = SequenceNumberMismatch
| otherwise = TimeoutExpired
return $ Left ret
| otherwise = do
sendAll sock query
response <- timeout tm (receive sock)
case response of
Nothing -> loop query checkSeqno (cnt + 1) False
Just res -> do
let valid = checkSeqno res
if valid then
return $ Right res
else
loop query checkSeqno (cnt + 1) False
sock = dnsSock rlv
tm = dnsTimeout rlv
retry = dnsRetry rlv
q = makeQuestion dom typ
check seqno res = identifier (header res) == seqno
#if mingw32_HOST_OS == 1
sendAll sock bs = do
sent <- send sock (LB.unpack bs)
when (sent < fromIntegral (LB.length bs)) $ sendAll sock (LB.drop (fromIntegral sent) bs)
#endif