{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language DuplicateRecordFields #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language LambdaCase #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
{-# language StandaloneDeriving #-}
{-# language UnboxedTuples #-}
module Socket.Datagram.IPv4.Spoof
(
Socket(..)
, Endpoint(..)
, Message(..)
, withSocket
, sendMutableByteArray
, SocketException(..)
, SendException(..)
) where
import Control.Concurrent (threadWaitWrite)
import Control.Exception (Exception,throwIO,mask,onException)
import Data.Bits (unsafeShiftR,complement,(.&.))
import Data.Kind (Type)
import Data.Primitive (MutableByteArray(..))
import Data.Word (Word16,Word8,Word64,Word32)
import Foreign.C.Error (Errno(..),eWOULDBLOCK,eAGAIN,eMFILE,eNFILE,eACCES,ePERM)
import Foreign.C.Types (CInt,CSize)
import GHC.Exts (RealWorld,touch#)
import GHC.IO (IO(..))
import Net.Types (IPv4(..))
import Socket (SocketUnrecoverableException(..),Interruptibility(..))
import Socket.Datagram (SendException(..))
import Socket.Datagram.IPv4.Undestined.Internal (Message(..))
import Socket.Debug (debug,whenDebugging)
import Socket.IPv4 (Endpoint(..))
import System.Posix.Types (Fd)
import Text.Printf (printf)
import qualified Data.Primitive as PM
import qualified Linux.Socket as L
import qualified Posix.Socket as S
import qualified GHC.Exts as E
import qualified Socket as SCK
newtype Socket = Socket Fd
deriving stock (Eq,Ord,Show)
data SocketException :: Type where
SocketPermissionDenied :: SocketException
SocketFileDescriptorLimit :: SocketException
deriving stock instance Show SocketException
deriving anyclass instance Exception SocketException
withSocket ::
(Socket -> IO a)
-> IO (Either SocketException a)
withSocket f = mask $ \restore -> do
debug "withSocket: opening raw socket"
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.raw)
S.rawProtocol
debug "withSocket: opened raw socket"
case e1 of
Left err -> handleSocketException SCK.functionWithSocket err
Right fd -> do
a <- onException (restore (f (Socket fd))) (S.uninterruptibleErrorlessClose fd)
S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Spoof
SCK.functionWithSocket
["close",describeErrorCode err]
Right _ -> pure (Right a)
sendMutableByteArray ::
Socket
-> Endpoint
-> Endpoint
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArray (Socket !s) !theSource !theRemote !thePayload !off !len = do
let ipHeaderSz = cintToInt L.sizeofIpHeader
let totalHeaderSz = cintToInt (L.sizeofIpHeader + L.sizeofUdpHeader)
let totalPacketSz = len + totalHeaderSz
buf <- PM.newPinnedByteArray (totalPacketSz + 1)
PM.setByteArray buf 0 (totalPacketSz + 1) (0 :: Word8)
let addr = PM.mutableByteArrayContents buf
L.pokeIpHeaderVersionIhl addr (4 * 16 + 5)
L.pokeIpHeaderTypeOfService addr 0
L.pokeIpHeaderTotalLength addr (S.hostToNetworkShort (intToWord16 totalPacketSz))
L.pokeIpHeaderIdentifier addr 0
L.pokeIpHeaderFragmentOffset addr 0
L.pokeIpHeaderTimeToLive addr 64
L.pokeIpHeaderProtocol addr (cintToWord8 (S.getProtocol S.udp))
L.pokeIpHeaderChecksum addr 0
let src = S.hostToNetworkLong (getIPv4 (address theSource))
L.pokeIpHeaderSourceAddress addr src
let dst = S.hostToNetworkLong (getIPv4 (address theRemote))
L.pokeIpHeaderDestinationAddress addr dst
let udpAddr = PM.plusAddr addr ipHeaderSz
L.pokeUdpHeaderSourcePort udpAddr (S.hostToNetworkShort (port theSource))
L.pokeUdpHeaderDestinationPort udpAddr (S.hostToNetworkShort (port theRemote))
let udpLen = cintToInt L.sizeofUdpHeader + len
L.pokeUdpHeaderLength udpAddr (S.hostToNetworkShort (intToWord16 udpLen))
PM.copyMutableByteArray buf totalHeaderSz thePayload off len
L.pokeUdpHeaderChecksum udpAddr . S.hostToNetworkShort =<< udpChecksum src dst buf ipHeaderSz udpLen
touchMutableByteArray buf
debug ("spoof send mutable: about to send to " ++ show theRemote)
whenDebugging $ do
d <- PM.newByteArray totalPacketSz
PM.copyMutableByteArray d 0 buf 0 totalPacketSz
x <- PM.unsafeFreezeByteArray d
debug ("raw packet: " ++ (foldMap (printf "%.2x ") (E.toList x)))
e1 <- S.uninterruptibleSendToMutableByteArray s buf 0 (intToCSize totalPacketSz)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
debug ("spoof send mutable: just sent to " ++ show theRemote)
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug ("send mutable: waiting to for write ready to send to " ++ show theRemote)
threadWaitWrite s
e2 <- S.uninterruptibleSendToMutableByteArray s buf
(intToCInt off)
(intToCSize len)
mempty
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet theRemote))
case e2 of
Left err2 -> do
debug ("send mutable: encountered error after sending")
handleSendException "sendMutableByteArray" err2
Right sz -> if csizeToInt sz == totalPacketSz
then pure (Right ())
else pure (Left (SendTruncated (csizeToInt sz)))
else do
debug "spoof send mutable: sent on first try but got error code"
handleSendException "sendMutableByteArray" err1
Right sz -> if csizeToInt sz == totalPacketSz
then do
debug ("send mutable: success")
pure (Right ())
else pure (Left (SendTruncated (csizeToInt sz)))
udpChecksum ::
Word32
-> Word32
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO Word16
udpChecksum src dst payload off len = do
let sum0 = word16ToWord64 (S.hostToNetworkShort (word32ToWord16 src))
debug ("udp checksum source lower: " ++ printf "%.8X" sum0)
let sum1 = sum0 + word16ToWord64 (S.hostToNetworkShort (word32ToWord16 (unsafeShiftR src 16)))
debug ("udp checksum source lower+upper: " ++ printf "%.8X" sum1)
let sum2 = sum1 + word16ToWord64 (S.hostToNetworkShort (word32ToWord16 dst))
sum3 = sum2 + word16ToWord64 (S.hostToNetworkShort (word32ToWord16 (unsafeShiftR dst 16)))
debug ("udp checksum source+dest lower+upper: " ++ printf "%.8X" sum3)
let sum4 = sum3 + word16ToWord64 (cintToWord16 (S.getProtocol S.udp))
sum5 = sum4 + word16ToWord64 (intToWord16 len)
debug ("udp checksum pseudoheader without carries: " ++ printf "%.8X" sum5)
let halfLen = unsafeShiftR (len + off) 1
debug ("udp checksum start offset: " ++ show (unsafeShiftR off 1))
debug ("udp checksum last offset: " ++ show halfLen)
let go :: Int -> Word64 -> IO Word64
go !ix !acc = if ix < halfLen
then do
w16 <- PM.readByteArray payload ix :: IO Word16
debug ("udp checksum payload iteration " ++ show ix ++ ": " ++ printf "%.8X" acc)
go (ix + 1) (word16ToWord64 (S.hostToNetworkShort w16) + acc)
else pure acc
r <- go (unsafeShiftR off 1) sum5
pure (word64ToWord16 (complement ((r .&. 0xFFFF) + (unsafeShiftR r 16 .&. 0xFFFF) + (unsafeShiftR r 32))))
endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
cintToInt :: CInt -> Int
cintToInt = fromIntegral
cintToWord8 :: CInt -> Word8
cintToWord8 = fromIntegral
intToWord16 :: Int -> Word16
intToWord16 = fromIntegral
cintToWord16 :: CInt -> Word16
cintToWord16 = fromIntegral
word16ToWord64 :: Word16 -> Word64
word16ToWord64 = fromIntegral
word64ToWord16 :: Word64 -> Word16
word64ToWord16 = fromIntegral
word32ToWord16 :: Word32 -> Word16
word32ToWord16 = fromIntegral
touchMutableByteArray :: MutableByteArray RealWorld -> IO ()
touchMutableByteArray (MutableByteArray x) = touchMutableByteArray# x
touchMutableByteArray# :: E.MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)
moduleSocketDatagramIPv4Spoof :: String
moduleSocketDatagramIPv4Spoof = "Socket.Datagram.IPv4.Spoof"
handleSocketException :: String -> Errno -> IO (Either SocketException a)
{-# INLINE handleSocketException #-}
handleSocketException func e
| e == ePERM = pure (Left SocketPermissionDenied)
| e == eMFILE = pure (Left SocketFileDescriptorLimit)
| e == eNFILE = pure (Left SocketFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Spoof
func
[describeErrorCode e]
describeErrorCode :: Errno -> String
describeErrorCode (Errno e) = "error code " ++ show e
handleSendException :: String -> Errno -> IO (Either (SendException i) a)
{-# INLINE handleSendException #-}
handleSendException func e
| e == eACCES = pure (Left SendBroadcasted)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketDatagramIPv4Spoof
func
[describeErrorCode e]