{-# LINE 1 "src-linux/Posix/Socket/Platform.hsc" #-}
{-# language BangPatterns #-}
{-# language DerivingStrategies #-}
{-# language DuplicateRecordFields #-}
{-# language GeneralizedNewtypeDeriving #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
{-# language UnboxedTuples #-}
{-# language ScopedTypeVariables #-}








module Posix.Socket.Platform
  ( -- * Encoding Socket Addresses
    encodeSocketAddressInternet
  , encodeSocketAddressUnix
    -- * Decoding Socket Addresses
  , decodeSocketAddressInternet
  , indexSocketAddressInternet
    -- * Sizes
  , sizeofSocketAddressInternet
  ) where

import Control.Monad (when)
import Data.Primitive (MutableByteArray,ByteArray(..),writeByteArray,indexByteArray)
import Data.Primitive.Addr (Addr(..))
import Data.Word (Word8)
import Foreign.C.Types (CUShort,CInt)
import GHC.Exts (ByteArray#,State#,RealWorld,runRW#,Ptr(..))
import GHC.ST (ST(..))
import Posix.Socket.Types (SocketAddress(..))
import Posix.Socket.Types (SocketAddressInternet(..),SocketAddressUnix(..))
import Foreign.Storable (peekByteOff)

import qualified Data.Primitive as PM
import qualified Data.Primitive.Addr as PMA
import qualified Foreign.Storable as FS

-- | The size of a serialized internet socket address.
sizeofSocketAddressInternet :: CInt
sizeofSocketAddressInternet :: CInt
sizeofSocketAddressInternet = (CInt
16)
{-# LINE 46 "src-linux/Posix/Socket/Platform.hsc" #-}

internalWriteSocketAddressInternet ::
     MutableByteArray s -- ^ Buffer, must have length of @sockaddr_in@
  -> SocketAddressInternet
  -> ST s ()
internalWriteSocketAddressInternet :: forall s. MutableByteArray s -> SocketAddressInternet -> ST s ()
internalWriteSocketAddressInternet MutableByteArray s
bs (SocketAddressInternet {Word16
port :: Word16
$sel:port:SocketAddressInternet :: SocketAddressInternet -> Word16
port, Word32
address :: Word32
$sel:address:SocketAddressInternet :: SocketAddressInternet -> Word32
address}) = do
  -- Initialize the bytearray by filling it with zeroes to ensure
  -- that the sin_zero padding that linux expects is properly zeroed.
  MutableByteArray (PrimState (ST s))
-> Int -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> Int -> a -> m ()
PM.setByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
bs Int
0 (Int
16) (Word8
0 :: Word8)
{-# LINE 55 "src-linux/Posix/Socket/Platform.hsc" #-}
  -- ATM: I cannot find a way to poke AF_INET into the socket address
  -- without hardcoding the expected length (CUShort). There may be
  -- a way to use hsc2hs to convert a size to a haskell type, but
  -- I am not sure of how to do this. At any rate, I do not expect
  -- that linux will ever change the bit size of sa_family_t, so I
  -- am not too concerned.
  (\MutableByteArray s
hsc_arr Int
hsc_ix -> MutableByteArray (PrimState (ST s)) -> Int -> CUShort -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
hsc_arr (Int
0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
hsc_ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8))) MutableByteArray s
bs Int
0 (CUShort
2 :: CUShort)
{-# LINE 62 "src-linux/Posix/Socket/Platform.hsc" #-}
  -- The port and the address are already supposed to be in network
  -- byte order in the SocketAddressInternet data type.
  (\MutableByteArray s
hsc_arr Int
hsc_ix -> MutableByteArray (PrimState (ST s)) -> Int -> Word16 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
hsc_arr (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
hsc_ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8))) MutableByteArray s
bs Int
0 Word16
port
{-# LINE 65 "src-linux/Posix/Socket/Platform.hsc" #-}
  (\hsc_arr hsc_ix -> writeByteArray hsc_arr (1 + (hsc_ix * 4))) bs 0 address
{-# LINE 66 "src-linux/Posix/Socket/Platform.hsc" #-}

-- | Serialize a IPv4 socket address so that it may be passed to @bind@.
--   This serialization is operating-system dependent.
encodeSocketAddressInternet :: SocketAddressInternet -> SocketAddress
encodeSocketAddressInternet :: SocketAddressInternet -> SocketAddress
encodeSocketAddressInternet SocketAddressInternet
sockAddrInternet =
  ByteArray -> SocketAddress
SocketAddress (ByteArray -> SocketAddress) -> ByteArray -> SocketAddress
forall a b. (a -> b) -> a -> b
$ (State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ByteArray
runByteArrayST ((State# RealWorld -> (# State# RealWorld, ByteArray# #))
 -> ByteArray)
-> (State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ByteArray
forall a b. (a -> b) -> a -> b
$ ST RealWorld ByteArray
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall s. ST s ByteArray -> State# s -> (# State# s, ByteArray# #)
unboxByteArrayST (ST RealWorld ByteArray
 -> State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ST RealWorld ByteArray
-> State# RealWorld
-> (# State# RealWorld, ByteArray# #)
forall a b. (a -> b) -> a -> b
$ do
    MutableByteArray RealWorld
bs <- Int -> ST RealWorld (MutableByteArray (PrimState (ST RealWorld)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
16)
{-# LINE 73 "src-linux/Posix/Socket/Platform.hsc" #-}
    internalWriteSocketAddressInternet bs sockAddrInternet
    ByteArray
r <- MutableByteArray (PrimState (ST RealWorld))
-> ST RealWorld ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState (ST RealWorld))
bs
    ByteArray -> ST RealWorld ByteArray
forall a. a -> ST RealWorld a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteArray
r

-- | Decode a @sockaddr_in@ from a @sockaddr@ of an unknown
--   family. This returns nothing when the size of the @sockaddr@
--   is wrong or when the @sin_family@ is not @AF_INET@.
decodeSocketAddressInternet :: SocketAddress -> Maybe SocketAddressInternet
decodeSocketAddressInternet :: SocketAddress -> Maybe SocketAddressInternet
decodeSocketAddressInternet (SocketAddress ByteArray
arr) =
  if ByteArray -> Int
PM.sizeofByteArray ByteArray
arr Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== ((Int
16))
{-# LINE 83 "src-linux/Posix/Socket/Platform.hsc" #-}
    -- We assume that AF_INET takes up 16 bits. See the comment in
    -- encodeSocketAddressInternet for more detail.
    then if ((\ByteArray
hsc_arr Int
hsc_ix -> ByteArray -> Int -> CUShort
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
hsc_arr (Int
0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
hsc_ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8))) ByteArray
arr Int
0) CUShort -> CUShort -> Bool
forall a. Eq a => a -> a -> Bool
== (CUShort
2 :: CUShort)
{-# LINE 86 "src-linux/Posix/Socket/Platform.hsc" #-}
      then SocketAddressInternet -> Maybe SocketAddressInternet
forall a. a -> Maybe a
Just (SocketAddressInternet -> Maybe SocketAddressInternet)
-> SocketAddressInternet -> Maybe SocketAddressInternet
forall a b. (a -> b) -> a -> b
$ SocketAddressInternet
        { $sel:port:SocketAddressInternet :: Word16
port = (\ByteArray
hsc_arr Int
hsc_ix -> ByteArray -> Int -> Word16
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
hsc_arr (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
hsc_ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8))) ByteArray
arr Int
0
{-# LINE 88 "src-linux/Posix/Socket/Platform.hsc" #-}
        , $sel:address:SocketAddressInternet :: Word32
address = (\ByteArray
hsc_arr Int
hsc_ix -> ByteArray -> Int -> Word32
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
hsc_arr (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
hsc_ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
4))) ByteArray
arr Int
0
{-# LINE 89 "src-linux/Posix/Socket/Platform.hsc" #-}
        }
      else Maybe SocketAddressInternet
forall a. Maybe a
Nothing
    else Maybe SocketAddressInternet
forall a. Maybe a
Nothing

-- | This is unsafe, but it is needed for the wrappers of @recvmmsg@.
-- The index uses @sockaddr_in@s as elements, not bytes. The caller of this
-- function is responsible for bounds checks. Returns the actual (non-internet)
-- socket family on a failure to parse.
indexSocketAddressInternet :: Addr -> Int -> IO (Either CInt SocketAddressInternet)
indexSocketAddressInternet :: Addr -> Int -> IO (Either CInt SocketAddressInternet)
indexSocketAddressInternet Addr
addr Int
ix = do
  CUShort
fam <- (\Ptr Any
hsc_ptr -> Ptr Any -> Int -> IO CUShort
forall b. Ptr b -> Int -> IO CUShort
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Any
hsc_ptr Int
0) Ptr Any
forall {a}. Ptr a
ptr
{-# LINE 100 "src-linux/Posix/Socket/Platform.hsc" #-}
  if fam == (2 :: CUShort)
{-# LINE 101 "src-linux/Posix/Socket/Platform.hsc" #-}
    then do
      port <- (\hsc_ptr -> peekByteOff hsc_ptr 2) ptr
{-# LINE 103 "src-linux/Posix/Socket/Platform.hsc" #-}
      address <- (\hsc_ptr -> peekByteOff hsc_ptr 4) ptr
{-# LINE 104 "src-linux/Posix/Socket/Platform.hsc" #-}
      pure (Right (SocketAddressInternet { port, address }))
    else pure (Left (cushortToCInt fam))
  where
  !(Addr Addr#
offAddr) = Addr -> Int -> Addr
PMA.plusAddr Addr
addr (Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* ((Int
16)))
{-# LINE 108 "src-linux/Posix/Socket/Platform.hsc" #-}
  ptr = Ptr offAddr

-- | Serialize a unix domain socket address so that it may be passed to @bind@.
--   This serialization is operating-system dependent. If the path provided by
--   the argument equals or exceeds the size of @sun_path@ (typically in the range 92
--   to 108 but varies by platform), the socket address will instead be given the
--   empty string as its path. This typically results in @bind@ returning an
--   error code.
encodeSocketAddressUnix :: SocketAddressUnix -> SocketAddress
encodeSocketAddressUnix :: SocketAddressUnix -> SocketAddress
encodeSocketAddressUnix (SocketAddressUnix !ByteArray
name) =
  ByteArray -> SocketAddress
SocketAddress (ByteArray -> SocketAddress) -> ByteArray -> SocketAddress
forall a b. (a -> b) -> a -> b
$ (State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ByteArray
runByteArrayST ((State# RealWorld -> (# State# RealWorld, ByteArray# #))
 -> ByteArray)
-> (State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ByteArray
forall a b. (a -> b) -> a -> b
$ ST RealWorld ByteArray
-> State# RealWorld -> (# State# RealWorld, ByteArray# #)
forall s. ST s ByteArray -> State# s -> (# State# s, ByteArray# #)
unboxByteArrayST (ST RealWorld ByteArray
 -> State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ST RealWorld ByteArray
-> State# RealWorld
-> (# State# RealWorld, ByteArray# #)
forall a b. (a -> b) -> a -> b
$ do
    -- On linux, sun_path always has exactly 108 bytes. It is a null-terminated
    -- string, so we initialize the byte array to zeroes to ensure this
    -- happens.
    let pathSize :: Int
pathSize = Int
108 :: Int
    -- Again, we hard-code the size of sa_family_t as the size of
    -- an unsigned short.
    let familySize :: Int
familySize = CUShort -> Int
forall a. Storable a => a -> Int
FS.sizeOf (CUShort
forall a. HasCallStack => a
undefined :: CUShort)
    MutableByteArray RealWorld
bs <- Int -> ST RealWorld (MutableByteArray (PrimState (ST RealWorld)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
PM.newByteArray (Int
pathSize Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
familySize)
    MutableByteArray (PrimState (ST RealWorld))
-> Int -> Int -> Word8 -> ST RealWorld ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> Int -> a -> m ()
PM.setByteArray MutableByteArray RealWorld
MutableByteArray (PrimState (ST RealWorld))
bs Int
familySize Int
pathSize (Word8
0 :: Word8)
    MutableByteArray (PrimState (ST RealWorld))
-> Int -> CUShort -> ST RealWorld ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
PM.writeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState (ST RealWorld))
bs Int
0 (CUShort
1 :: CUShort)
{-# LINE 129 "src-linux/Posix/Socket/Platform.hsc" #-}
    let sz = PM.sizeofByteArray name
    Bool -> ST RealWorld () -> ST RealWorld ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
pathSize) (ST RealWorld () -> ST RealWorld ())
-> ST RealWorld () -> ST RealWorld ()
forall a b. (a -> b) -> a -> b
$ do
      MutableByteArray (PrimState (ST RealWorld))
-> Int -> ByteArray -> Int -> Int -> ST RealWorld ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m)
-> Int -> ByteArray -> Int -> Int -> m ()
PM.copyByteArray MutableByteArray RealWorld
MutableByteArray (PrimState (ST RealWorld))
bs Int
familySize ByteArray
name Int
0 Int
sz
    MutableByteArray (PrimState (ST RealWorld))
-> ST RealWorld ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
PM.unsafeFreezeByteArray MutableByteArray RealWorld
MutableByteArray (PrimState (ST RealWorld))
bs

cushortToCInt :: CUShort -> CInt
cushortToCInt :: CUShort -> CInt
cushortToCInt = CUShort -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral

unboxByteArrayST :: ST s ByteArray -> State# s -> (# State# s, ByteArray# #)
unboxByteArrayST :: forall s. ST s ByteArray -> State# s -> (# State# s, ByteArray# #)
unboxByteArrayST (ST STRep s ByteArray
f) State# s
s = case STRep s ByteArray
f State# s
s of
  (# State# s
s', ByteArray ByteArray#
b #) -> (# State# s
s', ByteArray#
b #)

-- This is a specialization of runST that avoids a needless
-- data constructor allocation.
runByteArrayST :: (State# RealWorld -> (# State# RealWorld, ByteArray# #)) -> ByteArray
runByteArrayST :: (State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> ByteArray
runByteArrayST State# RealWorld -> (# State# RealWorld, ByteArray# #)
st_rep = case (State# RealWorld -> (# State# RealWorld, ByteArray# #))
-> (# State# RealWorld, ByteArray# #)
forall o. (State# RealWorld -> o) -> o
runRW# State# RealWorld -> (# State# RealWorld, ByteArray# #)
st_rep of (# State# RealWorld
_, ByteArray#
a #) -> ByteArray# -> ByteArray
ByteArray ByteArray#
a