module Hans.Layer.Dns (
DnsHandle
, runDnsLayer
, DnsException
, addNameServer
, removeNameServer
, HostName
, HostEntry(..)
, getHostByName
, getHostByAddr
) where
import Hans.Address.IP4
import Hans.Channel
import Hans.Layer
import Hans.Layer.Udp as Udp
import Hans.Message.Dns
import Hans.Message.Udp
import Hans.Timers
import Control.Concurrent ( forkIO, MVar, newEmptyMVar, takeMVar, putMVar )
import Control.Monad ( mzero, guard, when )
import Data.Bits ( shiftR, (.&.), (.|.) )
import Data.Foldable ( foldl' )
import Data.List ( intercalate )
import Data.String ( fromString )
import Data.Typeable ( Typeable )
import Data.Word ( Word16 )
import MonadLib ( get, set )
import qualified Control.Exception as X
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy as L
import qualified Data.Map.Strict as Map
type DnsHandle = Channel (Dns ())
runDnsLayer :: DnsHandle -> UdpHandle -> IO ()
runDnsLayer h udp =
do _ <- forkIO (loopLayer "dns" (emptyDnsState h udp) (receive h) id)
return ()
data DnsException = NoNameServers
| OutOfServers
| DoesNotExist
| DnsRequestFailed
deriving (Show,Typeable)
instance X.Exception DnsException
addNameServer :: DnsHandle -> IP4 -> IO ()
addNameServer h addr =
send h $ do state <- get
set $! state { dnsNameServers = addr : dnsNameServers state }
removeNameServer :: DnsHandle -> IP4 -> IO ()
removeNameServer h addr =
send h $ do state <- get
set $! state
{ dnsNameServers = filter (/= addr) (dnsNameServers state) }
type HostName = String
data HostEntry = HostEntry { hostName :: HostName
, hostAliases :: [HostName]
, hostAddresses :: [IP4]
} deriving (Show)
getHostByName :: DnsHandle -> HostName -> IO HostEntry
getHostByName h host =
do res <- newEmptyMVar
send h (getHostEntry res (FromHost host))
e <- takeMVar res
case e of
Right he -> return he
Left err -> X.throwIO err
getHostByAddr :: DnsHandle -> IP4 -> IO HostEntry
getHostByAddr h addr =
do res <- newEmptyMVar
send h (getHostEntry res (FromIP4 addr))
e <- takeMVar res
case e of
Right he -> return he
Left err -> X.throwIO err
type Dns = Layer DnsState
data DnsState = DnsState { dnsSelf :: !DnsHandle
, dnsUdpHandle :: !UdpHandle
, dnsNameServers :: ![IP4]
, dnsReqId :: !Word16
, dnsQueries :: !(Map.Map Word16 DnsQuery)
, dnsTimeout :: !Milliseconds
}
emptyDnsState :: DnsHandle -> UdpHandle -> DnsState
emptyDnsState h udp = DnsState { dnsSelf = h
, dnsUdpHandle = udp
, dnsNameServers = []
, dnsReqId = 1
, dnsQueries = Map.empty
, dnsTimeout = 180000
}
stepReqId :: Word16 -> Word16
stepReqId w = (w `shiftR` 1) .|. (negate (w .&. 0x1) .&. 0xB400)
registerRequest :: (Word16 -> DnsQuery) -> Dns Word16
registerRequest mk =
do state <- get
let reqId = dnsReqId state
set state { dnsReqId = stepReqId reqId
, dnsQueries = Map.insert reqId (mk reqId) (dnsQueries state)
}
return reqId
registerTimeout :: Word16 -> Timer -> Dns ()
registerTimeout reqId timer =
do DnsState { .. } <- get
case Map.lookup reqId dnsQueries of
Just query -> updateRequest reqId query { qTimeout = Just timer }
Nothing -> output (cancel timer)
updateRequest :: Word16 -> DnsQuery -> Dns ()
updateRequest reqId query =
do state <- get
set state { dnsQueries = Map.insert reqId query (dnsQueries state) }
lookupRequest :: Word16 -> Dns DnsQuery
lookupRequest reqId =
do DnsState { .. } <- get
case Map.lookup reqId dnsQueries of
Just query -> return query
Nothing -> mzero
removeRequest :: Word16 -> Dns ()
removeRequest reqId =
do state <- get
set state { dnsQueries = Map.delete reqId (dnsQueries state) }
data Source = FromHost HostName
| FromIP4 IP4
deriving (Show)
sourceQType :: Source -> [QType]
sourceQType FromHost{} = [QType A]
sourceQType FromIP4{} = [QType PTR]
sourceHost :: Source -> Name
sourceHost (FromHost h) = toLabels h
sourceHost (FromIP4 (IP4 a b c d)) = let byte w = fromString (show w)
in map byte [d,c,b,a] ++ ["in-addr","arpa"]
toLabels :: String -> Name
toLabels str = case break (== '.') str of
(as,_:bs) -> fromString as : toLabels bs
(as,_) -> [fromString as]
getHostEntry :: DnsResult -> Source -> Dns ()
getHostEntry res src =
do DnsState { .. } <- get
when (null dnsNameServers) $
do output (putError res NoNameServers)
mzero
output $
do port <- addUdpHandlerAnyPort dnsUdpHandle (serverResponse dnsSelf src)
send dnsSelf (createRequest res dnsNameServers src port)
createRequest :: DnsResult -> [IP4] -> Source -> UdpPort -> Dns ()
createRequest res nss src port =
do DnsState { .. } <- get
reqId <- registerRequest (mkDnsQuery res nss port src)
sendRequest reqId
sendRequest :: Word16 -> Dns ()
sendRequest reqId =
do query <- lookupRequest reqId
case qServers query of
n:rest -> do updateRequest reqId query { qServers = rest
, qLastServer = Just n
}
sendQuery n (qUdpPort query) reqId (qRequest query)
[] -> do removeRequest reqId
output (putError (qResult query) OutOfServers)
expireRequest :: Word16 -> Dns ()
expireRequest reqId =
do DnsQuery { .. } <- lookupRequest reqId
removeRequest reqId
output (putError qResult OutOfServers)
handleResponse :: Source -> IP4 -> UdpPort -> S.ByteString -> Dns ()
handleResponse src srcIp srcPort bytes =
do guard (srcPort == 53)
DNSPacket { .. } <- liftRight (parseDNSPacket bytes)
let DNSHeader { .. } = dnsHeader
DnsQuery { .. } <- lookupRequest dnsId
guard (Just srcIp == qLastServer && not dnsQuery)
if dnsRC == RespNoError
then output (putResult qResult (parseHostEntry src dnsAnswers))
else output (putError qResult DnsRequestFailed)
removeRequest dnsId
DnsState { .. } <- get
output $ do removeUdpHandler dnsUdpHandle qUdpPort
case qTimeout of
Just timeout -> cancel timeout
Nothing -> return ()
parseHostEntry :: Source -> [RR] -> HostEntry
parseHostEntry (FromHost host) = parseAddr host
parseHostEntry (FromIP4 addr) = parsePtr addr
parseAddr :: HostName -> [RR] -> HostEntry
parseAddr host = foldl' processAnswer emptyHostEntry
where
emptyHostEntry = HostEntry { hostName = host
, hostAliases = []
, hostAddresses = [] }
processAnswer he RR { .. } = case rrRData of
RDA ip -> he { hostAddresses = ip : hostAddresses he }
RDCNAME ns -> he { hostName = intercalate "." (map C8.unpack ns)
, hostAliases = hostName he : hostAliases he }
_ -> he
parsePtr :: IP4 -> [RR] -> HostEntry
parsePtr addr = foldl' processAnswer emptyHostEntry
where
emptyHostEntry = HostEntry { hostName = ""
, hostAliases = []
, hostAddresses = [addr] }
processAnswer he RR { .. } = case rrRData of
RDPTR name -> he { hostName = intercalate "." (map C8.unpack name) }
_ -> he
type DnsResult = MVar (Either DnsException HostEntry)
putResult :: DnsResult -> HostEntry -> IO ()
putResult var he = putMVar var (Right he)
putError :: DnsResult -> DnsException -> IO ()
putError var err = putMVar var (Left err)
data DnsQuery = DnsQuery { qResult :: DnsResult
, qUdpPort :: !UdpPort
, qRequest :: L.ByteString
, qServers :: [IP4]
, qLastServer :: Maybe IP4
, qTimeout :: Maybe Timer
}
mkDnsQuery :: DnsResult -> [IP4] -> UdpPort -> Source -> Word16 -> DnsQuery
mkDnsQuery res nss port src reqId =
DnsQuery { qResult = res
, qUdpPort = port
, qRequest = renderDNSPacket (mkDNSPacket host qs reqId)
, qServers = nss
, qLastServer = Nothing
, qTimeout = Nothing
}
where
host = sourceHost src
qs = sourceQType src
mkDNSPacket :: Name -> [QType] -> Word16 -> DNSPacket
mkDNSPacket name qs reqId =
DNSPacket { dnsHeader = hdr
, dnsQuestions = [ mkQuery q | q <- qs ]
, dnsAnswers = []
, dnsAuthorityRecords = []
, dnsAdditionalRecords = []
}
where
hdr = DNSHeader { dnsId = reqId
, dnsQuery = True
, dnsOpCode = OpQuery
, dnsAA = False
, dnsTC = False
, dnsRD = True
, dnsRA = False
, dnsRC = RespNoError
}
mkQuery qty = Query { qName = name
, qType = qty
, qClass = QClass IN
}
sendQuery :: IP4 -> UdpPort -> Word16 -> L.ByteString -> Dns ()
sendQuery nameServer sp reqId bytes =
do DnsState { .. } <- get
output $ do sendUdp dnsUdpHandle nameServer (Just sp) 53 bytes
expire <- delay dnsTimeout (send dnsSelf (expireRequest reqId) `X.finally` putStrLn "KILLED")
send dnsSelf (registerTimeout reqId expire)
serverResponse :: DnsHandle -> Source -> UdpPort -> Udp.Handler
serverResponse dns src _ srcIp srcPort bytes =
send dns (handleResponse src srcIp srcPort bytes)