{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnliftedFFITypes #-}

module Linux.Socket
  ( -- * Functions
    uninterruptibleAccept4
  , uninterruptibleAccept4_

    -- * Types
  , SocketFlags (..)

    -- * Option Names
  , LST.headerInclude

    -- * Message Flags
  , LST.dontWait
  , LST.truncate
  , LST.controlTruncate

    -- * Socket Flags
  , LST.closeOnExec
  , LST.nonblocking

    -- * Twiddle
  , applySocketFlags

    -- * UDP Header
  , LST.sizeofUdpHeader
  , LST.pokeUdpHeaderSourcePort
  , LST.pokeUdpHeaderDestinationPort
  , LST.pokeUdpHeaderLength
  , LST.pokeUdpHeaderChecksum

    -- * IPv4 Header
  , LST.sizeofIpHeader
  , LST.pokeIpHeaderVersionIhl
  , LST.pokeIpHeaderTypeOfService
  , LST.pokeIpHeaderTotalLength
  , LST.pokeIpHeaderIdentifier
  , LST.pokeIpHeaderFragmentOffset
  , LST.pokeIpHeaderTimeToLive
  , LST.pokeIpHeaderProtocol
  , LST.pokeIpHeaderChecksum
  , LST.pokeIpHeaderSourceAddress
  , LST.pokeIpHeaderDestinationAddress
  ) where

import Prelude hiding (truncate)

import Data.Bits ((.|.))
import Data.Primitive (MutableByteArray (..))
import Data.Void (Void)
import Foreign.C.Error (Errno, getErrno)
import Foreign.C.Types (CInt (..))
import Foreign.Ptr (nullPtr)
import GHC.Exts (Int (I#), MutableByteArray#, Ptr (..), RealWorld, shrinkMutableByteArray#)
import Linux.Socket.Types (SocketFlags (..))
import Posix.Socket (SocketAddress (..), Type (..))
import System.Posix.Types (Fd (..))

import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Linux.Socket.Types as LST

foreign import ccall unsafe "sys/socket.h accept4"
  c_unsafe_accept4 ::
    Fd ->
    MutableByteArray# RealWorld -> -- SocketAddress
    MutableByteArray# RealWorld -> -- Ptr CInt
    SocketFlags ->
    IO Fd

-- Variant of c_unsafe_ptr_accept4 that uses Ptr instead of MutableByteArray.
-- Currently, we expect that the two pointers are set to NULL.
-- This is only used internally.
foreign import ccall unsafe "sys/socket.h accept4"
  c_unsafe_ptr_accept4 ::
    Fd ->
    Ptr Void -> -- SocketAddress
    Ptr Void -> -- Ptr CInt
    SocketFlags ->
    IO Fd

{- | Linux extends the @type@ argument of
  <http://man7.org/linux/man-pages/man2/socket.2.html socket> to allow
  setting two socket flags on socket creation: @SOCK_CLOEXEC@ and
  @SOCK_NONBLOCK@. It is advisable to set @SOCK_CLOEXEC@ on when
  opening a socket on linux. For example, we may open a TCP Internet
  socket with:

  > uninterruptibleSocket internet (applySocketFlags closeOnExec stream) defaultProtocol

  To additionally open the socket in nonblocking mode
  (e.g. with @SOCK_NONBLOCK@):

  > uninterruptibleSocket internet (applySocketFlags (closeOnExec <> nonblocking) stream) defaultProtocol
-}
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags :: SocketFlags -> Type -> Type
applySocketFlags (SocketFlags CInt
s) (Type CInt
t) = CInt -> Type
Type (CInt
s CInt -> CInt -> CInt
forall a. Bits a => a -> a -> a
.|. CInt
t)

shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray MutableByteArray# RealWorld
arr) (I# Int#
sz) =
  (State# (PrimState IO) -> State# (PrimState IO)) -> IO ()
forall (m :: * -> *).
PrimMonad m =>
(State# (PrimState m) -> State# (PrimState m)) -> m ()
PM.primitive_ (MutableByteArray# RealWorld
-> Int# -> State# RealWorld -> State# RealWorld
forall d. MutableByteArray# d -> Int# -> State# d -> State# d
shrinkMutableByteArray# MutableByteArray# RealWorld
arr Int#
sz)

{- | Variant of 'Posix.Socket.uninterruptibleAccept' that allows setting
  flags on the newly-accepted connection.
-}
uninterruptibleAccept4 ::
  -- | Listening socket
  Fd ->
  -- | Maximum socket address size
  CInt ->
  -- | Set non-blocking and close-on-exec without extra syscall
  SocketFlags ->
  -- | Peer information and connected socket
  IO (Either Errno (CInt, SocketAddress, Fd))
{-# INLINE uninterruptibleAccept4 #-}
uninterruptibleAccept4 :: Fd
-> CInt
-> SocketFlags
-> IO (Either Errno (CInt, SocketAddress, Fd))
uninterruptibleAccept4 !Fd
sock !CInt
maxSz !SocketFlags
flags = do
  sockAddrBuf :: MutableByteArray RealWorld
sockAddrBuf@(MutableByteArray MutableByteArray# RealWorld
sockAddrBuf#) <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (CInt -> Int
cintToInt CInt
maxSz)
  lenBuf :: MutableByteArray RealWorld
lenBuf@(MutableByteArray MutableByteArray# RealWorld
lenBuf#) <- Int -> IO (MutableByteArray (PrimState IO))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (CInt -> Int
forall a. Prim a => a -> Int
PM.sizeOf (CInt
forall a. HasCallStack => a
undefined :: CInt))
  MutableByteArray (PrimState IO) -> Int -> CInt -> IO ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
lenBuf Int
0 CInt
maxSz
  Fd
r <- Fd
-> MutableByteArray# RealWorld
-> MutableByteArray# RealWorld
-> SocketFlags
-> IO Fd
c_unsafe_accept4 Fd
sock MutableByteArray# RealWorld
sockAddrBuf# MutableByteArray# RealWorld
lenBuf# SocketFlags
flags
  if Fd
r Fd -> Fd -> Bool
forall a. Ord a => a -> a -> Bool
> (-Fd
1)
    then do
      (CInt
sz :: CInt) <- MutableByteArray (PrimState IO) -> Int -> IO CInt
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
PM.readByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
lenBuf Int
0
      if CInt
sz CInt -> CInt -> Bool
forall a. Ord a => a -> a -> Bool
< CInt
maxSz
        then MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray MutableByteArray RealWorld
sockAddrBuf (CInt -> Int
cintToInt CInt
sz)
        else () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      ByteArray
sockAddr <- MutableByteArray (PrimState IO) -> IO ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState IO)
sockAddrBuf
      Either Errno (CInt, SocketAddress, Fd)
-> IO (Either Errno (CInt, SocketAddress, Fd))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((CInt, SocketAddress, Fd) -> Either Errno (CInt, SocketAddress, Fd)
forall a b. b -> Either a b
Right (CInt
sz, ByteArray -> SocketAddress
SocketAddress ByteArray
sockAddr, Fd
r))
    else (Errno -> Either Errno (CInt, SocketAddress, Fd))
-> IO Errno -> IO (Either Errno (CInt, SocketAddress, Fd))
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Errno -> Either Errno (CInt, SocketAddress, Fd)
forall a b. a -> Either a b
Left IO Errno
getErrno

{- | Variant of 'uninterruptibleAccept4' that requests that the kernel not
include the socket address in its reponse.
-}
uninterruptibleAccept4_ ::
  -- | Listening socket
  Fd ->
  -- | Set non-blocking and close-on-exec without extra syscall
  SocketFlags ->
  -- | Connected socket
  IO (Either Errno Fd)
{-# INLINE uninterruptibleAccept4_ #-}
uninterruptibleAccept4_ :: Fd -> SocketFlags -> IO (Either Errno Fd)
uninterruptibleAccept4_ !Fd
sock !SocketFlags
flags = do
  Fd
r <- Fd -> Ptr Void -> Ptr Void -> SocketFlags -> IO Fd
c_unsafe_ptr_accept4 Fd
sock Ptr Void
forall a. Ptr a
nullPtr Ptr Void
forall a. Ptr a
nullPtr SocketFlags
flags
  if Fd
r Fd -> Fd -> Bool
forall a. Ord a => a -> a -> Bool
> (-Fd
1)
    then Either Errno Fd -> IO (Either Errno Fd)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fd -> Either Errno Fd
forall a b. b -> Either a b
Right Fd
r)
    else (Errno -> Either Errno Fd) -> IO Errno -> IO (Either Errno Fd)
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Errno -> Either Errno Fd
forall a b. a -> Either a b
Left IO Errno
getErrno

cintToInt :: CInt -> Int
cintToInt :: CInt -> Int
cintToInt = CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral