{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
module Socket.Stream.IPv4
(
Listener
, Connection
, Endpoint(..)
, withListener
, withAccepted
, withConnection
, forkAccepted
, forkAcceptedUnmasked
, interruptibleForkAcceptedUnmasked
, sendByteArray
, sendByteArraySlice
, sendMutableByteArray
, sendMutableByteArraySlice
, interruptibleSendByteArray
, interruptibleSendByteArraySlice
, interruptibleSendMutableByteArraySlice
, receiveByteArray
, receiveBoundedByteArray
, receiveBoundedMutableByteArraySlice
, receiveMutableByteArray
, interruptibleReceiveByteArray
, interruptibleReceiveBoundedMutableByteArraySlice
, SendException(..)
, ReceiveException(..)
, ConnectException(..)
, SocketException(..)
, AcceptException(..)
, CloseException(..)
, Interruptibility(..)
, listen
, unlisten
, unlisten_
, connect
, disconnect
, disconnect_
, accept
, interruptibleAccept
) where
import Control.Applicative ((<|>))
import Control.Concurrent (ThreadId, threadWaitRead, threadWaitWrite)
import Control.Concurrent (threadWaitReadSTM,threadWaitWriteSTM)
import Control.Concurrent (forkIO, forkIOWithUnmask)
import Control.Exception (mask, mask_, onException, throwIO)
import Control.Monad.STM (STM,atomically,retry)
import Control.Concurrent.STM (TVar,modifyTVar',readTVar)
import Data.Bifunctor (bimap,first)
import Data.Bool (bool)
import Data.Functor (($>))
import Data.Primitive (ByteArray, MutableByteArray(..))
import Data.Word (Word16)
import Foreign.C.Error (Errno(..), eAGAIN, eINPROGRESS, eWOULDBLOCK, ePIPE, eNOTCONN)
import Foreign.C.Error (eADDRINUSE,eCONNRESET)
import Foreign.C.Error (eNFILE,eMFILE,eACCES,ePERM,eCONNABORTED)
import Foreign.C.Error (eTIMEDOUT,eADDRNOTAVAIL,eNETUNREACH,eCONNREFUSED)
import Foreign.C.Types (CInt, CSize)
import GHC.Exts (Int(I#), RealWorld, shrinkMutableByteArray#)
import Net.Types (IPv4(..))
import Socket (Interruptibility(..))
import Socket (SocketUnrecoverableException(..))
import Socket (cgetsockname,cclose)
import Socket.Debug (debug)
import Socket.IPv4 (Endpoint(..),describeEndpoint)
import Socket.Stream (ConnectException(..),SocketException(..),AcceptException(..))
import Socket.Stream (SendException(..),ReceiveException(..),CloseException(..))
import System.Posix.Types(Fd)
import qualified Control.Monad.Primitive as PM
import qualified Data.Primitive as PM
import qualified Linux.Socket as L
import qualified Posix.Socket as S
import qualified Socket as SCK
newtype Listener = Listener Fd
newtype Connection = Connection Fd
listen :: Endpoint -> IO (Either SocketException (Listener, Word16))
listen endpoint@Endpoint{port = specifiedPort} = do
debug ("listen: opening listen " ++ describeEndpoint endpoint)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("listen: opened listen " ++ describeEndpoint endpoint)
case e1 of
Left err -> handleSocketListenException SCK.functionWithListener err
Right fd -> do
e2 <- S.uninterruptibleBind fd
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint))
debug ("listen: requested binding for listen " ++ describeEndpoint endpoint)
case e2 of
Left err -> do
_ <- S.uninterruptibleClose fd
handleBindListenException specifiedPort SCK.functionWithListener err
Right _ -> S.uninterruptibleListen fd 16 >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
debug "listen: listen failed with error code"
handleBindListenException specifiedPort SCK.functionWithListener err
Right _ -> do
actualPort <- if specifiedPort == 0
then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,describeEndpoint endpoint,describeErrorCode err]
Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just S.SocketAddressInternet{port = actualPort} -> do
let cleanActualPort = S.networkToHostShort actualPort
debug ("listen: successfully bound listen " ++ describeEndpoint endpoint ++ " and got port " ++ show cleanActualPort)
pure cleanActualPort
Nothing -> do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,"non-internet socket family"]
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,describeEndpoint endpoint,"socket address size"]
else pure specifiedPort
pure (Right (Listener fd, actualPort))
unlisten :: Listener -> IO ()
unlisten (Listener fd) = S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cclose,describeErrorCode err]
Right _ -> pure ()
unlisten_ :: Listener -> IO ()
unlisten_ (Listener fd) = S.uninterruptibleErrorlessClose fd
withListener ::
Endpoint
-> (Listener -> Word16 -> IO a)
-> IO (Either SocketException a)
withListener endpoint f = mask $ \restore -> do
listen endpoint >>= \case
Left err -> pure (Left err)
Right (sck, actualPort) -> do
a <- onException
(restore (f sck actualPort))
(unlisten_ sck)
unlisten sck
pure (Right a)
accept :: Listener -> IO (Either (AcceptException 'Uninterruptible) (Connection,Endpoint))
accept (Listener fd) = do
threadWaitRead fd
waitlessAccept fd
interruptibleAccept ::
TVar Bool
-> Listener
-> IO (Either (AcceptException 'Interruptible) (Connection,Endpoint))
interruptibleAccept abandon (Listener fd) = do
interruptibleWaitRead abandon fd >>= \case
True -> waitlessAccept fd
False -> pure (Left AcceptInterrupted)
interruptibleAcceptCounting ::
TVar Int
-> TVar Bool
-> Listener
-> IO (Either (AcceptException 'Interruptible) (Connection,Endpoint))
interruptibleAcceptCounting counter abandon (Listener fd) = do
interruptibleWaitReadCounting counter abandon fd >>= \case
True -> waitlessAccept fd
False -> pure (Left AcceptInterrupted)
waitlessAccept :: Fd -> IO (Either (AcceptException i) (Connection,Endpoint))
waitlessAccept lstn = do
S.uninterruptibleAccept lstn S.sizeofSocketAddressInternet >>= \case
Left err -> handleAcceptException "withAccepted" err
Right (sockAddrRequiredSz,sockAddr,acpt) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
let !acceptedEndpoint = socketAddressInternetToEndpoint sockAddrInet
debug ("internalAccepted: successfully accepted connection from " ++ show acceptedEndpoint)
pure (Right (Connection acpt, acceptedEndpoint))
Nothing -> do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.nonInternetSocketFamily]
else do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.socketAddressSize]
internalAccepted ::
((forall x. IO x -> IO x) -> ((IO a -> IO d) -> IO (Either CloseException (),d)) -> IO (Either (AcceptException 'Uninterruptible) c))
-> Listener
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) c)
internalAccepted wrap (Listener !lst) f = do
threadWaitRead lst
mask $ \restore -> do
S.uninterruptibleAccept lst S.sizeofSocketAddressInternet >>= \case
Left err -> handleAcceptException "withAccepted" err
Right (sockAddrRequiredSz,sockAddr,acpt) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
let acceptedEndpoint = socketAddressInternetToEndpoint sockAddrInet
debug ("internalAccepted: successfully accepted connection from " ++ show acceptedEndpoint)
wrap restore $ \restore' -> do
a <- onException (restore' (f (Connection acpt) acceptedEndpoint)) (S.uninterruptibleClose acpt)
e <- gracefulCloseA acpt
pure (e,a)
Nothing -> do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.nonInternetSocketFamily]
else do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.socketAddressSize]
gracefulCloseA :: Fd -> IO (Either CloseException ())
gracefulCloseA fd = S.uninterruptibleShutdown fd S.write >>= \case
Left err -> if err == eNOTCONN
then gracefulCloseB fd
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cshutdown,describeErrorCode err]
Right _ -> gracefulCloseB fd
gracefulCloseB :: Fd -> IO (Either CloseException ())
gracefulCloseB fd = do
buf <- PM.newByteArray 1
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
threadWaitRead fd
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.crecv,describeErrorCode err]
Right sz -> if sz == 0
then S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cclose,describeErrorCode err]
Right _ -> pure (Right ())
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown A")
_ <- S.uninterruptibleClose fd
pure (Left ClosePeerContinuedSending)
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.crecv,describeErrorCode err1]
Right sz -> if sz == 0
then S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cclose,describeErrorCode err]
Right _ -> pure (Right ())
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown B")
_ <- S.uninterruptibleClose fd
pure (Left ClosePeerContinuedSending)
withAccepted ::
Listener
-> (Either CloseException () -> a -> IO b)
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) b)
withAccepted lstn consumeException cb = do
r <- mask $ \restore -> do
accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> do
a <- onException (restore (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
pure (Right (e,a))
case r of
Left e -> pure (Left e)
Right (e,a) -> fmap Right (consumeException e a)
forkAccepted ::
Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) ThreadId)
forkAccepted lst consumeException cb = internalAccepted
( \restore action -> do
tid <- forkIO $ do
(e,x) <- action restore
restore (consumeException e x)
pure (Right tid)
) lst cb
forkAcceptedUnmasked ::
Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) ThreadId)
forkAcceptedUnmasked lstn consumeException cb =
mask_ $ accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIOWithUnmask $ \unmask -> do
a <- onException (unmask (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
unmask (consumeException e a)
interruptibleForkAcceptedUnmasked ::
TVar Int
-> TVar Bool
-> Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Interruptible) ThreadId)
interruptibleForkAcceptedUnmasked !counter !abandon !lstn consumeException cb =
mask_ $ interruptibleAcceptCounting counter abandon lstn >>= \case
Left e -> do
case e of
AcceptInterrupted -> pure ()
_ -> atomically (modifyTVar' counter (subtract 1))
pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIOWithUnmask $ \unmask -> do
a <- onException
(unmask (cb conn endpoint))
(disconnect_ conn *> atomically (modifyTVar' counter (subtract 1)))
e <- disconnect conn
r <- unmask (consumeException e a)
atomically (modifyTVar' counter (subtract 1))
pure r
connect ::
Endpoint
-> IO (Either (ConnectException 'Uninterruptible) Connection)
connect !remote = do
debug ("connect: opening connection " ++ show remote)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("connect: opened connection " ++ show remote)
case e1 of
Left err -> handleSocketConnectException SCK.functionWithConnection err
Right fd -> do
let sockAddr = id
$ S.encodeSocketAddressInternet
$ endpointToSocketAddressInternet
$ remote
merr <- S.uninterruptibleConnect fd sockAddr >>= \case
Left err2 -> if err2 == eINPROGRESS
then do
threadWaitWrite fd
pure Nothing
else pure (Just err2)
Right _ -> pure Nothing
case merr of
Just err -> do
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection err
Nothing -> do
e <- S.uninterruptibleGetSocketOption fd
S.levelSocket S.optionError (intToCInt (PM.sizeOf (undefined :: CInt)))
case e of
Left err -> do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[SCK.cgetsockopt,describeEndpoint remote,describeErrorCode err]
Right (sz,S.OptionValue val) -> if sz == intToCInt (PM.sizeOf (undefined :: CInt))
then
let err = PM.indexByteArray val 0 :: CInt in
if err == 0
then pure (Right (Connection fd))
else do
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection (Errno err)
else do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[SCK.cgetsockopt,describeEndpoint remote,connectErrorOptionValueSize]
disconnect :: Connection -> IO (Either CloseException ())
disconnect (Connection fd) = gracefulCloseA fd
disconnect_ :: Connection -> IO ()
disconnect_ (Connection fd) = S.uninterruptibleErrorlessClose fd
withConnection ::
Endpoint
-> (Either CloseException () -> a -> IO b)
-> (Connection -> IO a)
-> IO (Either (ConnectException 'Uninterruptible) b)
withConnection !remote g f = mask $ \restore -> do
connect remote >>= \case
Left err -> pure (Left err)
Right conn -> do
a <- onException (restore (f conn)) (disconnect_ conn)
m <- disconnect conn
b <- g m a
pure (Right b)
sendByteArray ::
Connection
-> ByteArray
-> IO (Either (SendException 'Uninterruptible) ())
sendByteArray conn arr =
sendByteArraySlice conn arr 0 (PM.sizeofByteArray arr)
interruptibleSendByteArray ::
TVar Bool
-> Connection
-> ByteArray
-> IO (Either (SendException 'Interruptible) ())
interruptibleSendByteArray abandon conn arr =
interruptibleSendByteArraySlice abandon conn arr 0 (PM.sizeofByteArray arr)
sendByteArraySlice ::
Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendByteArraySlice !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalSend conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendByteArray
[SCK.negativeSliceLength]
interruptibleSendByteArraySlice ::
TVar Bool
-> Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) ())
interruptibleSendByteArraySlice !abandon !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalInterruptibleSend abandon conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendByteArray
[SCK.negativeSliceLength]
sendMutableByteArray ::
Connection
-> MutableByteArray RealWorld
-> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArray conn arr =
sendMutableByteArraySlice conn arr 0 =<< PM.getSizeofMutableByteArray arr
sendMutableByteArraySlice ::
Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArraySlice !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalSendMutable conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendMutableByteArray
[SCK.negativeSliceLength]
interruptibleSendMutableByteArraySlice ::
TVar Bool
-> Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) ())
interruptibleSendMutableByteArraySlice !abandon !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalInterruptibleSendMutable abandon conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendMutableByteArray
[SCK.negativeSliceLength]
internalInterruptibleSendMutable ::
TVar Bool
-> Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) CSize)
internalInterruptibleSendMutable !abandon !conn !payload !off !len =
veryInternalSendMutable
(\fd -> interruptibleWaitWrite abandon fd >>= \case
True -> pure (Right ())
False -> pure (Left SendInterrupted)
) conn payload off len
internalSendMutable ::
Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) CSize)
internalSendMutable !conn !payload !off !len =
veryInternalSendMutable
(\fd -> threadWaitWrite fd *> pure (Right ()))
conn payload off len
veryInternalSendMutable ::
(Fd -> IO (Either (SendException i) ()))
-> Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException i) CSize)
{-# INLINE veryInternalSendMutable #-}
veryInternalSendMutable wait (Connection !s) !payload !off !len = do
e1 <- S.uninterruptibleSendMutableByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
wait s >>= \case
Left err2 -> pure (Left err2)
Right () -> do
e3 <- S.uninterruptibleSendMutableByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
case e3 of
Left err3 -> handleSendException functionSendMutableByteArray err3
Right sz -> pure (Right sz)
else handleSendException "sendMutableByteArray" err1
Right sz -> pure (Right sz)
internalSend ::
Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) CSize)
internalSend !conn !payload !off !len = veryInternalSend
(\fd -> threadWaitWrite fd *> pure (Right ()))
conn payload off len
internalInterruptibleSend ::
TVar Bool
-> Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) CSize)
internalInterruptibleSend !abandon !conn !payload !off !len = veryInternalSend
(\fd -> interruptibleWaitWrite abandon fd >>= \case
True -> pure (Right ())
False -> pure (Left SendInterrupted)
) conn payload off len
veryInternalSend ::
(Fd -> IO (Either (SendException i) ()))
-> Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException i) CSize)
{-# INLINE veryInternalSend #-}
veryInternalSend wait (Connection !s) !payload !off !len = do
debug ("veryInternalSend: about to send chunk on stream socket, offset " ++ show off ++ " and length " ++ show len)
e1 <- S.uninterruptibleSendByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
debug "veryInternalSend: just sent chunk on stream socket"
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug "veryInternalSend: waiting to for write ready on stream socket"
wait s >>= \case
Left e -> pure (Left e)
Right _ -> do
e2 <- S.uninterruptibleSendByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
case e2 of
Left err2 -> do
debug "veryInternalSend: encountered error after sending chunk on stream socket"
handleSendException functionSendByteArray err2
Right sz -> pure (Right sz)
else handleSendException functionSendByteArray err1
Right sz -> pure (Right sz)
internalReceiveMaximally ::
Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException i) Int)
internalReceiveMaximally (Connection !fd) !maxSz !buf !off = do
debug "receive: stream socket about to wait"
threadWaitRead fd
debug ("receive: stream socket is now readable, receiving up to " ++ show maxSz ++ " bytes at offset " ++ show off)
e <- S.uninterruptibleReceiveMutableByteArray fd buf (intToCInt off) (intToCSize maxSz) mempty
debug "receive: finished reading from stream socket"
case e of
Left err -> handleReceiveException "internalReceiveMaximally" err
Right recvSz -> pure (Right (csizeToInt recvSz))
internalInterruptibleReceiveMaximally ::
TVar Bool
-> Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException 'Interruptible) Int)
{-# INLINE internalInterruptibleReceiveMaximally #-}
internalInterruptibleReceiveMaximally abandon (Connection !fd) !maxSz !buf !off = do
shouldReceive <- interruptibleWaitRead abandon fd
if shouldReceive
then do
e <- S.uninterruptibleReceiveMutableByteArray fd buf (intToCInt off) (intToCSize maxSz) mempty
case e of
Left err -> handleReceiveException "internalReceiveMaximally" err
Right recvSz -> pure (Right (csizeToInt recvSz))
else pure (Left ReceiveInterrupted)
receiveByteArray ::
Connection
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) ByteArray)
receiveByteArray !conn !total =
internalReceiveByteArray internalReceiveMaximally conn total
interruptibleReceiveByteArray ::
TVar Bool
-> Connection
-> Int
-> IO (Either (ReceiveException 'Interruptible) ByteArray)
interruptibleReceiveByteArray !abandon !conn !total =
internalReceiveByteArray (internalInterruptibleReceiveMaximally abandon) conn total
internalReceiveByteArray ::
(Connection -> Int -> MutableByteArray RealWorld -> Int -> IO (Either (ReceiveException i) Int))
-> Connection
-> Int
-> IO (Either (ReceiveException i) ByteArray)
internalReceiveByteArray recvMax !conn0 !total = do
marr <- PM.newByteArray total
go conn0 marr 0 total
where
go !conn !marr !off !remaining = case compare remaining 0 of
GT -> do
recvMax conn remaining marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then go conn marr (off + sz) (remaining - sz)
else pure (Left ReceiveShutdown)
EQ -> do
arr <- PM.unsafeFreezeByteArray marr
pure (Right arr)
LT -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveByteArray
[SCK.negativeSliceLength]
receiveMutableByteArray ::
Connection
-> MutableByteArray RealWorld
-> IO (Either (ReceiveException 'Uninterruptible) ())
receiveMutableByteArray !conn0 !marr0 = do
total <- PM.getSizeofMutableByteArray marr0
go conn0 marr0 0 total
where
go !conn !marr !off !remaining = if remaining > 0
then do
internalReceiveMaximally conn remaining marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then go conn marr (off + sz) (remaining - sz)
else pure (Left ReceiveShutdown)
else pure (Right ())
receiveBoundedMutableByteArraySlice ::
Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) Int)
receiveBoundedMutableByteArraySlice !conn !total !marr !off
| total > 0 = do
internalReceiveMaximally conn total marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then pure (Right sz)
else pure (Left ReceiveShutdown)
| total == 0 = pure (Right 0)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveMutableByteArraySlice
[SCK.negativeSliceLength]
interruptibleReceiveBoundedMutableByteArraySlice ::
TVar Bool
-> Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException 'Interruptible) Int)
interruptibleReceiveBoundedMutableByteArraySlice !abandon !conn !total !marr !off
| total > 0 = do
internalInterruptibleReceiveMaximally abandon conn total marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then pure (Right sz)
else pure (Left ReceiveShutdown)
| total == 0 = pure (Right 0)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveMutableByteArraySlice
[SCK.negativeSliceLength]
receiveBoundedByteArray ::
Connection
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) ByteArray)
receiveBoundedByteArray !conn !total
| total > 0 = do
m <- PM.newByteArray total
receiveBoundedMutableByteArraySlice conn total m 0 >>= \case
Left err -> pure (Left err)
Right sz -> do
shrinkMutableByteArray m sz
Right <$> PM.unsafeFreezeByteArray m
| total == 0 = pure (Right mempty)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveBoundedByteArray
[SCK.negativeSliceLength]
endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Endpoint
socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Endpoint
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)
moduleSocketStreamIPv4 :: String
moduleSocketStreamIPv4 = "Socket.Stream.IPv4"
functionSendMutableByteArray :: String
functionSendMutableByteArray = "sendMutableByteArray"
functionSendByteArray :: String
functionSendByteArray = "sendByteArray"
functionWithListener :: String
functionWithListener = "withListener"
functionReceiveBoundedByteArray :: String
functionReceiveBoundedByteArray = "receiveBoundedByteArray"
functionReceiveByteArray :: String
functionReceiveByteArray = "receiveByteArray"
functionReceiveMutableByteArraySlice :: String
functionReceiveMutableByteArraySlice = "receiveMutableByteArraySlice"
describeErrorCode :: Errno -> String
describeErrorCode (Errno e) = "error code " ++ show e
handleReceiveException :: String -> Errno -> IO (Either (ReceiveException i) a)
handleReceiveException func e
| e == eCONNRESET = pure (Left ReceiveReset)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSendException :: String -> Errno -> IO (Either (SendException i) a)
{-# INLINE handleSendException #-}
handleSendException func e
| e == ePIPE = pure (Left SendShutdown)
| e == eCONNRESET = pure (Left SendReset)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleConnectException :: String -> Errno -> IO (Either (ConnectException i) a)
handleConnectException func e
| e == eACCES = pure (Left ConnectFirewalled)
| e == ePERM = pure (Left ConnectFirewalled)
| e == eNETUNREACH = pure (Left ConnectNetworkUnreachable)
| e == eCONNREFUSED = pure (Left ConnectRefused)
| e == eADDRNOTAVAIL = pure (Left ConnectEphemeralPortsExhausted)
| e == eTIMEDOUT = pure (Left ConnectTimeout)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSocketConnectException :: String -> Errno -> IO (Either (ConnectException i) a)
handleSocketConnectException func e
| e == eMFILE = pure (Left ConnectFileDescriptorLimit)
| e == eNFILE = pure (Left ConnectFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSocketListenException :: String -> Errno -> IO (Either SocketException a)
handleSocketListenException func e
| e == eMFILE = pure (Left SocketFileDescriptorLimit)
| e == eNFILE = pure (Left SocketFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleBindListenException :: Word16 -> String -> Errno -> IO (Either SocketException a)
handleBindListenException thePort func e
| e == eACCES = pure (Left SocketPermissionDenied)
| e == eADDRINUSE = if thePort == 0
then pure (Left SocketAddressInUse)
else pure (Left SocketEphemeralPortsExhausted)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleAcceptException :: String -> Errno -> IO (Either (AcceptException i) a)
handleAcceptException func e
| e == eCONNABORTED = pure (Left AcceptConnectionAborted)
| e == eMFILE = pure (Left AcceptFileDescriptorLimit)
| e == eNFILE = pure (Left AcceptFileDescriptorLimit)
| e == ePERM = pure (Left AcceptFirewalled)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
connectErrorOptionValueSize :: String
connectErrorOptionValueSize = "incorrectly sized value of SO_ERROR option"
interruptibleWaitRead :: TVar Bool -> Fd -> IO Bool
interruptibleWaitRead !abandon !fd = do
(isReady,deregister) <- threadWaitReadSTM fd
shouldReceive <- atomically
((bool retry (pure False) =<< readTVar abandon) <|> (isReady $> True))
deregister
pure shouldReceive
interruptibleWaitWrite :: TVar Bool -> Fd -> IO Bool
interruptibleWaitWrite !abandon !fd = do
(isReady,deregister) <- threadWaitWriteSTM fd
shouldSend <- atomically
((bool retry (pure False) =<< readTVar abandon) <|> (isReady $> True))
deregister
pure shouldSend
interruptibleWaitReadCounting :: TVar Int -> TVar Bool -> Fd -> IO Bool
interruptibleWaitReadCounting !counter !abandon !fd = do
(isReady,deregister) <- threadWaitReadSTM fd
shouldReceive <- atomically $ do
readTVar abandon >>= \case
False -> do
isReady
modifyTVar' counter (+1)
pure True
True -> pure False
deregister
pure shouldReceive