{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language MagicHash #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language UnliftedFFITypes #-}
module Linux.Socket
(
uninterruptibleReceiveMultipleMessageA
, uninterruptibleReceiveMultipleMessageB
, SocketFlags(..)
, LST.headerInclude
, LST.dontWait
, LST.truncate
, LST.controlTruncate
, LST.closeOnExec
, LST.nonblocking
, applySocketFlags
, LST.sizeofUdpHeader
, LST.pokeUdpHeaderSourcePort
, LST.pokeUdpHeaderDestinationPort
, LST.pokeUdpHeaderLength
, LST.pokeUdpHeaderChecksum
, LST.sizeofIpHeader
, LST.pokeIpHeaderVersionIhl
, LST.pokeIpHeaderTypeOfService
, LST.pokeIpHeaderTotalLength
, LST.pokeIpHeaderIdentifier
, LST.pokeIpHeaderFragmentOffset
, LST.pokeIpHeaderTimeToLive
, LST.pokeIpHeaderProtocol
, LST.pokeIpHeaderChecksum
, LST.pokeIpHeaderSourceAddress
, LST.pokeIpHeaderDestinationAddress
) where
import Prelude hiding (truncate)
import Control.Monad (when)
import Data.Bits ((.|.))
import Data.Primitive (MutableByteArray(..),Addr(..),ByteArray(..))
import Data.Primitive (MutableUnliftedArray(..),UnliftedArray)
import Foreign.C.Error (Errno,getErrno)
import Foreign.C.Types (CInt(..),CSize(..),CUInt(..))
import GHC.Exts (Ptr,RealWorld,ByteArray#,MutableByteArray#,Addr#,MutableArrayArray#,Int(I#))
import GHC.Exts (shrinkMutableByteArray#,touch#,nullAddr#)
import GHC.IO (IO(..))
import Linux.Socket.Types (SocketFlags(..))
import Posix.Socket (Type(..),MessageFlags(..),Message(Receive))
import System.Posix.Types (Fd(..),CSsize(..))
import qualified Data.Primitive as PM
import qualified Control.Monad.Primitive as PM
import qualified Posix.Socket as S
import qualified Linux.Socket.Types as LST
foreign import ccall unsafe "sys/socket.h recvmmsg"
c_unsafe_addr_recvmmsg :: Fd
-> Addr#
-> CUInt
-> MessageFlags 'Receive
-> Addr#
-> IO CSsize
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags (SocketFlags s) (Type t) = Type (s .|. t)
uninterruptibleReceiveMultipleMessageA ::
Fd
-> CSize
-> CUInt
-> MessageFlags 'Receive
-> IO (Either Errno (CUInt,UnliftedArray ByteArray))
uninterruptibleReceiveMultipleMessageA !s !msgSize !msgCount !flags = do
bufs <- PM.unsafeNewUnliftedArray (cuintToInt msgCount)
mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader)
iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector)
let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = PM.mutableByteArrayContents mmsghdrsBuf
let iovecsAddr = PM.mutableByteArrayContents iovecsBuf
initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsghdrsAddr msgSize msgCount
r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr#
if r > (-1)
then do
(_,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize 0 (cssizeToInt r) bufs mmsghdrsAddr
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
pure (Right (maxMsgSz,frozenBufs))
else do
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
fmap Left getErrno
uninterruptibleReceiveMultipleMessageB ::
Fd
-> CInt
-> CSize
-> CUInt
-> MessageFlags 'Receive
-> IO (Either Errno (CInt,ByteArray,CUInt,UnliftedArray ByteArray))
uninterruptibleReceiveMultipleMessageB !s !expSockAddrSize !msgSize !msgCount !flags = do
bufs <- PM.unsafeNewUnliftedArray (cuintToInt msgCount)
mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader)
iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector)
sockaddrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt expSockAddrSize)
let sockaddrsAddr = PM.mutableByteArrayContents sockaddrsBuf
let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = PM.mutableByteArrayContents mmsghdrsBuf
let iovecsAddr = PM.mutableByteArrayContents iovecsBuf
initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr mmsghdrsAddr sockaddrsAddr expSockAddrSize msgSize msgCount
r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr#
if r > (-1)
then do
(validation,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize expSockAddrSize (cssizeToInt r) bufs mmsghdrsAddr
shrinkMutableByteArray sockaddrsBuf (cssizeToInt r * cintToInt expSockAddrSize)
sockaddrs <- PM.unsafeFreezeByteArray sockaddrsBuf
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
touchMutableByteArray sockaddrsBuf
pure (Right (validation,sockaddrs,maxMsgSz,frozenBufs))
else do
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
touchMutableByteArray sockaddrsBuf
fmap Left getErrno
initializeMultipleMessageHeadersWithoutSockAddr ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> Addr
-> CSize
-> CUInt
-> IO ()
initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsgHdrsAddr msgSize msgCount =
let go !ix !iovecAddr !mmsgHdrAddr = if ix < cuintToInt msgCount
then do
pokeMultipleMessageHeader mmsgHdrAddr PM.nullAddr 0 iovecAddr 1 PM.nullAddr 0 mempty 0
initializeIOVector bufs iovecAddr msgSize ix
go (ix + 1) (PM.plusAddr iovecAddr (cintToInt S.sizeofIOVector)) (PM.plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader))
else pure ()
in go 0 iovecsAddr mmsgHdrsAddr
initializeMultipleMessageHeadersWithSockAddr ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> Addr
-> Addr
-> CInt
-> CSize
-> CUInt
-> IO ()
initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0 sockaddrSize msgSize msgCount =
let go !ix !iovecAddr !mmsgHdrAddr !sockaddrAddr = if ix < cuintToInt msgCount
then do
pokeMultipleMessageHeader mmsgHdrAddr sockaddrAddr sockaddrSize iovecAddr 1 PM.nullAddr 0 mempty 0
initializeIOVector bufs iovecAddr msgSize ix
go (ix + 1)
(PM.plusAddr iovecAddr (cintToInt S.sizeofIOVector))
(PM.plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader))
(PM.plusAddr sockaddrAddr (cintToInt sockaddrSize))
else pure ()
in go 0 iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0
initializeIOVector ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> CSize
-> Int
-> IO ()
initializeIOVector bufs iovecAddr msgSize ix = do
buf <- PM.newPinnedByteArray (csizeToInt msgSize)
PM.writeUnliftedArray bufs ix buf
S.pokeIOVectorBase iovecAddr (PM.mutableByteArrayContents buf)
S.pokeIOVectorLength iovecAddr msgSize
shrinkAndFreezeMessages ::
CSize
-> CInt
-> Int
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> IO (CInt,CUInt,UnliftedArray ByteArray)
shrinkAndFreezeMessages !bufSize !expSockAddrSize !n !bufs !mmsghdr0 = do
r <- PM.unsafeNewUnliftedArray n
go r 0 0 0 mmsghdr0
where
go !r !validation !ix !maxMsgSz !mmsghdr = if ix < n
then do
sz <- LST.peekMultipleMessageHeaderLength mmsghdr
sockaddrSz <- LST.peekMultipleMessageHeaderNameLength mmsghdr
buf <- PM.readUnliftedArray bufs ix
when (cuintToInt sz < csizeToInt bufSize) (shrinkMutableByteArray buf (cuintToInt sz))
PM.writeUnliftedArray r ix =<< PM.unsafeFreezeByteArray buf
go r (validation .|. (sockaddrSz - expSockAddrSize)) (ix + 1) (max maxMsgSz sz)
(PM.plusAddr mmsghdr (cintToInt LST.sizeofMultipleMessageHeader))
else do
a <- PM.unsafeFreezeUnliftedArray r
pure (validation,maxMsgSz,a)
pokeMultipleMessageHeader :: Addr -> Addr -> CInt -> Addr -> CSize -> Addr -> CSize -> MessageFlags 'Receive -> CUInt -> IO ()
pokeMultipleMessageHeader mmsgHdrAddr a b c d e f g len = do
LST.pokeMultipleMessageHeaderName mmsgHdrAddr a
LST.pokeMultipleMessageHeaderNameLength mmsgHdrAddr b
LST.pokeMultipleMessageHeaderIOVector mmsgHdrAddr c
LST.pokeMultipleMessageHeaderIOVectorLength mmsgHdrAddr d
LST.pokeMultipleMessageHeaderControl mmsgHdrAddr e
LST.pokeMultipleMessageHeaderControlLength mmsgHdrAddr f
LST.pokeMultipleMessageHeaderFlags mmsgHdrAddr g
LST.pokeMultipleMessageHeaderLength mmsgHdrAddr len
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)
cintToInt :: CInt -> Int
cintToInt = fromIntegral
cuintToInt :: CUInt -> Int
cuintToInt = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
cssizeToInt :: CSsize -> Int
cssizeToInt = fromIntegral
touchMutableUnliftedArray :: MutableUnliftedArray RealWorld a -> IO ()
touchMutableUnliftedArray (MutableUnliftedArray x) = touchMutableUnliftedArray# x
touchMutableByteArray :: MutableByteArray RealWorld -> IO ()
touchMutableByteArray (MutableByteArray x) = touchMutableByteArray# x
touchMutableUnliftedArray# :: MutableArrayArray# RealWorld -> IO ()
touchMutableUnliftedArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)
touchMutableByteArray# :: MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)