module Network.DNS.Client
( module Network.DNS.Types
, resolve
, resolveAsync
, DNSError(..)
) where
import Data.Word
import Data.List (nub)
import Data.Maybe (fromMaybe)
import Data.Time
import Control.Monad (when)
import Control.Timeout
import Control.Concurrent (forkIO)
import System.IO.Unsafe (unsafePerformIO)
import System.Random (mkStdGen, random, Random, randomR)
import Control.Concurrent.STM
import qualified Data.Binary.Get as G
import qualified Data.Map as Map
import Network.Socket hiding (sendTo, recvFrom)
import Network.Socket.ByteString
import qualified Data.Binary.Put as P
import qualified Data.ByteString as B
import qualified Data.ByteString.Lazy as BL
import Network.DNS.Common
import Network.DNS.ResolveConfParse
import Network.DNS.Types
queryHeader :: Word16 -> Header
queryHeader id = Header id False QUERY False False True False NoError 1 0 0 0
data DNSError = Timeout
| AnswerNotIncluded
| DNSError ResponseCode
deriving (Show, Eq)
data InflightRequest =
InflightRequest { infRequest :: B.ByteString
, infCallback :: (Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO ())
, infCurrent :: [String]
, infSearch :: [[String]]
, infAttempts :: Int
, infTimeout :: TimeoutTag
, infType :: DNSType
, infSpecificNameserver :: Maybe Nameserver
}
data Nameserver = Nameserver { nsAddress :: Word32
, nsUp :: TVar Bool
, nsInflight :: TVar (Map.Map Word16 InflightRequest)
, nsTimeouts :: TVar Int
}
data ResolverConfig = ResolverConfig
{ rcNameservers :: [Nameserver]
, rcSearchPath :: [[String]]
, rcNdots :: Int
, rcAttempts :: Int
, rcTimeout :: Int
, rcRobin :: TVar Int
, rcSeeds :: TVar [Int]
}
maxInflightPerServer :: Int
maxInflightPerServer = 2048
globalSocket :: Socket
globalSocket = unsafePerformIO $ do
s <- socket AF_INET Datagram 0
bindSocket s $ SockAddrInet (PortNum 0) iNADDR_ANY
return s
lookupReply :: ResolverConfig
-> Word16
-> Word32
-> IO (Maybe (InflightRequest, Nameserver))
lookupReply config id addr = do
case filter (\x -> nsAddress x == addr) $ rcNameservers config of
[] -> return Nothing
(ns:_) -> do
minf <- atomically (do
inflight <- readTVar $ nsInflight ns
let (minf, inflight') = Map.updateLookupWithKey (const $ const Nothing) id inflight
writeTVar (nsInflight ns) inflight'
case minf of
Just inf -> cancelTimeout $ infTimeout inf
_ -> return False
return minf)
return $ minf >>= (\inf -> return (inf, ns))
readerThread :: Socket -> IO ()
readerThread socket = do
(bytes, SockAddrInet _ addr) <- recvFrom socket 1500
case parsePacket bytes of
Left _ -> readerThread socket
Right (Packet header _ ans nses additional) -> do
if (headIsResponse header == False) || (headIsTruncated header)
then readerThread socket
else do config <- getResolveConfig
minfns <- lookupReply config (headId header) addr
case minfns of
Nothing -> readerThread socket
Just (inf, ns) -> do
case headResponseCode header of
ServerError -> handleTransientError config inf $ DNSError ServerError
NoError -> handleReply config inf ns (ans, nses, additional)
x -> handleFailure config inf $ DNSError x
readerThread socket
globalConfig :: TVar (Maybe ResolverConfig)
globalConfig = unsafePerformIO $ newTVarIO Nothing
resolverConfigFromResolvConf :: IO ResolverConfig
resolverConfigFromResolvConf = do
Right resolvconf <- parseResolveConf "/etc/resolv.conf"
robin <- atomically $ newTVar 0
let toNameserver ip = do
up <- atomically $ newTVar True
inflight <- atomically $ newTVar Map.empty
timeouts <- atomically $ newTVar 0
return $ Nameserver ip up inflight timeouts
urandom <- BL.readFile "/dev/urandom"
let urandomParser :: G.Get [Int]
urandomParser = do
v <- G.getWord32be
rest <- urandomParser
return $ fromIntegral v : rest
seeds = G.runGet urandomParser urandom
tseeds <- atomically $ newTVar seeds
ns <- mapM toNameserver $ nub $ resolveNameservers resolvconf
return $ ResolverConfig ns
(resolveSearch resolvconf)
(fromMaybe 1 (resolveNdots resolvconf))
(fromMaybe 2 (resolveAttempts resolvconf))
(fromMaybe 5 (resolveTimeout resolvconf))
robin
tseeds
selectNameserver :: ResolverConfig -> STM Nameserver
selectNameserver config = do
robin <- readTVar $ rcRobin config
writeTVar (rcRobin config) $ (robin + 1) `mod` (length (rcNameservers config))
let servers = (drop robin $ rcNameservers config) ++
(take robin $ rcNameservers config)
servers' <- mapM (\x -> readTVar (nsUp x) >>= \up -> return (up, x)) servers
case filter fst servers' of
[] -> return $ head servers
x:_ -> return $ snd x
instance Random Word16 where
random g = (fromIntegral result, g') where
result :: Int
(result, g') = randomR (0, 65535) g
randomR (lo, hi) g = (fromIntegral result, g') where
result :: Int
(result, g') = randomR (fromIntegral lo, fromIntegral hi) g
submit4 :: ResolverConfig
-> InflightRequest
-> (IO () -> STM TimeoutTag)
-> STM (Word16, Word32)
submit4 config inf mtag = do
ns <- case infSpecificNameserver inf of
Nothing -> selectNameserver config
Just ns -> return ns
inflight <- readTVar $ nsInflight ns
when (Map.size inflight > maxInflightPerServer) retry
seeds <- readTVar (rcSeeds config)
let seed = head seeds
writeTVar (rcSeeds config) $ tail seeds
let prng = mkStdGen seed
f prng = r where
r = if Map.member candidate inflight
then f prng'
else candidate
(candidate, prng') = random prng
id = f prng
addr = nsAddress ns
tag <- mtag $ handleTimeout config id addr
let inf' = inf { infTimeout = tag, infAttempts = infAttempts inf + 1 }
writeTVar (nsInflight ns) $ Map.insert id inf' inflight
return (id, addr)
maxTimeoutsPerServer :: Int
maxTimeoutsPerServer = 5
probeCallback :: Nameserver
-> Either DNSError ([String], ([Entry], [Entry], [Entry])) -> IO ()
probeCallback ns result =
case result of
Left Timeout -> (addTimeout 60 $ probeNameserver ns) >> return ()
_ -> atomically (writeTVar (nsUp ns) True)
probeNameserver :: Nameserver -> IO ()
probeNameserver ns = do
config <- getResolveConfig
let labels = ["www", "google", "com"]
Just name = serialiseDNSName labels
req = B.concat $ BL.toChunks $ P.runPut $ serialiseQuestion name A
inf = InflightRequest req (probeCallback ns) labels [labels] 1 undefined A $ Just ns
transmit config inf
handleTimeout :: ResolverConfig
-> Word16
-> Word32
-> IO ()
handleTimeout config id addr = do
minfns <- lookupReply config id addr
case minfns of
Nothing -> return ()
Just (inf, ns) -> do
timeouts <- atomically $ do
timeouts <- readTVar $ nsTimeouts ns
writeTVar (nsTimeouts ns) $ timeouts + 1
return timeouts
when (timeouts > maxTimeoutsPerServer) $ do
mtimeout <- addTimeoutAtomic 60
atomically $ do
upflag <- readTVar $ nsUp ns
when (upflag == True) $ do
writeTVar (nsUp ns) False
mtimeout $ probeNameserver ns
return ()
handleTransientError config inf Timeout
handleReply :: ResolverConfig -> InflightRequest -> Nameserver -> ([Entry], [Entry], [Entry]) -> IO ()
handleReply _ inf ns answers = do
atomically $ do
writeTVar (nsTimeouts ns) 0
writeTVar (nsUp ns) True
(infCallback inf) $ Right (infCurrent inf, answers)
transmit :: ResolverConfig -> InflightRequest -> IO ()
transmit config inf = do
mtag <- addTimeoutAtomic $ fromIntegral $ rcTimeout config
(id, addr) <- atomically $ submit4 config inf mtag
let header = B.concat $ BL.toChunks $ P.runPut $ serialiseHeader $ queryHeader id
query = header `B.append` infRequest inf
sendTo globalSocket query $ SockAddrInet 53 addr
return ()
sendInflight :: ResolverConfig -> InflightRequest -> IO ()
sendInflight config inf =
if null $ infSearch inf
then (infCallback inf) $ Left $ DNSError NXDomain
else do let inf' = inf { infSearch = tail $ infSearch inf, infAttempts = 1 }
target = head $ infSearch inf
case serialiseDNSName target of
Nothing -> sendInflight config inf'
Just name -> do
transmit config $ inf' { infRequest = B.concat $ BL.toChunks $ P.runPut $ serialiseQuestion name $ infType inf
, infCurrent = target }
handleTransientError :: ResolverConfig -> InflightRequest -> DNSError -> IO ()
handleTransientError config inf result = do
if infAttempts inf > rcAttempts config
then handleFailure config inf result
else transmit config $ inf { infAttempts = 1 + infAttempts inf }
handleFailure :: ResolverConfig -> InflightRequest -> DNSError -> IO ()
handleFailure config inf result = do
if not $ null $ infSearch inf
then sendInflight config inf
else (infCallback inf) $ Left result
type DNSDB = Map.Map ([String], DNSType) [(UTCTime, RR)]
answersToDB :: UTCTime -> ([Entry], [Entry], [Entry]) -> DNSDB
answersToDB currentTime (ans, nses, additional) =
Map.unionsWith (++) $ map toMap [ans, nses, additional] where
toMap :: [Entry] -> Map.Map ([String], DNSType) [(UTCTime, RR)]
toMap = Map.fromListWith (++) . map (\(host, secs, rr) -> ((host, rrToType rr), [(t secs, rr)]))
t :: Word32 -> UTCTime
t = (flip addUTCTime) currentTime . fromIntegral
dbGet :: [String] -> DNSType -> DNSDB -> [(UTCTime, RR)]
dbGet host ty db = find host (0 :: Int) where
find _ 16 = []
find host n =
case Map.lookup (host, ty) db of
Just x -> x
Nothing -> case Map.lookup (host, CNAME) db of
Nothing -> []
Just ((_, RRCNAME host'):_) -> find host' (n + 1)
parseQuery :: DNSType
-> (Either DNSError [(UTCTime, RR)] -> IO ())
-> Either DNSError ([String], ([Entry], [Entry], [Entry]))
-> IO ()
parseQuery ty cb e =
case e of
Left x -> cb $ Left x
Right (host, answers) -> do
currentTime <- getCurrentTime
let db = answersToDB currentTime answers
case dbGet host ty db of
[] -> cb $ Left AnswerNotIncluded
x -> cb $ Right x
getResolveConfig :: IO ResolverConfig
getResolveConfig = do
config <- atomically $ readTVar globalConfig
case config of
Nothing -> do
config <- resolverConfigFromResolvConf
(set, config') <- atomically $ do
mconfig <- readTVar globalConfig
case mconfig of
Nothing -> do
writeTVar globalConfig $ Just config
return (True, config)
Just config'' -> return (False, config'')
when set (forkIO (readerThread globalSocket) >> return ())
return config'
Just config' -> return config'
resolve :: DNSType
-> String
-> IO (Either DNSError [(UTCTime, RR)])
resolve ty hostname = do
var <- atomically $ newEmptyTMVar
resolveAsync ty hostname (atomically . putTMVar var)
atomically $ takeTMVar var
resolveAsync :: DNSType
-> String
-> (Either DNSError [(UTCTime, RR)] -> IO ())
-> IO ()
resolveAsync ty host cb = do
config <- getResolveConfig
let labels = splitDNSName host
wrappedCb = parseQuery ty cb
inf = InflightRequest undefined wrappedCb labels [] 1 undefined ty Nothing
let names = if last host /= '.' && length labels 1 < rcNdots config && length (rcSearchPath config) > 0
then (map ((++) labels) $ rcSearchPath config) ++ [labels]
else [labels]
sendInflight config $ inf { infSearch = names }