{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language MagicHash #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language UnliftedFFITypes #-}
module Posix.Socket
(
uninterruptibleSocket
, uninterruptibleSocketPair
, uninterruptibleBind
, connect
, uninterruptibleConnect
, uninterruptibleListen
, accept
, uninterruptibleAccept
, accept_
, uninterruptibleGetSocketName
, uninterruptibleGetSocketOption
, uninterruptibleSetSocketOptionInt
, close
, uninterruptibleClose
, uninterruptibleErrorlessClose
, uninterruptibleShutdown
, send
, sendByteArray
, sendMutableByteArray
, uninterruptibleSend
, uninterruptibleSendByteArray
, uninterruptibleSendMutableByteArray
, uninterruptibleSendToByteArray
, uninterruptibleSendToMutableByteArray
, writeVector
, receive
, receiveByteArray
, uninterruptibleReceive
, uninterruptibleReceiveMutableByteArray
, uninterruptibleReceiveFromMutableByteArray
, uninterruptibleReceiveFromMutableByteArray_
, uninterruptibleReceiveMessageA
, uninterruptibleReceiveMessageB
, hostToNetworkLong
, hostToNetworkShort
, networkToHostLong
, networkToHostShort
, Domain(..)
, Type(..)
, Protocol(..)
, OptionName(..)
, OptionValue(..)
, Level(..)
, Message(..)
, MessageFlags(..)
, ShutdownType(..)
, SocketAddress(..)
, PST.SocketAddressInternet(..)
, PST.SocketAddressUnix(..)
, PSP.encodeSocketAddressInternet
, PSP.encodeSocketAddressUnix
, PSP.decodeSocketAddressInternet
, PSP.indexSocketAddressInternet
, PSP.sizeofSocketAddressInternet
, PST.unix
, PST.unspecified
, PST.internet
, PST.internet6
, PST.stream
, PST.datagram
, PST.raw
, PST.sequencedPacket
, PST.defaultProtocol
, PST.rawProtocol
, PST.icmp
, PST.tcp
, PST.udp
, PST.ip
, PST.ipv6
, PST.peek
, PST.outOfBand
, PST.waitAll
, PST.noSignal
, PST.read
, PST.write
, PST.readWrite
, PST.levelSocket
, PST.optionError
, PST.broadcast
, PST.peekMessageHeaderName
, PST.peekMessageHeaderNameLength
, PST.peekMessageHeaderIOVector
, PST.peekMessageHeaderIOVectorLength
, PST.pokeMessageHeaderName
, PST.pokeMessageHeaderNameLength
, PST.pokeMessageHeaderIOVector
, PST.pokeMessageHeaderIOVectorLength
, PST.pokeMessageHeaderControl
, PST.pokeMessageHeaderControlLength
, PST.pokeMessageHeaderFlags
, PST.sizeofMessageHeader
, PST.peekIOVectorBase
, PST.peekIOVectorLength
, PST.pokeIOVectorBase
, PST.pokeIOVectorLength
, PST.sizeofIOVector
) where
import GHC.ByteOrder (ByteOrder(BigEndian,LittleEndian),targetByteOrder)
import GHC.IO (IO(..))
import Data.Primitive (MutablePrimArray(..),MutableByteArray(..),Addr(..),ByteArray(..))
import Data.Primitive (MutableUnliftedArray(..),UnliftedArray(..))
import Data.Word (Word16,Word32,byteSwap16,byteSwap32)
import Data.Void (Void)
import Foreign.C.Error (Errno,getErrno)
import Foreign.C.Types (CInt(..),CSize(..))
import Foreign.Ptr (nullPtr)
import GHC.Exts (Ptr,RealWorld,ByteArray#,MutableByteArray#,Addr#)
import GHC.Exts (ArrayArray#,MutableArrayArray#,Int(I#))
import GHC.Exts (shrinkMutableByteArray#,touch#)
import Posix.Socket.Types (Domain(..),Protocol(..),Type(..),SocketAddress(..))
import Posix.Socket.Types (MessageFlags(..),Message(..),ShutdownType(..))
import Posix.Socket.Types (Level(..),OptionName(..),OptionValue(..))
import System.Posix.Types (Fd(..),CSsize(..))
import qualified Posix.Socket.Types as PST
import qualified Data.Primitive as PM
import qualified Control.Monad.Primitive as PM
import qualified Posix.Socket.Platform as PSP
foreign import ccall unsafe "sys/socket.h socket"
c_socket :: Domain -> Type -> Protocol -> IO Fd
foreign import ccall unsafe "sys/socket.h socketpair"
c_socketpair :: Domain -> Type -> Protocol -> MutableByteArray# RealWorld -> IO CInt
foreign import ccall unsafe "sys/socket.h listen"
c_listen :: Fd -> CInt -> IO CInt
foreign import ccall safe "unistd.h close"
c_safe_close :: Fd -> IO CInt
foreign import ccall unsafe "unistd.h close"
c_unsafe_close :: Fd -> IO CInt
foreign import ccall unsafe "unistd.h shutdown"
c_unsafe_shutdown :: Fd -> ShutdownType -> IO CInt
foreign import ccall unsafe "sys/socket.h bind"
c_bind :: Fd -> ByteArray# -> CInt -> IO CInt
foreign import ccall safe "sys/socket.h accept"
c_safe_accept :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO Fd
foreign import ccall unsafe "sys/socket.h accept"
c_unsafe_accept :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO Fd
foreign import ccall safe "sys/socket.h accept"
c_safe_ptr_accept :: Fd -> Ptr Void -> Ptr CInt -> IO Fd
foreign import ccall unsafe "sys/socket.h getsockname"
c_unsafe_getsockname :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO CInt
foreign import ccall unsafe "sys/socket.h getsockopt"
c_unsafe_getsockopt :: Fd
-> Level
-> OptionName
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO CInt
foreign import ccall unsafe "sys/socket.h setsockopt_int"
c_unsafe_setsockopt_int :: Fd
-> Level
-> OptionName
-> CInt
-> IO CInt
foreign import ccall safe "sys/socket.h connect"
c_safe_connect :: Fd -> ByteArray# -> CInt -> IO CInt
foreign import ccall safe "sys/socket.h connect"
c_safe_mutablebytearray_connect :: Fd -> MutableByteArray# RealWorld -> CInt -> IO CInt
foreign import ccall unsafe "sys/socket.h connect"
c_unsafe_connect :: Fd -> ByteArray# -> CInt -> IO CInt
foreign import ccall safe "sys/socket.h send"
c_safe_addr_send :: Fd -> Addr# -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall safe "sys/socket.h send_offset"
c_safe_bytearray_send :: Fd -> ByteArray# -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall safe "sys/socket.h send_offset"
c_safe_mutablebytearray_send :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall safe "sys/socket.h send"
c_safe_mutablebytearray_no_offset_send :: Fd -> MutableByteArray# RealWorld -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h send"
c_unsafe_addr_send :: Fd -> Addr# -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h send_offset"
c_unsafe_bytearray_send :: Fd -> ByteArray# -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h send_offset"
c_unsafe_mutable_bytearray_send :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h sendto_offset"
c_unsafe_bytearray_sendto :: Fd -> ByteArray# -> CInt -> CSize -> MessageFlags 'Send -> ByteArray# -> CInt -> IO CSsize
foreign import ccall unsafe "sys/socket.h sendto_offset"
c_unsafe_mutable_bytearray_sendto :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Send -> ByteArray# -> CInt -> IO CSsize
foreign import ccall safe "sys/uio.h writev"
c_safe_writev :: Fd -> MutableByteArray# RealWorld -> CInt -> IO CSsize
foreign import ccall safe "sys/socket.h recv"
c_safe_addr_recv :: Fd -> Addr# -> CSize -> MessageFlags 'Receive -> IO CSsize
foreign import ccall unsafe "sys/socket.h recv"
c_unsafe_addr_recv :: Fd -> Addr# -> CSize -> MessageFlags 'Receive -> IO CSsize
foreign import ccall unsafe "sys/socket.h recv_offset"
c_unsafe_mutable_byte_array_recv :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Receive -> IO CSsize
foreign import ccall unsafe "sys/socket.h recvfrom_offset"
c_unsafe_mutable_byte_array_recvfrom :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Receive -> MutableByteArray# RealWorld -> MutableByteArray# RealWorld -> IO CSsize
foreign import ccall unsafe "sys/socket.h recvfrom_offset"
c_unsafe_mutable_byte_array_ptr_recvfrom :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Receive -> Ptr Void -> Ptr CInt -> IO CSsize
foreign import ccall unsafe "sys/socket.h recvmsg"
c_unsafe_addr_recvmsg :: Fd
-> Addr#
-> MessageFlags 'Receive
-> IO CSsize
uninterruptibleSocket ::
Domain
-> Type
-> Protocol
-> IO (Either Errno Fd)
uninterruptibleSocket dom typ prot = c_socket dom typ prot >>= errorsFromFd
uninterruptibleSocketPair ::
Domain
-> Type
-> Protocol
-> IO (Either Errno (Fd,Fd))
uninterruptibleSocketPair dom typ prot = do
(sockets@(MutablePrimArray sockets#) :: MutablePrimArray RealWorld Fd) <- PM.newPrimArray 2
r <- c_socketpair dom typ prot sockets#
if r == 0
then do
fd1 <- PM.readPrimArray sockets 0
fd2 <- PM.readPrimArray sockets 1
pure (Right (fd1,fd2))
else fmap Left getErrno
uninterruptibleBind ::
Fd
-> SocketAddress
-> IO (Either Errno ())
uninterruptibleBind fd (SocketAddress b@(ByteArray b#)) =
c_bind fd b# (intToCInt (PM.sizeofByteArray b)) >>= errorsFromInt
uninterruptibleListen ::
Fd
-> CInt
-> IO (Either Errno ())
uninterruptibleListen fd backlog = c_listen fd backlog >>= errorsFromInt
connect ::
Fd
-> SocketAddress
-> IO (Either Errno ())
connect fd (SocketAddress sockAddr@(ByteArray sockAddr#)) =
case PM.isByteArrayPinned sockAddr of
True -> c_safe_connect fd sockAddr# (intToCInt (PM.sizeofByteArray sockAddr)) >>= errorsFromInt
False -> do
let len = PM.sizeofByteArray sockAddr
x@(MutableByteArray x#) <- PM.newPinnedByteArray len
PM.copyByteArray x 0 sockAddr 0 len
c_safe_mutablebytearray_connect fd x# (intToCInt len) >>= errorsFromInt
uninterruptibleConnect ::
Fd
-> SocketAddress
-> IO (Either Errno ())
uninterruptibleConnect fd (SocketAddress sockAddr@(ByteArray sockAddr#)) =
c_unsafe_connect fd sockAddr# (intToCInt (PM.sizeofByteArray sockAddr)) >>= errorsFromInt
accept ::
Fd
-> CInt
-> IO (Either Errno (CInt,SocketAddress,Fd))
accept !sock !maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newPinnedByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newPinnedByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_safe_accept sock sockAddrBuf# lenBuf#
if r > (-1)
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
let minSz = min sz maxSz
x <- PM.newByteArray (cintToInt minSz)
PM.copyMutableByteArray x 0 sockAddrBuf 0 (cintToInt minSz)
sockAddr <- PM.unsafeFreezeByteArray x
pure (Right (sz,SocketAddress sockAddr,r))
else fmap Left getErrno
uninterruptibleAccept ::
Fd
-> CInt
-> IO (Either Errno (CInt,SocketAddress,Fd))
uninterruptibleAccept !sock !maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_accept sock sockAddrBuf# lenBuf#
if r > (-1)
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray sockAddrBuf (cintToInt sz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
pure (Right (sz,SocketAddress sockAddr,r))
else fmap Left getErrno
accept_ ::
Fd
-> IO (Either Errno Fd)
accept_ sock =
c_safe_ptr_accept sock nullPtr nullPtr >>= errorsFromFd
uninterruptibleGetSocketName ::
Fd
-> CInt
-> IO (Either Errno (CInt,SocketAddress))
uninterruptibleGetSocketName sock maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_getsockname sock sockAddrBuf# lenBuf#
if r == 0
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray sockAddrBuf (cintToInt sz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
pure (Right (sz,SocketAddress sockAddr))
else fmap Left getErrno
uninterruptibleGetSocketOption ::
Fd
-> Level
-> OptionName
-> CInt
-> IO (Either Errno (CInt,OptionValue))
uninterruptibleGetSocketOption sock level optName maxSz = do
valueBuf@(MutableByteArray valueBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_getsockopt sock level optName valueBuf# lenBuf#
if r == 0
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray valueBuf (cintToInt sz)
else pure ()
value <- PM.unsafeFreezeByteArray valueBuf
pure (Right (sz,OptionValue value))
else fmap Left getErrno
uninterruptibleSetSocketOptionInt ::
Fd
-> Level
-> OptionName
-> CInt
-> IO (Either Errno ())
uninterruptibleSetSocketOptionInt sock level optName optValue =
c_unsafe_setsockopt_int sock level optName optValue >>= errorsFromInt
sendByteArray ::
Fd
-> ByteArray
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
sendByteArray fd b@(ByteArray b#) off len flags = if PM.isByteArrayPinned b
then errorsFromSize =<< c_safe_bytearray_send fd b# off len flags
else do
x@(MutableByteArray x#) <- PM.newPinnedByteArray (csizeToInt len)
PM.copyByteArray x (cintToInt off) b 0 (csizeToInt len)
errorsFromSize =<< c_safe_mutablebytearray_no_offset_send fd x# len flags
writeVector ::
Fd
-> UnliftedArray ByteArray
-> IO (Either Errno CSize)
writeVector fd buffers = do
iovecs@(MutableByteArray iovecs#) :: MutableByteArray RealWorld <-
PM.newPinnedByteArray
(cintToInt PST.sizeofIOVector * PM.sizeofUnliftedArray buffers)
downward (PM.sizeofUnliftedArray buffers) $ \i -> do
buffer <- pinByteArray (PM.indexUnliftedArray buffers i)
let
targetAddr :: Addr
targetAddr =
PM.mutableByteArrayContents iovecs `PM.plusAddr`
(i * cintToInt PST.sizeofIOVector)
PST.pokeIOVectorBase targetAddr (PM.byteArrayContents buffer)
PST.pokeIOVectorLength targetAddr (intToCSize (PM.sizeofByteArray buffer))
r <- errorsFromSize =<<
c_safe_writev fd iovecs# (intToCInt (PM.sizeofUnliftedArray buffers))
touchUnliftedArray buffers
pure r
downward :: Int -> (Int -> IO a) -> IO ()
downward !hi f = go (hi - 1) where
go !ix = if ix >= 0
then f ix *> go (ix - 1)
else pure ()
pinByteArray :: ByteArray -> IO ByteArray
pinByteArray byteArray =
if PM.isByteArrayPinned byteArray
then
pure byteArray
else do
pinnedByteArray <- PM.newPinnedByteArray len
PM.copyByteArray pinnedByteArray 0 byteArray 0 len
PM.unsafeFreezeByteArray pinnedByteArray
where
len = PM.sizeofByteArray byteArray
sendMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
sendMutableByteArray fd b@(MutableByteArray b#) off len flags = if PM.isMutableByteArrayPinned b
then errorsFromSize =<< c_safe_mutablebytearray_send fd b# off len flags
else do
x@(MutableByteArray x#) <- PM.newPinnedByteArray (csizeToInt len)
PM.copyMutableByteArray x (cintToInt off) b 0 (csizeToInt len)
errorsFromSize =<< c_safe_mutablebytearray_no_offset_send fd x# len flags
send ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
send fd (Addr addr) len flags =
c_safe_addr_send fd addr len flags >>= errorsFromSize
uninterruptibleSend ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
uninterruptibleSend fd (Addr addr) len flags =
c_unsafe_addr_send fd addr len flags >>= errorsFromSize
uninterruptibleSendByteArray ::
Fd
-> ByteArray
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
uninterruptibleSendByteArray fd (ByteArray b) off len flags =
c_unsafe_bytearray_send fd b off len flags >>= errorsFromSize
uninterruptibleSendMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
uninterruptibleSendMutableByteArray fd (MutableByteArray b) off len flags =
c_unsafe_mutable_bytearray_send fd b off len flags >>= errorsFromSize
uninterruptibleSendToByteArray ::
Fd
-> ByteArray
-> CInt
-> CSize
-> MessageFlags 'Send
-> SocketAddress
-> IO (Either Errno CSize)
uninterruptibleSendToByteArray fd (ByteArray b) off len flags (SocketAddress a@(ByteArray a#)) =
c_unsafe_bytearray_sendto fd b off len flags a# (intToCInt (PM.sizeofByteArray a)) >>= errorsFromSize
uninterruptibleSendToMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Send
-> SocketAddress
-> IO (Either Errno CSize)
uninterruptibleSendToMutableByteArray fd (MutableByteArray b) off len flags (SocketAddress a@(ByteArray a#)) =
c_unsafe_mutable_bytearray_sendto fd b off len flags a# (intToCInt (PM.sizeofByteArray a)) >>= errorsFromSize
receive ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
receive fd (Addr addr) len flags =
c_safe_addr_recv fd addr len flags >>= errorsFromSize
receiveByteArray ::
Fd
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno ByteArray)
receiveByteArray !fd !len !flags = do
m <- PM.newPinnedByteArray (csizeToInt len)
let !(Addr addr) = PM.mutableByteArrayContents m
r <- c_safe_addr_recv fd addr len flags
if r /= (-1)
then do
let sz = cssizeToInt r
x <- PM.newByteArray sz
PM.copyMutableByteArray x 0 m 0 sz
a <- PM.unsafeFreezeByteArray x
pure (Right a)
else fmap Left getErrno
uninterruptibleReceive ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
uninterruptibleReceive !fd (Addr !addr) !len !flags =
c_unsafe_addr_recv fd addr len flags >>= errorsFromSize
uninterruptibleReceiveMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
uninterruptibleReceiveMutableByteArray !fd (MutableByteArray !b) !off !len !flags =
c_unsafe_mutable_byte_array_recv fd b off len flags >>= errorsFromSize
uninterruptibleReceiveFromMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Receive
-> CInt
-> IO (Either Errno (CInt,SocketAddress,CSize))
{-# INLINE uninterruptibleReceiveFromMutableByteArray #-}
uninterruptibleReceiveFromMutableByteArray !fd (MutableByteArray !b) !off !len !flags !maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_mutable_byte_array_recvfrom fd b off len flags sockAddrBuf# lenBuf#
if r > (-1)
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray sockAddrBuf (cintToInt sz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
pure (Right (sz,SocketAddress sockAddr,cssizeToCSize r))
else fmap Left getErrno
uninterruptibleReceiveFromMutableByteArray_ ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
uninterruptibleReceiveFromMutableByteArray_ !fd (MutableByteArray !b) !off !len !flags =
c_unsafe_mutable_byte_array_ptr_recvfrom fd b off len flags nullPtr nullPtr >>= errorsFromSize
uninterruptibleReceiveMessageA ::
Fd
-> CSize
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno (CSize,UnliftedArray ByteArray))
uninterruptibleReceiveMessageA !s !chunkSize !chunkCount !flags = do
bufs <- PM.unsafeNewUnliftedArray (csizeToInt chunkCount)
iovecsBuf <- PM.newPinnedByteArray (csizeToInt chunkCount * cintToInt PST.sizeofIOVector)
let iovecsAddr = PM.mutableByteArrayContents iovecsBuf
initializeIOVectors bufs iovecsAddr chunkSize chunkCount
msgHdrBuf <- PM.newPinnedByteArray (cintToInt PST.sizeofMessageHeader)
let !msgHdrAddr@(Addr msgHdrAddr#) = PM.mutableByteArrayContents msgHdrBuf
pokeMessageHeader msgHdrAddr PM.nullAddr 0 iovecsAddr chunkCount PM.nullAddr 0 flags
r <- c_unsafe_addr_recvmsg s msgHdrAddr# flags
if r > (-1)
then do
filled <- countAndShrinkIOVectors (csizeToInt chunkCount) (cssizeToInt r) (csizeToInt chunkSize) bufs
frozenBufs <- deepFreezeIOVectors filled bufs
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray msgHdrBuf
pure (Right (cssizeToCSize r,frozenBufs))
else do
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray msgHdrBuf
fmap Left getErrno
uninterruptibleReceiveMessageB ::
Fd
-> CSize
-> CSize
-> MessageFlags 'Receive
-> CInt
-> IO (Either Errno (CInt,SocketAddress,CSize,UnliftedArray ByteArray))
uninterruptibleReceiveMessageB !s !chunkSize !chunkCount !flags !maxSockAddrSz = do
sockAddrBuf <- PM.newPinnedByteArray (cintToInt maxSockAddrSz)
bufs <- PM.unsafeNewUnliftedArray (csizeToInt chunkCount)
iovecsBuf <- PM.newPinnedByteArray (csizeToInt chunkCount * cintToInt PST.sizeofIOVector)
let iovecsAddr = PM.mutableByteArrayContents iovecsBuf
initializeIOVectors bufs iovecsAddr chunkSize chunkCount
msgHdrBuf <- PM.newPinnedByteArray (cintToInt PST.sizeofMessageHeader)
let !msgHdrAddr@(Addr msgHdrAddr#) = PM.mutableByteArrayContents msgHdrBuf
pokeMessageHeader msgHdrAddr (PM.mutableByteArrayContents sockAddrBuf) maxSockAddrSz iovecsAddr chunkCount PM.nullAddr 0 flags
r <- c_unsafe_addr_recvmsg s msgHdrAddr# flags
if r > (-1)
then do
actualSockAddrSz <- PST.peekMessageHeaderNameLength msgHdrAddr
if actualSockAddrSz < maxSockAddrSz
then shrinkMutableByteArray sockAddrBuf (cintToInt actualSockAddrSz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
filled <- countAndShrinkIOVectors (csizeToInt chunkCount) (cssizeToInt r) (csizeToInt chunkSize) bufs
frozenBufs <- deepFreezeIOVectors filled bufs
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray msgHdrBuf
touchMutableByteArray sockAddrBuf
pure (Right (actualSockAddrSz,SocketAddress sockAddr,cssizeToCSize r,frozenBufs))
else do
touchMutableUnliftedArray bufs
touchMutableByteArray iovecsBuf
touchMutableByteArray msgHdrBuf
touchMutableByteArray sockAddrBuf
fmap Left getErrno
close ::
Fd
-> IO (Either Errno ())
close fd = c_safe_close fd >>= errorsFromInt
uninterruptibleClose ::
Fd
-> IO (Either Errno ())
uninterruptibleClose fd = c_unsafe_close fd >>= errorsFromInt
uninterruptibleErrorlessClose ::
Fd
-> IO ()
uninterruptibleErrorlessClose fd = do
_ <- c_unsafe_close fd
pure ()
uninterruptibleShutdown ::
Fd
-> ShutdownType
-> IO (Either Errno ())
uninterruptibleShutdown fd typ =
c_unsafe_shutdown fd typ >>= errorsFromInt
errorsFromSize :: CSsize -> IO (Either Errno CSize)
errorsFromSize r = if r > (-1)
then pure (Right (cssizeToCSize r))
else fmap Left getErrno
errorsFromFd :: Fd -> IO (Either Errno Fd)
errorsFromFd r = if r > (-1)
then pure (Right r)
else fmap Left getErrno
errorsFromInt :: CInt -> IO (Either Errno ())
errorsFromInt r = if r == 0
then pure (Right ())
else fmap Left getErrno
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
cintToInt :: CInt -> Int
cintToInt = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
cssizeToInt :: CSsize -> Int
cssizeToInt = fromIntegral
cssizeToCSize :: CSsize -> CSize
cssizeToCSize = fromIntegral
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)
hostToNetworkShort :: Word16 -> Word16
hostToNetworkShort = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap16
networkToHostShort :: Word16 -> Word16
networkToHostShort = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap16
hostToNetworkLong :: Word32 -> Word32
hostToNetworkLong = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap32
networkToHostLong :: Word32 -> Word32
networkToHostLong = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap32
pokeMessageHeader :: Addr -> Addr -> CInt -> Addr -> CSize -> Addr -> CSize -> MessageFlags 'Receive -> IO ()
pokeMessageHeader msgHdrAddr a b c d e f g = do
PST.pokeMessageHeaderName msgHdrAddr a
PST.pokeMessageHeaderNameLength msgHdrAddr b
PST.pokeMessageHeaderIOVector msgHdrAddr c
PST.pokeMessageHeaderIOVectorLength msgHdrAddr d
PST.pokeMessageHeaderControl msgHdrAddr e
PST.pokeMessageHeaderControlLength msgHdrAddr f
PST.pokeMessageHeaderFlags msgHdrAddr g
initializeIOVectors ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> CSize
-> CSize
-> IO ()
initializeIOVectors bufs iovecsAddr chunkSize chunkCount =
let go !ix !iovecAddr = if ix < csizeToInt chunkCount
then do
initializeIOVector bufs iovecAddr chunkSize ix
go (ix + 1) (PM.plusAddr iovecAddr (cintToInt PST.sizeofIOVector))
else pure ()
in go 0 iovecsAddr
initializeIOVector ::
MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Addr
-> CSize
-> Int
-> IO ()
initializeIOVector bufs iovecAddr chunkSize ix = do
buf <- PM.newPinnedByteArray (csizeToInt chunkSize)
PM.writeUnliftedArray bufs ix buf
PST.pokeIOVectorBase iovecAddr (PM.mutableByteArrayContents buf)
PST.pokeIOVectorLength iovecAddr chunkSize
countAndShrinkIOVectors ::
Int
-> Int
-> Int
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> IO Int
countAndShrinkIOVectors !n !totalUsedSz !maxBufSz !bufs = go 0 totalUsedSz where
go !ix !remainingBytes = if ix < n
then if remainingBytes >= maxBufSz
then go
(ix + 1)
(remainingBytes - maxBufSz)
else if remainingBytes == 0
then pure ix
else do
buf <- PM.readUnliftedArray bufs ix
shrinkMutableByteArray buf remainingBytes
pure (ix + 1)
else pure ix
deepFreezeIOVectors ::
Int
-> MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> IO (UnliftedArray ByteArray)
deepFreezeIOVectors n m = do
x <- PM.unsafeNewUnliftedArray n
let go !ix = if ix < n
then do
PM.writeUnliftedArray x ix =<< PM.unsafeFreezeByteArray =<< PM.readUnliftedArray m ix
go (ix + 1)
else PM.unsafeFreezeUnliftedArray x
go 0
touchMutableUnliftedArray :: MutableUnliftedArray RealWorld a -> IO ()
touchMutableUnliftedArray (MutableUnliftedArray x) = touchMutableUnliftedArray# x
touchUnliftedArray :: UnliftedArray a -> IO ()
touchUnliftedArray (UnliftedArray x) = touchUnliftedArray# x
touchMutableByteArray :: MutableByteArray RealWorld -> IO ()
touchMutableByteArray (MutableByteArray x) = touchMutableByteArray# x
touchMutableUnliftedArray# :: MutableArrayArray# RealWorld -> IO ()
touchMutableUnliftedArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)
touchUnliftedArray# :: ArrayArray# -> IO ()
touchUnliftedArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)
touchMutableByteArray# :: MutableByteArray# RealWorld -> IO ()
touchMutableByteArray# x = IO $ \s -> case touch# x s of s' -> (# s', () #)