{-# language BangPatterns #-}
{-# language CPP #-}
{-# language DataKinds #-}
{-# language MagicHash #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language UnliftedFFITypes #-}
module Linux.Socket
(
uninterruptibleAccept4
, uninterruptibleAccept4_
#if defined(UNLIFTEDARRAYFUNCTIONS)
, uninterruptibleReceiveMultipleMessageA
, uninterruptibleReceiveMultipleMessageB
, uninterruptibleReceiveMultipleMessageC
, uninterruptibleReceiveMultipleMessageD
#endif
, SocketFlags(..)
, LST.headerInclude
, LST.dontWait
, LST.truncate
, LST.controlTruncate
, LST.closeOnExec
, LST.nonblocking
, applySocketFlags
, LST.sizeofUdpHeader
, LST.pokeUdpHeaderSourcePort
, LST.pokeUdpHeaderDestinationPort
, LST.pokeUdpHeaderLength
, LST.pokeUdpHeaderChecksum
, LST.sizeofIpHeader
, LST.pokeIpHeaderVersionIhl
, LST.pokeIpHeaderTypeOfService
, LST.pokeIpHeaderTotalLength
, LST.pokeIpHeaderIdentifier
, LST.pokeIpHeaderFragmentOffset
, LST.pokeIpHeaderTimeToLive
, LST.pokeIpHeaderProtocol
, LST.pokeIpHeaderChecksum
, LST.pokeIpHeaderSourceAddress
, LST.pokeIpHeaderDestinationAddress
) where
import Prelude hiding (truncate)
import Control.Monad (when)
import Data.Bits ((.|.))
import Data.Primitive (MutableByteArray(..),ByteArray(..),MutablePrimArray(..))
import Data.Primitive.Addr (Addr(..),plusAddr,nullAddr)
#if defined(UNLIFTEDARRAYFUNCTIONS)
import Data.Primitive.Unlifted.Array (MutableUnliftedArray,UnliftedArray)
import Data.Primitive.Unlifted.Array (MutableUnliftedArray_(MutableUnliftedArray))
import Data.Primitive.Unlifted.Array.Primops (MutableUnliftedArray#(MutableUnliftedArray#))
#endif
import Data.Void (Void)
import Data.Word (Word8)
import Foreign.C.Error (Errno,getErrno)
import Foreign.C.Types (CInt(..),CSize(..),CUInt(..))
import Foreign.Ptr (nullPtr)
import GHC.Exts (Ptr(..),RealWorld,MutableArray#,MutableByteArray#,Addr#,Int(I#))
import GHC.Exts (shrinkMutableByteArray#,touch#,nullAddr#)
import GHC.IO (IO(..))
import Linux.Socket.Types (SocketFlags(..))
import Posix.Socket (Type(..),MessageFlags(..),Message(Receive),SocketAddress(..))
import System.Posix.Types (Fd(..),CSsize(..))
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
#if defined(UNLIFTEDARRAYFUNCTIONS)
import qualified Data.Primitive.Unlifted.Array as PM
#endif
import qualified Linux.Socket.Types as LST
import qualified Posix.Socket as S
foreign import ccall unsafe "sys/socket.h recvmmsg"
c_unsafe_addr_recvmmsg :: Fd
-> Addr#
-> CUInt
-> MessageFlags 'Receive
-> Addr#
-> IO CSsize
foreign import ccall unsafe "sys/socket.h accept4"
c_unsafe_accept4 :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> SocketFlags
-> IO Fd
foreign import ccall unsafe "sys/socket.h accept4"
c_unsafe_ptr_accept4 ::
Fd
-> Ptr Void
-> Ptr Void
-> SocketFlags
-> IO Fd
#if defined(UNLIFTEDARRAYFUNCTIONS)
foreign import ccall unsafe "HaskellPosix.h recvmmsg_sockaddr_in"
c_unsafe_recvmmsg_sockaddr_in ::
Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> MutableArray# RealWorld (MutableByteArray# RealWorld)
-> CUInt
-> MessageFlags 'Receive
-> IO CInt
foreign import ccall unsafe "HaskellPosix.h recvmmsg_sockaddr_discard"
c_unsafe_recvmmsg_sockaddr_discard ::
Fd
-> MutableByteArray# RealWorld
-> MutableArray# RealWorld (MutableByteArray# RealWorld)
-> CUInt
-> MessageFlags 'Receive
-> IO CInt
#endif
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags (SocketFlags CInt
s) (Type CInt
t) = CInt -> Type
Type (CInt
s forall a. Bits a => a -> a -> a
.|. CInt
t)
#if defined(UNLIFTEDARRAYFUNCTIONS)
uninterruptibleReceiveMultipleMessageA ::
Fd
-> CSize
-> CUInt
-> MessageFlags 'Receive
-> IO (Either Errno (CUInt,UnliftedArray ByteArray))
uninterruptibleReceiveMultipleMessageA !s !msgSize !msgCount !flags = do
placeholder <- PM.newByteArray 0
bufs <- PM.newUnliftedArray (cuintToInt msgCount) placeholder
mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader)
iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector)
let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = ptrToAddr (PM.mutableByteArrayContents mmsghdrsBuf)
let iovecsAddr = ptrToAddr (PM.mutableByteArrayContents iovecsBuf)
initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsghdrsAddr msgSize msgCount
r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr#
if r > (-1)
then do
(_,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize 0 (cssizeToInt r) bufs mmsghdrsAddr
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
pure (Right (maxMsgSz,frozenBufs))
else do
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
fmap Left getErrno
uninterruptibleReceiveMultipleMessageB ::
Fd
-> CInt
-> CSize
-> CUInt
-> MessageFlags 'Receive
-> IO (Either Errno (CInt,ByteArray,CUInt,UnliftedArray ByteArray))
uninterruptibleReceiveMultipleMessageB !s !expSockAddrSize !msgSize !msgCount !flags = do
placeholder <- PM.newByteArray 0
bufs <- PM.newUnliftedArray (cuintToInt msgCount) placeholder
mmsghdrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt LST.sizeofMultipleMessageHeader)
iovecsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt S.sizeofIOVector)
sockaddrsBuf <- PM.newPinnedByteArray (cuintToInt msgCount * cintToInt expSockAddrSize)
let sockaddrsAddr = ptrToAddr (PM.mutableByteArrayContents sockaddrsBuf)
let !mmsghdrsAddr@(Addr mmsghdrsAddr#) = ptrToAddr (PM.mutableByteArrayContents mmsghdrsBuf)
let iovecsAddr = ptrToAddr (PM.mutableByteArrayContents iovecsBuf)
initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr mmsghdrsAddr sockaddrsAddr expSockAddrSize msgSize msgCount
r <- c_unsafe_addr_recvmmsg s mmsghdrsAddr# msgCount flags nullAddr#
if r > (-1)
then do
(validation,maxMsgSz,frozenBufs) <- shrinkAndFreezeMessages msgSize expSockAddrSize (cssizeToInt r) bufs mmsghdrsAddr
shrinkMutableByteArray sockaddrsBuf (cssizeToInt r * cintToInt expSockAddrSize)
sockaddrs <- PM.unsafeFreezeByteArray sockaddrsBuf
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
touchMutableByteArray sockaddrsBuf
pure (Right (validation,sockaddrs,maxMsgSz,frozenBufs))
else do
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray mmsghdrsBuf
touchMutableByteArray sockaddrsBuf
fmap Left getErrno
uninterruptibleReceiveMultipleMessageC ::
Fd
-> MutablePrimArray RealWorld CInt
-> MutablePrimArray RealWorld S.SocketAddressInternet
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> CUInt
-> MessageFlags 'Receive
-> IO (Either Errno CInt)
uninterruptibleReceiveMultipleMessageC !s (MutablePrimArray lens) (MutablePrimArray addrs) (MutableUnliftedArray (MutableUnliftedArray# payloads)) !msgCount !flags =
c_unsafe_recvmmsg_sockaddr_in s lens addrs payloads msgCount flags >>= errorsFromInt
uninterruptibleReceiveMultipleMessageD ::
Fd
-> MutablePrimArray RealWorld CInt
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> CUInt
-> MessageFlags 'Receive
-> IO (Either Errno CInt)
uninterruptibleReceiveMultipleMessageD !s (MutablePrimArray lens) (MutableUnliftedArray (MutableUnliftedArray# payloads)) !msgCount !flags =
c_unsafe_recvmmsg_sockaddr_discard s lens payloads msgCount flags >>= errorsFromInt
initializeMultipleMessageHeadersWithoutSockAddr ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> Addr
-> CSize
-> CUInt
-> IO ()
initializeMultipleMessageHeadersWithoutSockAddr bufs iovecsAddr mmsgHdrsAddr msgSize msgCount =
let go !ix !iovecAddr !mmsgHdrAddr = if ix < cuintToInt msgCount
then do
pokeMultipleMessageHeader mmsgHdrAddr nullAddr 0 iovecAddr 1 nullAddr 0 mempty 0
initializeIOVector bufs iovecAddr msgSize ix
go (ix + 1) (plusAddr iovecAddr (cintToInt S.sizeofIOVector)) (plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader))
else pure ()
in go 0 iovecsAddr mmsgHdrsAddr
initializeMultipleMessageHeadersWithSockAddr ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> Addr
-> Addr
-> CInt
-> CSize
-> CUInt
-> IO ()
initializeMultipleMessageHeadersWithSockAddr bufs iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0 sockaddrSize msgSize msgCount =
let go !ix !iovecAddr !mmsgHdrAddr !sockaddrAddr = if ix < cuintToInt msgCount
then do
pokeMultipleMessageHeader mmsgHdrAddr sockaddrAddr sockaddrSize iovecAddr 1 nullAddr 0 mempty 0
initializeIOVector bufs iovecAddr msgSize ix
go (ix + 1)
(plusAddr iovecAddr (cintToInt S.sizeofIOVector))
(plusAddr mmsgHdrAddr (cintToInt LST.sizeofMultipleMessageHeader))
(plusAddr sockaddrAddr (cintToInt sockaddrSize))
else pure ()
in go 0 iovecsAddr0 mmsgHdrsAddr0 sockaddrsAddr0
ptrToAddr :: Ptr Word8 -> Addr
ptrToAddr (Ptr x) = Addr x
initializeIOVector ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> CSize
-> Int
-> IO ()
initializeIOVector bufs iovecAddr msgSize ix = do
buf <- PM.newPinnedByteArray (csizeToInt msgSize)
PM.writeUnliftedArray bufs ix buf
S.pokeIOVectorBase iovecAddr (ptrToAddr (PM.mutableByteArrayContents buf))
S.pokeIOVectorLength iovecAddr msgSize
shrinkAndFreezeMessages ::
CSize
-> CInt
-> Int
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> IO (CInt,CUInt,UnliftedArray ByteArray)
shrinkAndFreezeMessages !bufSize !expSockAddrSize !n !bufs !mmsghdr0 = do
r <- PM.unsafeNewUnliftedArray n
go r 0 0 0 mmsghdr0
where
go !r !validation !ix !maxMsgSz !mmsghdr = if ix < n
then do
sz <- LST.peekMultipleMessageHeaderLength mmsghdr
sockaddrSz <- LST.peekMultipleMessageHeaderNameLength mmsghdr
buf <- PM.readUnliftedArray bufs ix
when (cuintToInt sz < csizeToInt bufSize) (shrinkMutableByteArray buf (cuintToInt sz))
PM.writeUnliftedArray r ix =<< PM.unsafeFreezeByteArray buf
go r (validation .|. (sockaddrSz - expSockAddrSize)) (ix + 1) (max maxMsgSz sz)
(plusAddr mmsghdr (cintToInt LST.sizeofMultipleMessageHeader))
else do
a <- PM.unsafeFreezeUnliftedArray r
pure (validation,maxMsgSz,a)
#endif
pokeMultipleMessageHeader :: Addr -> Addr -> CInt -> Addr -> CSize -> Addr -> CSize -> MessageFlags 'Receive -> CUInt -> IO ()
Addr
mmsgHdrAddr Addr
a CInt
b Addr
c CSize
d Addr
e CSize
f MessageFlags 'Receive
g CUInt
len = do
Addr -> Addr -> IO ()
LST.pokeMultipleMessageHeaderName Addr
mmsgHdrAddr Addr
a
Addr -> CInt -> IO ()
LST.pokeMultipleMessageHeaderNameLength Addr
mmsgHdrAddr CInt
b
Addr -> Addr -> IO ()
LST.pokeMultipleMessageHeaderIOVector Addr
mmsgHdrAddr Addr
c
Addr -> CSize -> IO ()
LST.pokeMultipleMessageHeaderIOVectorLength Addr
mmsgHdrAddr CSize
d
Addr -> Addr -> IO ()
LST.pokeMultipleMessageHeaderControl Addr
mmsgHdrAddr Addr
e
Addr -> CSize -> IO ()
LST.pokeMultipleMessageHeaderControlLength Addr
mmsgHdrAddr CSize
f
Addr -> MessageFlags 'Receive -> IO ()
LST.pokeMultipleMessageHeaderFlags Addr
mmsgHdrAddr MessageFlags 'Receive
g
Addr -> CUInt -> IO ()
LST.pokeMultipleMessageHeaderLength Addr
mmsgHdrAddr CUInt
len
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray MutableByteArray# RealWorld
arr) (I# Int#
sz) =
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
PM.primitive_ (forall d. MutableByteArray# d -> Int# -> State# d -> State# d
shrinkMutableByteArray# MutableByteArray# RealWorld
arr Int#
sz)
uninterruptibleAccept4 ::
Fd
-> CInt
-> SocketFlags
-> IO (Either Errno (CInt,SocketAddress,Fd))
{-# inline uninterruptibleAccept4 #-}
uninterruptibleAccept4 :: Fd
-> CInt
-> SocketFlags
-> IO (Either Errno (CInt, SocketAddress, Fd))
uninterruptibleAccept4 !Fd
sock !CInt
maxSz !SocketFlags
flags = do
sockAddrBuf :: MutableByteArray RealWorld
sockAddrBuf@(MutableByteArray MutableByteArray# RealWorld
sockAddrBuf#) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (CInt -> Int
cintToInt CInt
maxSz)
lenBuf :: MutableByteArray RealWorld
lenBuf@(MutableByteArray MutableByteArray# RealWorld
lenBuf#) <- forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (forall a. Prim a => a -> Int
PM.sizeOf (forall a. HasCallStack => a
undefined :: CInt))
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray RealWorld
lenBuf Int
0 CInt
maxSz
Fd
r <- Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> SocketFlags
-> IO Fd
c_unsafe_accept4 Fd
sock MutableByteArray# RealWorld
sockAddrBuf# MutableByteArray# RealWorld
lenBuf# SocketFlags
flags
if Fd
r forall a. Ord a => a -> a -> Bool
> (-Fd
1)
then do
(CInt
sz :: CInt) <- forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
PM.readByteArray MutableByteArray RealWorld
lenBuf Int
0
if CInt
sz forall a. Ord a => a -> a -> Bool
< CInt
maxSz
then MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray MutableByteArray RealWorld
sockAddrBuf (CInt -> Int
cintToInt CInt
sz)
else forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
ByteArray
sockAddr <- forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
sockAddrBuf
forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right (CInt
sz,ByteArray -> SocketAddress
SocketAddress ByteArray
sockAddr,Fd
r))
else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. a -> Either a b
Left IO Errno
getErrno
uninterruptibleAccept4_ ::
Fd
-> SocketFlags
-> IO (Either Errno Fd)
{-# inline uninterruptibleAccept4_ #-}
uninterruptibleAccept4_ :: Fd -> SocketFlags -> IO (Either Errno Fd)
uninterruptibleAccept4_ !Fd
sock !SocketFlags
flags = do
Fd
r <- Fd -> Ptr Void -> Ptr Void -> SocketFlags -> IO Fd
c_unsafe_ptr_accept4 Fd
sock forall a. Ptr a
nullPtr forall a. Ptr a
nullPtr SocketFlags
flags
if Fd
r forall a. Ord a => a -> a -> Bool
> (-Fd
1)
then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right Fd
r)
else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. a -> Either a b
Left IO Errno
getErrno
cintToInt :: CInt -> Int
cintToInt :: CInt -> Int
cintToInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral
cuintToInt :: CUInt -> Int
cuintToInt :: CUInt -> Int
cuintToInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral
csizeToInt :: CSize -> Int
csizeToInt :: CSize -> Int
csizeToInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral
cssizeToInt :: CSsize -> Int
cssizeToInt :: CSsize -> Int
cssizeToInt = forall a b. (Integral a, Num b) => a -> b
fromIntegral
errorsFromInt :: CInt -> IO (Either Errno CInt)
{-# inline errorsFromInt #-}
errorsFromInt :: CInt -> IO (Either Errno CInt)
errorsFromInt CInt
r = if CInt
r forall a. Ord a => a -> a -> Bool
> (-CInt
1)
then forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a b. b -> Either a b
Right CInt
r)
else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. a -> Either a b
Left IO Errno
getErrno
touchMutableByteArray :: MutableByteArray RealWorld -> IO ()
touchMutableByteArray :: MutableByteArray RealWorld -> IO ()
touchMutableByteArray (MutableByteArray MutableByteArray# RealWorld
x) = MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# MutableByteArray# RealWorld
x
touchMutableByteArray# :: MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# :: MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# MutableByteArray# RealWorld
x = forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s -> case touch# :: forall a. a -> State# RealWorld -> State# RealWorld
touch# MutableByteArray# RealWorld
x State# RealWorld
s of State# RealWorld
s' -> (# State# RealWorld
s', () #)
#if defined(UNLIFTEDARRAYFUNCTIONS)
touchMutableUnliftedArray :: MutableUnliftedArray RealWorld a -> IO ()
touchMutableUnliftedArray (MutableUnliftedArray x) = touchMutableUnliftedArray# x
touchMutableUnliftedArray# :: MutableUnliftedArray# RealWorld a -> IO ()
touchMutableUnliftedArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)
#endif