{-# language BangPatterns #-}
{-# language BinaryLiterals #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language EmptyCase #-}
module Network.Icmp.Ping.Hosts
( hosts
, range
) where
import Control.Applicative ((<|>))
import Control.Concurrent (threadWaitReadSTM,threadWaitWriteSTM)
import Control.Concurrent.STM.TVar (readTVar,registerDelay)
import Control.Exception (onException,mask)
import Data.Bits (unsafeShiftL, unsafeShiftR, (.&.), (.|.), testBit)
import Data.Functor (($>))
import Data.Primitive (PrimArray,MutableByteArray)
import Data.Word (Word64,Word8,Word16,Word32)
import Foreign.C.Error (Errno(..),eACCES)
import Foreign.C.Types (CSize(..))
import GHC.Clock (getMonotonicTimeNSec)
import GHC.Exts (RealWorld)
import GHC.IO (IO(..))
import Net.Types (IPv4(..),IPv4Range)
import Network.Icmp.Common (IcmpException(..))
import Network.Icmp.Marshal (peekIcmpHeaderPayload,peekIcmpHeaderType)
import Network.Icmp.Marshal (peekIcmpHeaderSequenceNumber)
import Network.Icmp.Marshal (sizeOfIcmpHeader,pokeIcmpHeader)
import Network.Icmp.Ping.Debug (debug)
import Posix.Socket (SocketAddressInternet(..))
import System.Endian (toBE32)
import System.Posix.Types (Fd(..))
import Unsafe.Coerce (unsafeCoerce)
import qualified Control.Monad.STM as STM
import qualified Data.Map.Unboxed.Unboxed as MUU
import qualified Data.Primitive as PM
import qualified Data.Set.Unboxed as SU
import qualified Linux.Socket as SCK
import qualified Posix.Socket as SCK
import qualified Net.IPv4 as IPv4
fullPacketSize :: Int
fullPacketSize = sizeOfIcmpHeader + 4
waitForRead ::
Int
-> Fd
-> IO Bool
waitForRead !maxWaitTime !sock = do
(isReadyAction,deregister) <- threadWaitReadSTM sock
delay <- registerDelay maxWaitTime
isContentReady <- STM.atomically $
(isReadyAction $> True)
<|>
(do isDone <- readTVar delay
STM.check isDone
pure False
)
deregister
pure isContentReady
waitForReadWrite :: Fd -> IO Bool
waitForReadWrite sock = do
(isReadyRead,deregisterRead) <- threadWaitReadSTM sock
(isReadyWrite,deregisterWrite) <- threadWaitWriteSTM sock
r <- STM.atomically ((isReadyRead $> True) <|> (isReadyWrite $> False))
deregisterRead
deregisterWrite
pure r
range ::
Int
-> IPv4Range
-> IO (Either IcmpException (MUU.Map IPv4 Word64))
range !pause !r = hosts pause $ coerceIPv4Set
(SU.enumFromTo
(getIPv4 (IPv4.lowerInclusive r))
(getIPv4 (IPv4.upperInclusive r))
)
coerceIPv4Set :: SU.Set Word32 -> SU.Set IPv4
coerceIPv4Set = unsafeCoerce
hosts ::
Int
-> SU.Set IPv4
-> IO (Either IcmpException (MUU.Map IPv4 Word64))
hosts !pause !theHosts = do
mask $ \restore -> SCK.uninterruptibleSocket SCK.internet SCK.datagram SCK.icmp >>= \case
Left (Errno e) -> pure (Left (IcmpExceptionSocket e))
Right sock -> do
durations <- restore
( do let hostsArr = SU.toArray theHosts
!buffer <- PM.newByteArray fullPacketSize
(m,r) <- MUU.adjustManyInline
(\adjust -> hostsStepA buffer sock pause hostsArr (PM.sizeofPrimArray hostsArr) adjust
) (MUU.fromSet (const initialStatus) theHosts)
pure $ case r of
Left pair -> Left pair
Right _ -> Right
( MUU.mapMaybe
(\w -> case testBit w 47 of
True -> Just (extractTimestamp w)
False -> Nothing
) m
)
)
`onException`
(SCK.uninterruptibleClose sock)
SCK.uninterruptibleClose sock >>= \case
Left (Errno e) -> pure (Left (IcmpExceptionClose e))
Right _ -> pure durations
hostsStepA :: MutableByteArray RealWorld -> Fd -> Int -> PrimArray IPv4 -> Int -> (IPv4 -> (Word64 -> IO Word64) -> IO ()) -> IO (Either IcmpException ())
hostsStepA !buffer !sock !pause !hostsArr !hostsLen adjust = go 0 where
go !ix = if ix < hostsLen
then do
debug "waiting for read-write"
waitForReadWrite sock >>= \case
True -> do
debug "ready for read"
r <- SCK.uninterruptibleReceiveFromMutableByteArray_ sock buffer 0 (intToCSize fullPacketSize) SCK.dontWait
case r of
Left (Errno e) -> pure (Left (IcmpExceptionReceive e))
Right receivedBytes -> if receivedBytes == intToCSize fullPacketSize
then do
payload' <- peekIcmpHeaderPayload buffer
adjust (IPv4 payload') $ \w -> case extractStatus w of
0b01 -> do
sequenceNumber' <- peekIcmpHeaderSequenceNumber buffer
if sequenceNumber' == extractSequenceNumber w
then do
end <- getMonotonicTimeNSec
pure (completeStatus ((end .&. 0x3FFFFFFFFFFF) - extractTimestamp w))
else pure w
_ ->
pure w
go ix
else do
go ix
False -> do
debug "ready for write"
let host = PM.indexPrimArray hostsArr ix
PM.setByteArray buffer 0 sizeOfIcmpHeader (0 :: Word8)
pokeIcmpHeader buffer (intToWord16 ix) (getIPv4 host)
let sockaddr = SCK.encodeSocketAddressInternet
(SocketAddressInternet { port = 0, address = toBE32 (getIPv4 host) })
mwriteError <- SCK.uninterruptibleSendToMutableByteArray sock buffer 0 (intToCSize fullPacketSize) SCK.dontWait sockaddr
case mwriteError of
Left (Errno e)
| Errno e == eACCES -> go (ix + 1)
| otherwise -> pure (Left (IcmpExceptionSend e))
Right sentBytes -> if sentBytes == intToCSize fullPacketSize
then do
start <- getMonotonicTimeNSec
adjust host (\_ -> pure (pendingStatus (intToWord16 ix) start))
go (ix + 1)
else do
pure (Left (IcmpExceptionSendBytes sentBytes))
else hostsStepB buffer sock pause adjust =<< getMonotonicTimeNSec
hostsStepB :: MutableByteArray RealWorld -> Fd -> Int -> (IPv4 -> (Word64 -> IO Word64) -> IO ()) -> Word64 -> IO (Either IcmpException ())
hostsStepB !buffer !sock !pause !adjust !initialTime = go initialTime where
go !currentTime = do
debug "Step B iteration"
let remainingMicroseconds = pause - word64ToInt (div (currentTime - initialTime) 1000)
if remainingMicroseconds > 0
then do
isReady <- waitForRead remainingMicroseconds sock
if isReady
then do
r <- SCK.uninterruptibleReceiveFromMutableByteArray_ sock buffer 0 (intToCSize fullPacketSize) SCK.dontWait
case r of
Left (Errno e) -> pure (Left (IcmpExceptionReceive e))
Right receivedBytes -> if receivedBytes == intToCSize fullPacketSize
then do
payload' <- peekIcmpHeaderPayload buffer
end <- getMonotonicTimeNSec
peekIcmpHeaderType buffer >>= \case
0 -> do
adjust (IPv4 payload') $ \w -> case extractStatus w of
0b01 -> do
sequenceNumber' <- peekIcmpHeaderSequenceNumber buffer
if sequenceNumber' == extractSequenceNumber w
then pure (completeStatus ((end .&. 0x3FFFFFFFFFFF) - extractTimestamp w))
else pure w
_ -> pure w
go end
_ -> go end
else go =<< getMonotonicTimeNSec
else pure (Right ())
else pure (Right ())
pendingStatus :: Word16 -> Word64 -> Word64
pendingStatus seqNum timestamp =
0x400000000000 .|. (timestamp .&. 0x3FFFFFFFFFFF) .|. (unsafeShiftL (word16ToWord64 seqNum) 48)
completeStatus :: Word64 -> Word64
completeStatus timestamp = 0xC00000000000 .|. timestamp
initialStatus :: Word64
initialStatus = 0
extractStatus :: Word64 -> Word64
extractStatus w =
unsafeShiftR (0xC00000000000 .&. w) 46
extractSequenceNumber :: Word64 -> Word16
extractSequenceNumber w = word64ToWord16 (unsafeShiftR w 48)
extractTimestamp :: Word64 -> Word64
extractTimestamp w = (w .&. 0x3FFFFFFFFFFF)
word64ToWord16 :: Word64 -> Word16
word64ToWord16 = fromIntegral
word16ToWord64 :: Word16 -> Word64
word16ToWord64 = fromIntegral
intToWord16 :: Int -> Word16
intToWord16 = fromIntegral
word64ToInt :: Word64 -> Int
word64ToInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral