{-# language BangPatterns #-}
{-# language DataKinds #-}
{-# language DeriveAnyClass #-}
{-# language DerivingStrategies #-}
{-# language DuplicateRecordFields #-}
{-# language GADTs #-}
{-# language KindSignatures #-}
{-# language MagicHash #-}
{-# language NamedFieldPuns #-}
{-# language StandaloneDeriving #-}
{-# language UnboxedTuples #-}
module Socket.IPv4
( Peer(..)
, Message(..)
, IPv4Slab(..)
, SocketException(..)
, describeEndpoint
, freezeIPv4Slab
, newIPv4Slab
) where
import Control.Exception (Exception)
import Control.Monad.Primitive (primitive_,primitive)
import Data.Kind (Type)
import Data.Primitive.Unlifted.Array (MutableUnliftedArray)
import Data.Primitive (ByteArray(..),MutableByteArray)
import Data.Primitive (MutablePrimArray)
import Data.Primitive (SmallArray,SmallMutableArray)
import Data.Word (Word16)
import Foreign.C.Types (CInt)
import GHC.Exts (RealWorld,Int(I#))
import Net.Types (IPv4(..))
import Socket.Error (die)
import qualified Data.Primitive as PM
import qualified Data.Primitive.Unlifted.Array as PM
import qualified Data.Text as T
import qualified GHC.Exts as Exts
import qualified Net.IPv4 as IPv4
import qualified Posix.Socket as S
data Peer = Peer
{ address :: !IPv4
, port :: !Word16
} deriving stock (Eq,Show)
data Message = Message
{ peer :: {-# UNPACK #-} !Peer
, payload :: !ByteArray
} deriving stock (Eq,Show)
data IPv4Slab = IPv4Slab
{ sizes :: !(MutablePrimArray RealWorld CInt)
, peers :: !(MutablePrimArray RealWorld S.SocketAddressInternet)
, payloads :: !(MutableUnliftedArray RealWorld (MutableByteArray RealWorld))
}
newIPv4Slab ::
Int
-> Int
-> IO IPv4Slab
newIPv4Slab !n !m = if n >= 1 && m >= 1
then do
sizes <- PM.newPrimArray n
peers <- PM.newPrimArray n
payloads <- PM.unsafeNewUnliftedArray n
let go !ix = if ix > (-1)
then do
writeMutableByteArrayArray payloads ix =<< PM.newByteArray m
go (ix - 1)
else pure ()
go (n - 1)
pure IPv4Slab{sizes,peers,payloads}
else die "newSlabIPv4"
describeEndpoint :: Peer -> String
describeEndpoint (Peer {address,port}) =
T.unpack (IPv4.encode address) ++ ":" ++ show port
data SocketException :: Type where
SocketPermissionDenied :: SocketException
SocketAddressInUse :: SocketException
SocketEphemeralPortsExhausted :: SocketException
SocketFileDescriptorLimit :: SocketException
deriving stock instance Show SocketException
deriving anyclass instance Exception SocketException
freezeIPv4Slab ::
IPv4Slab
-> Int
-> IO (SmallArray Message)
freezeIPv4Slab slab n = do
msgs <- PM.newSmallArray n errorThunk
freezeSlabGo slab msgs (n - 1)
freezeSlabGo :: IPv4Slab -> SmallMutableArray RealWorld Message -> Int -> IO (SmallArray Message)
freezeSlabGo slab@IPv4Slab{payloads,peers,sizes} !arr !ix = if ix > (-1)
then do
!size <- PM.readPrimArray sizes ix
!sockaddr <- PM.readPrimArray peers ix
payloadMut <- readMutableByteArrayArray payloads ix
originalSize <- PM.getSizeofMutableByteArray payloadMut
!payload <- PM.unsafeFreezeByteArray =<< PM.resizeMutableByteArray payloadMut (cintToInt size)
writeMutableByteArrayArray payloads ix =<< PM.newByteArray originalSize
let !peer = sockAddrToPeer sockaddr
!msg = Message {peer,payload}
PM.writeSmallArray arr ix msg
freezeSlabGo slab arr (ix - 1)
else PM.unsafeFreezeSmallArray arr
{-# NOINLINE errorThunk #-}
errorThunk :: Message
errorThunk = error "Socket.IPv4.errorThunk"
cintToInt :: CInt -> Int
cintToInt = fromIntegral
sockAddrToPeer :: S.SocketAddressInternet -> Peer
sockAddrToPeer (S.SocketAddressInternet {address,port}) = Peer
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
writeMutableByteArrayArray
:: MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Int
-> MutableByteArray RealWorld
-> IO ()
writeMutableByteArrayArray (PM.MutableUnliftedArray maa#) (I# i#) (PM.MutableByteArray a)
= primitive_ (Exts.writeMutableByteArrayArray# maa# i# a)
readMutableByteArrayArray
:: MutableUnliftedArray RealWorld (MutableByteArray RealWorld)
-> Int
-> IO (MutableByteArray RealWorld)
readMutableByteArrayArray (PM.MutableUnliftedArray maa#) (I# i#)
= primitive $ \s -> case Exts.readMutableByteArrayArray# maa# i# s of
(# s', aa# #) -> (# s', PM.MutableByteArray aa# #)