{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language MagicHash #-}
{-# language ScopedTypeVariables #-}
{-# language UnboxedTuples #-}
{-# language UnliftedFFITypes #-}
module Posix.Socket
(
uninterruptibleSocket
, uninterruptibleSocketPair
, uninterruptibleBind
, connect
, uninterruptibleConnect
, uninterruptibleListen
, accept
, uninterruptibleAccept
, accept_
, uninterruptibleGetSocketName
, uninterruptibleGetSocketOption
, close
, uninterruptibleClose
, uninterruptibleErrorlessClose
, uninterruptibleShutdown
, send
, sendByteArray
, sendMutableByteArray
, uninterruptibleSend
, uninterruptibleSendByteArray
, uninterruptibleSendMutableByteArray
, uninterruptibleSendToByteArray
, uninterruptibleSendToMutableByteArray
, receive
, receiveByteArray
, uninterruptibleReceive
, uninterruptibleReceiveMutableByteArray
, uninterruptibleReceiveFromMutableByteArray
, uninterruptibleReceiveFromMutableByteArray_
, hostToNetworkLong
, hostToNetworkShort
, networkToHostLong
, networkToHostShort
, Domain(..)
, Type(..)
, Protocol(..)
, OptionName(..)
, OptionValue(..)
, Level(..)
, MessageFlags(..)
, ShutdownType(..)
, SocketAddress(..)
, PST.SocketAddressInternet(..)
, PST.SocketAddressUnix(..)
, PSP.encodeSocketAddressInternet
, PSP.encodeSocketAddressUnix
, PSP.decodeSocketAddressInternet
, PSP.sizeofSocketAddressInternet
, PST.unix
, PST.unspecified
, PST.internet
, PST.internet6
, PST.stream
, PST.datagram
, PST.raw
, PST.sequencedPacket
, PST.defaultProtocol
, PST.rawProtocol
, PST.icmp
, PST.tcp
, PST.udp
, PST.ip
, PST.ipv6
, PST.peek
, PST.outOfBand
, PST.waitAll
, PST.read
, PST.write
, PST.readWrite
, PST.levelSocket
, PST.optionError
) where
import GHC.ByteOrder (ByteOrder(BigEndian,LittleEndian),targetByteOrder)
import GHC.IO (IO(..))
import Data.Primitive (MutablePrimArray(..),MutableByteArray(..),Addr(..),ByteArray(..))
import Data.Word (Word16,Word32,byteSwap16,byteSwap32)
import Data.Void (Void)
import Foreign.C.Error (Errno,getErrno)
import Foreign.C.Types (CInt(..),CSize(..))
import Foreign.Ptr (nullPtr)
import GHC.Exts (Ptr,RealWorld,ByteArray#,MutableByteArray#,Addr#,Int(I#))
import GHC.Exts (shrinkMutableByteArray#)
import Posix.Socket.Types (Domain(..),Protocol(..),Type(..),SocketAddress(..))
import Posix.Socket.Types (MessageFlags(..),Message(..),ShutdownType(..))
import Posix.Socket.Types (Level(..),OptionName(..),OptionValue(..))
import System.Posix.Types (Fd(..),CSsize(..))
import qualified Posix.Socket.Types as PST
import qualified Data.Primitive as PM
import qualified Control.Monad.Primitive as PM
import qualified Posix.Socket.Platform as PSP
foreign import ccall unsafe "sys/socket.h socket"
c_socket :: Domain -> Type -> Protocol -> IO Fd
foreign import ccall unsafe "sys/socket.h socketpair"
c_socketpair :: Domain -> Type -> Protocol -> MutableByteArray# RealWorld -> IO CInt
foreign import ccall unsafe "sys/socket.h listen"
c_listen :: Fd -> CInt -> IO CInt
foreign import ccall safe "unistd.h close"
c_safe_close :: Fd -> IO CInt
foreign import ccall unsafe "unistd.h close"
c_unsafe_close :: Fd -> IO CInt
foreign import ccall unsafe "unistd.h shutdown"
c_unsafe_shutdown :: Fd -> ShutdownType -> IO CInt
foreign import ccall unsafe "sys/socket.h bind"
c_bind :: Fd -> ByteArray# -> CInt -> IO CInt
foreign import ccall safe "sys/socket.h accept"
c_safe_accept :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO Fd
foreign import ccall unsafe "sys/socket.h accept"
c_unsafe_accept :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO Fd
foreign import ccall safe "sys/socket.h accept"
c_safe_ptr_accept :: Fd -> Ptr Void -> Ptr CInt -> IO Fd
foreign import ccall unsafe "sys/socket.h getsockname"
c_unsafe_getsockname :: Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO CInt
foreign import ccall unsafe "sys/socket.h getsockopt"
c_unsafe_getsockopt :: Fd
-> Level
-> OptionName
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> IO CInt
foreign import ccall safe "sys/socket.h connect"
c_safe_connect :: Fd -> ByteArray# -> CInt -> IO CInt
foreign import ccall safe "sys/socket.h connect"
c_safe_mutablebytearray_connect :: Fd -> MutableByteArray# RealWorld -> CInt -> IO CInt
foreign import ccall unsafe "sys/socket.h connect"
c_unsafe_connect :: Fd -> ByteArray# -> CInt -> IO CInt
foreign import ccall safe "sys/socket.h send"
c_safe_addr_send :: Fd -> Addr# -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall safe "sys/socket.h send_offset"
c_safe_bytearray_send :: Fd -> ByteArray# -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall safe "sys/socket.h send_offset"
c_safe_mutablebytearray_send :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall safe "sys/socket.h send"
c_safe_mutablebytearray_no_offset_send :: Fd -> MutableByteArray# RealWorld -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h send"
c_unsafe_addr_send :: Fd -> Addr# -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h send_offset"
c_unsafe_bytearray_send :: Fd -> ByteArray# -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h send_offset"
c_unsafe_mutable_bytearray_send :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Send -> IO CSsize
foreign import ccall unsafe "sys/socket.h sendto_offset"
c_unsafe_bytearray_sendto :: Fd -> ByteArray# -> CInt -> CSize -> MessageFlags 'Send -> ByteArray# -> CInt -> IO CSsize
foreign import ccall unsafe "sys/socket.h sendto_offset"
c_unsafe_mutable_bytearray_sendto :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Send -> ByteArray# -> CInt -> IO CSsize
foreign import ccall safe "sys/socket.h recv"
c_safe_addr_recv :: Fd -> Addr# -> CSize -> MessageFlags 'Receive -> IO CSsize
foreign import ccall unsafe "sys/socket.h recv"
c_unsafe_addr_recv :: Fd -> Addr# -> CSize -> MessageFlags 'Receive -> IO CSsize
foreign import ccall unsafe "sys/socket.h recv_offset"
c_unsafe_mutable_byte_array_recv :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Receive -> IO CSsize
foreign import ccall unsafe "sys/socket.h recvfrom_offset"
c_unsafe_mutable_byte_array_recvfrom :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Receive -> MutableByteArray# RealWorld -> MutableByteArray# RealWorld -> IO CSsize
foreign import ccall unsafe "sys/socket.h recvfrom_offset"
c_unsafe_mutable_byte_array_ptr_recvfrom :: Fd -> MutableByteArray# RealWorld -> CInt -> CSize -> MessageFlags 'Receive -> Ptr Void -> Ptr CInt -> IO CSsize
uninterruptibleSocket ::
Domain
-> Type
-> Protocol
-> IO (Either Errno Fd)
uninterruptibleSocket dom typ prot = c_socket dom typ prot >>= errorsFromFd
uninterruptibleSocketPair ::
Domain
-> Type
-> Protocol
-> IO (Either Errno (Fd,Fd))
uninterruptibleSocketPair dom typ prot = do
(sockets@(MutablePrimArray sockets#) :: MutablePrimArray RealWorld Fd) <- PM.newPrimArray 2
r <- c_socketpair dom typ prot sockets#
if r == 0
then do
fd1 <- PM.readPrimArray sockets 0
fd2 <- PM.readPrimArray sockets 1
pure (Right (fd1,fd2))
else fmap Left getErrno
uninterruptibleBind ::
Fd
-> SocketAddress
-> IO (Either Errno ())
uninterruptibleBind fd (SocketAddress b@(ByteArray b#)) =
c_bind fd b# (intToCInt (PM.sizeofByteArray b)) >>= errorsFromInt
uninterruptibleListen ::
Fd
-> CInt
-> IO (Either Errno ())
uninterruptibleListen fd backlog = c_listen fd backlog >>= errorsFromInt
connect ::
Fd
-> SocketAddress
-> IO (Either Errno ())
connect fd (SocketAddress sockAddr@(ByteArray sockAddr#)) =
case PM.isByteArrayPinned sockAddr of
True -> c_safe_connect fd sockAddr# (intToCInt (PM.sizeofByteArray sockAddr)) >>= errorsFromInt
False -> do
let len = PM.sizeofByteArray sockAddr
x@(MutableByteArray x#) <- PM.newPinnedByteArray len
PM.copyByteArray x 0 sockAddr 0 len
c_safe_mutablebytearray_connect fd x# (intToCInt len) >>= errorsFromInt
uninterruptibleConnect ::
Fd
-> SocketAddress
-> IO (Either Errno ())
uninterruptibleConnect fd (SocketAddress sockAddr@(ByteArray sockAddr#)) =
c_unsafe_connect fd sockAddr# (intToCInt (PM.sizeofByteArray sockAddr)) >>= errorsFromInt
accept ::
Fd
-> CInt
-> IO (Either Errno (CInt,SocketAddress,Fd))
accept !sock !maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newPinnedByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newPinnedByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_safe_accept sock sockAddrBuf# lenBuf#
if r > (-1)
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
let minSz = min sz maxSz
x <- PM.newByteArray (cintToInt minSz)
PM.copyMutableByteArray x 0 sockAddrBuf 0 (cintToInt minSz)
sockAddr <- PM.unsafeFreezeByteArray x
pure (Right (sz,SocketAddress sockAddr,r))
else fmap Left getErrno
uninterruptibleAccept ::
Fd
-> CInt
-> IO (Either Errno (CInt,SocketAddress,Fd))
uninterruptibleAccept !sock !maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_accept sock sockAddrBuf# lenBuf#
if r > (-1)
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray sockAddrBuf (cintToInt sz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
pure (Right (sz,SocketAddress sockAddr,r))
else fmap Left getErrno
accept_ ::
Fd
-> IO (Either Errno Fd)
accept_ sock =
c_safe_ptr_accept sock nullPtr nullPtr >>= errorsFromFd
uninterruptibleGetSocketName ::
Fd
-> CInt
-> IO (Either Errno (CInt,SocketAddress))
uninterruptibleGetSocketName sock maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_getsockname sock sockAddrBuf# lenBuf#
if r == 0
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray sockAddrBuf (cintToInt sz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
pure (Right (sz,SocketAddress sockAddr))
else fmap Left getErrno
uninterruptibleGetSocketOption ::
Fd
-> Level
-> OptionName
-> CInt
-> IO (Either Errno (CInt,OptionValue))
uninterruptibleGetSocketOption sock level optName maxSz = do
valueBuf@(MutableByteArray valueBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_getsockopt sock level optName valueBuf# lenBuf#
if r == 0
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray valueBuf (cintToInt sz)
else pure ()
value <- PM.unsafeFreezeByteArray valueBuf
pure (Right (sz,OptionValue value))
else fmap Left getErrno
sendByteArray ::
Fd
-> ByteArray
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
sendByteArray fd b@(ByteArray b#) off len flags = if PM.isByteArrayPinned b
then errorsFromSize =<< c_safe_bytearray_send fd b# off len flags
else do
x@(MutableByteArray x#) <- PM.newPinnedByteArray (csizeToInt len)
PM.copyByteArray x (cintToInt off) b 0 (csizeToInt len)
errorsFromSize =<< c_safe_mutablebytearray_no_offset_send fd x# len flags
sendMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
sendMutableByteArray fd b@(MutableByteArray b#) off len flags = if PM.isMutableByteArrayPinned b
then errorsFromSize =<< c_safe_mutablebytearray_send fd b# off len flags
else do
x@(MutableByteArray x#) <- PM.newPinnedByteArray (csizeToInt len)
PM.copyMutableByteArray x (cintToInt off) b 0 (csizeToInt len)
errorsFromSize =<< c_safe_mutablebytearray_no_offset_send fd x# len flags
send ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
send fd (Addr addr) len flags =
c_safe_addr_send fd addr len flags >>= errorsFromSize
uninterruptibleSend ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
uninterruptibleSend fd (Addr addr) len flags =
c_unsafe_addr_send fd addr len flags >>= errorsFromSize
uninterruptibleSendByteArray ::
Fd
-> ByteArray
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
uninterruptibleSendByteArray fd (ByteArray b) off len flags =
c_unsafe_bytearray_send fd b off len flags >>= errorsFromSize
uninterruptibleSendMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Send
-> IO (Either Errno CSize)
uninterruptibleSendMutableByteArray fd (MutableByteArray b) off len flags =
c_unsafe_mutable_bytearray_send fd b off len flags >>= errorsFromSize
uninterruptibleSendToByteArray ::
Fd
-> ByteArray
-> CInt
-> CSize
-> MessageFlags 'Send
-> SocketAddress
-> IO (Either Errno CSize)
uninterruptibleSendToByteArray fd (ByteArray b) off len flags (SocketAddress a@(ByteArray a#)) =
c_unsafe_bytearray_sendto fd b off len flags a# (intToCInt (PM.sizeofByteArray a)) >>= errorsFromSize
uninterruptibleSendToMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Send
-> SocketAddress
-> IO (Either Errno CSize)
uninterruptibleSendToMutableByteArray fd (MutableByteArray b) off len flags (SocketAddress a@(ByteArray a#)) =
c_unsafe_mutable_bytearray_sendto fd b off len flags a# (intToCInt (PM.sizeofByteArray a)) >>= errorsFromSize
receive ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
receive fd (Addr addr) len flags =
c_safe_addr_recv fd addr len flags >>= errorsFromSize
receiveByteArray ::
Fd
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno ByteArray)
receiveByteArray !fd !len !flags = do
m <- PM.newPinnedByteArray (csizeToInt len)
let !(Addr addr) = PM.mutableByteArrayContents m
r <- c_safe_addr_recv fd addr len flags
if r /= (-1)
then do
let sz = cssizeToInt r
x <- PM.newByteArray sz
PM.copyMutableByteArray x 0 m 0 sz
a <- PM.unsafeFreezeByteArray x
pure (Right a)
else fmap Left getErrno
uninterruptibleReceive ::
Fd
-> Addr
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
uninterruptibleReceive !fd (Addr !addr) !len !flags =
c_unsafe_addr_recv fd addr len flags >>= errorsFromSize
uninterruptibleReceiveMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
uninterruptibleReceiveMutableByteArray !fd (MutableByteArray !b) !off !len !flags =
c_unsafe_mutable_byte_array_recv fd b off len flags >>= errorsFromSize
uninterruptibleReceiveFromMutableByteArray ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Receive
-> CInt
-> IO (Either Errno (CInt,SocketAddress,CSize))
{-# INLINE uninterruptibleReceiveFromMutableByteArray #-}
uninterruptibleReceiveFromMutableByteArray !fd (MutableByteArray !b) !off !len !flags !maxSz = do
sockAddrBuf@(MutableByteArray sockAddrBuf#) <- PM.newByteArray (cintToInt maxSz)
lenBuf@(MutableByteArray lenBuf#) <- PM.newByteArray (PM.sizeOf (undefined :: CInt))
PM.writeByteArray lenBuf 0 maxSz
r <- c_unsafe_mutable_byte_array_recvfrom fd b off len flags sockAddrBuf# lenBuf#
if r > (-1)
then do
(sz :: CInt) <- PM.readByteArray lenBuf 0
if sz < maxSz
then shrinkMutableByteArray sockAddrBuf (cintToInt sz)
else pure ()
sockAddr <- PM.unsafeFreezeByteArray sockAddrBuf
pure (Right (sz,SocketAddress sockAddr,cssizeToCSize r))
else fmap Left getErrno
uninterruptibleReceiveFromMutableByteArray_ ::
Fd
-> MutableByteArray RealWorld
-> CInt
-> CSize
-> MessageFlags 'Receive
-> IO (Either Errno CSize)
uninterruptibleReceiveFromMutableByteArray_ !fd (MutableByteArray !b) !off !len !flags =
c_unsafe_mutable_byte_array_ptr_recvfrom fd b off len flags nullPtr nullPtr >>= errorsFromSize
close ::
Fd
-> IO (Either Errno ())
close fd = c_safe_close fd >>= errorsFromInt
uninterruptibleClose ::
Fd
-> IO (Either Errno ())
uninterruptibleClose fd = c_unsafe_close fd >>= errorsFromInt
uninterruptibleErrorlessClose ::
Fd
-> IO ()
uninterruptibleErrorlessClose fd = do
_ <- c_unsafe_close fd
pure ()
uninterruptibleShutdown ::
Fd
-> ShutdownType
-> IO (Either Errno ())
uninterruptibleShutdown fd typ =
c_unsafe_shutdown fd typ >>= errorsFromInt
errorsFromSize :: CSsize -> IO (Either Errno CSize)
errorsFromSize r = if r > (-1)
then pure (Right (cssizeToCSize r))
else fmap Left getErrno
errorsFromFd :: Fd -> IO (Either Errno Fd)
errorsFromFd r = if r > (-1)
then pure (Right r)
else fmap Left getErrno
errorsFromInt :: CInt -> IO (Either Errno ())
errorsFromInt r = if r == 0
then pure (Right ())
else fmap Left getErrno
intToCInt :: Int -> CInt
intToCInt = fromIntegral
cintToInt :: CInt -> Int
cintToInt = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
cssizeToInt :: CSsize -> Int
cssizeToInt = fromIntegral
cssizeToCSize :: CSsize -> CSize
cssizeToCSize = fromIntegral
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)
hostToNetworkShort :: Word16 -> Word16
hostToNetworkShort = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap16
networkToHostShort :: Word16 -> Word16
networkToHostShort = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap16
hostToNetworkLong :: Word32 -> Word32
hostToNetworkLong = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap32
networkToHostLong :: Word32 -> Word32
networkToHostLong = case targetByteOrder of
BigEndian -> id
LittleEndian -> byteSwap32