{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnliftedFFITypes #-}
module Linux.Socket
(
uninterruptibleAccept4
, uninterruptibleAccept4_
, SocketFlags (..)
, LST.headerInclude
, LST.dontWait
, LST.truncate
, LST.controlTruncate
, LST.closeOnExec
, LST.nonblocking
, applySocketFlags
, LST.sizeofUdpHeader
, LST.pokeUdpHeaderSourcePort
, LST.pokeUdpHeaderDestinationPort
, LST.pokeUdpHeaderLength
, LST.pokeUdpHeaderChecksum
, LST.sizeofIpHeader
, LST.pokeIpHeaderVersionIhl
, LST.pokeIpHeaderTypeOfService
, LST.pokeIpHeaderTotalLength
, LST.pokeIpHeaderIdentifier
, LST.pokeIpHeaderFragmentOffset
, LST.pokeIpHeaderTimeToLive
, LST.pokeIpHeaderProtocol
, LST.pokeIpHeaderChecksum
, LST.pokeIpHeaderSourceAddress
, LST.pokeIpHeaderDestinationAddress
) where
import Prelude hiding (truncate)
import Data.Bits ((.|.))
import Data.Primitive (MutableByteArray (..))
import Data.Void (Void)
import Foreign.C.Error (Errno, getErrno)
import Foreign.C.Types (CInt (..))
import Foreign.Ptr (nullPtr)
import GHC.Exts (Int (I#), MutableByteArray#, Ptr (..), RealWorld, shrinkMutableByteArray#)
import Linux.Socket.Types (SocketFlags (..))
import Posix.Socket (SocketAddress (..), Type (..))
import System.Posix.Types (Fd (..))
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Linux.Socket.Types as LST
foreign import ccall unsafe "sys/socket.h accept4"
c_unsafe_accept4 ::
Fd ->
MutableByteArray# RealWorld ->
MutableByteArray# RealWorld ->
SocketFlags ->
IO Fd
foreign import ccall unsafe "sys/socket.h accept4"
c_unsafe_ptr_accept4 ::
Fd ->
Ptr Void ->
Ptr Void ->
SocketFlags ->
IO Fd
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags (SocketFlags CInt
s) (Type CInt
t) = CInt -> Type
Type (CInt
s CInt -> CInt -> CInt
forall a. Bits a => a -> a -> a
.|. CInt
t)
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray MutableByteArray# RealWorld
arr) (I# Int#
sz) =
(State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
PM.primitive_ (MutableByteArray# RealWorld
-> Int# -> State# RealWorld -> State# RealWorld
forall d. MutableByteArray# d -> Int# -> State# d -> State# d
shrinkMutableByteArray# MutableByteArray# RealWorld
arr Int#
sz)
uninterruptibleAccept4 ::
Fd ->
CInt ->
SocketFlags ->
IO (Either Errno (CInt, SocketAddress, Fd))
{-# INLINE uninterruptibleAccept4 #-}
uninterruptibleAccept4 :: Fd
-> CInt
-> SocketFlags
-> IO (Either Errno (CInt, SocketAddress, Fd))
uninterruptibleAccept4 !Fd
sock !CInt
maxSz !SocketFlags
flags = do
sockAddrBuf :: MutableByteArray RealWorld
sockAddrBuf@(MutableByteArray MutableByteArray# RealWorld
sockAddrBuf#) <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (CInt -> Int
cintToInt CInt
maxSz)
lenBuf :: MutableByteArray RealWorld
lenBuf@(MutableByteArray MutableByteArray# RealWorld
lenBuf#) <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (CInt -> Int
forall a. Prim a => a -> Int
PM.sizeOf (CInt
forall a. HasCallStack => a
undefined :: CInt))
MutableByteArray (PrimState IO) -> Int -> CInt -> IO ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
lenBuf Int
0 CInt
maxSz
Fd
r <- Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> SocketFlags
-> IO Fd
c_unsafe_accept4 Fd
sock MutableByteArray# RealWorld
sockAddrBuf# MutableByteArray# RealWorld
lenBuf# SocketFlags
flags
if Fd
r Fd -> Fd -> Bool
forall a. Ord a => a -> a -> Bool
> (-Fd
1)
then do
(CInt
sz :: CInt) <- MutableByteArray (PrimState IO) -> Int -> IO CInt
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
PM.readByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
lenBuf Int
0
if CInt
sz CInt -> CInt -> Bool
forall a. Ord a => a -> a -> Bool
< CInt
maxSz
then MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray MutableByteArray RealWorld
sockAddrBuf (CInt -> Int
cintToInt CInt
sz)
else () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
ByteArray
sockAddr <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
sockAddrBuf
Either Errno (CInt, SocketAddress, Fd)
-> IO (Either Errno (CInt, SocketAddress, Fd))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((CInt, SocketAddress, Fd) -> Either Errno (CInt, SocketAddress, Fd)
forall a b. b -> Either a b
Right (CInt
sz, ByteArray -> SocketAddress
SocketAddress ByteArray
sockAddr, Fd
r))
else (Errno -> Either Errno (CInt, SocketAddress, Fd))
-> IO Errno -> IO (Either Errno (CInt, SocketAddress, Fd))
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Errno -> Either Errno (CInt, SocketAddress, Fd)
forall a b. a -> Either a b
Left IO Errno
getErrno
uninterruptibleAccept4_ ::
Fd ->
SocketFlags ->
IO (Either Errno Fd)
{-# INLINE uninterruptibleAccept4_ #-}
uninterruptibleAccept4_ :: Fd -> SocketFlags -> IO (Either Errno Fd)
uninterruptibleAccept4_ !Fd
sock !SocketFlags
flags = do
Fd
r <- Fd -> Ptr Void -> Ptr Void -> SocketFlags -> IO Fd
c_unsafe_ptr_accept4 Fd
sock Ptr Void
forall a. Ptr a
nullPtr Ptr Void
forall a. Ptr a
nullPtr SocketFlags
flags
if Fd
r Fd -> Fd -> Bool
forall a. Ord a => a -> a -> Bool
> (-Fd
1)
then Either Errno Fd -> IO (Either Errno Fd)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fd -> Either Errno Fd
forall a b. b -> Either a b
Right Fd
r)
else (Errno -> Either Errno Fd) -> IO Errno -> IO (Either Errno Fd)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Errno -> Either Errno Fd
forall a b. a -> Either a b
Left IO Errno
getErrno
cintToInt :: CInt -> Int
cintToInt :: CInt -> Int
cintToInt = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral