{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE GADTs #-}
module Socket.Stream.IPv4
(
Listener
, Connection
, Endpoint(..)
, withListener
, withAccepted
, withConnection
, interruptibleWithConnection
, forkAccepted
, forkAcceptedUnmasked
, interruptibleForkAcceptedUnmasked
, sendByteArray
, sendByteArraySlice
, sendMutableByteArray
, sendMutableByteArraySlice
, sendAddr
, sendByteString
, sendLazyByteString
, interruptibleSendByteArray
, interruptibleSendByteArraySlice
, interruptibleSendMutableByteArraySlice
, receiveByteArray
, receiveBoundedByteArray
, receiveBoundedMutableByteArraySlice
, receiveMutableByteArray
, receiveByteString
, interruptibleReceiveByteArray
, interruptibleReceiveBoundedMutableByteArraySlice
, SendException(..)
, ReceiveException(..)
, ConnectException(..)
, SocketException(..)
, AcceptException(..)
, CloseException(..)
, Interruptibility(..)
, listen
, unlisten
, unlisten_
, connect
, interruptibleConnect
, disconnect
, disconnect_
, accept
, interruptibleAccept
) where
import Control.Applicative ((<|>))
import Control.Concurrent (ThreadId, threadWaitRead, threadWaitWrite)
import Control.Concurrent (threadWaitReadSTM,threadWaitWriteSTM)
import Control.Concurrent (forkIO, forkIOWithUnmask)
import Control.Exception (mask, mask_, onException, throwIO)
import Control.Monad.STM (atomically,retry)
import Control.Concurrent.STM (TVar,modifyTVar',readTVar)
import Data.Bool (bool)
import Data.ByteString (ByteString)
import Data.Functor (($>))
import Data.Primitive (Addr(..), ByteArray, MutableByteArray(..))
import Data.Word (Word16)
import Foreign.C.Error (Errno(..), eAGAIN, eINPROGRESS, eWOULDBLOCK, ePIPE, eNOTCONN)
import Foreign.C.Error (eADDRINUSE,eCONNRESET)
import Foreign.C.Error (eNFILE,eMFILE,eACCES,ePERM,eCONNABORTED)
import Foreign.C.Error (eTIMEDOUT,eADDRNOTAVAIL,eNETUNREACH,eCONNREFUSED)
import Foreign.C.Types (CInt, CSize)
import GHC.Exts (Ptr(Ptr),Int(I#), RealWorld, shrinkMutableByteArray#, byteArrayContents#, unsafeCoerce#)
import Net.Types (IPv4(..))
import Socket (Interruptibility(..))
import Socket (SocketUnrecoverableException(..))
import Socket (cgetsockname,cclose)
import Socket.Debug (debug)
import Socket.IPv4 (Endpoint(..),describeEndpoint)
import Socket.Stream (ConnectException(..),SocketException(..),AcceptException(..))
import Socket.Stream (SendException(..),ReceiveException(..),CloseException(..))
import System.Posix.Types(Fd)
import qualified Control.Monad.Primitive as PM
import qualified Data.ByteString.Internal as BI
import qualified Data.ByteString.Unsafe as BU
import qualified Data.ByteString.Lazy.Internal as LBS
import qualified Data.Primitive as PM
import qualified Foreign.C.Error.Describe as D
import qualified Linux.Socket as L
import qualified Posix.Socket as S
import qualified Socket as SCK
import qualified GHC.ForeignPtr as FP
newtype Listener = Listener Fd
newtype Connection = Connection Fd
listen :: Endpoint -> IO (Either SocketException (Listener, Word16))
listen endpoint@Endpoint{port = specifiedPort} = do
debug ("listen: opening listen " ++ describeEndpoint endpoint)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("listen: opened listen " ++ describeEndpoint endpoint)
case e1 of
Left err -> handleSocketListenException SCK.functionWithListener err
Right fd -> do
e2 <- S.uninterruptibleBind fd
(S.encodeSocketAddressInternet (endpointToSocketAddressInternet endpoint))
debug ("listen: requested binding for listen " ++ describeEndpoint endpoint)
case e2 of
Left err -> do
_ <- S.uninterruptibleClose fd
handleBindListenException specifiedPort SCK.functionWithListener err
Right _ -> S.uninterruptibleListen fd 16 >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
debug "listen: listen failed with error code"
handleBindListenException specifiedPort SCK.functionWithListener err
Right _ -> do
actualPort <- if specifiedPort == 0
then S.uninterruptibleGetSocketName fd S.sizeofSocketAddressInternet >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,describeEndpoint endpoint,describeErrorCode err]
Right (sockAddrRequiredSz,sockAddr) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just S.SocketAddressInternet{port = actualPort} -> do
let cleanActualPort = S.networkToHostShort actualPort
debug ("listen: successfully bound listen " ++ describeEndpoint endpoint ++ " and got port " ++ show cleanActualPort)
pure cleanActualPort
Nothing -> do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,"non-internet socket family"]
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cgetsockname,describeEndpoint endpoint,"socket address size"]
else pure specifiedPort
pure (Right (Listener fd, actualPort))
unlisten :: Listener -> IO ()
unlisten (Listener fd) = S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[cclose,describeErrorCode err]
Right _ -> pure ()
unlisten_ :: Listener -> IO ()
unlisten_ (Listener fd) = S.uninterruptibleErrorlessClose fd
withListener ::
Endpoint
-> (Listener -> Word16 -> IO a)
-> IO (Either SocketException a)
withListener endpoint f = mask $ \restore -> do
listen endpoint >>= \case
Left err -> pure (Left err)
Right (sck, actualPort) -> do
a <- onException
(restore (f sck actualPort))
(unlisten_ sck)
unlisten sck
pure (Right a)
accept :: Listener -> IO (Either (AcceptException 'Uninterruptible) (Connection,Endpoint))
accept (Listener fd) = do
threadWaitRead fd
waitlessAccept fd
interruptibleAccept ::
TVar Bool
-> Listener
-> IO (Either (AcceptException 'Interruptible) (Connection,Endpoint))
interruptibleAccept abandon (Listener fd) = do
interruptibleWaitRead abandon fd >>= \case
True -> waitlessAccept fd
False -> pure (Left AcceptInterrupted)
interruptibleAcceptCounting ::
TVar Int
-> TVar Bool
-> Listener
-> IO (Either (AcceptException 'Interruptible) (Connection,Endpoint))
interruptibleAcceptCounting counter abandon (Listener fd) = do
interruptibleWaitReadCounting counter abandon fd >>= \case
True -> waitlessAccept fd
False -> pure (Left AcceptInterrupted)
waitlessAccept :: Fd -> IO (Either (AcceptException i) (Connection,Endpoint))
waitlessAccept lstn = do
S.uninterruptibleAccept lstn S.sizeofSocketAddressInternet >>= \case
Left err -> handleAcceptException "withAccepted" err
Right (sockAddrRequiredSz,sockAddr,acpt) -> if sockAddrRequiredSz == S.sizeofSocketAddressInternet
then case S.decodeSocketAddressInternet sockAddr of
Just sockAddrInet -> do
let !acceptedEndpoint = socketAddressInternetToEndpoint sockAddrInet
debug ("internalAccepted: successfully accepted connection from " ++ show acceptedEndpoint)
pure (Right (Connection acpt, acceptedEndpoint))
Nothing -> do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.nonInternetSocketFamily]
else do
_ <- S.uninterruptibleClose acpt
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionWithAccepted
[SCK.cgetsockname,SCK.socketAddressSize]
gracefulCloseA :: Fd -> IO (Either CloseException ())
gracefulCloseA fd = S.uninterruptibleShutdown fd S.write >>= \case
Left err -> if err == eNOTCONN
then gracefulCloseB fd
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cshutdown,describeErrorCode err]
Right _ -> gracefulCloseB fd
gracefulCloseB :: Fd -> IO (Either CloseException ())
gracefulCloseB fd = do
buf <- PM.newByteArray 1
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
threadWaitRead fd
S.uninterruptibleReceiveMutableByteArray fd buf 0 1 mempty >>= \case
Left err -> do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.crecv,describeErrorCode err]
Right sz -> if sz == 0
then S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cclose,describeErrorCode err]
Right _ -> pure (Right ())
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown A")
_ <- S.uninterruptibleClose fd
pure (Left ClosePeerContinuedSending)
else do
_ <- S.uninterruptibleClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.crecv,describeErrorCode err1]
Right sz -> if sz == 0
then S.uninterruptibleClose fd >>= \case
Left err -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
SCK.functionGracefulClose
[SCK.cclose,describeErrorCode err]
Right _ -> pure (Right ())
else do
debug ("Socket.Stream.IPv4.gracefulClose: remote not shutdown B")
_ <- S.uninterruptibleClose fd
pure (Left ClosePeerContinuedSending)
withAccepted ::
Listener
-> (Either CloseException () -> a -> IO b)
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) b)
withAccepted lstn consumeException cb = do
r <- mask $ \restore -> do
accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> do
a <- onException (restore (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
pure (Right (e,a))
case r of
Left e -> pure (Left e)
Right (e,a) -> fmap Right (consumeException e a)
forkAccepted ::
Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) ThreadId)
forkAccepted lstn consumeException cb =
mask $ \restore -> accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIO $ do
a <- onException (restore (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
restore (consumeException e a)
forkAcceptedUnmasked ::
Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Uninterruptible) ThreadId)
forkAcceptedUnmasked lstn consumeException cb =
mask_ $ accept lstn >>= \case
Left e -> pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIOWithUnmask $ \unmask -> do
a <- onException (unmask (cb conn endpoint)) (disconnect_ conn)
e <- disconnect conn
unmask (consumeException e a)
interruptibleForkAcceptedUnmasked ::
TVar Int
-> TVar Bool
-> Listener
-> (Either CloseException () -> a -> IO ())
-> (Connection -> Endpoint -> IO a)
-> IO (Either (AcceptException 'Interruptible) ThreadId)
interruptibleForkAcceptedUnmasked !counter !abandon !lstn consumeException cb =
mask_ $ interruptibleAcceptCounting counter abandon lstn >>= \case
Left e -> do
case e of
AcceptInterrupted -> pure ()
_ -> atomically (modifyTVar' counter (subtract 1))
pure (Left e)
Right (conn, endpoint) -> fmap Right $ forkIOWithUnmask $ \unmask -> do
a <- onException
(unmask (cb conn endpoint))
(disconnect_ conn *> atomically (modifyTVar' counter (subtract 1)))
e <- disconnect conn
r <- unmask (consumeException e a)
atomically (modifyTVar' counter (subtract 1))
pure r
connect ::
Endpoint
-> IO (Either (ConnectException 'Uninterruptible) Connection)
connect !remote = do
beforeEstablishment remote >>= \case
Left err -> pure (Left err)
Right (fd,sockAddr) -> S.uninterruptibleConnect fd sockAddr >>= \case
Left err2 -> if err2 == eINPROGRESS
then do
threadWaitWrite fd
afterEstablishment fd
else do
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection err2
Right _ -> afterEstablishment fd
interruptibleConnect ::
TVar Bool
-> Endpoint
-> IO (Either (ConnectException 'Interruptible) Connection)
interruptibleConnect !abandon !remote = do
beforeEstablishment remote >>= \case
Left err -> pure (Left err)
Right (fd,sockAddr) -> S.uninterruptibleConnect fd sockAddr >>= \case
Left err2 -> if err2 == eINPROGRESS
then do
interruptibleWaitWrite abandon fd >>= \case
True -> afterEstablishment fd
False -> pure (Left ConnectInterrupted)
else do
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection err2
Right _ -> afterEstablishment fd
beforeEstablishment :: Endpoint -> IO (Either (ConnectException i) (Fd,S.SocketAddress))
{-# INLINE beforeEstablishment #-}
beforeEstablishment !remote = do
debug ("beforeEstablishment: opening connection " ++ show remote)
e1 <- S.uninterruptibleSocket S.internet
(L.applySocketFlags (L.closeOnExec <> L.nonblocking) S.stream)
S.defaultProtocol
debug ("beforeEstablishment: opened connection " ++ show remote)
case e1 of
Left err -> handleSocketConnectException SCK.functionWithConnection err
Right fd -> do
let sockAddr = id
$ S.encodeSocketAddressInternet
$ endpointToSocketAddressInternet
$ remote
pure (Right (fd,sockAddr))
afterEstablishment :: Fd -> IO (Either (ConnectException i) Connection)
afterEstablishment fd = do
e <- S.uninterruptibleGetSocketOption fd
S.levelSocket S.optionError (intToCInt (PM.sizeOf (undefined :: CInt)))
case e of
Left err -> do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[SCK.cgetsockopt,describeErrorCode err]
Right (sz,S.OptionValue val) -> if sz == intToCInt (PM.sizeOf (undefined :: CInt))
then
let err = PM.indexByteArray val 0 :: CInt in
if err == 0
then pure (Right (Connection fd))
else do
S.uninterruptibleErrorlessClose fd
handleConnectException SCK.functionWithConnection (Errno err)
else do
S.uninterruptibleErrorlessClose fd
throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionWithListener
[SCK.cgetsockopt,connectErrorOptionValueSize]
disconnect :: Connection -> IO (Either CloseException ())
disconnect (Connection fd) = gracefulCloseA fd
disconnect_ :: Connection -> IO ()
disconnect_ (Connection fd) = S.uninterruptibleErrorlessClose fd
withConnection ::
Endpoint
-> (Either CloseException () -> a -> IO b)
-> (Connection -> IO a)
-> IO (Either (ConnectException 'Uninterruptible) b)
withConnection !remote g f = mask $ \restore -> do
connect remote >>= \case
Left err -> pure (Left err)
Right conn -> do
a <- onException (restore (f conn)) (disconnect_ conn)
m <- disconnect conn
b <- g m a
pure (Right b)
interruptibleWithConnection ::
TVar Bool
-> Endpoint
-> (Either CloseException () -> a -> IO b)
-> (Connection -> IO a)
-> IO (Either (ConnectException 'Interruptible) b)
interruptibleWithConnection !abandon !remote g f = mask $ \restore -> do
interruptibleConnect abandon remote >>= \case
Left err -> pure (Left err)
Right conn -> do
a <- onException (restore (f conn)) (disconnect_ conn)
m <- disconnect conn
b <- g m a
pure (Right b)
sendByteArray ::
Connection
-> ByteArray
-> IO (Either (SendException 'Uninterruptible) ())
sendByteArray conn arr =
sendByteArraySlice conn arr 0 (PM.sizeofByteArray arr)
interruptibleSendByteArray ::
TVar Bool
-> Connection
-> ByteArray
-> IO (Either (SendException 'Interruptible) ())
interruptibleSendByteArray abandon conn arr =
interruptibleSendByteArraySlice abandon conn arr 0 (PM.sizeofByteArray arr)
sendByteString ::
Connection
-> ByteString
-> IO (Either (SendException 'Uninterruptible) ())
sendByteString !conn !payload = BU.unsafeUseAsCStringLen payload
(\(Ptr addr,len) -> sendAddr conn (Addr addr) len)
sendLazyByteString ::
Connection
-> LBS.ByteString
-> IO (Either (SendException 'Uninterruptible) ())
sendLazyByteString !conn !chunks0 = go chunks0 where
go LBS.Empty = pure (Right ())
go (LBS.Chunk chunk chunks) = sendByteString conn chunk >>= \case
Left e -> pure (Left e)
Right _ -> go chunks
sendAddr ::
Connection
-> Addr
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendAddr !conn !payload0 !len0 = go payload0 len0
where
go !payload !len = if len > 0
then internalSendAddr conn payload len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (PM.plusAddr payload sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendByteArray
[SCK.negativeSliceLength]
sendByteArraySlice ::
Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendByteArraySlice !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalSend conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendByteArray
[SCK.negativeSliceLength]
interruptibleSendByteArraySlice ::
TVar Bool
-> Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) ())
interruptibleSendByteArraySlice !abandon !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalInterruptibleSend abandon conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendByteArray
[SCK.negativeSliceLength]
sendMutableByteArray ::
Connection
-> MutableByteArray RealWorld
-> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArray conn arr =
sendMutableByteArraySlice conn arr 0 =<< PM.getSizeofMutableByteArray arr
sendMutableByteArraySlice ::
Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) ())
sendMutableByteArraySlice !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalSendMutable conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendMutableByteArray
[SCK.negativeSliceLength]
interruptibleSendMutableByteArraySlice ::
TVar Bool
-> Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) ())
interruptibleSendMutableByteArraySlice !abandon !conn !payload !off0 !len0 = go off0 len0
where
go !off !len = if len > 0
then internalInterruptibleSendMutable abandon conn payload off len >>= \case
Left e -> pure (Left e)
Right sz' -> do
let sz = csizeToInt sz'
go (off + sz) (len - sz)
else if len == 0
then pure (Right ())
else throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionSendMutableByteArray
[SCK.negativeSliceLength]
internalInterruptibleSendMutable ::
TVar Bool
-> Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) CSize)
internalInterruptibleSendMutable !abandon !conn !payload !off !len =
veryInternalSendMutable
(\fd -> interruptibleWaitWrite abandon fd >>= \case
True -> pure (Right ())
False -> pure (Left SendInterrupted)
) conn payload off len
internalSendMutable ::
Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) CSize)
internalSendMutable !conn !payload !off !len =
veryInternalSendMutable
(\fd -> threadWaitWrite fd *> pure (Right ()))
conn payload off len
veryInternalSendMutable ::
(Fd -> IO (Either (SendException i) ()))
-> Connection
-> MutableByteArray RealWorld
-> Int
-> Int
-> IO (Either (SendException i) CSize)
{-# INLINE veryInternalSendMutable #-}
veryInternalSendMutable wait (Connection !s) !payload !off !len = do
e1 <- S.uninterruptibleSendMutableByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
wait s >>= \case
Left err2 -> pure (Left err2)
Right () -> do
e3 <- S.uninterruptibleSendMutableByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
case e3 of
Left err3 -> handleSendException functionSendMutableByteArray err3
Right sz -> pure (Right sz)
else handleSendException "sendMutableByteArray" err1
Right sz -> pure (Right sz)
internalSend ::
Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Uninterruptible) CSize)
internalSend !conn !payload !off !len = veryInternalSend
(\fd -> threadWaitWrite fd *> pure (Right ()))
conn payload off len
internalSendAddr ::
Connection
-> Addr
-> Int
-> IO (Either (SendException 'Uninterruptible) CSize)
internalSendAddr !conn !payload !len = veryInternalSendAddr
(\fd -> threadWaitWrite fd *> pure (Right ()))
conn payload len
internalInterruptibleSend ::
TVar Bool
-> Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException 'Interruptible) CSize)
internalInterruptibleSend !abandon !conn !payload !off !len = veryInternalSend
(\fd -> interruptibleWaitWrite abandon fd >>= \case
True -> pure (Right ())
False -> pure (Left SendInterrupted)
) conn payload off len
veryInternalSend ::
(Fd -> IO (Either (SendException i) ()))
-> Connection
-> ByteArray
-> Int
-> Int
-> IO (Either (SendException i) CSize)
{-# INLINE veryInternalSend #-}
veryInternalSend wait (Connection !s) !payload !off !len = do
debug ("veryInternalSend: about to send chunk on stream socket, offset " ++ show off ++ " and length " ++ show len)
e1 <- S.uninterruptibleSendByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
debug "veryInternalSend: just sent chunk on stream socket"
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug "veryInternalSend: waiting to for write ready on stream socket"
wait s >>= \case
Left e -> pure (Left e)
Right _ -> do
e2 <- S.uninterruptibleSendByteArray s payload
(intToCInt off)
(intToCSize len)
(S.noSignal)
case e2 of
Left err2 -> do
debug "veryInternalSend: encountered error after sending chunk on stream socket"
handleSendException functionSendByteArray err2
Right sz -> pure (Right sz)
else handleSendException functionSendByteArray err1
Right sz -> pure (Right sz)
veryInternalSendAddr ::
(Fd -> IO (Either (SendException i) ()))
-> Connection
-> Addr
-> Int
-> IO (Either (SendException i) CSize)
{-# INLINE veryInternalSendAddr #-}
veryInternalSendAddr wait (Connection !s) !payload !len = do
debug ("veryInternalSendAddr: about to send chunk on stream socket, length " ++ show len)
e1 <- S.uninterruptibleSend s payload
(intToCSize len)
(S.noSignal)
debug "veryInternalSendAddr: just sent chunk on stream socket"
case e1 of
Left err1 -> if err1 == eWOULDBLOCK || err1 == eAGAIN
then do
debug "veryInternalSendAddr: waiting to for write ready on stream socket"
wait s >>= \case
Left e -> pure (Left e)
Right _ -> do
e2 <- S.uninterruptibleSend s payload
(intToCSize len)
(S.noSignal)
case e2 of
Left err2 -> do
debug "veryInternalSendAddr: encountered error after sending chunk on stream socket"
handleSendException functionSendByteArray err2
Right sz -> pure (Right sz)
else handleSendException functionSendByteArray err1
Right sz -> pure (Right sz)
internalReceiveMaximally ::
Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException i) Int)
internalReceiveMaximally (Connection !fd) !maxSz !buf !off = do
debug "receive: stream socket about to wait"
threadWaitRead fd
debug ("receive: stream socket is now readable, receiving up to " ++ show maxSz ++ " bytes at offset " ++ show off)
e <- S.uninterruptibleReceiveMutableByteArray fd buf (intToCInt off) (intToCSize maxSz) mempty
debug "receive: finished reading from stream socket"
case e of
Left err -> handleReceiveException "internalReceiveMaximally" err
Right recvSz -> pure (Right (csizeToInt recvSz))
internalInterruptibleReceiveMaximally ::
TVar Bool
-> Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException 'Interruptible) Int)
{-# INLINE internalInterruptibleReceiveMaximally #-}
internalInterruptibleReceiveMaximally abandon (Connection !fd) !maxSz !buf !off = do
shouldReceive <- interruptibleWaitRead abandon fd
if shouldReceive
then do
e <- S.uninterruptibleReceiveMutableByteArray fd buf (intToCInt off) (intToCSize maxSz) mempty
case e of
Left err -> handleReceiveException "internalReceiveMaximally" err
Right recvSz -> pure (Right (csizeToInt recvSz))
else pure (Left ReceiveInterrupted)
receiveByteArray ::
Connection
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) ByteArray)
receiveByteArray !conn !total =
internalReceiveByteArray internalReceiveMaximally conn total
interruptibleReceiveByteArray ::
TVar Bool
-> Connection
-> Int
-> IO (Either (ReceiveException 'Interruptible) ByteArray)
interruptibleReceiveByteArray !abandon !conn !total =
internalReceiveByteArray (internalInterruptibleReceiveMaximally abandon) conn total
internalReceiveByteArray ::
(Connection -> Int -> MutableByteArray RealWorld -> Int -> IO (Either (ReceiveException i) Int))
-> Connection
-> Int
-> IO (Either (ReceiveException i) ByteArray)
internalReceiveByteArray recvMax !conn0 !total = do
marr <- PM.newByteArray total
go conn0 marr 0 total
where
go !conn !marr !off !remaining = case compare remaining 0 of
GT -> do
recvMax conn remaining marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then go conn marr (off + sz) (remaining - sz)
else pure (Left ReceiveShutdown)
EQ -> do
arr <- PM.unsafeFreezeByteArray marr
pure (Right arr)
LT -> throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveByteArray
[SCK.negativeSliceLength]
receiveByteString ::
Connection
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) ByteString)
receiveByteString !conn !total = do
marr@(MutableByteArray marr#) <- PM.newPinnedByteArray total
receiveMutableByteArray conn marr >>= \case
Left err -> pure (Left err)
Right _ -> pure (Right (BI.PS (FP.ForeignPtr (byteArrayContents# (unsafeCoerce# marr#)) (FP.PlainPtr marr#)) 0 total))
receiveMutableByteArray ::
Connection
-> MutableByteArray RealWorld
-> IO (Either (ReceiveException 'Uninterruptible) ())
receiveMutableByteArray !conn0 !marr0 = do
total <- PM.getSizeofMutableByteArray marr0
go conn0 marr0 0 total
where
go !conn !marr !off !remaining = if remaining > 0
then do
internalReceiveMaximally conn remaining marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then go conn marr (off + sz) (remaining - sz)
else pure (Left ReceiveShutdown)
else pure (Right ())
receiveBoundedMutableByteArraySlice ::
Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) Int)
receiveBoundedMutableByteArraySlice !conn !total !marr !off
| total > 0 = do
internalReceiveMaximally conn total marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then pure (Right sz)
else pure (Left ReceiveShutdown)
| total == 0 = pure (Right 0)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveMutableByteArraySlice
[SCK.negativeSliceLength]
interruptibleReceiveBoundedMutableByteArraySlice ::
TVar Bool
-> Connection
-> Int
-> MutableByteArray RealWorld
-> Int
-> IO (Either (ReceiveException 'Interruptible) Int)
interruptibleReceiveBoundedMutableByteArraySlice !abandon !conn !total !marr !off
| total > 0 = do
internalInterruptibleReceiveMaximally abandon conn total marr off >>= \case
Left err -> pure (Left err)
Right sz -> if sz /= 0
then pure (Right sz)
else pure (Left ReceiveShutdown)
| total == 0 = pure (Right 0)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveMutableByteArraySlice
[SCK.negativeSliceLength]
receiveBoundedByteArray ::
Connection
-> Int
-> IO (Either (ReceiveException 'Uninterruptible) ByteArray)
receiveBoundedByteArray !conn !total
| total > 0 = do
m <- PM.newByteArray total
receiveBoundedMutableByteArraySlice conn total m 0 >>= \case
Left err -> pure (Left err)
Right sz -> do
shrinkMutableByteArray m sz
Right <$> PM.unsafeFreezeByteArray m
| total == 0 = pure (Right mempty)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
functionReceiveBoundedByteArray
[SCK.negativeSliceLength]
endpointToSocketAddressInternet :: Endpoint -> S.SocketAddressInternet
endpointToSocketAddressInternet (Endpoint {address, port}) = S.SocketAddressInternet
{ port = S.hostToNetworkShort port
, address = S.hostToNetworkLong (getIPv4 address)
}
socketAddressInternetToEndpoint :: S.SocketAddressInternet -> Endpoint
socketAddressInternetToEndpoint (S.SocketAddressInternet {address,port}) = Endpoint
{ address = IPv4 (S.networkToHostLong address)
, port = S.networkToHostShort port
}
intToCInt :: Int -> CInt
intToCInt = fromIntegral
intToCSize :: Int -> CSize
intToCSize = fromIntegral
csizeToInt :: CSize -> Int
csizeToInt = fromIntegral
shrinkMutableByteArray :: MutableByteArray RealWorld -> Int -> IO ()
shrinkMutableByteArray (MutableByteArray arr) (I# sz) =
PM.primitive_ (shrinkMutableByteArray# arr sz)
moduleSocketStreamIPv4 :: String
moduleSocketStreamIPv4 = "Socket.Stream.IPv4"
functionSendMutableByteArray :: String
functionSendMutableByteArray = "sendMutableByteArray"
functionSendByteArray :: String
functionSendByteArray = "sendByteArray"
functionWithListener :: String
functionWithListener = "withListener"
functionReceiveBoundedByteArray :: String
functionReceiveBoundedByteArray = "receiveBoundedByteArray"
functionReceiveByteArray :: String
functionReceiveByteArray = "receiveByteArray"
functionReceiveMutableByteArraySlice :: String
functionReceiveMutableByteArraySlice = "receiveMutableByteArraySlice"
describeErrorCode :: Errno -> String
describeErrorCode err@(Errno e) = "error code " ++ D.string err ++ " (" ++ show e ++ ")"
handleReceiveException :: String -> Errno -> IO (Either (ReceiveException i) a)
handleReceiveException func e
| e == eCONNRESET = pure (Left ReceiveReset)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSendException :: String -> Errno -> IO (Either (SendException i) a)
{-# INLINE handleSendException #-}
handleSendException func e
| e == ePIPE = pure (Left SendShutdown)
| e == eCONNRESET = pure (Left SendReset)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleConnectException :: String -> Errno -> IO (Either (ConnectException i) a)
handleConnectException func e
| e == eACCES = pure (Left ConnectFirewalled)
| e == ePERM = pure (Left ConnectFirewalled)
| e == eNETUNREACH = pure (Left ConnectNetworkUnreachable)
| e == eCONNREFUSED = pure (Left ConnectRefused)
| e == eADDRNOTAVAIL = pure (Left ConnectEphemeralPortsExhausted)
| e == eTIMEDOUT = pure (Left ConnectTimeout)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSocketConnectException :: String -> Errno -> IO (Either (ConnectException i) a)
handleSocketConnectException func e
| e == eMFILE = pure (Left ConnectFileDescriptorLimit)
| e == eNFILE = pure (Left ConnectFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleSocketListenException :: String -> Errno -> IO (Either SocketException a)
handleSocketListenException func e
| e == eMFILE = pure (Left SocketFileDescriptorLimit)
| e == eNFILE = pure (Left SocketFileDescriptorLimit)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleBindListenException :: Word16 -> String -> Errno -> IO (Either SocketException a)
handleBindListenException thePort func e
| e == eACCES = pure (Left SocketPermissionDenied)
| e == eADDRINUSE = if thePort == 0
then pure (Left SocketAddressInUse)
else pure (Left SocketEphemeralPortsExhausted)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
handleAcceptException :: String -> Errno -> IO (Either (AcceptException i) a)
handleAcceptException func e
| e == eCONNABORTED = pure (Left AcceptConnectionAborted)
| e == eMFILE = pure (Left AcceptFileDescriptorLimit)
| e == eNFILE = pure (Left AcceptFileDescriptorLimit)
| e == ePERM = pure (Left AcceptFirewalled)
| otherwise = throwIO $ SocketUnrecoverableException
moduleSocketStreamIPv4
func
[describeErrorCode e]
connectErrorOptionValueSize :: String
connectErrorOptionValueSize = "incorrectly sized value of SO_ERROR option"
interruptibleWaitRead :: TVar Bool -> Fd -> IO Bool
interruptibleWaitRead !abandon !fd = do
(isReady,deregister) <- threadWaitReadSTM fd
shouldReceive <- atomically
((bool retry (pure False) =<< readTVar abandon) <|> (isReady $> True))
deregister
pure shouldReceive
interruptibleWaitWrite :: TVar Bool -> Fd -> IO Bool
interruptibleWaitWrite !abandon !fd = do
(isReady,deregister) <- threadWaitWriteSTM fd
shouldSend <- atomically
((bool retry (pure False) =<< readTVar abandon) <|> (isReady $> True))
deregister
pure shouldSend
interruptibleWaitReadCounting :: TVar Int -> TVar Bool -> Fd -> IO Bool
interruptibleWaitReadCounting !counter !abandon !fd = do
(isReady,deregister) <- threadWaitReadSTM fd
shouldReceive <- atomically $ do
readTVar abandon >>= \case
False -> do
isReady
modifyTVar' counter (+1)
pure True
True -> pure False
deregister
pure shouldReceive