-- | TCP implementation of the transport layer.
--
-- The TCP implementation guarantees that only a single TCP connection (socket)
-- will be used between endpoints, provided that the addresses specified are
-- canonical. If /A/ connects to /B/ and reports its address as
-- @192.168.0.1:8080@ and /B/ subsequently connects tries to connect to /A/ as
-- @client1.local:http-alt@ then the transport layer will not realize that the
-- TCP connection can be reused.
--
-- Applications that use the TCP transport should use
-- 'Network.Socket.withSocketsDo' in their main function for Windows
-- compatibility (see "Network.Socket").

{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE LambdaCase #-}

module Network.Transport.TCP
  ( -- * Main API
    createTransport
  , TCPAddr(..)
  , defaultTCPAddr
  , TCPAddrInfo(..)
  , TCPParameters(..)
  , defaultTCPParameters
    -- * Internals (exposed for unit tests)
  , createTransportExposeInternals
  , TransportInternals(..)
  , EndPointId
  , ControlHeader(..)
  , ConnectionRequestResponse(..)
  , firstNonReservedLightweightConnectionId
  , firstNonReservedHeavyweightConnectionId
  , socketToEndPoint
  , LightweightConnectionId
  , QDisc(..)
  , simpleUnboundedQDisc
  , simpleOnePlaceQDisc
    -- * Design notes
    -- $design
  ) where

import Prelude hiding
  ( mapM_
#if ! MIN_VERSION_base(4,6,0)
  , catch
#endif
  )

import Network.Transport
import Network.Transport.TCP.Internal
  ( ControlHeader(..)
  , encodeControlHeader
  , decodeControlHeader
  , ConnectionRequestResponse(..)
  , encodeConnectionRequestResponse
  , decodeConnectionRequestResponse
  , forkServer
  , recvWithLength
  , recvExact
  , recvWord32
  , encodeWord32
  , tryCloseSocket
  , tryShutdownSocketBoth
  , resolveSockAddr
  , EndPointId
  , encodeEndPointAddress
  , decodeEndPointAddress
  , currentProtocolVersion
  , randomEndPointAddress
  )
import Network.Transport.Internal
  ( prependLength
  , mapIOException
  , tryIO
  , tryToEnum
  , void
  , timeoutMaybe
  , asyncWhenCancelled
  )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
  ( HostName
  , ServiceName
  , Socket
  , getAddrInfo
  , maxListenQueue
  , socket
  , addrFamily
  , addrAddress
  , SocketType(Stream)
  , defaultProtocol
  , setSocketOption
  , SocketOption(ReuseAddr, NoDelay, UserTimeout, KeepAlive)
  , isSupportedSocketOption
  , connect
  , AddrInfo
  , SockAddr(..)
  )

#ifdef USE_MOCK_NETWORK
import Network.Transport.TCP.Mock.Socket.ByteString (sendMany)
#else
import Network.Socket.ByteString (sendMany)
#endif

import Control.Concurrent
  ( forkIO
  , ThreadId
  , killThread
  , myThreadId
  , threadDelay
  , throwTo
  )
import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
import Control.Concurrent.MVar
  ( MVar
  , newMVar
  , modifyMVar
  , modifyMVar_
  , readMVar
  , tryReadMVar
  , takeMVar
  , putMVar
  , tryPutMVar
  , newEmptyMVar
  , withMVar
  )
import Control.Concurrent.Async (async, wait)
import Control.Category ((>>>))
import Control.Applicative ((<$>))
import Control.Monad (when, unless, join, mplus, (<=<))
import Control.Exception
  ( IOException
  , SomeException
  , AsyncException
  , handle
  , throw
  , throwIO
  , try
  , bracketOnError
  , bracket
  , fromException
  , finally
  , catch
  , bracket
  , mask
  , mask_
  )
import Data.IORef (IORef, newIORef, writeIORef, readIORef, writeIORef)
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (concat, length, null)
import qualified Data.ByteString.Char8 as BSC (pack, unpack)
import Data.Bits (shiftL, (.|.))
import Data.Maybe (isJust, isNothing, fromJust)
import Data.Word (Word32)
import Data.Set (Set)
import qualified Data.Set as Set
  ( empty
  , insert
  , elems
  , singleton
  , null
  , delete
  , member
  )
import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Traversable (traverse)
import Data.Accessor (Accessor, accessor, (^.), (^=), (^:))
import qualified Data.Accessor.Container as DAC (mapMaybe)
import Data.Foldable (forM_, mapM_)
import qualified System.Timeout (timeout)

-- $design
--
-- [Goals]
--
-- The TCP transport maps multiple logical connections between /A/ and /B/ (in
-- either direction) to a single TCP connection:
--
-- > +-------+                          +-------+
-- > | A     |==========================| B     |
-- > |       |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\   |
-- > |   Q   |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~Q   |
-- > |   \~~~|~~~~~~~~~~~~~~~~~~~~~~~~~<|       |
-- > |       |==========================|       |
-- > +-------+                          +-------+
--
-- Ignoring the complications detailed below, the TCP connection is set up is
-- when the first lightweight connection is created (in either direction), and
-- torn down when the last lightweight connection (in either direction) is
-- closed.
--
-- [Connecting]
--
-- Let /A/, /B/ be two endpoints without any connections. When /A/ wants to
-- connect to /B/, it locally records that it is trying to connect to /B/ and
-- sends a request to /B/. As part of the request /A/ sends its own endpoint
-- address to /B/ (so that /B/ can reuse the connection in the other direction).
--
-- When /B/ receives the connection request it first checks if it did not
-- already initiate a connection request to /A/. If not it will acknowledge the
-- connection request by sending 'ConnectionRequestAccepted' to /A/ and record
-- that it has a TCP connection to /A/.
--
-- The tricky case arises when /A/ sends a connection request to /B/ and /B/
-- finds that it had already sent a connection request to /A/. In this case /B/
-- will accept the connection request from /A/ if /A/s endpoint address is
-- smaller (lexicographically) than /B/s, and reject it otherwise. If it rejects
-- it, it sends a 'ConnectionRequestCrossed' message to /A/. The
-- lexicographical ordering is an arbitrary but convenient way to break the
-- tie. If a connection exists between /A/ and /B/ when /B/ rejects the request,
-- /B/ will probe the connection to make sure it is healthy. If /A/ does not
-- answer timely to the probe, /B/ will discard the connection.
--
-- When it receives a 'ConnectionRequestCrossed' message the /A/ thread that
-- initiated the request just needs to wait until the /A/ thread that is dealing
-- with /B/'s connection request completes, unless there is a network failure.
-- If there is a network failure, the initiator thread would timeout and return
-- an error.
--
-- [Disconnecting]
--
-- The TCP connection is created as soon as the first logical connection from
-- /A/ to /B/ (or /B/ to /A/) is established. At this point a thread (@#@) is
-- spawned that listens for incoming connections from /B/:
--
-- > +-------+                          +-------+
-- > | A     |==========================| B     |
-- > |       |>~~~~~~~~~~~~~~~~~~~~~~~~~|~~~\   |
-- > |       |                          |   Q   |
-- > |      #|                          |       |
-- > |       |==========================|       |
-- > +-------+                          +-------+
--
-- The question is when the TCP connection can be closed again.  Conceptually,
-- we want to do reference counting: when there are no logical connections left
-- between /A/ and /B/ we want to close the socket (possibly after some
-- timeout).
--
-- However, /A/ and /B/ need to agree that the refcount has reached zero.  It
-- might happen that /B/ sends a connection request over the existing socket at
-- the same time that /A/ closes its logical connection to /B/ and closes the
-- socket. This will cause a failure in /B/ (which will have to retry) which is
-- not caused by a network failure, which is unfortunate. (Note that the
-- connection request from /B/ might succeed even if /A/ closes the socket.)
--
-- Instead, when /A/ is ready to close the socket it sends a 'CloseSocket'
-- request to /B/ and records that its connection to /B/ is closing. If /A/
-- receives a new connection request from /B/ after having sent the
-- 'CloseSocket' request it simply forgets that it sent a 'CloseSocket' request
-- and increments the reference count of the connection again.
--
-- When /B/ receives a 'CloseSocket' message and it too is ready to close the
-- connection, it will respond with a reciprocal 'CloseSocket' request to /A/
-- and then actually close the socket. /A/ meanwhile will not send any more
-- requests to /B/ after having sent a 'CloseSocket' request, and will actually
-- close its end of the socket only when receiving the 'CloseSocket' message
-- from /B/. (Since /A/ recorded that its connection to /B/ is in closing state
-- after sending a 'CloseSocket' request to /B/, it knows not to reciprocate /B/
-- reciprocal 'CloseSocket' message.)
--
-- If there is a concurrent thread in /A/ waiting to connect to /B/ after /A/
-- has sent a 'CloseSocket' request then this thread will block until /A/ knows
-- whether to reuse the old socket (if /B/ sends a new connection request
-- instead of acknowledging the 'CloseSocket') or to set up a new socket.

--------------------------------------------------------------------------------
-- Internal datatypes                                                         --
--------------------------------------------------------------------------------

-- We use underscores for fields that we might update (using accessors)
--
-- All data types follow the same structure:
--
-- * A top-level data type describing static properties (TCPTransport,
--   LocalEndPoint, RemoteEndPoint)
-- * The 'static' properties include an MVar containing a data structure for
--   the dynamic properties (TransportState, LocalEndPointState,
--   RemoteEndPointState). The state could be invalid/valid/closed,/etc.
-- * For the case of "valid" we use third data structure to give more details
--   about the state (ValidTransportState, ValidLocalEndPointState,
--   ValidRemoteEndPointState).

-- | Information about the network addresses of a transport: the external
-- host/port as well as the bound host/port, which are not necessarily the
-- same.
data TransportAddrInfo = TransportAddrInfo
  { TransportAddrInfo -> HostName
transportHost     :: !N.HostName
  , TransportAddrInfo -> HostName
transportPort     :: !N.ServiceName
  , TransportAddrInfo -> HostName
transportBindHost :: !N.HostName
  , TransportAddrInfo -> HostName
transportBindPort :: !N.ServiceName
  }

data TCPTransport = TCPTransport
  { TCPTransport -> Maybe TransportAddrInfo
transportAddrInfo :: !(Maybe TransportAddrInfo)
    -- ^ This is 'Nothing' in case the transport is not addressable from the
    -- network: peers cannot connect to it unless it has a connection to the
    -- peer.
  , TCPTransport -> MVar TransportState
transportState    :: !(MVar TransportState)
  , TCPTransport -> TCPParameters
transportParams   :: !TCPParameters
  }

data TransportState =
    TransportValid !ValidTransportState
  | TransportClosed

data ValidTransportState = ValidTransportState
  { ValidTransportState -> Map HeavyweightConnectionId LocalEndPoint
_localEndPoints :: !(Map EndPointId LocalEndPoint)
  , ValidTransportState -> HeavyweightConnectionId
_nextEndPointId :: !EndPointId
  }

data LocalEndPoint = LocalEndPoint
  { LocalEndPoint -> EndPointAddress
localAddress    :: !EndPointAddress
  , LocalEndPoint -> HeavyweightConnectionId
localEndPointId :: !EndPointId
  , LocalEndPoint -> MVar LocalEndPointState
localState      :: !(MVar LocalEndPointState)
    -- | A 'QDisc' is held here rather than on the 'ValidLocalEndPointState'
    --   because even closed 'LocalEndPoint's can have queued input data.
  , LocalEndPoint -> QDisc Event
localQueue      :: !(QDisc Event)
  }

data LocalEndPointState =
    LocalEndPointValid !ValidLocalEndPointState
  | LocalEndPointClosed

data ValidLocalEndPointState = ValidLocalEndPointState
  { -- Next available ID for an outgoing lightweight self-connection
    -- (see also remoteNextConnOutId)
    ValidLocalEndPointState -> HeavyweightConnectionId
_localNextConnOutId :: !LightweightConnectionId
    -- Next available ID for an incoming heavyweight connection
  , ValidLocalEndPointState -> HeavyweightConnectionId
_nextConnInId :: !HeavyweightConnectionId
    -- Currently active outgoing heavyweight connections
  , ValidLocalEndPointState -> Map EndPointAddress RemoteEndPoint
_localConnections :: !(Map EndPointAddress RemoteEndPoint)
  }

-- REMOTE ENDPOINTS
--
-- Remote endpoints (basically, TCP connections) have the following lifecycle:
--
--   Init  ---+---> Invalid
--            |
--            +-------------------------------\
--            |                               |
--            |       /----------\            |
--            |       |          |            |
--            |       v          |            v
--            +---> Valid ---> Closing ---> Closed
--            |       |          |            |
--            |       |          |            v
--            \-------+----------+--------> Failed
--
-- Init: There are two places where we create new remote endpoints: in
--   createConnectionTo (in response to an API 'connect' call) and in
--   handleConnectionRequest (when a remote node tries to connect to us).
--   'Init' carries an MVar () 'resolved' which concurrent threads can use to
--   wait for the remote endpoint to finish initialization. We record who
--   requested the connection (the local endpoint or the remote endpoint).
--
-- Invalid: We put the remote endpoint in invalid state only during
--   createConnectionTo when we fail to connect.
--
-- Valid: This is the "normal" state for a working remote endpoint.
--
-- Closing: When we detect that a remote endpoint is no longer used, we send a
--   CloseSocket request across the connection and put the remote endpoint in
--   closing state. As with Init, 'Closing' carries an MVar () 'resolved' which
--   concurrent threads can use to wait for the remote endpoint to either be
--   closed fully (if the communication parnet responds with another
--   CloseSocket) or be put back in 'Valid' state if the remote endpoint denies
--   the request.
--
--   We also put the endpoint in Closed state, directly from Init, if we our
--   outbound connection request crossed an inbound connection request and we
--   decide to keep the inbound (i.e., the remote endpoint sent us a
--   ConnectionRequestCrossed message).
--
-- Closed: The endpoint is put in Closed state after a successful garbage
--   collection.
--
-- Failed: If the connection to the remote endpoint is lost, or the local
-- endpoint (or the whole transport) is closed manually, the remote endpoint is
-- put in Failed state, and we record the reason.
--
-- Invariants for dealing with remote endpoints:
--
-- INV-SEND: Whenever we send data the remote endpoint must be locked (to avoid
--   interleaving bits of payload).
--
-- INV-CLOSE: Local endpoints should never point to remote endpoint in closed
--   state.  Whenever we put an endpoint in Closed state we remove that
--   endpoint from localConnections first, so that if a concurrent thread reads
--   the MVar, finds RemoteEndPointClosed, and then looks up the endpoint in
--   localConnections it is guaranteed to either find a different remote
--   endpoint, or else none at all (if we don't insist in this order some
--   threads might start spinning).
--
-- INV-RESOLVE: We should only signal on 'resolved' while the remote endpoint is
--   locked, and the remote endpoint must be in Valid or Closed state once
--   unlocked. This guarantees that there will not be two threads attempting to
--   both signal on 'resolved'.
--
-- INV-LOST: If a send or recv fails, or a socket is closed unexpectedly, we
--   first put the remote endpoint in Closed state, and then send a
--   EventConnectionLost event. This guarantees that we only send this event
--   once.
--
-- INV-CLOSING: An endpoint in closing state is for all intents and purposes
--   closed; that is, we shouldn't do any 'send's on it (although 'recv' is
--   acceptable, of course -- as we are waiting for the remote endpoint to
--   confirm or deny the request).
--
-- INV-LOCK-ORDER: Remote endpoint must be locked before their local endpoints.
--   In other words: it is okay to call modifyMVar on a local endpoint inside a
--   modifyMVar on a remote endpoint, but not the other way around. In
--   particular, it is okay to call removeRemoteEndPoint inside
--   modifyRemoteState.

data RemoteEndPoint = RemoteEndPoint
  { RemoteEndPoint -> EndPointAddress
remoteAddress   :: !EndPointAddress
  , RemoteEndPoint -> MVar RemoteState
remoteState     :: !(MVar RemoteState)
  , RemoteEndPoint -> HeavyweightConnectionId
remoteId        :: !HeavyweightConnectionId
  , RemoteEndPoint -> Chan (IO ())
remoteScheduled :: !(Chan (IO ()))
  }

data RequestedBy = RequestedByUs | RequestedByThem
  deriving (RequestedBy -> RequestedBy -> Bool
(RequestedBy -> RequestedBy -> Bool)
-> (RequestedBy -> RequestedBy -> Bool) -> Eq RequestedBy
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RequestedBy -> RequestedBy -> Bool
== :: RequestedBy -> RequestedBy -> Bool
$c/= :: RequestedBy -> RequestedBy -> Bool
/= :: RequestedBy -> RequestedBy -> Bool
Eq, Int -> RequestedBy -> ShowS
[RequestedBy] -> ShowS
RequestedBy -> HostName
(Int -> RequestedBy -> ShowS)
-> (RequestedBy -> HostName)
-> ([RequestedBy] -> ShowS)
-> Show RequestedBy
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RequestedBy -> ShowS
showsPrec :: Int -> RequestedBy -> ShowS
$cshow :: RequestedBy -> HostName
show :: RequestedBy -> HostName
$cshowList :: [RequestedBy] -> ShowS
showList :: [RequestedBy] -> ShowS
Show)

data RemoteState =
    -- | Invalid remote endpoint (for example, invalid address)
    RemoteEndPointInvalid !(TransportError ConnectErrorCode)
    -- | The remote endpoint is being initialized
  | RemoteEndPointInit !(MVar ()) !(MVar ()) !RequestedBy
    -- | "Normal" working endpoint
  | RemoteEndPointValid !ValidRemoteEndPointState
    -- | The remote endpoint is being closed (garbage collected)
  | RemoteEndPointClosing !(MVar ()) !ValidRemoteEndPointState
    -- | The remote endpoint has been closed (garbage collected)
  | RemoteEndPointClosed
    -- | The remote endpoint has failed, or has been forcefully shutdown
    -- using a closeTransport or closeEndPoint API call
  | RemoteEndPointFailed !IOException

-- TODO: we might want to replace Set (here and elsewhere) by faster
-- containers
--
-- TODO: we could get rid of 'remoteIncoming' (and maintain less state) if
-- we introduce a new event 'AllConnectionsClosed'
data ValidRemoteEndPointState = ValidRemoteEndPointState
  { ValidRemoteEndPointState -> Int
_remoteOutgoing      :: !Int
  , ValidRemoteEndPointState -> Set HeavyweightConnectionId
_remoteIncoming      :: !(Set LightweightConnectionId)
  , ValidRemoteEndPointState -> HeavyweightConnectionId
_remoteLastIncoming  :: !LightweightConnectionId
  , ValidRemoteEndPointState -> HeavyweightConnectionId
_remoteNextConnOutId :: !LightweightConnectionId
  ,  ValidRemoteEndPointState -> Socket
remoteSocket        :: !N.Socket
     -- | When the connection is being probed, yields an IO action that can be
     -- used to release any resources dedicated to the probing.
  ,  ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing       :: Maybe (IO ())
     -- | MVar protects the socket usage by the concurrent threads and
     -- prohibits its usage after SomeException.
     --
     -- Nothing allows the socket usage. @Just e@ is set on an
     -- exception after which the socket should not be used (see 'sendOn').
  ,  ValidRemoteEndPointState -> MVar (Maybe SomeException)
remoteSendLock      :: !(MVar (Maybe SomeException))
     -- | An IO which returns when the socket (remoteSocket) has been closed.
     --   The program/thread which created the socket is always responsible
     --   for closing it, but sometimes other threads need to know when this
     --   happens.
  ,  ValidRemoteEndPointState -> IO ()
remoteSocketClosed  :: !(IO ())
  }

-- | Pair of local and a remote endpoint (for conciseness in signatures)
type EndPointPair = (LocalEndPoint, RemoteEndPoint)

-- | Lightweight connection ID (sender allocated)
--
-- A ConnectionId is the concentation of a 'HeavyweightConnectionId' and a
-- 'LightweightConnectionId'.
type LightweightConnectionId = Word32

-- | Heavyweight connection ID (recipient allocated)
--
-- A ConnectionId is the concentation of a 'HeavyweightConnectionId' and a
-- 'LightweightConnectionId'.
type HeavyweightConnectionId = Word32

-- | A transport which is addressable from the network must give a host/port
-- on which to bind/listen, and determine its external address (host/port) from
-- the actual port (which may not be known, in case 0 is used for the bind
-- port).
data TCPAddrInfo = TCPAddrInfo {
    TCPAddrInfo -> HostName
tcpBindHost :: N.HostName
  , TCPAddrInfo -> HostName
tcpBindPort :: N.ServiceName
  , TCPAddrInfo -> HostName -> (HostName, HostName)
tcpExternalAddress :: N.ServiceName -> (N.HostName, N.ServiceName)
  }

-- | Addressability of a transport. If your transport cannot be connected
-- to, for instance because it runs behind NAT, use Unaddressable.
data TCPAddr = Addressable TCPAddrInfo | Unaddressable

-- | The bind and external host/port are the same.
defaultTCPAddr :: N.HostName -> N.ServiceName -> TCPAddr
defaultTCPAddr :: HostName -> HostName -> TCPAddr
defaultTCPAddr HostName
host HostName
port = TCPAddrInfo -> TCPAddr
Addressable (TCPAddrInfo -> TCPAddr) -> TCPAddrInfo -> TCPAddr
forall a b. (a -> b) -> a -> b
$ TCPAddrInfo {
    tcpBindHost :: HostName
tcpBindHost = HostName
host
  , tcpBindPort :: HostName
tcpBindPort = HostName
port
  , tcpExternalAddress :: HostName -> (HostName, HostName)
tcpExternalAddress = (,) HostName
host
  }

-- | Parameters for setting up the TCP transport
data TCPParameters = TCPParameters {
    -- | Backlog for 'listen'.
    -- Defaults to SOMAXCONN.
    TCPParameters -> Int
tcpBacklog :: Int
    -- | Should we set SO_REUSEADDR on the server socket?
    -- Defaults to True.
  , TCPParameters -> Bool
tcpReuseServerAddr :: Bool
    -- | Should we set SO_REUSEADDR on client sockets?
    -- Defaults to True.
  , TCPParameters -> Bool
tcpReuseClientAddr :: Bool
    -- | Should we set TCP_NODELAY on connection sockets?
    -- Defaults to True.
  , TCPParameters -> Bool
tcpNoDelay :: Bool
    -- | Should we set TCP_KEEPALIVE on connection sockets?
    -- Defaults to False.
  , TCPParameters -> Bool
tcpKeepAlive :: Bool
    -- | Value of TCP_USER_TIMEOUT in milliseconds
  , TCPParameters -> Maybe Int
tcpUserTimeout :: Maybe Int
    -- | A connect timeout for all 'connect' calls of the transport
    -- in microseconds
    --
    -- This can be overriden for each connect call with
    -- 'ConnectHints'.'connectTimeout'.
    --
    -- Connection requests to this transport will also timeout if they don't
    -- send the required data before this many microseconds.
    --
    -- Defaults to Nothing (no timeout).
  , TCPParameters -> Maybe Int
transportConnectTimeout :: Maybe Int
    -- | Create a QDisc for an EndPoint.
  , TCPParameters -> forall t. IO (QDisc t)
tcpNewQDisc :: forall t . IO (QDisc t)
    -- | Maximum length (in bytes) for a peer's address.
    -- If a peer attempts to send an address of length exceeding the limit,
    -- the connection will be refused (socket will close).
  , TCPParameters -> HeavyweightConnectionId
tcpMaxAddressLength :: Word32
    -- | Maximum length (in bytes) to receive from a peer.
    -- If a peer attempts to send data on a lightweight connection exceeding
    -- the limit, the heavyweight connection which carries that lightweight
    -- connection will go down. The peer and the local node will get an
    -- EventConnectionLost.
  , TCPParameters -> HeavyweightConnectionId
tcpMaxReceiveLength :: Word32
    -- | If True, new connections will be accepted only if the socket's host
    -- matches the host that the peer claims in its EndPointAddress.
    -- This is useful when operating on untrusted networks, because the peer
    -- could otherwise deny service to some victim by claiming the victim's
    -- address.
    -- Defaults to False.
  , TCPParameters -> Bool
tcpCheckPeerHost :: Bool
    -- | What to do if there's an exception when accepting a new TCP
    -- connection. Throwing an exception here will cause the server to
    -- terminate.
    -- Defaults to `throwIO`.
  , TCPParameters -> SomeException -> IO ()
tcpServerExceptionHandler :: SomeException -> IO ()
  }

-- | Internal functionality we expose for unit testing
data TransportInternals = TransportInternals
  { -- | The ID of the thread that listens for new incoming connections
    TransportInternals -> Maybe ThreadId
transportThread     :: Maybe ThreadId
    -- | A variant of newEndPoint in which the QDisc determined by the
    -- transport's TCPParameters can be optionally overridden.
  , TransportInternals
-> (forall t. Maybe (QDisc t))
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
newEndPointInternal :: (forall t . Maybe (QDisc t))
                        -> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
    -- | Find the socket between a local and a remote endpoint
  , TransportInternals
-> EndPointAddress -> EndPointAddress -> IO Socket
socketBetween       :: EndPointAddress
                        -> EndPointAddress
                        -> IO N.Socket
  }

--------------------------------------------------------------------------------
-- Top-level functionality                                                    --
--------------------------------------------------------------------------------

-- | Create a TCP transport
createTransport
  :: TCPAddr
  -> TCPParameters
  -> IO (Either IOException Transport)
createTransport :: TCPAddr -> TCPParameters -> IO (Either IOError Transport)
createTransport TCPAddr
addr TCPParameters
params =
  (IOError -> Either IOError Transport)
-> ((Transport, TransportInternals) -> Either IOError Transport)
-> Either IOError (Transport, TransportInternals)
-> Either IOError Transport
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either IOError -> Either IOError Transport
forall a b. a -> Either a b
Left (Transport -> Either IOError Transport
forall a b. b -> Either a b
Right (Transport -> Either IOError Transport)
-> ((Transport, TransportInternals) -> Transport)
-> (Transport, TransportInternals)
-> Either IOError Transport
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Transport, TransportInternals) -> Transport
forall a b. (a, b) -> a
fst) (Either IOError (Transport, TransportInternals)
 -> Either IOError Transport)
-> IO (Either IOError (Transport, TransportInternals))
-> IO (Either IOError Transport)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TCPAddr
-> TCPParameters
-> IO (Either IOError (Transport, TransportInternals))
createTransportExposeInternals TCPAddr
addr TCPParameters
params

-- | You should probably not use this function (used for unit testing only)
createTransportExposeInternals
  :: TCPAddr
  -> TCPParameters
  -> IO (Either IOException (Transport, TransportInternals))
createTransportExposeInternals :: TCPAddr
-> TCPParameters
-> IO (Either IOError (Transport, TransportInternals))
createTransportExposeInternals TCPAddr
addr TCPParameters
params = do
    MVar TransportState
state <- TransportState -> IO (MVar TransportState)
forall a. a -> IO (MVar a)
newMVar (TransportState -> IO (MVar TransportState))
-> (ValidTransportState -> TransportState)
-> ValidTransportState
-> IO (MVar TransportState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValidTransportState -> TransportState
TransportValid (ValidTransportState -> IO (MVar TransportState))
-> ValidTransportState -> IO (MVar TransportState)
forall a b. (a -> b) -> a -> b
$ ValidTransportState
      { _localEndPoints :: Map HeavyweightConnectionId LocalEndPoint
_localEndPoints = Map HeavyweightConnectionId LocalEndPoint
forall k a. Map k a
Map.empty
      , _nextEndPointId :: HeavyweightConnectionId
_nextEndPointId = HeavyweightConnectionId
0
      }
    case TCPAddr
addr of

      TCPAddr
Unaddressable ->
        let transport :: TCPTransport
transport = TCPTransport { transportState :: MVar TransportState
transportState    = MVar TransportState
state
                                     , transportAddrInfo :: Maybe TransportAddrInfo
transportAddrInfo = Maybe TransportAddrInfo
forall a. Maybe a
Nothing
                                     , transportParams :: TCPParameters
transportParams   = TCPParameters
params
                                     }
        in  ((Transport, TransportInternals)
 -> Either IOError (Transport, TransportInternals))
-> IO (Transport, TransportInternals)
-> IO (Either IOError (Transport, TransportInternals))
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Transport, TransportInternals)
-> Either IOError (Transport, TransportInternals)
forall a b. b -> Either a b
Right (TCPTransport
-> Maybe ThreadId -> IO (Transport, TransportInternals)
mkTransport TCPTransport
transport Maybe ThreadId
forall a. Maybe a
Nothing)

      Addressable (TCPAddrInfo HostName
bindHost HostName
bindPort HostName -> (HostName, HostName)
mkExternal) -> IO (Transport, TransportInternals)
-> IO (Either IOError (Transport, TransportInternals))
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO (Transport, TransportInternals)
 -> IO (Either IOError (Transport, TransportInternals)))
-> IO (Transport, TransportInternals)
-> IO (Either IOError (Transport, TransportInternals))
forall a b. (a -> b) -> a -> b
$ mdo
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ( Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust (TCPParameters -> Maybe Int
tcpUserTimeout TCPParameters
params) Bool -> Bool -> Bool
&&
               Bool -> Bool
not (SocketOption -> Bool
N.isSupportedSocketOption SocketOption
N.UserTimeout)
             ) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError (HostName -> IOError) -> HostName -> IOError
forall a b. (a -> b) -> a -> b
$ HostName
"Network.Transport.TCP.createTransport: " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++
                                HostName
"the parameter tcpUserTimeout is unsupported " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++
                                HostName
"in this system."
        -- We don't know for sure the actual port 'forkServer' binded until it
        -- completes (see description of 'forkServer'), yet we need the port to
        -- construct a transport. So we tie a recursive knot.
        (HostName
port', (Transport, TransportInternals)
result) <- do
          let (HostName
externalHost, HostName
externalPort) = HostName -> (HostName, HostName)
mkExternal HostName
port'
          let addrInfo :: TransportAddrInfo
addrInfo = TransportAddrInfo { transportHost :: HostName
transportHost     = HostName
externalHost
                                           , transportPort :: HostName
transportPort     = HostName
externalPort
                                           , transportBindHost :: HostName
transportBindHost = HostName
bindHost
                                           , transportBindPort :: HostName
transportBindPort = HostName
port'
                                           }
          let transport :: TCPTransport
transport = TCPTransport { transportState :: MVar TransportState
transportState    = MVar TransportState
state
                                       , transportAddrInfo :: Maybe TransportAddrInfo
transportAddrInfo = TransportAddrInfo -> Maybe TransportAddrInfo
forall a. a -> Maybe a
Just TransportAddrInfo
addrInfo
                                       , transportParams :: TCPParameters
transportParams   = TCPParameters
params
                                       }
          IO (HostName, ThreadId)
-> ((HostName, ThreadId) -> IO ())
-> ((HostName, ThreadId)
    -> IO (HostName, (Transport, TransportInternals)))
-> IO (HostName, (Transport, TransportInternals))
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (HostName
-> HostName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (SomeException -> IO ())
-> (IO () -> (Socket, SockAddr) -> IO ())
-> IO (HostName, ThreadId)
forkServer
                              HostName
bindHost
                              HostName
bindPort
                              (TCPParameters -> Int
tcpBacklog TCPParameters
params)
                              (TCPParameters -> Bool
tcpReuseServerAddr TCPParameters
params)
                              (TCPTransport -> SomeException -> IO ()
errorHandler TCPTransport
transport)
                              (TCPTransport -> SomeException -> IO ()
terminationHandler TCPTransport
transport)
                              (TCPTransport -> IO () -> (Socket, SockAddr) -> IO ()
handleConnectionRequest TCPTransport
transport))
                       (\(HostName
_port', ThreadId
tid) -> ThreadId -> IO ()
killThread ThreadId
tid)
                       (\(HostName
port'', ThreadId
tid) -> (HostName
port'',) ((Transport, TransportInternals)
 -> (HostName, (Transport, TransportInternals)))
-> IO (Transport, TransportInternals)
-> IO (HostName, (Transport, TransportInternals))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TCPTransport
-> Maybe ThreadId -> IO (Transport, TransportInternals)
mkTransport TCPTransport
transport (ThreadId -> Maybe ThreadId
forall a. a -> Maybe a
Just ThreadId
tid))
        (Transport, TransportInternals)
-> IO (Transport, TransportInternals)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Transport, TransportInternals)
result
  where
    mkTransport :: TCPTransport
                -> Maybe ThreadId
                -> IO (Transport, TransportInternals)
    mkTransport :: TCPTransport
-> Maybe ThreadId -> IO (Transport, TransportInternals)
mkTransport TCPTransport
transport Maybe ThreadId
mtid = do
      (Transport, TransportInternals)
-> IO (Transport, TransportInternals)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
        ( Transport
            { newEndPoint :: IO (Either (TransportError NewEndPointErrorCode) EndPoint)
newEndPoint = do
                QDisc Event
qdisc <- TCPParameters -> forall t. IO (QDisc t)
tcpNewQDisc TCPParameters
params
                TCPTransport
-> QDisc Event
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint TCPTransport
transport QDisc Event
qdisc
            , closeTransport :: IO ()
closeTransport = let evs :: [Event]
evs = [ Event
EndPointClosed ]
                               in TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport TCPTransport
transport Maybe ThreadId
mtid [Event]
evs
            }
        , TransportInternals
            { transportThread :: Maybe ThreadId
transportThread     = Maybe ThreadId
mtid
            , socketBetween :: EndPointAddress -> EndPointAddress -> IO Socket
socketBetween       = TCPTransport -> EndPointAddress -> EndPointAddress -> IO Socket
internalSocketBetween TCPTransport
transport
            , newEndPointInternal :: (forall t. Maybe (QDisc t))
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
newEndPointInternal = \forall t. Maybe (QDisc t)
mqdisc -> case Maybe (QDisc Event)
forall t. Maybe (QDisc t)
mqdisc of
                Just QDisc Event
qdisc -> TCPTransport
-> QDisc Event
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint TCPTransport
transport QDisc Event
qdisc
                Maybe (QDisc Event)
Nothing -> do
                  QDisc Event
qdisc <- TCPParameters -> forall t. IO (QDisc t)
tcpNewQDisc TCPParameters
params
                  TCPTransport
-> QDisc Event
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint TCPTransport
transport QDisc Event
qdisc
            }
        )

    errorHandler :: TCPTransport -> SomeException -> IO ()
    errorHandler :: TCPTransport -> SomeException -> IO ()
errorHandler TCPTransport
_ = TCPParameters -> SomeException -> IO ()
tcpServerExceptionHandler TCPParameters
params

    terminationHandler :: TCPTransport -> SomeException -> IO ()
    terminationHandler :: TCPTransport -> SomeException -> IO ()
terminationHandler TCPTransport
transport SomeException
ex = do
      let evs :: [Event]
evs = [ TransportError EventErrorCode -> Event
ErrorEvent (EventErrorCode -> HostName -> TransportError EventErrorCode
forall error. error -> HostName -> TransportError error
TransportError EventErrorCode
EventTransportFailed (SomeException -> HostName
forall a. Show a => a -> HostName
show SomeException
ex))
                , IOError -> Event
forall a e. Exception e => e -> a
throw (IOError -> Event) -> IOError -> Event
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Transport closed"
                ]
      TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport TCPTransport
transport Maybe ThreadId
forall a. Maybe a
Nothing [Event]
evs

-- | Default TCP parameters
defaultTCPParameters :: TCPParameters
defaultTCPParameters :: TCPParameters
defaultTCPParameters = TCPParameters {
    tcpBacklog :: Int
tcpBacklog         = Int
N.maxListenQueue
  , tcpReuseServerAddr :: Bool
tcpReuseServerAddr = Bool
True
  , tcpReuseClientAddr :: Bool
tcpReuseClientAddr = Bool
True
  , tcpNoDelay :: Bool
tcpNoDelay         = Bool
True
  , tcpKeepAlive :: Bool
tcpKeepAlive       = Bool
False
  , tcpUserTimeout :: Maybe Int
tcpUserTimeout     = Maybe Int
forall a. Maybe a
Nothing
  , tcpNewQDisc :: forall t. IO (QDisc t)
tcpNewQDisc        = IO (QDisc t)
forall t. IO (QDisc t)
simpleUnboundedQDisc
  , transportConnectTimeout :: Maybe Int
transportConnectTimeout = Maybe Int
forall a. Maybe a
Nothing
  , tcpMaxAddressLength :: HeavyweightConnectionId
tcpMaxAddressLength = HeavyweightConnectionId
forall a. Bounded a => a
maxBound
  , tcpMaxReceiveLength :: HeavyweightConnectionId
tcpMaxReceiveLength = HeavyweightConnectionId
forall a. Bounded a => a
maxBound
  , tcpCheckPeerHost :: Bool
tcpCheckPeerHost   = Bool
False
  , tcpServerExceptionHandler :: SomeException -> IO ()
tcpServerExceptionHandler = SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO
  }

--------------------------------------------------------------------------------
-- API functions                                                              --
--------------------------------------------------------------------------------

-- | Close the transport
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport :: TCPTransport -> Maybe ThreadId -> [Event] -> IO ()
apiCloseTransport TCPTransport
transport Maybe ThreadId
mTransportThread [Event]
evs =
  (() -> IO ()) -> IO () -> IO ()
forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Maybe ValidTransportState
mTSt <- MVar TransportState
-> (TransportState
    -> IO (TransportState, Maybe ValidTransportState))
-> IO (Maybe ValidTransportState)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (TCPTransport -> MVar TransportState
transportState TCPTransport
transport) ((TransportState -> IO (TransportState, Maybe ValidTransportState))
 -> IO (Maybe ValidTransportState))
-> (TransportState
    -> IO (TransportState, Maybe ValidTransportState))
-> IO (Maybe ValidTransportState)
forall a b. (a -> b) -> a -> b
$ \TransportState
st -> case TransportState
st of
      TransportValid ValidTransportState
vst -> (TransportState, Maybe ValidTransportState)
-> IO (TransportState, Maybe ValidTransportState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TransportState
TransportClosed, ValidTransportState -> Maybe ValidTransportState
forall a. a -> Maybe a
Just ValidTransportState
vst)
      TransportState
TransportClosed    -> (TransportState, Maybe ValidTransportState)
-> IO (TransportState, Maybe ValidTransportState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TransportState
TransportClosed, Maybe ValidTransportState
forall a. Maybe a
Nothing)
    Maybe ValidTransportState
-> (ValidTransportState -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe ValidTransportState
mTSt ((ValidTransportState -> IO ()) -> IO ())
-> (ValidTransportState -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ (LocalEndPoint -> IO ())
-> Map HeavyweightConnectionId LocalEndPoint -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (TCPTransport -> [Event] -> LocalEndPoint -> IO ()
apiCloseEndPoint TCPTransport
transport [Event]
evs) (Map HeavyweightConnectionId LocalEndPoint -> IO ())
-> (ValidTransportState
    -> Map HeavyweightConnectionId LocalEndPoint)
-> ValidTransportState
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ValidTransportState
-> T ValidTransportState
     (Map HeavyweightConnectionId LocalEndPoint)
-> Map HeavyweightConnectionId LocalEndPoint
forall r a. r -> T r a -> a
^. T ValidTransportState (Map HeavyweightConnectionId LocalEndPoint)
localEndPoints)
    -- This will invoke the termination handler, which in turn will call
    -- apiCloseTransport again, but then the transport will already be closed
    -- and we won't be passed a transport thread, so we terminate immmediate
    Maybe ThreadId -> (ThreadId -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe ThreadId
mTransportThread ThreadId -> IO ()
killThread

-- | Create a new endpoint
apiNewEndPoint :: TCPTransport
               -> QDisc Event
               -> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint :: TCPTransport
-> QDisc Event
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
apiNewEndPoint TCPTransport
transport QDisc Event
qdisc =
  IO EndPoint
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO EndPoint
 -> IO (Either (TransportError NewEndPointErrorCode) EndPoint))
-> (IO EndPoint -> IO EndPoint)
-> IO EndPoint
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EndPoint -> IO ()) -> IO EndPoint -> IO EndPoint
forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled EndPoint -> IO ()
closeEndPoint (IO EndPoint
 -> IO (Either (TransportError NewEndPointErrorCode) EndPoint))
-> IO EndPoint
-> IO (Either (TransportError NewEndPointErrorCode) EndPoint)
forall a b. (a -> b) -> a -> b
$ do
    LocalEndPoint
ourEndPoint <- TCPTransport -> QDisc Event -> IO LocalEndPoint
createLocalEndPoint TCPTransport
transport QDisc Event
qdisc
    EndPoint -> IO EndPoint
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return EndPoint
      { receive :: IO Event
receive       = QDisc Event -> IO Event
forall t. QDisc t -> IO t
qdiscDequeue (LocalEndPoint -> QDisc Event
localQueue LocalEndPoint
ourEndPoint)
      , address :: EndPointAddress
address       = LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint
      , connect :: EndPointAddress
-> Reliability
-> ConnectHints
-> IO (Either (TransportError ConnectErrorCode) Connection)
connect       = TCPTransport
-> LocalEndPoint
-> EndPointAddress
-> Reliability
-> ConnectHints
-> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect TCPTransport
transport LocalEndPoint
ourEndPoint
      , closeEndPoint :: IO ()
closeEndPoint = let evs :: [Event]
evs = [ Event
EndPointClosed ]
                        in  TCPTransport -> [Event] -> LocalEndPoint -> IO ()
apiCloseEndPoint TCPTransport
transport [Event]
evs LocalEndPoint
ourEndPoint
      , newMulticastGroup :: IO
  (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
newMulticastGroup     = Either (TransportError NewMulticastGroupErrorCode) MulticastGroup
-> IO
     (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup
 -> IO
      (Either
         (TransportError NewMulticastGroupErrorCode) MulticastGroup))
-> (TransportError NewMulticastGroupErrorCode
    -> Either
         (TransportError NewMulticastGroupErrorCode) MulticastGroup)
-> TransportError NewMulticastGroupErrorCode
-> IO
     (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError NewMulticastGroupErrorCode
-> Either
     (TransportError NewMulticastGroupErrorCode) MulticastGroup
forall a b. a -> Either a b
Left (TransportError NewMulticastGroupErrorCode
 -> IO
      (Either
         (TransportError NewMulticastGroupErrorCode) MulticastGroup))
-> TransportError NewMulticastGroupErrorCode
-> IO
     (Either (TransportError NewMulticastGroupErrorCode) MulticastGroup)
forall a b. (a -> b) -> a -> b
$ TransportError NewMulticastGroupErrorCode
newMulticastGroupError
      , resolveMulticastGroup :: MulticastAddress
-> IO
     (Either
        (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
resolveMulticastGroup = Either
  (TransportError ResolveMulticastGroupErrorCode) MulticastGroup
-> IO
     (Either
        (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either
   (TransportError ResolveMulticastGroupErrorCode) MulticastGroup
 -> IO
      (Either
         (TransportError ResolveMulticastGroupErrorCode) MulticastGroup))
-> (MulticastAddress
    -> Either
         (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
-> MulticastAddress
-> IO
     (Either
        (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError ResolveMulticastGroupErrorCode
-> Either
     (TransportError ResolveMulticastGroupErrorCode) MulticastGroup
forall a b. a -> Either a b
Left (TransportError ResolveMulticastGroupErrorCode
 -> Either
      (TransportError ResolveMulticastGroupErrorCode) MulticastGroup)
-> (MulticastAddress
    -> TransportError ResolveMulticastGroupErrorCode)
-> MulticastAddress
-> Either
     (TransportError ResolveMulticastGroupErrorCode) MulticastGroup
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError ResolveMulticastGroupErrorCode
-> MulticastAddress
-> TransportError ResolveMulticastGroupErrorCode
forall a b. a -> b -> a
const TransportError ResolveMulticastGroupErrorCode
resolveMulticastGroupError
      }
  where
    newMulticastGroupError :: TransportError NewMulticastGroupErrorCode
newMulticastGroupError =
      NewMulticastGroupErrorCode
-> HostName -> TransportError NewMulticastGroupErrorCode
forall error. error -> HostName -> TransportError error
TransportError NewMulticastGroupErrorCode
NewMulticastGroupUnsupported HostName
"Multicast not supported"
    resolveMulticastGroupError :: TransportError ResolveMulticastGroupErrorCode
resolveMulticastGroupError =
      ResolveMulticastGroupErrorCode
-> HostName -> TransportError ResolveMulticastGroupErrorCode
forall error. error -> HostName -> TransportError error
TransportError ResolveMulticastGroupErrorCode
ResolveMulticastGroupUnsupported HostName
"Multicast not supported"

-- | Abstraction of a queue for an 'EndPoint'.
--
--   A value of type @QDisc t@ is a queue of events of an abstract type @t@.
--
--   This specifies which 'Event's will come from
--   'receive :: EndPoint -> IO Event' and when. It is highly general so that
--   the simple yet potentially very fast implementation backed by a single
--   unbounded channel can be used, without excluding more nuanced policies
--   like class-based queueing with bounded buffers for each peer, which may be
--   faster in certain conditions but probably has lower maximal throughput.
--
--   A 'QDisc' must satisfy some properties in order for the semantics of
--   network-transport to hold true. In general, an event fed with
--   'qdiscEnqueue' must not be dropped. i.e. provided that no other event in
--   the QDisc has higher priority, the event should eventually be returned by
--   'qdiscDequeue'. An exception to this are 'Receive' events of unreliable
--   connections.
--
--   Every call to 'receive' is just 'qdiscDequeue' on that 'EndPoint's
--   'QDisc'. Whenever an event arises from a socket, `qdiscEnqueue` is called
--   with the relevant metadata in the same thread that reads from the socket.
--   You can be clever about when to block here, so as to control network
--   ingress. This applies also to loopback connections (an 'EndPoint' connects
--   to itself), in which case blocking on the enqueue would only block some
--   thread in your program rather than some chatty network peer. The 'Event'
--   which is to be enqueued is given to 'qdiscEnqueue' so that the 'QDisc'
--   can know about open connections, their identifiers and peer addresses, etc.
data QDisc t = QDisc {
    -- | Dequeue an event.
    forall t. QDisc t -> IO t
qdiscDequeue :: IO t
    -- | @qdiscEnqueue ep ev t@ enqueues and event @t@, originated from the
    -- given remote endpoint @ep@ and with data @ev@.
    --
    -- @ep@ might be the local endpoint if it relates to a self-connection.
    --
    -- @ev@ might be in practice the value given as @t@. It is passed in
    -- the abstract form @t@ to enforce it is dequeued unmodified, but the
    -- 'QDisc' implementation can still observe the concrete form @ev@ to
    -- make prioritization decisions.
  , forall t. QDisc t -> EndPointAddress -> Event -> t -> IO ()
qdiscEnqueue :: EndPointAddress -> Event -> t -> IO ()
  }

-- | Post an 'Event' using a 'QDisc'.
qdiscEnqueue' :: QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' :: QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
qdisc EndPointAddress
addr Event
event = QDisc Event -> EndPointAddress -> Event -> Event -> IO ()
forall t. QDisc t -> EndPointAddress -> Event -> t -> IO ()
qdiscEnqueue QDisc Event
qdisc EndPointAddress
addr Event
event Event
event

-- | A very simple QDisc backed by an unbounded channel.
simpleUnboundedQDisc :: forall t . IO (QDisc t)
simpleUnboundedQDisc :: forall t. IO (QDisc t)
simpleUnboundedQDisc = do
  Chan t
eventChan <- IO (Chan t)
forall a. IO (Chan a)
newChan
  QDisc t -> IO (QDisc t)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (QDisc t -> IO (QDisc t)) -> QDisc t -> IO (QDisc t)
forall a b. (a -> b) -> a -> b
$ QDisc {
      qdiscDequeue :: IO t
qdiscDequeue = Chan t -> IO t
forall a. Chan a -> IO a
readChan Chan t
eventChan
    , qdiscEnqueue :: EndPointAddress -> Event -> t -> IO ()
qdiscEnqueue = (Event -> t -> IO ()) -> EndPointAddress -> Event -> t -> IO ()
forall a b. a -> b -> a
const ((t -> IO ()) -> Event -> t -> IO ()
forall a b. a -> b -> a
const (Chan t -> t -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan t
eventChan))
    }

-- | A very simple QDisc backed by a 1-place queue (MVar).
--   With this QDisc, all threads reading from sockets will try to put their
--   events into the same MVar. That MVar will be cleared by calls to
--   'receive'. Thus the rate at which data is read from the wire is directly
--   related to the rate at which data is pulled from the EndPoint by
--   'receive'.
simpleOnePlaceQDisc :: forall t . IO (QDisc t)
simpleOnePlaceQDisc :: forall t. IO (QDisc t)
simpleOnePlaceQDisc = do
  MVar t
mvar <- IO (MVar t)
forall a. IO (MVar a)
newEmptyMVar
  QDisc t -> IO (QDisc t)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (QDisc t -> IO (QDisc t)) -> QDisc t -> IO (QDisc t)
forall a b. (a -> b) -> a -> b
$ QDisc {
      qdiscDequeue :: IO t
qdiscDequeue = MVar t -> IO t
forall a. MVar a -> IO a
takeMVar MVar t
mvar
    , qdiscEnqueue :: EndPointAddress -> Event -> t -> IO ()
qdiscEnqueue = (Event -> t -> IO ()) -> EndPointAddress -> Event -> t -> IO ()
forall a b. a -> b -> a
const ((t -> IO ()) -> Event -> t -> IO ()
forall a b. a -> b -> a
const (MVar t -> t -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar t
mvar))
    }

-- | Connnect to an endpoint
apiConnect :: TCPTransport
           -> LocalEndPoint    -- ^ Local end point
           -> EndPointAddress  -- ^ Remote address
           -> Reliability      -- ^ Reliability (ignored)
           -> ConnectHints     -- ^ Hints
           -> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect :: TCPTransport
-> LocalEndPoint
-> EndPointAddress
-> Reliability
-> ConnectHints
-> IO (Either (TransportError ConnectErrorCode) Connection)
apiConnect TCPTransport
transport LocalEndPoint
ourEndPoint EndPointAddress
theirAddress Reliability
_reliability ConnectHints
hints =
  IO Connection
-> IO (Either (TransportError ConnectErrorCode) Connection)
forall e a. Exception e => IO a -> IO (Either e a)
try (IO Connection
 -> IO (Either (TransportError ConnectErrorCode) Connection))
-> (IO Connection -> IO Connection)
-> IO Connection
-> IO (Either (TransportError ConnectErrorCode) Connection)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Connection -> IO ()) -> IO Connection -> IO Connection
forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled Connection -> IO ()
close (IO Connection
 -> IO (Either (TransportError ConnectErrorCode) Connection))
-> IO Connection
-> IO (Either (TransportError ConnectErrorCode) Connection)
forall a b. (a -> b) -> a -> b
$
    if LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint EndPointAddress -> EndPointAddress -> Bool
forall a. Eq a => a -> a -> Bool
== EndPointAddress
theirAddress
      then LocalEndPoint -> IO Connection
connectToSelf LocalEndPoint
ourEndPoint
      else do
        LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken LocalEndPoint
ourEndPoint EndPointAddress
theirAddress
        (RemoteEndPoint
theirEndPoint, HeavyweightConnectionId
connId) <-
          TCPTransport
-> LocalEndPoint
-> EndPointAddress
-> ConnectHints
-> IO (RemoteEndPoint, HeavyweightConnectionId)
createConnectionTo TCPTransport
transport LocalEndPoint
ourEndPoint EndPointAddress
theirAddress ConnectHints
hints
        -- connAlive can be an IORef rather than an MVar because it is protected
        -- by the remoteState MVar. We don't need the overhead of locking twice.
        IORef Bool
connAlive <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True
        Connection -> IO Connection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Connection
          { send :: [ByteString] -> IO (Either (TransportError SendErrorCode) ())
send  = EndPointPair
-> HeavyweightConnectionId
-> IORef Bool
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
apiSend  (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HeavyweightConnectionId
connId IORef Bool
connAlive
          , close :: IO ()
close = EndPointPair -> HeavyweightConnectionId -> IORef Bool -> IO ()
apiClose (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HeavyweightConnectionId
connId IORef Bool
connAlive
          }
  where
  params :: TCPParameters
params = TCPTransport -> TCPParameters
transportParams TCPTransport
transport

-- | Close a connection
apiClose :: EndPointPair -> LightweightConnectionId -> IORef Bool -> IO ()
apiClose :: EndPointPair -> HeavyweightConnectionId -> IORef Bool -> IO ()
apiClose (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HeavyweightConnectionId
connId IORef Bool
connAlive =
  IO (Either IOError ()) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError ()) -> IO ())
-> (IO () -> IO (Either IOError ())) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO (Either IOError ())
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO () -> IO (Either IOError ()))
-> (IO () -> IO ()) -> IO () -> IO (Either IOError ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (() -> IO ()) -> IO () -> IO ()
forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
finally
    (LocalEndPoint
-> ((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ()
forall a.
LocalEndPoint
-> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
withScheduledAction LocalEndPoint
ourEndPoint (((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ())
-> ((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteEndPoint -> IO () -> IO ()
sched -> do
      MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
        RemoteEndPointValid ValidRemoteEndPointState
vst -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          if Bool
alive
            then do
              IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Bool
connAlive Bool
False
              RemoteEndPoint -> IO () -> IO ()
sched RemoteEndPoint
theirEndPoint (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst [
                    HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
CloseConnection)
                  , HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
connId
                  ]
              RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid
                     (ValidRemoteEndPointState -> RemoteState)
-> (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState
-> RemoteState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Accessor ValidRemoteEndPointState Int
remoteOutgoing Accessor ValidRemoteEndPointState Int
-> (Int -> Int)
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> (a -> a) -> r -> r
^: (\Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
                     (ValidRemoteEndPointState -> RemoteState)
-> ValidRemoteEndPointState -> RemoteState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst
                     )
            else
              RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst)
        RemoteState
_ ->
          RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
st)
    (EndPointPair -> IO ()
closeIfUnused (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint))


-- | Send data across a connection
apiSend :: EndPointPair             -- ^ Local and remote endpoint
        -> LightweightConnectionId  -- ^ Connection ID
        -> IORef Bool               -- ^ Is the connection still alive?
        -> [ByteString]             -- ^ Payload
        -> IO (Either (TransportError SendErrorCode) ())
apiSend :: EndPointPair
-> HeavyweightConnectionId
-> IORef Bool
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
apiSend (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HeavyweightConnectionId
connId IORef Bool
connAlive [ByteString]
payload =
    -- We don't need the overhead of asyncWhenCancelled here
    IO () -> IO (Either (TransportError SendErrorCode) ())
forall e a. Exception e => IO a -> IO (Either e a)
try (IO () -> IO (Either (TransportError SendErrorCode) ()))
-> (IO () -> IO ())
-> IO ()
-> IO (Either (TransportError SendErrorCode) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IOError -> TransportError SendErrorCode) -> IO () -> IO ()
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError SendErrorCode
sendFailed (IO () -> IO (Either (TransportError SendErrorCode) ()))
-> IO () -> IO (Either (TransportError SendErrorCode) ())
forall a b. (a -> b) -> a -> b
$ LocalEndPoint
-> ((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ()
forall a.
LocalEndPoint
-> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
withScheduledAction LocalEndPoint
ourEndPoint (((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ())
-> ((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteEndPoint -> IO () -> IO ()
sched -> do
      MVar RemoteState -> (RemoteState -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO ()) -> IO ())
-> (RemoteState -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
        RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
          EndPointPair -> HostName -> IO ()
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"apiSend"
        RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
          EndPointPair -> HostName -> IO ()
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"apiSend"
        RemoteEndPointValid ValidRemoteEndPointState
vst -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          if Bool
alive
            then RemoteEndPoint -> IO () -> IO ()
sched RemoteEndPoint
theirEndPoint (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
              ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst (HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
connId ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString]
prependLength [ByteString]
payload)
            else TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendClosed HostName
"Connection closed"
        RemoteEndPointClosing MVar ()
_ ValidRemoteEndPointState
_ -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          if Bool
alive
            -- RemoteEndPointClosing is only entered by 'closeIfUnused',
            -- which guarantees that there are no alive connections.
            then EndPointPair -> HostName -> IO ()
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"apiSend RemoteEndPointClosing"
            else TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendClosed HostName
"Connection closed"
        RemoteState
RemoteEndPointClosed -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          if Bool
alive
            -- This is normal. If the remote endpoint closes up while we have
            -- an outgoing connection (CloseEndPoint or CloseSocket message),
            -- we'll post the connection lost event but we won't update these
            -- 'connAlive' IORefs.
            then TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendFailed HostName
"Remote endpoint closed"
            else TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendClosed HostName
"Connection closed"
        RemoteEndPointFailed IOError
err -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          if Bool
alive
            then TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendFailed (IOError -> HostName
forall a. Show a => a -> HostName
show IOError
err)
            else TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendClosed HostName
"Connection closed"
  where
    sendFailed :: IOError -> TransportError SendErrorCode
sendFailed = SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendFailed (HostName -> TransportError SendErrorCode)
-> (IOError -> HostName) -> IOError -> TransportError SendErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> HostName
forall a. Show a => a -> HostName
show

-- | Force-close the endpoint
apiCloseEndPoint :: TCPTransport    -- ^ Transport
                 -> [Event]         -- ^ Events used to report closure
                 -> LocalEndPoint   -- ^ Local endpoint
                 -> IO ()
apiCloseEndPoint :: TCPTransport -> [Event] -> LocalEndPoint -> IO ()
apiCloseEndPoint TCPTransport
transport [Event]
evs LocalEndPoint
ourEndPoint =
  (() -> IO ()) -> IO () -> IO ()
forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    -- Remove the reference from the transport state
    TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint TCPTransport
transport LocalEndPoint
ourEndPoint
    -- Close the local endpoint
    Maybe ValidLocalEndPointState
mOurState <- MVar LocalEndPointState
-> (LocalEndPointState
    -> IO (LocalEndPointState, Maybe ValidLocalEndPointState))
-> IO (Maybe ValidLocalEndPointState)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint) ((LocalEndPointState
  -> IO (LocalEndPointState, Maybe ValidLocalEndPointState))
 -> IO (Maybe ValidLocalEndPointState))
-> (LocalEndPointState
    -> IO (LocalEndPointState, Maybe ValidLocalEndPointState))
-> IO (Maybe ValidLocalEndPointState)
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st ->
      case LocalEndPointState
st of
        LocalEndPointValid ValidLocalEndPointState
vst ->
          (LocalEndPointState, Maybe ValidLocalEndPointState)
-> IO (LocalEndPointState, Maybe ValidLocalEndPointState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (LocalEndPointState
LocalEndPointClosed, ValidLocalEndPointState -> Maybe ValidLocalEndPointState
forall a. a -> Maybe a
Just ValidLocalEndPointState
vst)
        LocalEndPointState
LocalEndPointClosed ->
          (LocalEndPointState, Maybe ValidLocalEndPointState)
-> IO (LocalEndPointState, Maybe ValidLocalEndPointState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (LocalEndPointState
LocalEndPointClosed, Maybe ValidLocalEndPointState
forall a. Maybe a
Nothing)
    Maybe ValidLocalEndPointState
-> (ValidLocalEndPointState -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe ValidLocalEndPointState
mOurState ((ValidLocalEndPointState -> IO ()) -> IO ())
-> (ValidLocalEndPointState -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ValidLocalEndPointState
vst -> do
      Map EndPointAddress RemoteEndPoint
-> (RemoteEndPoint -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
-> Map EndPointAddress RemoteEndPoint
forall r a. r -> T r a -> a
^. T ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections) RemoteEndPoint -> IO ()
tryCloseRemoteSocket
      let qdisc :: QDisc Event
qdisc = LocalEndPoint -> QDisc Event
localQueue LocalEndPoint
ourEndPoint
      [Event] -> (Event -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Event]
evs (QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
qdisc (LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint))
  where
    -- Close the remote socket and return the set of all incoming connections
    tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
    tryCloseRemoteSocket :: RemoteEndPoint -> IO ()
tryCloseRemoteSocket RemoteEndPoint
theirEndPoint = LocalEndPoint
-> ((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ()
forall a.
LocalEndPoint
-> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
withScheduledAction LocalEndPoint
ourEndPoint (((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ())
-> ((RemoteEndPoint -> IO () -> IO ()) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteEndPoint -> IO () -> IO ()
sched -> do
      -- We make an attempt to close the connection nicely
      -- (by sending a CloseSocket first)
      let closed :: RemoteState
closed = IOError -> RemoteState
RemoteEndPointFailed (IOError -> RemoteState)
-> (HostName -> IOError) -> HostName -> RemoteState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> IOError
userError (HostName -> RemoteState) -> HostName -> RemoteState
forall a b. (a -> b) -> a -> b
$ HostName
"apiCloseEndPoint"
      MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st ->
        case RemoteState
st of
          RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
st
          RemoteEndPointInit MVar ()
resolved MVar ()
_ RequestedBy
_ -> do
            MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
closed
          RemoteEndPointValid ValidRemoteEndPointState
vst -> do
            -- Schedule an action to send a CloseEndPoint message and then
            -- wait for the socket to actually close (meaning that this
            -- end point is no longer receiving from it).
            -- Since we replace the state in this MVar with 'closed', it's
            -- guaranteed that no other actions will be scheduled after this
            -- one.
            RemoteEndPoint -> IO () -> IO ()
sched RemoteEndPoint
theirEndPoint (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
              IO (Either IOError ()) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError ()) -> IO ())
-> IO (Either IOError ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO (Either IOError ())
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO () -> IO (Either IOError ()))
-> IO () -> IO (Either IOError ())
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst
                [ HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
CloseEndPoint) ]
              -- Release probing resources if probing.
              Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
              Socket -> IO ()
tryShutdownSocketBoth (ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst)
              ValidRemoteEndPointState -> IO ()
remoteSocketClosed ValidRemoteEndPointState
vst
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
closed
          RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst -> do
            -- Release probing resources if probing.
            Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
            MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
            -- Schedule an action to wait for the socket to actually close (this
            -- end point is no longer receiving from it).
            -- Since we replace the state in this MVar with 'closed', it's
            -- guaranteed that no other actions will be scheduled after this
            -- one.
            RemoteEndPoint -> IO () -> IO ()
sched RemoteEndPoint
theirEndPoint (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
              Socket -> IO ()
tryShutdownSocketBoth (ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst)
              ValidRemoteEndPointState -> IO ()
remoteSocketClosed ValidRemoteEndPointState
vst
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
closed
          RemoteState
RemoteEndPointClosed ->
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
st
          RemoteEndPointFailed IOError
err ->
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> RemoteState
RemoteEndPointFailed IOError
err)


--------------------------------------------------------------------------------
-- Incoming requests                                                          --
--------------------------------------------------------------------------------

-- | Handle a connection request (that is, a remote endpoint that is trying to
-- establish a TCP connection with us)
--
-- 'handleConnectionRequest' runs in the context of the transport thread, which
-- can be killed asynchronously by 'closeTransport'. We fork a separate thread
-- as soon as we have located the lcoal endpoint that the remote endpoint is
-- interested in. We cannot fork any sooner because then we have no way of
-- storing the thread ID and hence no way of killing the thread when we take
-- the transport down. We must be careful to close the socket when a (possibly
-- asynchronous, ThreadKilled) exception occurs. (If an exception escapes from
-- handleConnectionRequest the transport will be shut down.)
handleConnectionRequest :: TCPTransport -> IO () -> (N.Socket, N.SockAddr) -> IO ()
handleConnectionRequest :: TCPTransport -> IO () -> (Socket, SockAddr) -> IO ()
handleConnectionRequest TCPTransport
transport IO ()
socketClosed (Socket
sock, SockAddr
sockAddr) = (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle SomeException -> IO ()
handleException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TCPParameters -> Bool
tcpNoDelay (TCPParameters -> Bool) -> TCPParameters -> Bool
forall a b. (a -> b) -> a -> b
$ TCPTransport -> TCPParameters
transportParams TCPTransport
transport) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.NoDelay Int
1
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TCPParameters -> Bool
tcpKeepAlive (TCPParameters -> Bool) -> TCPParameters -> Bool
forall a b. (a -> b) -> a -> b
$ TCPTransport -> TCPParameters
transportParams TCPTransport
transport) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.KeepAlive Int
1
    Maybe Int -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (TCPParameters -> Maybe Int
tcpUserTimeout (TCPParameters -> Maybe Int) -> TCPParameters -> Maybe Int
forall a b. (a -> b) -> a -> b
$ TCPTransport -> TCPParameters
transportParams TCPTransport
transport) ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
      Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.UserTimeout
    let handleVersioned :: IO (Maybe (IO ()))
handleVersioned = do
          -- Always receive the protocol version and a handshake (content of the
          -- handshake is version-dependent, but the length is always sent,
          -- regardless of the version).
          HeavyweightConnectionId
protocolVersion <- Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock
          HeavyweightConnectionId
handshakeLength <- Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock
          -- For now we support only version 0.0.0.0.
          case HeavyweightConnectionId
protocolVersion of
            HeavyweightConnectionId
0x00000000 -> (Socket, SockAddr) -> IO (Maybe (IO ()))
handleConnectionRequestV0 (Socket
sock, SockAddr
sockAddr)
            HeavyweightConnectionId
_ -> do
              -- Inform the peer that we want version 0x00000000
              Socket -> [ByteString] -> IO ()
sendMany Socket
sock [
                  HeavyweightConnectionId -> ByteString
encodeWord32 (ConnectionRequestResponse -> HeavyweightConnectionId
encodeConnectionRequestResponse ConnectionRequestResponse
ConnectionRequestUnsupportedVersion)
                , HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
0x00000000
                ]
              -- Clear the socket of the unsupported handshake data.
              [ByteString]
_ <- Socket -> HeavyweightConnectionId -> IO [ByteString]
recvExact Socket
sock HeavyweightConnectionId
handshakeLength
              IO (Maybe (IO ()))
handleVersioned
    -- The handshake must complete within the optional timeout duration.
    -- No socket 'recv's are to be run outside the timeout. The continuation
    -- returned may 'send', but not 'recv'.
    let connTimeout :: Maybe Int
connTimeout = TCPParameters -> Maybe Int
transportConnectTimeout (TCPTransport -> TCPParameters
transportParams TCPTransport
transport)
    Maybe (Maybe (IO ()))
outcome <- (IO (Maybe (IO ())) -> IO (Maybe (Maybe (IO ()))))
-> (Int -> IO (Maybe (IO ())) -> IO (Maybe (Maybe (IO ()))))
-> Maybe Int
-> IO (Maybe (IO ()))
-> IO (Maybe (Maybe (IO ())))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ((Maybe (IO ()) -> Maybe (Maybe (IO ())))
-> IO (Maybe (IO ())) -> IO (Maybe (Maybe (IO ())))
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe (IO ()) -> Maybe (Maybe (IO ()))
forall a. a -> Maybe a
Just) Int -> IO (Maybe (IO ())) -> IO (Maybe (Maybe (IO ())))
forall a. Int -> IO a -> IO (Maybe a)
System.Timeout.timeout Maybe Int
connTimeout IO (Maybe (IO ()))
handleVersioned
    case Maybe (Maybe (IO ()))
outcome of
      Maybe (Maybe (IO ()))
Nothing -> IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"handleConnectionRequest: timed out")
      Just Maybe (IO ())
act -> Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe (IO ())
act IO () -> IO ()
forall a. a -> a
id

  where

    handleException :: SomeException -> IO ()
    handleException :: SomeException -> IO ()
handleException SomeException
ex = do
      Maybe AsyncException -> IO ()
rethrowIfAsync (SomeException -> Maybe AsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
ex)

    rethrowIfAsync :: Maybe AsyncException -> IO ()
    rethrowIfAsync :: Maybe AsyncException -> IO ()
rethrowIfAsync = (AsyncException -> IO Any) -> Maybe AsyncException -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ AsyncException -> IO Any
forall e a. Exception e => e -> IO a
throwIO

    handleConnectionRequestV0 :: (N.Socket, N.SockAddr) -> IO (Maybe (IO ()))
    handleConnectionRequestV0 :: (Socket, SockAddr) -> IO (Maybe (IO ()))
handleConnectionRequestV0 (Socket
sock, SockAddr
sockAddr) = do
      -- Get the OS-determined host and port.
      (HostName
numericHost, HostName
resolvedHost, HostName
actualPort) <-
        SockAddr -> IO (Maybe (HostName, HostName, HostName))
resolveSockAddr SockAddr
sockAddr IO (Maybe (HostName, HostName, HostName))
-> (Maybe (HostName, HostName, HostName)
    -> IO (HostName, HostName, HostName))
-> IO (HostName, HostName, HostName)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
          IO (HostName, HostName, HostName)
-> ((HostName, HostName, HostName)
    -> IO (HostName, HostName, HostName))
-> Maybe (HostName, HostName, HostName)
-> IO (HostName, HostName, HostName)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (IOError -> IO (HostName, HostName, HostName)
forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"handleConnectionRequest: invalid socket address")) (HostName, HostName, HostName) -> IO (HostName, HostName, HostName)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
      -- The peer must send our identifier and their address promptly, if a
      -- timeout is set.
      (HeavyweightConnectionId
ourEndPointId, EndPointAddress
theirAddress, Maybe HostName
mTheirHost) <- do
        HeavyweightConnectionId
ourEndPointId <- Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock
        let maxAddressLength :: HeavyweightConnectionId
maxAddressLength = TCPParameters -> HeavyweightConnectionId
tcpMaxAddressLength (TCPParameters -> HeavyweightConnectionId)
-> TCPParameters -> HeavyweightConnectionId
forall a b. (a -> b) -> a -> b
$ TCPTransport -> TCPParameters
transportParams TCPTransport
transport
        ByteString
mTheirAddress <- [ByteString] -> ByteString
BS.concat ([ByteString] -> ByteString) -> IO [ByteString] -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HeavyweightConnectionId -> Socket -> IO [ByteString]
recvWithLength HeavyweightConnectionId
maxAddressLength Socket
sock
        -- Sending a length = 0 address means unaddressable.
        if ByteString -> Bool
BS.null ByteString
mTheirAddress
        then do
          EndPointAddress
theirAddress <- IO EndPointAddress
randomEndPointAddress
          (HeavyweightConnectionId, EndPointAddress, Maybe HostName)
-> IO (HeavyweightConnectionId, EndPointAddress, Maybe HostName)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (HeavyweightConnectionId
ourEndPointId, EndPointAddress
theirAddress, Maybe HostName
forall a. Maybe a
Nothing)
        else do
          let theirAddress :: EndPointAddress
theirAddress = ByteString -> EndPointAddress
EndPointAddress ByteString
mTheirAddress
          (HostName
theirHost, HostName
_, HeavyweightConnectionId
_)
            <- IO (HostName, HostName, HeavyweightConnectionId)
-> ((HostName, HostName, HeavyweightConnectionId)
    -> IO (HostName, HostName, HeavyweightConnectionId))
-> Maybe (HostName, HostName, HeavyweightConnectionId)
-> IO (HostName, HostName, HeavyweightConnectionId)
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (IOError -> IO (HostName, HostName, HeavyweightConnectionId)
forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"handleConnectionRequest: peer gave malformed address"))
                     (HostName, HostName, HeavyweightConnectionId)
-> IO (HostName, HostName, HeavyweightConnectionId)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
                     (EndPointAddress
-> Maybe (HostName, HostName, HeavyweightConnectionId)
decodeEndPointAddress EndPointAddress
theirAddress)
          (HeavyweightConnectionId, EndPointAddress, Maybe HostName)
-> IO (HeavyweightConnectionId, EndPointAddress, Maybe HostName)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (HeavyweightConnectionId
ourEndPointId, EndPointAddress
theirAddress, HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
theirHost)
      let checkPeerHost :: Bool
checkPeerHost = TCPParameters -> Bool
tcpCheckPeerHost (TCPTransport -> TCPParameters
transportParams TCPTransport
transport)
      Bool
continue <- case (Maybe HostName
mTheirHost, Bool
checkPeerHost) of
        (Just HostName
theirHost, Bool
True) -> do
          -- If the OS-determined host doesn't match the host that the peer gave us,
          -- then we have no choice but to reject the connection. It's because we
          -- use the EndPointAddress to key the remote end points (localConnections)
          -- and we don't want to allow a peer to deny service to other peers by
          -- claiming to have their host and port.
          if HostName
theirHost HostName -> HostName -> Bool
forall a. Eq a => a -> a -> Bool
== HostName
numericHost Bool -> Bool -> Bool
|| HostName
theirHost HostName -> HostName -> Bool
forall a. Eq a => a -> a -> Bool
== HostName
resolvedHost
          then Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
          else do
            Socket -> [ByteString] -> IO ()
sendMany Socket
sock ([ByteString] -> IO ()) -> [ByteString] -> IO ()
forall a b. (a -> b) -> a -> b
$
                HeavyweightConnectionId -> ByteString
encodeWord32 (ConnectionRequestResponse -> HeavyweightConnectionId
encodeConnectionRequestResponse ConnectionRequestResponse
ConnectionRequestHostMismatch)
              ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ([ByteString] -> [ByteString]
prependLength [HostName -> ByteString
BSC.pack HostName
theirHost] [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [ByteString] -> [ByteString]
prependLength [HostName -> ByteString
BSC.pack HostName
numericHost] [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [ByteString] -> [ByteString]
prependLength [HostName -> ByteString
BSC.pack HostName
resolvedHost])
            Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        (Maybe HostName, Bool)
_ -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
      if Bool
continue
      then do
        LocalEndPoint
ourEndPoint <- MVar TransportState
-> (TransportState -> IO LocalEndPoint) -> IO LocalEndPoint
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (TCPTransport -> MVar TransportState
transportState TCPTransport
transport) ((TransportState -> IO LocalEndPoint) -> IO LocalEndPoint)
-> (TransportState -> IO LocalEndPoint) -> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ \TransportState
st -> case TransportState
st of
          TransportValid ValidTransportState
vst ->
            case ValidTransportState
vst ValidTransportState
-> T ValidTransportState (Maybe LocalEndPoint)
-> Maybe LocalEndPoint
forall r a. r -> T r a -> a
^. HeavyweightConnectionId
-> T ValidTransportState (Maybe LocalEndPoint)
localEndPointAt HeavyweightConnectionId
ourEndPointId of
              Maybe LocalEndPoint
Nothing -> do
                Socket -> [ByteString] -> IO ()
sendMany Socket
sock [HeavyweightConnectionId -> ByteString
encodeWord32 (ConnectionRequestResponse -> HeavyweightConnectionId
encodeConnectionRequestResponse ConnectionRequestResponse
ConnectionRequestInvalid)]
                IOError -> IO LocalEndPoint
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO LocalEndPoint) -> IOError -> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"handleConnectionRequest: Invalid endpoint"
              Just LocalEndPoint
ourEndPoint ->
                LocalEndPoint -> IO LocalEndPoint
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPoint
ourEndPoint
          TransportState
TransportClosed ->
            IOError -> IO LocalEndPoint
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO LocalEndPoint) -> IOError -> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Transport closed"
        Maybe (IO ()) -> IO (Maybe (IO ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO () -> Maybe (IO ())
forall a. a -> Maybe a
Just (LocalEndPoint -> EndPointAddress -> IO ()
go LocalEndPoint
ourEndPoint EndPointAddress
theirAddress))
      else Maybe (IO ()) -> IO (Maybe (IO ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IO ())
forall a. Maybe a
Nothing

      where

      go :: LocalEndPoint -> EndPointAddress -> IO ()
      go :: LocalEndPoint -> EndPointAddress -> IO ()
go LocalEndPoint
ourEndPoint EndPointAddress
theirAddress = (SomeException -> IO ()) -> IO () -> IO ()
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle SomeException -> IO ()
handleException (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do

        LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken LocalEndPoint
ourEndPoint EndPointAddress
theirAddress
        (RemoteEndPoint
theirEndPoint, Bool
isNew) <-
          LocalEndPoint
-> EndPointAddress
-> RequestedBy
-> Maybe (IO ())
-> IO (RemoteEndPoint, Bool)
findRemoteEndPoint LocalEndPoint
ourEndPoint EndPointAddress
theirAddress RequestedBy
RequestedByThem Maybe (IO ())
forall a. Maybe a
Nothing

        if Bool -> Bool
not Bool
isNew
          then do
            IO (Either IOError ()) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError ()) -> IO ())
-> IO (Either IOError ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO (Either IOError ())
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO () -> IO (Either IOError ()))
-> IO () -> IO (Either IOError ())
forall a b. (a -> b) -> a -> b
$ Socket -> [ByteString] -> IO ()
sendMany Socket
sock
              [HeavyweightConnectionId -> ByteString
encodeWord32 (ConnectionRequestResponse -> HeavyweightConnectionId
encodeConnectionRequestResponse ConnectionRequestResponse
ConnectionRequestCrossed)]
            RemoteEndPoint -> IO ()
probeIfValid RemoteEndPoint
theirEndPoint
          else do
            MVar (Maybe SomeException)
sendLock <- Maybe SomeException -> IO (MVar (Maybe SomeException))
forall a. a -> IO (MVar a)
newMVar Maybe SomeException
forall a. Maybe a
Nothing
            let vst :: ValidRemoteEndPointState
vst = ValidRemoteEndPointState
                        {  remoteSocket :: Socket
remoteSocket        = Socket
sock
                        ,  remoteSocketClosed :: IO ()
remoteSocketClosed  = IO ()
socketClosed
                        ,  remoteProbing :: Maybe (IO ())
remoteProbing       = Maybe (IO ())
forall a. Maybe a
Nothing
                        ,  remoteSendLock :: MVar (Maybe SomeException)
remoteSendLock      = MVar (Maybe SomeException)
sendLock
                        , _remoteOutgoing :: Int
_remoteOutgoing      = Int
0
                        , _remoteIncoming :: Set HeavyweightConnectionId
_remoteIncoming      = Set HeavyweightConnectionId
forall a. Set a
Set.empty
                        , _remoteLastIncoming :: HeavyweightConnectionId
_remoteLastIncoming  = HeavyweightConnectionId
0
                        , _remoteNextConnOutId :: HeavyweightConnectionId
_remoteNextConnOutId = HeavyweightConnectionId
firstNonReservedLightweightConnectionId
                        }
            Socket -> [ByteString] -> IO ()
sendMany Socket
sock [HeavyweightConnectionId -> ByteString
encodeWord32 (ConnectionRequestResponse -> HeavyweightConnectionId
encodeConnectionRequestResponse ConnectionRequestResponse
ConnectionRequestAccepted)]
            -- resolveInit will update the shared state, and handleIncomingMessages
            -- will always ultimately clean up after it.
            -- Closing up the socket is also out of our hands. It will happen
            -- when handleIncomingMessages finishes.
            EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst)
              IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally`
              TCPParameters -> EndPointPair -> IO ()
handleIncomingMessages (TCPTransport -> TCPParameters
transportParams TCPTransport
transport) (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)

      probeIfValid :: RemoteEndPoint -> IO ()
      probeIfValid :: RemoteEndPoint -> IO ()
probeIfValid RemoteEndPoint
theirEndPoint = MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$
        \RemoteState
st -> case RemoteState
st of
          RemoteEndPointValid
            vst :: ValidRemoteEndPointState
vst@(ValidRemoteEndPointState { remoteProbing :: ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing = Maybe (IO ())
Nothing }) -> do
              ThreadId
tid <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do
                -- send probe
                let params :: TCPParameters
params = TCPTransport -> TCPParameters
transportParams TCPTransport
transport
                IO (Either IOError (Maybe ())) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError (Maybe ())) -> IO ())
-> IO (Either IOError (Maybe ())) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO (Maybe ()) -> IO (Either IOError (Maybe ()))
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO (Maybe ()) -> IO (Either IOError (Maybe ())))
-> IO (Maybe ()) -> IO (Either IOError (Maybe ()))
forall a b. (a -> b) -> a -> b
$ Int -> IO () -> IO (Maybe ())
forall a. Int -> IO a -> IO (Maybe a)
System.Timeout.timeout
                    (Int -> (Int -> Int) -> Maybe Int -> Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (-Int
1) Int -> Int
forall a. a -> a
id (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ TCPParameters -> Maybe Int
transportConnectTimeout TCPParameters
params) (IO () -> IO (Maybe ())) -> IO () -> IO (Maybe ())
forall a b. (a -> b) -> a -> b
$ do
                  Socket -> [ByteString] -> IO ()
sendMany (ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst)
                    [HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
ProbeSocket)]
                  Int -> IO ()
threadDelay Int
forall a. Bounded a => a
maxBound
                -- Discard the connection if this thread is not killed (i.e. the
                -- probe ack does not arrive on time).
                --
                -- The thread handling incoming messages will detect the socket is
                -- closed and will report the failure upwards.
                Socket -> IO ()
tryCloseSocket (ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst)
                -- Waiting the probe ack and closing the socket is only needed in
                -- platforms where TCP_USER_TIMEOUT is not available or when the
                -- user does not set it. Otherwise the ack would be handled at the
                -- TCP level and the the thread handling incoming messages would
                -- get the error.

              RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState -> IO RemoteState) -> RemoteState -> IO RemoteState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid
                ValidRemoteEndPointState
vst { remoteProbing = Just (killThread tid) }
          RemoteState
_                       -> RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
st

-- | Handle requests from a remote endpoint.
--
-- Returns only if the remote party closes the socket or if an error occurs.
-- This runs in a thread that will never be killed.
handleIncomingMessages :: TCPParameters -> EndPointPair -> IO ()
handleIncomingMessages :: TCPParameters -> EndPointPair -> IO ()
handleIncomingMessages TCPParameters
params (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) =
    IO (Either IOError Socket)
-> (Either IOError Socket -> IO ())
-> (Either IOError Socket -> IO ())
-> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Either IOError Socket)
acquire Either IOError Socket -> IO ()
release Either IOError Socket -> IO ()
act

  where

    -- Use shared remote endpoint state to get a socket, or an appropriate
    -- exception in case it's neither valid nor closing.
    acquire :: IO (Either IOError N.Socket)
    acquire :: IO (Either IOError Socket)
acquire = MVar RemoteState
-> (RemoteState -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar RemoteState
theirState ((RemoteState -> IO (Either IOError Socket))
 -> IO (Either IOError Socket))
-> (RemoteState -> IO (Either IOError Socket))
-> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
      RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
        EndPointPair -> HostName -> IO (Either IOError Socket)
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
          HostName
"handleIncomingMessages (invalid)"
      RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
        EndPointPair -> HostName -> IO (Either IOError Socket)
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
          HostName
"handleIncomingMessages (init)"
      RemoteEndPointValid ValidRemoteEndPointState
ep ->
        Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either IOError Socket -> IO (Either IOError Socket))
-> (Socket -> Either IOError Socket)
-> Socket
-> IO (Either IOError Socket)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> Either IOError Socket
forall a b. b -> Either a b
Right (Socket -> IO (Either IOError Socket))
-> Socket -> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
ep
      RemoteEndPointClosing MVar ()
_ ValidRemoteEndPointState
ep ->
        Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either IOError Socket -> IO (Either IOError Socket))
-> (Socket -> Either IOError Socket)
-> Socket
-> IO (Either IOError Socket)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> Either IOError Socket
forall a b. b -> Either a b
Right (Socket -> IO (Either IOError Socket))
-> Socket -> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
ep
      RemoteState
RemoteEndPointClosed ->
        Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either IOError Socket -> IO (Either IOError Socket))
-> (IOError -> Either IOError Socket)
-> IOError
-> IO (Either IOError Socket)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> Either IOError Socket
forall a b. a -> Either a b
Left (IOError -> IO (Either IOError Socket))
-> IOError -> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"handleIncomingMessages (already closed)"
      RemoteEndPointFailed IOError
_ ->
        Either IOError Socket -> IO (Either IOError Socket)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either IOError Socket -> IO (Either IOError Socket))
-> (IOError -> Either IOError Socket)
-> IOError
-> IO (Either IOError Socket)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> Either IOError Socket
forall a b. a -> Either a b
Left (IOError -> IO (Either IOError Socket))
-> IOError -> IO (Either IOError Socket)
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"handleIncomingMessages (failed)"

    -- 'Right' is the normal case in which there still is a live socket to
    -- the remote endpoint, and so 'act' was run and installed its own
    -- exception handler.
    release :: Either IOError N.Socket -> IO ()
    release :: Either IOError Socket -> IO ()
release (Left IOError
err) = IOError -> IO ()
prematureExit IOError
err
    release (Right Socket
_) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    act :: Either IOError N.Socket -> IO ()
    act :: Either IOError Socket -> IO ()
act (Left IOError
_) = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    act (Right Socket
sock) = Socket -> IO ()
go Socket
sock IO () -> (IOError -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` IOError -> IO ()
prematureExit

    -- Dispatch
    --
    -- If a recv throws an exception this will be caught top-level and
    -- 'prematureExit' will be invoked. The same will happen if the remote
    -- endpoint is put into a Closed (or Closing) state by a concurrent thread
    -- (because a 'send' failed) -- the individual handlers below will throw a
    -- user exception which is then caught and handled the same way as an
    -- exception thrown by 'recv'.
    go :: N.Socket -> IO ()
    go :: Socket -> IO ()
go Socket
sock = do
      HeavyweightConnectionId
lcid <- Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock :: IO LightweightConnectionId
      if HeavyweightConnectionId
lcid HeavyweightConnectionId -> HeavyweightConnectionId -> Bool
forall a. Ord a => a -> a -> Bool
>= HeavyweightConnectionId
firstNonReservedLightweightConnectionId
        then do
          Socket -> HeavyweightConnectionId -> IO ()
readMessage Socket
sock HeavyweightConnectionId
lcid
          Socket -> IO ()
go Socket
sock
        else
          case HeavyweightConnectionId -> Maybe ControlHeader
decodeControlHeader HeavyweightConnectionId
lcid of
            Just ControlHeader
CreatedNewConnection -> do
              Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock IO HeavyweightConnectionId
-> (HeavyweightConnectionId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HeavyweightConnectionId -> IO ()
createdNewConnection
              Socket -> IO ()
go Socket
sock
            Just ControlHeader
CloseConnection -> do
              Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock IO HeavyweightConnectionId
-> (HeavyweightConnectionId -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= HeavyweightConnectionId -> IO ()
closeConnection
              Socket -> IO ()
go Socket
sock
            Just ControlHeader
CloseSocket -> do
              Bool
didClose <- Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock IO HeavyweightConnectionId
-> (HeavyweightConnectionId -> IO Bool) -> IO Bool
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Socket -> HeavyweightConnectionId -> IO Bool
closeSocket Socket
sock
              Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
didClose (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
go Socket
sock
            Just ControlHeader
CloseEndPoint -> do
              let closeRemoteEndPoint :: ValidRemoteEndPointState -> IO ()
closeRemoteEndPoint ValidRemoteEndPointState
vst = do
                    Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
                    -- close incoming connections
                    [HeavyweightConnectionId]
-> (HeavyweightConnectionId -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Set HeavyweightConnectionId -> [HeavyweightConnectionId]
forall a. Set a -> [a]
Set.elems (Set HeavyweightConnectionId -> [HeavyweightConnectionId])
-> Set HeavyweightConnectionId -> [HeavyweightConnectionId]
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> Set HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming) ((HeavyweightConnectionId -> IO ()) -> IO ())
-> (HeavyweightConnectionId -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
                      QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (Event -> IO ())
-> (HeavyweightConnectionId -> Event)
-> HeavyweightConnectionId
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionId -> Event
ConnectionClosed (ConnectionId -> Event)
-> (HeavyweightConnectionId -> ConnectionId)
-> HeavyweightConnectionId
-> Event
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeavyweightConnectionId -> ConnectionId
connId
                    -- report the endpoint as gone if we have any outgoing
                    -- connections
                    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState Int -> Int
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState Int
remoteOutgoing Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                      let code :: EventErrorCode
code = EndPointAddress -> EventErrorCode
EventConnectionLost (RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint)
                      QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (Event -> IO ())
-> (TransportError EventErrorCode -> Event)
-> TransportError EventErrorCode
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError EventErrorCode -> Event
ErrorEvent (TransportError EventErrorCode -> IO ())
-> TransportError EventErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$
                        EventErrorCode -> HostName -> TransportError EventErrorCode
forall error. error -> HostName -> TransportError error
TransportError EventErrorCode
code HostName
"The remote endpoint was closed."
              EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar RemoteState
theirState ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
s -> case RemoteState
s of
                RemoteEndPointValid ValidRemoteEndPointState
vst     -> do
                  ValidRemoteEndPointState -> IO ()
closeRemoteEndPoint ValidRemoteEndPointState
vst
                  RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
RemoteEndPointClosed
                RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst -> do
                  ValidRemoteEndPointState -> IO ()
closeRemoteEndPoint ValidRemoteEndPointState
vst
                  MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
                  RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
RemoteEndPointClosed
                RemoteState
_                           -> RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
s
            Just ControlHeader
ProbeSocket -> do
              IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Socket -> [ByteString] -> IO ()
sendMany Socket
sock [HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
ProbeSocketAck)]
              Socket -> IO ()
go Socket
sock
            Just ControlHeader
ProbeSocketAck -> do
              IO ()
stopProbing
              Socket -> IO ()
go Socket
sock
            Maybe ControlHeader
Nothing ->
              IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Invalid control request"

    -- Create a new connection
    createdNewConnection :: LightweightConnectionId -> IO ()
    createdNewConnection :: HeavyweightConnectionId -> IO ()
createdNewConnection HeavyweightConnectionId
lcid = do
      MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar RemoteState
theirState ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> do
        ValidRemoteEndPointState
vst <- case RemoteState
st of
          RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
            EndPointPair -> HostName -> IO ValidRemoteEndPointState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:createNewConnection (invalid)"
          RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
            EndPointPair -> HostName -> IO ValidRemoteEndPointState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:createNewConnection (init)"
          RemoteEndPointValid ValidRemoteEndPointState
vst ->
            ValidRemoteEndPointState -> IO ValidRemoteEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( (T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> (Set HeavyweightConnectionId -> Set HeavyweightConnectionId)
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> (a -> a) -> r -> r
^: HeavyweightConnectionId
-> Set HeavyweightConnectionId -> Set HeavyweightConnectionId
forall a. Ord a => a -> Set a -> Set a
Set.insert HeavyweightConnectionId
lcid)
                   (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState -> ValidRemoteEndPointState
forall a b. (a -> b) -> a -> b
$ (Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteLastIncoming Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> a -> r -> r
^= HeavyweightConnectionId
lcid)
                   ValidRemoteEndPointState
vst
                   )
          RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst -> do
            -- If the endpoint is in closing state that means we send a
            -- CloseSocket request to the remote endpoint. If the remote
            -- endpoint replies that it created a new connection, it either
            -- ignored our request or it sent the request before it got ours.
            -- Either way, at this point we simply restore the endpoint to
            -- RemoteEndPointValid
            MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
            ValidRemoteEndPointState -> IO ValidRemoteEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( (T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> Set HeavyweightConnectionId
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> a -> r -> r
^= HeavyweightConnectionId -> Set HeavyweightConnectionId
forall a. a -> Set a
Set.singleton HeavyweightConnectionId
lcid)
                   (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteLastIncoming Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> a -> r -> r
^= HeavyweightConnectionId
lcid)
                   (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState -> ValidRemoteEndPointState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst
                   )
          RemoteEndPointFailed IOError
err ->
            IOError -> IO ValidRemoteEndPointState
forall e a. Exception e => e -> IO a
throwIO IOError
err
          RemoteState
RemoteEndPointClosed ->
            EndPointPair -> HostName -> IO ValidRemoteEndPointState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"createNewConnection (closed)"
        RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst)
      QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (ConnectionId -> Reliability -> EndPointAddress -> Event
ConnectionOpened (HeavyweightConnectionId -> ConnectionId
connId HeavyweightConnectionId
lcid) Reliability
ReliableOrdered EndPointAddress
theirAddr)

    -- Close a connection
    -- It is important that we verify that the connection is in fact open,
    -- because otherwise we should not decrement the reference count
    closeConnection :: LightweightConnectionId -> IO ()
    closeConnection :: HeavyweightConnectionId -> IO ()
closeConnection HeavyweightConnectionId
lcid = do
      MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar RemoteState
theirState ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
        RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
          EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"closeConnection (invalid)"
        RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
          EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"closeConnection (init)"
        RemoteEndPointValid ValidRemoteEndPointState
vst -> do
          Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (HeavyweightConnectionId -> Set HeavyweightConnectionId -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member HeavyweightConnectionId
lcid (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> Set HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming)) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
            IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO ()) -> IOError -> IO ()
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Invalid CloseConnection"
          RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid
                 (ValidRemoteEndPointState -> RemoteState)
-> (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState
-> RemoteState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> (Set HeavyweightConnectionId -> Set HeavyweightConnectionId)
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> (a -> a) -> r -> r
^: HeavyweightConnectionId
-> Set HeavyweightConnectionId -> Set HeavyweightConnectionId
forall a. Ord a => a -> Set a -> Set a
Set.delete HeavyweightConnectionId
lcid)
                 (ValidRemoteEndPointState -> RemoteState)
-> ValidRemoteEndPointState -> RemoteState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst
                 )
        RemoteEndPointClosing MVar ()
_ ValidRemoteEndPointState
_ ->
          -- If the remote endpoint is in Closing state, that means that are as
          -- far as we are concerned there are no incoming connections. This
          -- means that a CloseConnection request at this point is invalid.
          IOError -> IO RemoteState
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO RemoteState) -> IOError -> IO RemoteState
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Invalid CloseConnection request"
        RemoteEndPointFailed IOError
err ->
          IOError -> IO RemoteState
forall e a. Exception e => e -> IO a
throwIO IOError
err
        RemoteState
RemoteEndPointClosed ->
          EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"closeConnection (closed)"
      QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (ConnectionId -> Event
ConnectionClosed (HeavyweightConnectionId -> ConnectionId
connId HeavyweightConnectionId
lcid))

    -- Close the socket (if we don't have any outgoing connections)
    closeSocket :: N.Socket -> LightweightConnectionId -> IO Bool
    closeSocket :: Socket -> HeavyweightConnectionId -> IO Bool
closeSocket Socket
sock HeavyweightConnectionId
lastReceivedId = do
      Maybe (Action ())
mAct <- MVar RemoteState
-> (RemoteState -> IO (RemoteState, Maybe (Action ())))
-> IO (Maybe (Action ()))
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar RemoteState
theirState ((RemoteState -> IO (RemoteState, Maybe (Action ())))
 -> IO (Maybe (Action ())))
-> (RemoteState -> IO (RemoteState, Maybe (Action ())))
-> IO (Maybe (Action ()))
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> do
        case RemoteState
st of
          RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
            EndPointPair -> HostName -> IO (RemoteState, Maybe (Action ()))
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:closeSocket (invalid)"
          RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
            EndPointPair -> HostName -> IO (RemoteState, Maybe (Action ()))
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:closeSocket (init)"
          RemoteEndPointValid ValidRemoteEndPointState
vst -> do
            -- We regard a CloseSocket message as an (optimized) way for the
            -- remote endpoint to indicate that all its connections to us are
            -- now properly closed
            [HeavyweightConnectionId]
-> (HeavyweightConnectionId -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Set HeavyweightConnectionId -> [HeavyweightConnectionId]
forall a. Set a -> [a]
Set.elems (Set HeavyweightConnectionId -> [HeavyweightConnectionId])
-> Set HeavyweightConnectionId -> [HeavyweightConnectionId]
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> Set HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming) ((HeavyweightConnectionId -> IO ()) -> IO ())
-> (HeavyweightConnectionId -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
              QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (Event -> IO ())
-> (HeavyweightConnectionId -> Event)
-> HeavyweightConnectionId
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionId -> Event
ConnectionClosed (ConnectionId -> Event)
-> (HeavyweightConnectionId -> ConnectionId)
-> HeavyweightConnectionId
-> Event
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeavyweightConnectionId -> ConnectionId
connId
            let vst' :: ValidRemoteEndPointState
vst' = T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> Set HeavyweightConnectionId
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> a -> r -> r
^= Set HeavyweightConnectionId
forall a. Set a
Set.empty (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState -> ValidRemoteEndPointState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst
            -- The peer sends the connection id of the last connection which
            -- they accepted from us.
            --
            -- If it's not the same as the id of the last connection that we
            -- have made to them (assuming we haven't cycled through all
            -- identifiers so fast) then they hadn't seen the request before
            -- they tried to close the socket. In that case, we don't close the
            -- socket. They'll see our in-flight connection request and then
            -- abandon their attempt to close the socket.
            --
            -- It's possible that a local connection is coming up but has not
            -- yet sent CreatedNewConnection (see createConnectionTo), in
            -- which case remoteOutgoing is positive, but the sent and received
            -- ids do match. In this case we don't close the socket, because
            -- that connection will soon sent the message and bump the lastSentId.
            --
            -- Both disjuncts are needed: it's possible that remoteOutgoing is
            -- 0 and the ids do not match, in case we have created and closed
            -- a connection but the peer has not yet heard of it.
            if ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState Int -> Int
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState Int
remoteOutgoing Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
|| HeavyweightConnectionId
lastReceivedId HeavyweightConnectionId -> HeavyweightConnectionId -> Bool
forall a. Eq a => a -> a -> Bool
/= ValidRemoteEndPointState -> HeavyweightConnectionId
lastSentId ValidRemoteEndPointState
vst
              then
                (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst', Maybe (Action ())
forall a. Maybe a
Nothing)
              else do
                -- Release probing resources if probing.
                Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
                EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
                -- Attempt to reply (but don't insist)
                Action ()
act <- RemoteEndPoint -> IO () -> IO (Action ())
forall a. RemoteEndPoint -> IO a -> IO (Action a)
schedule RemoteEndPoint
theirEndPoint (IO () -> IO (Action ())) -> IO () -> IO (Action ())
forall a b. (a -> b) -> a -> b
$ do
                  IO (Either IOError ()) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError ()) -> IO ())
-> IO (Either IOError ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO (Either IOError ())
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO () -> IO (Either IOError ()))
-> IO () -> IO (Either IOError ())
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst'
                    [ HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
CloseSocket)
                    , HeavyweightConnectionId -> ByteString
encodeWord32 (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteLastIncoming)
                    ]
                (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState
RemoteEndPointClosed, Action () -> Maybe (Action ())
forall a. a -> Maybe a
Just Action ()
act)
          RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst ->  do
            -- Like above, we need to check if there is a ConnectionCreated
            -- message that we sent but that the remote endpoint has not yet
            -- received. However, since we are in 'closing' state, the only
            -- way this may happen is when we sent a ConnectionCreated,
            -- ConnectionClosed, and CloseSocket message, none of which have
            -- yet been received. It's sufficient to check that the peer has
            -- not seen the ConnectionCreated message. In case they have seen
            -- it (so that lastReceivedId == lastSendId vst) then they must
            -- have seen the other messages or else they would not have sent
            -- CloseSocket.
            -- We leave the endpoint in closing state in that case.
            if HeavyweightConnectionId
lastReceivedId HeavyweightConnectionId -> HeavyweightConnectionId -> Bool
forall a. Eq a => a -> a -> Bool
/= ValidRemoteEndPointState -> HeavyweightConnectionId
lastSentId ValidRemoteEndPointState
vst
              then do
                (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MVar () -> ValidRemoteEndPointState -> RemoteState
RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst, Maybe (Action ())
forall a. Maybe a
Nothing)
              else do
                -- Release probing resources if probing.
                Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState Int -> Int
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState Int
remoteOutgoing Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                  let code :: EventErrorCode
code = EndPointAddress -> EventErrorCode
EventConnectionLost (RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint)
                  let msg :: HostName
msg  = HostName
"socket closed prematurely by peer"
                  QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (Event -> IO ())
-> (TransportError EventErrorCode -> Event)
-> TransportError EventErrorCode
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError EventErrorCode -> Event
ErrorEvent (TransportError EventErrorCode -> IO ())
-> TransportError EventErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ EventErrorCode -> HostName -> TransportError EventErrorCode
forall error. error -> HostName -> TransportError error
TransportError EventErrorCode
code HostName
msg
                Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
                EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
                -- Nothing to do, but we want to indicate that the socket
                -- really did close.
                Action ()
act <- RemoteEndPoint -> IO () -> IO (Action ())
forall a. RemoteEndPoint -> IO a -> IO (Action a)
schedule RemoteEndPoint
theirEndPoint (IO () -> IO (Action ())) -> IO () -> IO (Action ())
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
                (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState
RemoteEndPointClosed, Action () -> Maybe (Action ())
forall a. a -> Maybe a
Just Action ()
act)
          RemoteEndPointFailed IOError
err ->
            IOError -> IO (RemoteState, Maybe (Action ()))
forall e a. Exception e => e -> IO a
throwIO IOError
err
          RemoteState
RemoteEndPointClosed ->
            EndPointPair -> HostName -> IO (RemoteState, Maybe (Action ()))
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:closeSocket (closed)"
      case Maybe (Action ())
mAct of
        Maybe (Action ())
Nothing -> Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
        Just Action ()
act -> do
          EndPointPair -> Action () -> IO ()
forall a. EndPointPair -> Action a -> IO a
runScheduledAction (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) Action ()
act
          Bool -> IO Bool
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

    -- Read a message and output it on the endPoint's channel. By rights we
    -- should verify that the connection ID is valid, but this is unnecessary
    -- overhead
    readMessage :: N.Socket -> LightweightConnectionId -> IO ()
    readMessage :: Socket -> HeavyweightConnectionId -> IO ()
readMessage Socket
sock HeavyweightConnectionId
lcid =
      HeavyweightConnectionId -> Socket -> IO [ByteString]
recvWithLength HeavyweightConnectionId
recvLimit Socket
sock IO [ByteString] -> ([ByteString] -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
        QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (Event -> IO ())
-> ([ByteString] -> Event) -> [ByteString] -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConnectionId -> [ByteString] -> Event
Received (HeavyweightConnectionId -> ConnectionId
connId HeavyweightConnectionId
lcid)

    -- Stop probing a connection as a result of receiving a probe ack.
    stopProbing :: IO ()
    stopProbing :: IO ()
stopProbing = MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar RemoteState
theirState ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
      RemoteEndPointValid
        vst :: ValidRemoteEndPointState
vst@(ValidRemoteEndPointState { remoteProbing :: ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing = Just IO ()
stop }) -> do
          IO ()
stop
          RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState -> IO RemoteState) -> RemoteState -> IO RemoteState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst { remoteProbing = Nothing }
      RemoteState
_ -> RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
st

    -- Arguments
    ourQueue :: QDisc Event
ourQueue    = LocalEndPoint -> QDisc Event
localQueue LocalEndPoint
ourEndPoint
    ourState :: MVar LocalEndPointState
ourState    = LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint
    theirState :: MVar RemoteState
theirState  = RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint
    theirAddr :: EndPointAddress
theirAddr   = RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint
    recvLimit :: HeavyweightConnectionId
recvLimit   = TCPParameters -> HeavyweightConnectionId
tcpMaxReceiveLength TCPParameters
params

    -- Deal with a premature exit
    prematureExit :: IOException -> IO ()
    prematureExit :: IOError -> IO ()
prematureExit IOError
err = do
      MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar RemoteState
theirState ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st ->
        case RemoteState
st of
          RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
            EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:prematureExit"
          RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
            EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:prematureExit"
          RemoteEndPointValid ValidRemoteEndPointState
vst -> do
            -- Release probing resources if probing.
            Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
            let code :: EventErrorCode
code = EndPointAddress -> EventErrorCode
EventConnectionLost (RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint)
            QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (Event -> IO ())
-> (TransportError EventErrorCode -> Event)
-> TransportError EventErrorCode
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TransportError EventErrorCode -> Event
ErrorEvent (TransportError EventErrorCode -> IO ())
-> TransportError EventErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ EventErrorCode -> HostName -> TransportError EventErrorCode
forall error. error -> HostName -> TransportError error
TransportError EventErrorCode
code (IOError -> HostName
forall a. Show a => a -> HostName
show IOError
err)
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> RemoteState
RemoteEndPointFailed IOError
err)
          RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst -> do
            -- Release probing resources if probing.
            Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
            MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> RemoteState
RemoteEndPointFailed IOError
err)
          RemoteState
RemoteEndPointClosed ->
            EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
              HostName
"handleIncomingMessages:prematureExit"
          RemoteEndPointFailed IOError
err' -> do
            -- Here we post a connection-lost event, but only if the
            -- local endpoint is not closed; if it's closed, the EndPointClosed
            -- event will be posted without connection-lost events, and this is
            -- part of the network-transport specification (there's a test
            -- case for it).
            MVar LocalEndPointState
-> (LocalEndPointState -> IO LocalEndPointState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar LocalEndPointState
ourState ((LocalEndPointState -> IO LocalEndPointState) -> IO ())
-> (LocalEndPointState -> IO LocalEndPointState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st' -> case LocalEndPointState
st' of
              LocalEndPointState
LocalEndPointClosed -> LocalEndPointState -> IO LocalEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPointState
st'
              LocalEndPointValid ValidLocalEndPointState
_ -> do
                let code :: EventErrorCode
code = EndPointAddress -> EventErrorCode
EventConnectionLost (RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint)
                    err :: TransportError EventErrorCode
err  = EventErrorCode -> HostName -> TransportError EventErrorCode
forall error. error -> HostName -> TransportError error
TransportError EventErrorCode
code (IOError -> HostName
forall a. Show a => a -> HostName
show IOError
err')
                QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
theirAddr (TransportError EventErrorCode -> Event
ErrorEvent TransportError EventErrorCode
err)
                LocalEndPointState -> IO LocalEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPointState
st'
            RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> RemoteState
RemoteEndPointFailed IOError
err')

    -- Construct a connection ID
    connId :: LightweightConnectionId -> ConnectionId
    connId :: HeavyweightConnectionId -> ConnectionId
connId = HeavyweightConnectionId -> HeavyweightConnectionId -> ConnectionId
createConnectionId (RemoteEndPoint -> HeavyweightConnectionId
remoteId RemoteEndPoint
theirEndPoint)

    -- The ID of the last connection _we_ created (or 0 for none)
    lastSentId :: ValidRemoteEndPointState -> LightweightConnectionId
    lastSentId :: ValidRemoteEndPointState -> HeavyweightConnectionId
lastSentId ValidRemoteEndPointState
vst =
      if ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteNextConnOutId HeavyweightConnectionId -> HeavyweightConnectionId -> Bool
forall a. Eq a => a -> a -> Bool
== HeavyweightConnectionId
firstNonReservedLightweightConnectionId
        then HeavyweightConnectionId
0
        else (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteNextConnOutId) HeavyweightConnectionId
-> HeavyweightConnectionId -> HeavyweightConnectionId
forall a. Num a => a -> a -> a
- HeavyweightConnectionId
1

--------------------------------------------------------------------------------
-- Uninterruptable auxiliary functions                                        --
--                                                                            --
-- All these functions assume they are running in a thread which will never   --
-- be killed.
--------------------------------------------------------------------------------

-- | Create a connection to a remote endpoint
--
-- If the remote endpoint is in 'RemoteEndPointClosing' state then we will
-- block until that is resolved.
--
-- May throw a TransportError ConnectErrorCode exception.
createConnectionTo
  :: TCPTransport
  -> LocalEndPoint
  -> EndPointAddress
  -> ConnectHints
  -> IO (RemoteEndPoint, LightweightConnectionId)
createConnectionTo :: TCPTransport
-> LocalEndPoint
-> EndPointAddress
-> ConnectHints
-> IO (RemoteEndPoint, HeavyweightConnectionId)
createConnectionTo TCPTransport
transport LocalEndPoint
ourEndPoint EndPointAddress
theirAddress ConnectHints
hints = do
    -- @timer@ is an IO action that completes when the timeout expires.
    Maybe (IO ())
timer <- case Maybe Int
connTimeout of
              Just Int
t -> do
                MVar ()
mv <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
                ThreadId
_ <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
t IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
mv ()
                Maybe (IO ()) -> IO (Maybe (IO ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IO ()) -> IO (Maybe (IO ())))
-> Maybe (IO ()) -> IO (Maybe (IO ()))
forall a b. (a -> b) -> a -> b
$ IO () -> Maybe (IO ())
forall a. a -> Maybe a
Just (IO () -> Maybe (IO ())) -> IO () -> Maybe (IO ())
forall a b. (a -> b) -> a -> b
$ MVar () -> IO ()
forall a. MVar a -> IO a
readMVar MVar ()
mv
              Maybe Int
_      -> Maybe (IO ()) -> IO (Maybe (IO ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IO ())
forall a. Maybe a
Nothing
    Maybe (IO ())
-> Maybe (RemoteEndPoint, ConnectionRequestResponse)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
go Maybe (IO ())
timer Maybe (RemoteEndPoint, ConnectionRequestResponse)
forall a. Maybe a
Nothing

  where

    params :: TCPParameters
params = TCPTransport -> TCPParameters
transportParams TCPTransport
transport
    connTimeout :: Maybe Int
connTimeout = ConnectHints -> Maybe Int
connectTimeout ConnectHints
hints Maybe Int -> Maybe Int -> Maybe Int
forall a. Maybe a -> Maybe a -> Maybe a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` TCPParameters -> Maybe Int
transportConnectTimeout TCPParameters
params

    -- The second argument indicates the response obtained to the last
    -- connection request and the remote endpoint that was used.
    go :: Maybe (IO ())
-> Maybe (RemoteEndPoint, ConnectionRequestResponse)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
go Maybe (IO ())
timer Maybe (RemoteEndPoint, ConnectionRequestResponse)
mr = do
      (RemoteEndPoint
theirEndPoint, Bool
isNew) <- (IOError -> TransportError ConnectErrorCode)
-> IO (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
connectFailed
        (LocalEndPoint
-> EndPointAddress
-> RequestedBy
-> Maybe (IO ())
-> IO (RemoteEndPoint, Bool)
findRemoteEndPoint LocalEndPoint
ourEndPoint EndPointAddress
theirAddress RequestedBy
RequestedByUs Maybe (IO ())
timer)
       IO (RemoteEndPoint, Bool) -> IO () -> IO (RemoteEndPoint, Bool)
forall a b. IO a -> IO b -> IO a
`finally` case Maybe (RemoteEndPoint, ConnectionRequestResponse)
mr of
         Just (RemoteEndPoint
theirEndPoint, ConnectionRequestResponse
ConnectionRequestCrossed) ->
           MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$
             \RemoteState
rst -> case RemoteState
rst of
               RemoteEndPointInit MVar ()
resolved MVar ()
_ RequestedBy
_ -> do
                 MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
                 EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
                 RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
RemoteEndPointClosed
               RemoteState
_ -> RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
rst
         Maybe (RemoteEndPoint, ConnectionRequestResponse)
_ -> () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      if Bool
isNew
        then do
          Maybe ConnectionRequestResponse
mr' <- (SomeException -> IO (Maybe ConnectionRequestResponse))
-> IO (Maybe ConnectionRequestResponse)
-> IO (Maybe ConnectionRequestResponse)
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle (Maybe ConnectionRequestResponse
-> SomeException -> IO (Maybe ConnectionRequestResponse)
forall a. a -> SomeException -> IO a
absorbAllExceptions Maybe ConnectionRequestResponse
forall a. Maybe a
Nothing) (IO (Maybe ConnectionRequestResponse)
 -> IO (Maybe ConnectionRequestResponse))
-> IO (Maybe ConnectionRequestResponse)
-> IO (Maybe ConnectionRequestResponse)
forall a b. (a -> b) -> a -> b
$
            TCPTransport
-> EndPointPair
-> Maybe Int
-> IO (Maybe ConnectionRequestResponse)
setupRemoteEndPoint TCPTransport
transport (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) Maybe Int
connTimeout
          Maybe (IO ())
-> Maybe (RemoteEndPoint, ConnectionRequestResponse)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
go Maybe (IO ())
timer ((ConnectionRequestResponse
 -> (RemoteEndPoint, ConnectionRequestResponse))
-> Maybe ConnectionRequestResponse
-> Maybe (RemoteEndPoint, ConnectionRequestResponse)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) RemoteEndPoint
theirEndPoint) Maybe ConnectionRequestResponse
mr')
        else do
          -- 'findRemoteEndPoint' will have increased 'remoteOutgoing'
          (IOError -> TransportError ConnectErrorCode)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
connectFailed (IO (RemoteEndPoint, HeavyweightConnectionId)
 -> IO (RemoteEndPoint, HeavyweightConnectionId))
-> IO (RemoteEndPoint, HeavyweightConnectionId)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
forall a b. (a -> b) -> a -> b
$ do
            Action HeavyweightConnectionId
act <- MVar RemoteState
-> (RemoteState
    -> IO (RemoteState, Action HeavyweightConnectionId))
-> IO (Action HeavyweightConnectionId)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO (RemoteState, Action HeavyweightConnectionId))
 -> IO (Action HeavyweightConnectionId))
-> (RemoteState
    -> IO (RemoteState, Action HeavyweightConnectionId))
-> IO (Action HeavyweightConnectionId)
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
              RemoteEndPointValid ValidRemoteEndPointState
vst -> do
                let connId :: HeavyweightConnectionId
connId = ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteNextConnOutId
                Action HeavyweightConnectionId
act <- RemoteEndPoint
-> IO HeavyweightConnectionId
-> IO (Action HeavyweightConnectionId)
forall a. RemoteEndPoint -> IO a -> IO (Action a)
schedule RemoteEndPoint
theirEndPoint (IO HeavyweightConnectionId -> IO (Action HeavyweightConnectionId))
-> IO HeavyweightConnectionId
-> IO (Action HeavyweightConnectionId)
forall a b. (a -> b) -> a -> b
$ do
                  ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst [
                      HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
CreatedNewConnection)
                    , HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
connId
                    ]
                  HeavyweightConnectionId -> IO HeavyweightConnectionId
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return HeavyweightConnectionId
connId
                (RemoteState, Action HeavyweightConnectionId)
-> IO (RemoteState, Action HeavyweightConnectionId)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid
                       (ValidRemoteEndPointState -> RemoteState)
-> ValidRemoteEndPointState -> RemoteState
forall a b. (a -> b) -> a -> b
$ Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteNextConnOutId Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> a -> r -> r
^= HeavyweightConnectionId
connId HeavyweightConnectionId
-> HeavyweightConnectionId -> HeavyweightConnectionId
forall a. Num a => a -> a -> a
+ HeavyweightConnectionId
1
                       (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState -> ValidRemoteEndPointState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst
                       , Action HeavyweightConnectionId
act
                       )
              -- Error cases
              RemoteEndPointInvalid TransportError ConnectErrorCode
err ->
                TransportError ConnectErrorCode
-> IO (RemoteState, Action HeavyweightConnectionId)
forall e a. Exception e => e -> IO a
throwIO TransportError ConnectErrorCode
err
              RemoteEndPointFailed IOError
err ->
                IOError -> IO (RemoteState, Action HeavyweightConnectionId)
forall e a. Exception e => e -> IO a
throwIO IOError
err
              -- Algorithmic errors
              RemoteState
_ ->
                EndPointPair
-> HostName -> IO (RemoteState, Action HeavyweightConnectionId)
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"createConnectionTo"
            -- TODO: deal with exception case?
            HeavyweightConnectionId
connId <- EndPointPair
-> Action HeavyweightConnectionId -> IO HeavyweightConnectionId
forall a. EndPointPair -> Action a -> IO a
runScheduledAction (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) Action HeavyweightConnectionId
act
            (RemoteEndPoint, HeavyweightConnectionId)
-> IO (RemoteEndPoint, HeavyweightConnectionId)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteEndPoint
theirEndPoint, HeavyweightConnectionId
connId)


    connectFailed :: IOException -> TransportError ConnectErrorCode
    connectFailed :: IOError -> TransportError ConnectErrorCode
connectFailed = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed (HostName -> TransportError ConnectErrorCode)
-> (IOError -> HostName)
-> IOError
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> HostName
forall a. Show a => a -> HostName
show

    absorbAllExceptions :: a -> SomeException -> IO a
    absorbAllExceptions :: forall a. a -> SomeException -> IO a
absorbAllExceptions a
a SomeException
_ex =
      a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- | Set up a remote endpoint
setupRemoteEndPoint
  :: TCPTransport
  -> EndPointPair
  -> Maybe Int
  -> IO (Maybe ConnectionRequestResponse)
setupRemoteEndPoint :: TCPTransport
-> EndPointPair
-> Maybe Int
-> IO (Maybe ConnectionRequestResponse)
setupRemoteEndPoint TCPTransport
transport (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) Maybe Int
connTimeout = do
    let mOurAddress :: Maybe EndPointAddress
mOurAddress = EndPointAddress -> TransportAddrInfo -> EndPointAddress
forall a b. a -> b -> a
const EndPointAddress
ourAddress (TransportAddrInfo -> EndPointAddress)
-> Maybe TransportAddrInfo -> Maybe EndPointAddress
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TCPTransport -> Maybe TransportAddrInfo
transportAddrInfo TCPTransport
transport
    Either
  (TransportError ConnectErrorCode)
  (MVar (), Socket, ConnectionRequestResponse)
result <- Maybe EndPointAddress
-> EndPointAddress
-> Bool
-> Bool
-> Bool
-> Maybe Int
-> Maybe Int
-> IO
     (Either
        (TransportError ConnectErrorCode)
        (MVar (), Socket, ConnectionRequestResponse))
socketToEndPoint Maybe EndPointAddress
mOurAddress
                               EndPointAddress
theirAddress
                               (TCPParameters -> Bool
tcpReuseClientAddr TCPParameters
params)
                               (TCPParameters -> Bool
tcpNoDelay TCPParameters
params)
                               (TCPParameters -> Bool
tcpKeepAlive TCPParameters
params)
                               (TCPParameters -> Maybe Int
tcpUserTimeout TCPParameters
params)
                               Maybe Int
connTimeout
    Maybe (MVar (), Socket)
didAccept <- case Either
  (TransportError ConnectErrorCode)
  (MVar (), Socket, ConnectionRequestResponse)
result of
      -- Since a socket was created, we are now responsible for closing it.
      --
      -- In case the connection was accepted, we have some work to do.
      -- We'll remember how to wait for the socket to close
      -- (readMVar socketClosedVar), and we'll take care of closing it up
      -- once handleIncomingMessages has finished.
      Right (MVar ()
socketClosedVar, Socket
sock, ConnectionRequestResponse
ConnectionRequestAccepted) -> do
        MVar (Maybe SomeException)
sendLock <- Maybe SomeException -> IO (MVar (Maybe SomeException))
forall a. a -> IO (MVar a)
newMVar Maybe SomeException
forall a. Maybe a
Nothing
        let vst :: ValidRemoteEndPointState
vst = ValidRemoteEndPointState
                    {  remoteSocket :: Socket
remoteSocket        = Socket
sock
                    ,  remoteSocketClosed :: IO ()
remoteSocketClosed  = MVar () -> IO ()
forall a. MVar a -> IO a
readMVar MVar ()
socketClosedVar
                    ,  remoteProbing :: Maybe (IO ())
remoteProbing       = Maybe (IO ())
forall a. Maybe a
Nothing
                    ,  remoteSendLock :: MVar (Maybe SomeException)
remoteSendLock      = MVar (Maybe SomeException)
sendLock
                    , _remoteOutgoing :: Int
_remoteOutgoing      = Int
0
                    , _remoteIncoming :: Set HeavyweightConnectionId
_remoteIncoming      = Set HeavyweightConnectionId
forall a. Set a
Set.empty
                    , _remoteLastIncoming :: HeavyweightConnectionId
_remoteLastIncoming  = HeavyweightConnectionId
0
                    , _remoteNextConnOutId :: HeavyweightConnectionId
_remoteNextConnOutId = HeavyweightConnectionId
firstNonReservedLightweightConnectionId
                    }
        EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst)
        Maybe (MVar (), Socket) -> IO (Maybe (MVar (), Socket))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((MVar (), Socket) -> Maybe (MVar (), Socket)
forall a. a -> Maybe a
Just (MVar ()
socketClosedVar, Socket
sock))
      Right (MVar ()
socketClosedVar, Socket
sock, ConnectionRequestResponse
ConnectionRequestUnsupportedVersion) -> do
        -- If the peer doesn't support V0 then there's nothing we can do, for
        -- it's the only version we support.
        let err :: TransportError ConnectErrorCode
err = HostName -> TransportError ConnectErrorCode
connectFailed HostName
"setupRemoteEndPoint: unsupported version"
        EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (TransportError ConnectErrorCode -> RemoteState
RemoteEndPointInvalid TransportError ConnectErrorCode
err)
        Socket -> IO ()
tryCloseSocket Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosedVar ()
        Maybe (MVar (), Socket) -> IO (Maybe (MVar (), Socket))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (MVar (), Socket)
forall a. Maybe a
Nothing
      Right (MVar ()
socketClosedVar, Socket
sock, ConnectionRequestResponse
ConnectionRequestInvalid) -> do
        let err :: TransportError ConnectErrorCode
err = HostName -> TransportError ConnectErrorCode
invalidAddress HostName
"setupRemoteEndPoint: Invalid endpoint"
        EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (TransportError ConnectErrorCode -> RemoteState
RemoteEndPointInvalid TransportError ConnectErrorCode
err)
        Socket -> IO ()
tryCloseSocket Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosedVar ()
        Maybe (MVar (), Socket) -> IO (Maybe (MVar (), Socket))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (MVar (), Socket)
forall a. Maybe a
Nothing
      Right (MVar ()
socketClosedVar, Socket
sock, ConnectionRequestResponse
ConnectionRequestCrossed) -> do
        MVar RemoteState -> (RemoteState -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO ()) -> IO ())
-> (RemoteState -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
          RemoteEndPointInit MVar ()
_ MVar ()
crossed RequestedBy
_ ->
            MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
crossed ()
          RemoteEndPointFailed IOError
ex ->
            IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO IOError
ex
          RemoteState
_ ->
            EndPointPair -> HostName -> IO ()
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"setupRemoteEndPoint: Crossed"
        Socket -> IO ()
tryCloseSocket Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosedVar ()
        Maybe (MVar (), Socket) -> IO (Maybe (MVar (), Socket))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (MVar (), Socket)
forall a. Maybe a
Nothing
      Right (MVar ()
socketClosedVar, Socket
sock, ConnectionRequestResponse
ConnectionRequestHostMismatch) -> do
        let handler :: SomeException -> IO (TransportError ConnectErrorCode)
            handler :: SomeException -> IO (TransportError ConnectErrorCode)
handler SomeException
err = TransportError ConnectErrorCode
-> IO (TransportError ConnectErrorCode)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed (SomeException -> HostName
forall a. Show a => a -> HostName
show SomeException
err))
        TransportError ConnectErrorCode
err <- (SomeException -> IO (TransportError ConnectErrorCode))
-> IO (TransportError ConnectErrorCode)
-> IO (TransportError ConnectErrorCode)
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
handle SomeException -> IO (TransportError ConnectErrorCode)
handler (IO (TransportError ConnectErrorCode)
 -> IO (TransportError ConnectErrorCode))
-> IO (TransportError ConnectErrorCode)
-> IO (TransportError ConnectErrorCode)
forall a b. (a -> b) -> a -> b
$ do
          [ByteString]
claimedHost <- HeavyweightConnectionId -> Socket -> IO [ByteString]
recvWithLength (TCPParameters -> HeavyweightConnectionId
tcpMaxReceiveLength TCPParameters
params) Socket
sock
          [ByteString]
actualNumericHost <- HeavyweightConnectionId -> Socket -> IO [ByteString]
recvWithLength (TCPParameters -> HeavyweightConnectionId
tcpMaxReceiveLength TCPParameters
params) Socket
sock
          [ByteString]
actualResolvedHost <- HeavyweightConnectionId -> Socket -> IO [ByteString]
recvWithLength (TCPParameters -> HeavyweightConnectionId
tcpMaxReceiveLength TCPParameters
params) Socket
sock
          let reason :: HostName
reason = [HostName] -> HostName
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
                  HostName
"setupRemoteEndPoint: Host mismatch"
                , HostName
". Claimed: "
                , ByteString -> HostName
BSC.unpack ([ByteString] -> ByteString
BS.concat [ByteString]
claimedHost)
                , HostName
"; Numeric: "
                , ByteString -> HostName
BSC.unpack ([ByteString] -> ByteString
BS.concat [ByteString]
actualNumericHost)
                , HostName
"; Resolved: "
                , ByteString -> HostName
BSC.unpack ([ByteString] -> ByteString
BS.concat [ByteString]
actualResolvedHost)
                ]
          TransportError ConnectErrorCode
-> IO (TransportError ConnectErrorCode)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed HostName
reason)
        EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (TransportError ConnectErrorCode -> RemoteState
RemoteEndPointInvalid TransportError ConnectErrorCode
err)
        Socket -> IO ()
tryCloseSocket Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosedVar ()
        Maybe (MVar (), Socket) -> IO (Maybe (MVar (), Socket))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (MVar (), Socket)
forall a. Maybe a
Nothing
      Left TransportError ConnectErrorCode
err -> do
        EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (TransportError ConnectErrorCode -> RemoteState
RemoteEndPointInvalid TransportError ConnectErrorCode
err)
        Maybe (MVar (), Socket) -> IO (Maybe (MVar (), Socket))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (MVar (), Socket)
forall a. Maybe a
Nothing

    -- We handle incoming messages in a separate thread, and are careful to
    -- always close the socket once that thread is finished.
    Maybe (MVar (), Socket) -> ((MVar (), Socket) -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe (MVar (), Socket)
didAccept (((MVar (), Socket) -> IO ()) -> IO ())
-> ((MVar (), Socket) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \(MVar ()
socketClosed, Socket
sock) -> IO ThreadId -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$
      TCPParameters -> EndPointPair -> IO ()
handleIncomingMessages TCPParameters
params (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
      IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally`
      (Socket -> IO ()
tryCloseSocket Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosed ())
    Maybe ConnectionRequestResponse
-> IO (Maybe ConnectionRequestResponse)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe ConnectionRequestResponse
 -> IO (Maybe ConnectionRequestResponse))
-> Maybe ConnectionRequestResponse
-> IO (Maybe ConnectionRequestResponse)
forall a b. (a -> b) -> a -> b
$ (TransportError ConnectErrorCode
 -> Maybe ConnectionRequestResponse)
-> ((MVar (), Socket, ConnectionRequestResponse)
    -> Maybe ConnectionRequestResponse)
-> Either
     (TransportError ConnectErrorCode)
     (MVar (), Socket, ConnectionRequestResponse)
-> Maybe ConnectionRequestResponse
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (Maybe ConnectionRequestResponse
-> TransportError ConnectErrorCode
-> Maybe ConnectionRequestResponse
forall a b. a -> b -> a
const Maybe ConnectionRequestResponse
forall a. Maybe a
Nothing) (ConnectionRequestResponse -> Maybe ConnectionRequestResponse
forall a. a -> Maybe a
Just (ConnectionRequestResponse -> Maybe ConnectionRequestResponse)
-> ((MVar (), Socket, ConnectionRequestResponse)
    -> ConnectionRequestResponse)
-> (MVar (), Socket, ConnectionRequestResponse)
-> Maybe ConnectionRequestResponse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\(MVar ()
_,Socket
_,ConnectionRequestResponse
x) -> ConnectionRequestResponse
x)) Either
  (TransportError ConnectErrorCode)
  (MVar (), Socket, ConnectionRequestResponse)
result
  where
    params :: TCPParameters
params          = TCPTransport -> TCPParameters
transportParams TCPTransport
transport
    ourAddress :: EndPointAddress
ourAddress      = LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint
    theirAddress :: EndPointAddress
theirAddress    = RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint
    invalidAddress :: HostName -> TransportError ConnectErrorCode
invalidAddress  = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectNotFound
    connectFailed :: HostName -> TransportError ConnectErrorCode
connectFailed   = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed

-- | Send a CloseSocket request if the remote endpoint is unused
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused :: EndPointPair -> IO ()
closeIfUnused (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) = do
  Maybe (Action ())
mAct <- MVar RemoteState
-> (RemoteState -> IO (RemoteState, Maybe (Action ())))
-> IO (Maybe (Action ()))
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO (RemoteState, Maybe (Action ())))
 -> IO (Maybe (Action ())))
-> (RemoteState -> IO (RemoteState, Maybe (Action ())))
-> IO (Maybe (Action ()))
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
    RemoteEndPointValid ValidRemoteEndPointState
vst ->
      if ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState Int -> Int
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState Int
remoteOutgoing Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
&& Set HeavyweightConnectionId -> Bool
forall a. Set a -> Bool
Set.null (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> T ValidRemoteEndPointState (Set HeavyweightConnectionId)
-> Set HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming)
        then do
          MVar ()
resolved <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
          Action ()
act <- RemoteEndPoint -> IO () -> IO (Action ())
forall a. RemoteEndPoint -> IO a -> IO (Action a)
schedule RemoteEndPoint
theirEndPoint (IO () -> IO (Action ())) -> IO () -> IO (Action ())
forall a b. (a -> b) -> a -> b
$
            ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst [ HeavyweightConnectionId -> ByteString
encodeWord32 (ControlHeader -> HeavyweightConnectionId
encodeControlHeader ControlHeader
CloseSocket)
                       , HeavyweightConnectionId -> ByteString
encodeWord32 (ValidRemoteEndPointState
vst ValidRemoteEndPointState
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteLastIncoming)
                       ]
          (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MVar () -> ValidRemoteEndPointState -> RemoteState
RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
vst, Action () -> Maybe (Action ())
forall a. a -> Maybe a
Just Action ()
act)
        else
          (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid ValidRemoteEndPointState
vst, Maybe (Action ())
forall a. Maybe a
Nothing)
    RemoteState
_ ->
      (RemoteState, Maybe (Action ()))
-> IO (RemoteState, Maybe (Action ()))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState
st, Maybe (Action ())
forall a. Maybe a
Nothing)
  Maybe (Action ()) -> (Action () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe (Action ())
mAct ((Action () -> IO ()) -> IO ()) -> (Action () -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ EndPointPair -> Action () -> IO ()
forall a. EndPointPair -> Action a -> IO a
runScheduledAction (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)

-- | Reset a remote endpoint if it is in Invalid mode
--
-- If the remote endpoint is currently in broken state, and
--
--   - a user calls the API function 'connect', or and the remote endpoint is
--   - an inbound connection request comes in from this remote address
--
-- we remove the remote endpoint first.
--
-- Throws a TransportError ConnectFailed exception if the local endpoint is
-- closed.
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken :: LocalEndPoint -> EndPointAddress -> IO ()
resetIfBroken LocalEndPoint
ourEndPoint EndPointAddress
theirAddress = do
  Maybe RemoteEndPoint
mTheirEndPoint <- MVar LocalEndPointState
-> (LocalEndPointState -> IO (Maybe RemoteEndPoint))
-> IO (Maybe RemoteEndPoint)
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint) ((LocalEndPointState -> IO (Maybe RemoteEndPoint))
 -> IO (Maybe RemoteEndPoint))
-> (LocalEndPointState -> IO (Maybe RemoteEndPoint))
-> IO (Maybe RemoteEndPoint)
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
    LocalEndPointValid ValidLocalEndPointState
vst ->
      Maybe RemoteEndPoint -> IO (Maybe RemoteEndPoint)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState (Maybe RemoteEndPoint)
-> Maybe RemoteEndPoint
forall r a. r -> T r a -> a
^. EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo EndPointAddress
theirAddress)
    LocalEndPointState
LocalEndPointClosed ->
      TransportError ConnectErrorCode -> IO (Maybe RemoteEndPoint)
forall e a. Exception e => e -> IO a
throwIO (TransportError ConnectErrorCode -> IO (Maybe RemoteEndPoint))
-> TransportError ConnectErrorCode -> IO (Maybe RemoteEndPoint)
forall a b. (a -> b) -> a -> b
$ ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed HostName
"Endpoint closed"
  Maybe RemoteEndPoint -> (RemoteEndPoint -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe RemoteEndPoint
mTheirEndPoint ((RemoteEndPoint -> IO ()) -> IO ())
-> (RemoteEndPoint -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteEndPoint
theirEndPoint ->
    MVar RemoteState -> (RemoteState -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO ()) -> IO ())
-> (RemoteState -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
      RemoteEndPointInvalid TransportError ConnectErrorCode
_ ->
        EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
      RemoteEndPointFailed IOError
_ ->
        EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
      RemoteState
_ ->
        () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Special case of 'apiConnect': connect an endpoint to itself
--
-- May throw a TransportError ConnectErrorCode (if the local endpoint is closed)
connectToSelf :: LocalEndPoint
               -> IO Connection
connectToSelf :: LocalEndPoint -> IO Connection
connectToSelf LocalEndPoint
ourEndPoint = do
    IORef Bool
connAlive <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
True  -- Protected by the local endpoint lock
    HeavyweightConnectionId
lconnId   <- (IOError -> TransportError ConnectErrorCode)
-> IO HeavyweightConnectionId -> IO HeavyweightConnectionId
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
connectFailed (IO HeavyweightConnectionId -> IO HeavyweightConnectionId)
-> IO HeavyweightConnectionId -> IO HeavyweightConnectionId
forall a b. (a -> b) -> a -> b
$ LocalEndPoint -> IO HeavyweightConnectionId
getLocalNextConnOutId LocalEndPoint
ourEndPoint
    let connId :: ConnectionId
connId = HeavyweightConnectionId -> HeavyweightConnectionId -> ConnectionId
createConnectionId HeavyweightConnectionId
heavyweightSelfConnectionId HeavyweightConnectionId
lconnId
    QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
ourAddress (Event -> IO ()) -> Event -> IO ()
forall a b. (a -> b) -> a -> b
$
      ConnectionId -> Reliability -> EndPointAddress -> Event
ConnectionOpened ConnectionId
connId Reliability
ReliableOrdered (LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint)
    Connection -> IO Connection
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Connection
      { send :: [ByteString] -> IO (Either (TransportError SendErrorCode) ())
send  = IORef Bool
-> ConnectionId
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
selfSend IORef Bool
connAlive ConnectionId
connId
      , close :: IO ()
close = IORef Bool -> ConnectionId -> IO ()
selfClose IORef Bool
connAlive ConnectionId
connId
      }
  where
    selfSend :: IORef Bool
             -> ConnectionId
             -> [ByteString]
             -> IO (Either (TransportError SendErrorCode) ())
    selfSend :: IORef Bool
-> ConnectionId
-> [ByteString]
-> IO (Either (TransportError SendErrorCode) ())
selfSend IORef Bool
connAlive ConnectionId
connId [ByteString]
msg =
      IO () -> IO (Either (TransportError SendErrorCode) ())
forall e a. Exception e => IO a -> IO (Either e a)
try (IO () -> IO (Either (TransportError SendErrorCode) ()))
-> ((LocalEndPointState -> IO ()) -> IO ())
-> (LocalEndPointState -> IO ())
-> IO (Either (TransportError SendErrorCode) ())
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MVar LocalEndPointState -> (LocalEndPointState -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar LocalEndPointState
ourState ((LocalEndPointState -> IO ())
 -> IO (Either (TransportError SendErrorCode) ()))
-> (LocalEndPointState -> IO ())
-> IO (Either (TransportError SendErrorCode) ())
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
        LocalEndPointValid ValidLocalEndPointState
_ -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          if Bool
alive
            then ()
-> (QDisc Event -> EndPointAddress -> Event -> IO ())
-> QDisc Event
-> EndPointAddress
-> Event
-> IO ()
forall a b. a -> b -> b
seq ((ByteString -> () -> ()) -> () -> [ByteString] -> ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ByteString -> () -> ()
forall a b. a -> b -> b
seq () [ByteString]
msg)
                   QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
ourAddress (ConnectionId -> [ByteString] -> Event
Received ConnectionId
connId [ByteString]
msg)
            else TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendClosed HostName
"Connection closed"
        LocalEndPointState
LocalEndPointClosed ->
          TransportError SendErrorCode -> IO ()
forall e a. Exception e => e -> IO a
throwIO (TransportError SendErrorCode -> IO ())
-> TransportError SendErrorCode -> IO ()
forall a b. (a -> b) -> a -> b
$ SendErrorCode -> HostName -> TransportError SendErrorCode
forall error. error -> HostName -> TransportError error
TransportError SendErrorCode
SendFailed HostName
"Endpoint closed"

    selfClose :: IORef Bool -> ConnectionId -> IO ()
    selfClose :: IORef Bool -> ConnectionId -> IO ()
selfClose IORef Bool
connAlive ConnectionId
connId =
      MVar LocalEndPointState -> (LocalEndPointState -> IO ()) -> IO ()
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar MVar LocalEndPointState
ourState ((LocalEndPointState -> IO ()) -> IO ())
-> (LocalEndPointState -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
        LocalEndPointValid ValidLocalEndPointState
_ -> do
          Bool
alive <- IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
connAlive
          Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
alive (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
            QDisc Event -> EndPointAddress -> Event -> IO ()
qdiscEnqueue' QDisc Event
ourQueue EndPointAddress
ourAddress (ConnectionId -> Event
ConnectionClosed ConnectionId
connId)
            IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Bool
connAlive Bool
False
        LocalEndPointState
LocalEndPointClosed ->
          () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    ourQueue :: QDisc Event
ourQueue = LocalEndPoint -> QDisc Event
localQueue LocalEndPoint
ourEndPoint
    ourState :: MVar LocalEndPointState
ourState = LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint
    connectFailed :: IOError -> TransportError ConnectErrorCode
connectFailed = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed (HostName -> TransportError ConnectErrorCode)
-> (IOError -> HostName)
-> IOError
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> HostName
forall a. Show a => a -> HostName
show
    ourAddress :: EndPointAddress
ourAddress = LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint

-- | Resolve an endpoint currently in 'Init' state
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit :: EndPointPair -> RemoteState -> IO ()
resolveInit (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) RemoteState
newState =
  MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
    RemoteEndPointInit MVar ()
resolved MVar ()
crossed RequestedBy
_ -> do
      MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
resolved ()
      -- Unblock the reader (if any) if the ConnectionRequestCrossed
      -- message did not come within the connection timeout.
      MVar () -> () -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar MVar ()
crossed ()
      case RemoteState
newState of
        RemoteState
RemoteEndPointClosed ->
          EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint)
        RemoteState
_ ->
          () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
      RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteState
newState
    RemoteEndPointFailed IOError
ex ->
      IOError -> IO RemoteState
forall e a. Exception e => e -> IO a
throwIO IOError
ex
    RemoteState
_ ->
      EndPointPair -> HostName -> IO RemoteState
forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
"resolveInit"

-- | Get the next outgoing self-connection ID
--
-- Throws an IO exception when the endpoint is closed.
getLocalNextConnOutId :: LocalEndPoint -> IO LightweightConnectionId
getLocalNextConnOutId :: LocalEndPoint -> IO HeavyweightConnectionId
getLocalNextConnOutId LocalEndPoint
ourEndpoint =
  MVar LocalEndPointState
-> (LocalEndPointState
    -> IO (LocalEndPointState, HeavyweightConnectionId))
-> IO HeavyweightConnectionId
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndpoint) ((LocalEndPointState
  -> IO (LocalEndPointState, HeavyweightConnectionId))
 -> IO HeavyweightConnectionId)
-> (LocalEndPointState
    -> IO (LocalEndPointState, HeavyweightConnectionId))
-> IO HeavyweightConnectionId
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
    LocalEndPointValid ValidLocalEndPointState
vst -> do
      let connId :: HeavyweightConnectionId
connId = ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidLocalEndPointState HeavyweightConnectionId
localNextConnOutId
      (LocalEndPointState, HeavyweightConnectionId)
-> IO (LocalEndPointState, HeavyweightConnectionId)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidLocalEndPointState -> LocalEndPointState
LocalEndPointValid
             (ValidLocalEndPointState -> LocalEndPointState)
-> (ValidLocalEndPointState -> ValidLocalEndPointState)
-> ValidLocalEndPointState
-> LocalEndPointState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (T ValidLocalEndPointState HeavyweightConnectionId
localNextConnOutId T ValidLocalEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
-> ValidLocalEndPointState
-> ValidLocalEndPointState
forall r a. T r a -> a -> r -> r
^= HeavyweightConnectionId
connId HeavyweightConnectionId
-> HeavyweightConnectionId -> HeavyweightConnectionId
forall a. Num a => a -> a -> a
+ HeavyweightConnectionId
1)
             (ValidLocalEndPointState -> LocalEndPointState)
-> ValidLocalEndPointState -> LocalEndPointState
forall a b. (a -> b) -> a -> b
$ ValidLocalEndPointState
vst
             , HeavyweightConnectionId
connId)
    LocalEndPointState
LocalEndPointClosed ->
      IOError -> IO (LocalEndPointState, HeavyweightConnectionId)
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO (LocalEndPointState, HeavyweightConnectionId))
-> IOError -> IO (LocalEndPointState, HeavyweightConnectionId)
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Local endpoint closed"

-- | Create a new local endpoint
--
-- May throw a TransportError NewEndPointErrorCode exception if the transport
-- is closed.
createLocalEndPoint :: TCPTransport
                    -> QDisc Event
                    -> IO LocalEndPoint
createLocalEndPoint :: TCPTransport -> QDisc Event -> IO LocalEndPoint
createLocalEndPoint TCPTransport
transport QDisc Event
qdisc = do
    MVar LocalEndPointState
state <- LocalEndPointState -> IO (MVar LocalEndPointState)
forall a. a -> IO (MVar a)
newMVar (LocalEndPointState -> IO (MVar LocalEndPointState))
-> (ValidLocalEndPointState -> LocalEndPointState)
-> ValidLocalEndPointState
-> IO (MVar LocalEndPointState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValidLocalEndPointState -> LocalEndPointState
LocalEndPointValid (ValidLocalEndPointState -> IO (MVar LocalEndPointState))
-> ValidLocalEndPointState -> IO (MVar LocalEndPointState)
forall a b. (a -> b) -> a -> b
$ ValidLocalEndPointState
      { _localNextConnOutId :: HeavyweightConnectionId
_localNextConnOutId = HeavyweightConnectionId
firstNonReservedLightweightConnectionId
      , _localConnections :: Map EndPointAddress RemoteEndPoint
_localConnections   = Map EndPointAddress RemoteEndPoint
forall k a. Map k a
Map.empty
      , _nextConnInId :: HeavyweightConnectionId
_nextConnInId       = HeavyweightConnectionId
firstNonReservedHeavyweightConnectionId
      }
    MVar TransportState
-> (TransportState -> IO (TransportState, LocalEndPoint))
-> IO LocalEndPoint
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (TCPTransport -> MVar TransportState
transportState TCPTransport
transport) ((TransportState -> IO (TransportState, LocalEndPoint))
 -> IO LocalEndPoint)
-> (TransportState -> IO (TransportState, LocalEndPoint))
-> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ \TransportState
st -> case TransportState
st of
      TransportValid ValidTransportState
vst -> do
        let ix :: HeavyweightConnectionId
ix   = ValidTransportState
vst ValidTransportState
-> T ValidTransportState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidTransportState HeavyweightConnectionId
nextEndPointId
        EndPointAddress
addr <- case TCPTransport -> Maybe TransportAddrInfo
transportAddrInfo TCPTransport
transport of
          Maybe TransportAddrInfo
Nothing -> IO EndPointAddress
randomEndPointAddress
          Just TransportAddrInfo
addrInfo -> EndPointAddress -> IO EndPointAddress
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (EndPointAddress -> IO EndPointAddress)
-> EndPointAddress -> IO EndPointAddress
forall a b. (a -> b) -> a -> b
$
            HostName -> HostName -> HeavyweightConnectionId -> EndPointAddress
encodeEndPointAddress (TransportAddrInfo -> HostName
transportHost TransportAddrInfo
addrInfo)
                                  (TransportAddrInfo -> HostName
transportPort TransportAddrInfo
addrInfo)
                                  HeavyweightConnectionId
ix
        let localEndPoint :: LocalEndPoint
localEndPoint = LocalEndPoint { localAddress :: EndPointAddress
localAddress    = EndPointAddress
addr
                                          , localEndPointId :: HeavyweightConnectionId
localEndPointId = HeavyweightConnectionId
ix
                                          , localQueue :: QDisc Event
localQueue      = QDisc Event
qdisc
                                          , localState :: MVar LocalEndPointState
localState      = MVar LocalEndPointState
state
                                          }
        (TransportState, LocalEndPoint)
-> IO (TransportState, LocalEndPoint)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidTransportState -> TransportState
TransportValid
               (ValidTransportState -> TransportState)
-> (ValidTransportState -> ValidTransportState)
-> ValidTransportState
-> TransportState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeavyweightConnectionId
-> T ValidTransportState (Maybe LocalEndPoint)
localEndPointAt HeavyweightConnectionId
ix T ValidTransportState (Maybe LocalEndPoint)
-> Maybe LocalEndPoint
-> ValidTransportState
-> ValidTransportState
forall r a. T r a -> a -> r -> r
^= LocalEndPoint -> Maybe LocalEndPoint
forall a. a -> Maybe a
Just LocalEndPoint
localEndPoint)
               (ValidTransportState -> ValidTransportState)
-> (ValidTransportState -> ValidTransportState)
-> ValidTransportState
-> ValidTransportState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (T ValidTransportState HeavyweightConnectionId
nextEndPointId T ValidTransportState HeavyweightConnectionId
-> HeavyweightConnectionId
-> ValidTransportState
-> ValidTransportState
forall r a. T r a -> a -> r -> r
^= HeavyweightConnectionId
ix HeavyweightConnectionId
-> HeavyweightConnectionId -> HeavyweightConnectionId
forall a. Num a => a -> a -> a
+ HeavyweightConnectionId
1)
               (ValidTransportState -> TransportState)
-> ValidTransportState -> TransportState
forall a b. (a -> b) -> a -> b
$ ValidTransportState
vst
               , LocalEndPoint
localEndPoint
               )
      TransportState
TransportClosed ->
        TransportError NewEndPointErrorCode
-> IO (TransportState, LocalEndPoint)
forall e a. Exception e => e -> IO a
throwIO (NewEndPointErrorCode
-> HostName -> TransportError NewEndPointErrorCode
forall error. error -> HostName -> TransportError error
TransportError NewEndPointErrorCode
NewEndPointFailed HostName
"Transport closed")


-- | Remove reference to a remote endpoint from a local endpoint
--
-- If the local endpoint is closed, do nothing
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint :: EndPointPair -> IO ()
removeRemoteEndPoint (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) =
    MVar LocalEndPointState
-> (LocalEndPointState -> IO LocalEndPointState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ MVar LocalEndPointState
ourState ((LocalEndPointState -> IO LocalEndPointState) -> IO ())
-> (LocalEndPointState -> IO LocalEndPointState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
      LocalEndPointValid ValidLocalEndPointState
vst ->
        case ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState (Maybe RemoteEndPoint)
-> Maybe RemoteEndPoint
forall r a. r -> T r a -> a
^. EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo EndPointAddress
theirAddress of
          Maybe RemoteEndPoint
Nothing ->
            LocalEndPointState -> IO LocalEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPointState
st
          Just RemoteEndPoint
remoteEndPoint' ->
            if RemoteEndPoint -> HeavyweightConnectionId
remoteId RemoteEndPoint
remoteEndPoint' HeavyweightConnectionId -> HeavyweightConnectionId -> Bool
forall a. Eq a => a -> a -> Bool
== RemoteEndPoint -> HeavyweightConnectionId
remoteId RemoteEndPoint
theirEndPoint
              then LocalEndPointState -> IO LocalEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
                ( ValidLocalEndPointState -> LocalEndPointState
LocalEndPointValid
                (ValidLocalEndPointState -> LocalEndPointState)
-> (ValidLocalEndPointState -> ValidLocalEndPointState)
-> ValidLocalEndPointState
-> LocalEndPointState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo (RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint) T ValidLocalEndPointState (Maybe RemoteEndPoint)
-> Maybe RemoteEndPoint
-> ValidLocalEndPointState
-> ValidLocalEndPointState
forall r a. T r a -> a -> r -> r
^= Maybe RemoteEndPoint
forall a. Maybe a
Nothing)
                (ValidLocalEndPointState -> LocalEndPointState)
-> ValidLocalEndPointState -> LocalEndPointState
forall a b. (a -> b) -> a -> b
$ ValidLocalEndPointState
vst
                )
              else LocalEndPointState -> IO LocalEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPointState
st
      LocalEndPointState
LocalEndPointClosed ->
        LocalEndPointState -> IO LocalEndPointState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPointState
LocalEndPointClosed
  where
    ourState :: MVar LocalEndPointState
ourState     = LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint
    theirAddress :: EndPointAddress
theirAddress = RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint

-- | Remove reference to a local endpoint from the transport state
--
-- Does nothing if the transport is closed
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint :: TCPTransport -> LocalEndPoint -> IO ()
removeLocalEndPoint TCPTransport
transport LocalEndPoint
ourEndPoint =
  MVar TransportState
-> (TransportState -> IO TransportState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (TCPTransport -> MVar TransportState
transportState TCPTransport
transport) ((TransportState -> IO TransportState) -> IO ())
-> (TransportState -> IO TransportState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \TransportState
st -> case TransportState
st of
    TransportValid ValidTransportState
vst ->
      TransportState -> IO TransportState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidTransportState -> TransportState
TransportValid
             (ValidTransportState -> TransportState)
-> (ValidTransportState -> ValidTransportState)
-> ValidTransportState
-> TransportState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeavyweightConnectionId
-> T ValidTransportState (Maybe LocalEndPoint)
localEndPointAt (LocalEndPoint -> HeavyweightConnectionId
localEndPointId LocalEndPoint
ourEndPoint) T ValidTransportState (Maybe LocalEndPoint)
-> Maybe LocalEndPoint
-> ValidTransportState
-> ValidTransportState
forall r a. T r a -> a -> r -> r
^= Maybe LocalEndPoint
forall a. Maybe a
Nothing)
             (ValidTransportState -> TransportState)
-> ValidTransportState -> TransportState
forall a b. (a -> b) -> a -> b
$ ValidTransportState
vst
             )
    TransportState
TransportClosed ->
      TransportState -> IO TransportState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return TransportState
TransportClosed

-- | Find a remote endpoint. If the remote endpoint does not yet exist we
-- create it in Init state. Returns if the endpoint was new, or 'Nothing' if
-- it times out.
findRemoteEndPoint
  :: LocalEndPoint
  -> EndPointAddress
  -> RequestedBy
  -> Maybe (IO ())           -- ^ an action which completes when the time is up
  -> IO (RemoteEndPoint, Bool)
findRemoteEndPoint :: LocalEndPoint
-> EndPointAddress
-> RequestedBy
-> Maybe (IO ())
-> IO (RemoteEndPoint, Bool)
findRemoteEndPoint LocalEndPoint
ourEndPoint EndPointAddress
theirAddress RequestedBy
findOrigin Maybe (IO ())
mtimer = IO (RemoteEndPoint, Bool)
go
  where
    go :: IO (RemoteEndPoint, Bool)
go = do
      (RemoteEndPoint
theirEndPoint, Bool
isNew) <- MVar LocalEndPointState
-> (LocalEndPointState
    -> IO (LocalEndPointState, (RemoteEndPoint, Bool)))
-> IO (RemoteEndPoint, Bool)
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar LocalEndPointState
ourState ((LocalEndPointState
  -> IO (LocalEndPointState, (RemoteEndPoint, Bool)))
 -> IO (RemoteEndPoint, Bool))
-> (LocalEndPointState
    -> IO (LocalEndPointState, (RemoteEndPoint, Bool)))
-> IO (RemoteEndPoint, Bool)
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
        LocalEndPointValid ValidLocalEndPointState
vst -> case ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState (Maybe RemoteEndPoint)
-> Maybe RemoteEndPoint
forall r a. r -> T r a -> a
^. EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo EndPointAddress
theirAddress of
          Just RemoteEndPoint
theirEndPoint ->
            (LocalEndPointState, (RemoteEndPoint, Bool))
-> IO (LocalEndPointState, (RemoteEndPoint, Bool))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (LocalEndPointState
st, (RemoteEndPoint
theirEndPoint, Bool
False))
          Maybe RemoteEndPoint
Nothing -> do
            MVar ()
resolved   <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
            MVar ()
crossed    <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
            MVar RemoteState
theirState <- RemoteState -> IO (MVar RemoteState)
forall a. a -> IO (MVar a)
newMVar (MVar () -> MVar () -> RequestedBy -> RemoteState
RemoteEndPointInit MVar ()
resolved MVar ()
crossed RequestedBy
findOrigin)
            Chan (IO ())
scheduled  <- IO (Chan (IO ()))
forall a. IO (Chan a)
newChan
            let theirEndPoint :: RemoteEndPoint
theirEndPoint = RemoteEndPoint
                                  { remoteAddress :: EndPointAddress
remoteAddress   = EndPointAddress
theirAddress
                                  , remoteState :: MVar RemoteState
remoteState     = MVar RemoteState
theirState
                                  , remoteId :: HeavyweightConnectionId
remoteId        = ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState HeavyweightConnectionId
-> HeavyweightConnectionId
forall r a. r -> T r a -> a
^. T ValidLocalEndPointState HeavyweightConnectionId
nextConnInId
                                  , remoteScheduled :: Chan (IO ())
remoteScheduled = Chan (IO ())
scheduled
                                  }
            (LocalEndPointState, (RemoteEndPoint, Bool))
-> IO (LocalEndPointState, (RemoteEndPoint, Bool))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ( ValidLocalEndPointState -> LocalEndPointState
LocalEndPointValid
                   (ValidLocalEndPointState -> LocalEndPointState)
-> (ValidLocalEndPointState -> ValidLocalEndPointState)
-> ValidLocalEndPointState
-> LocalEndPointState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo EndPointAddress
theirAddress T ValidLocalEndPointState (Maybe RemoteEndPoint)
-> Maybe RemoteEndPoint
-> ValidLocalEndPointState
-> ValidLocalEndPointState
forall r a. T r a -> a -> r -> r
^= RemoteEndPoint -> Maybe RemoteEndPoint
forall a. a -> Maybe a
Just RemoteEndPoint
theirEndPoint)
                   (ValidLocalEndPointState -> ValidLocalEndPointState)
-> (ValidLocalEndPointState -> ValidLocalEndPointState)
-> ValidLocalEndPointState
-> ValidLocalEndPointState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (T ValidLocalEndPointState HeavyweightConnectionId
nextConnInId T ValidLocalEndPointState HeavyweightConnectionId
-> (HeavyweightConnectionId -> HeavyweightConnectionId)
-> ValidLocalEndPointState
-> ValidLocalEndPointState
forall r a. T r a -> (a -> a) -> r -> r
^: (HeavyweightConnectionId
-> HeavyweightConnectionId -> HeavyweightConnectionId
forall a. Num a => a -> a -> a
+ HeavyweightConnectionId
1))
                   (ValidLocalEndPointState -> LocalEndPointState)
-> ValidLocalEndPointState -> LocalEndPointState
forall a b. (a -> b) -> a -> b
$ ValidLocalEndPointState
vst
                   , (RemoteEndPoint
theirEndPoint, Bool
True)
                   )
        LocalEndPointState
LocalEndPointClosed ->
          IOError -> IO (LocalEndPointState, (RemoteEndPoint, Bool))
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO (LocalEndPointState, (RemoteEndPoint, Bool)))
-> IOError -> IO (LocalEndPointState, (RemoteEndPoint, Bool))
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Local endpoint closed"

      if Bool
isNew
        then
          (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteEndPoint
theirEndPoint, Bool
True)
        else do
          let theirState :: MVar RemoteState
theirState = RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint
          RemoteState
snapshot <- MVar RemoteState
-> (RemoteState -> IO (RemoteState, RemoteState)) -> IO RemoteState
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar RemoteState
theirState ((RemoteState -> IO (RemoteState, RemoteState)) -> IO RemoteState)
-> (RemoteState -> IO (RemoteState, RemoteState)) -> IO RemoteState
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
            RemoteEndPointValid ValidRemoteEndPointState
vst ->
              case RequestedBy
findOrigin of
                RequestedBy
RequestedByUs -> do
                  let st' :: RemoteState
st' = ValidRemoteEndPointState -> RemoteState
RemoteEndPointValid
                          (ValidRemoteEndPointState -> RemoteState)
-> (ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> ValidRemoteEndPointState
-> RemoteState
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Accessor ValidRemoteEndPointState Int
remoteOutgoing Accessor ValidRemoteEndPointState Int
-> (Int -> Int)
-> ValidRemoteEndPointState
-> ValidRemoteEndPointState
forall r a. T r a -> (a -> a) -> r -> r
^: (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))
                          (ValidRemoteEndPointState -> RemoteState)
-> ValidRemoteEndPointState -> RemoteState
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState
vst
                  (RemoteState, RemoteState) -> IO (RemoteState, RemoteState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState
st', RemoteState
st')
                RequestedBy
RequestedByThem ->
                  (RemoteState, RemoteState) -> IO (RemoteState, RemoteState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState
st, RemoteState
st)
            RemoteState
_ ->
              (RemoteState, RemoteState) -> IO (RemoteState, RemoteState)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState
st, RemoteState
st)
          -- The snapshot may no longer be up to date at this point, but if we
          -- increased the refcount then it can only either be Valid or Failed
          -- (after an explicit call to 'closeEndPoint' or 'closeTransport')
          case RemoteState
snapshot of
            RemoteEndPointInvalid TransportError ConnectErrorCode
err ->
              TransportError ConnectErrorCode -> IO (RemoteEndPoint, Bool)
forall e a. Exception e => e -> IO a
throwIO TransportError ConnectErrorCode
err
            RemoteEndPointInit MVar ()
resolved MVar ()
crossed RequestedBy
initOrigin ->
              case (RequestedBy
findOrigin, RequestedBy
initOrigin) of
                (RequestedBy
RequestedByUs, RequestedBy
RequestedByUs) ->
                  Maybe (IO ()) -> MVar () -> IO ()
forall {a} {a}. Maybe (IO a) -> MVar a -> IO a
readMVarTimeout Maybe (IO ())
mtimer MVar ()
resolved IO () -> IO (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (RemoteEndPoint, Bool)
go
                (RequestedBy
RequestedByUs, RequestedBy
RequestedByThem) ->
                  Maybe (IO ()) -> MVar () -> IO ()
forall {a} {a}. Maybe (IO a) -> MVar a -> IO a
readMVarTimeout Maybe (IO ())
mtimer MVar ()
resolved IO () -> IO (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (RemoteEndPoint, Bool)
go
                (RequestedBy
RequestedByThem, RequestedBy
RequestedByUs) ->
                  if EndPointAddress
ourAddress EndPointAddress -> EndPointAddress -> Bool
forall a. Ord a => a -> a -> Bool
> EndPointAddress
theirAddress
                    then do
                      -- Wait for the Crossed message and recheck the state
                      -- of the remote endpoint after this (it may well be
                      -- invalid already in case of a timeout).
                      MVar () -> IO (Maybe ())
forall a. MVar a -> IO (Maybe a)
tryReadMVar MVar ()
crossed IO (Maybe ())
-> (Maybe () -> IO (RemoteEndPoint, Bool))
-> IO (RemoteEndPoint, Bool)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
                        Maybe ()
Nothing -> Maybe (IO ()) -> MVar () -> IO ()
forall {a} {a}. Maybe (IO a) -> MVar a -> IO a
readMVarTimeout Maybe (IO ())
mtimer MVar ()
crossed IO () -> IO (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (RemoteEndPoint, Bool)
go
                        Maybe ()
_       -> (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteEndPoint
theirEndPoint, Bool
True)
                    else
                      (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteEndPoint
theirEndPoint, Bool
False)
                (RequestedBy
RequestedByThem, RequestedBy
RequestedByThem) ->
                  IOError -> IO (RemoteEndPoint, Bool)
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO (RemoteEndPoint, Bool))
-> IOError -> IO (RemoteEndPoint, Bool)
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Already connected"
            RemoteEndPointValid ValidRemoteEndPointState
_ ->
              -- We assume that the request crossed if we find the endpoint in
              -- Valid state. It is possible that this is really an invalid
              -- request, but only in the case of a broken client (we don't
              -- maintain enough history to be able to tell the difference).
              (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteEndPoint
theirEndPoint, Bool
False)
            RemoteEndPointClosing MVar ()
resolved ValidRemoteEndPointState
_ ->
              Maybe (IO ()) -> MVar () -> IO ()
forall {a} {a}. Maybe (IO a) -> MVar a -> IO a
readMVarTimeout Maybe (IO ())
mtimer MVar ()
resolved IO () -> IO (RemoteEndPoint, Bool) -> IO (RemoteEndPoint, Bool)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO (RemoteEndPoint, Bool)
go
            RemoteState
RemoteEndPointClosed ->
              IO (RemoteEndPoint, Bool)
go
            RemoteEndPointFailed IOError
err ->
              IOError -> IO (RemoteEndPoint, Bool)
forall e a. Exception e => e -> IO a
throwIO IOError
err

    ourState :: MVar LocalEndPointState
ourState   = LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint
    ourAddress :: EndPointAddress
ourAddress = LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint

    -- | Like 'readMVar' but it throws an exception if the timer expires.
    readMVarTimeout :: Maybe (IO a) -> MVar a -> IO a
readMVarTimeout Maybe (IO a)
Nothing MVar a
mv = MVar a -> IO a
forall a. MVar a -> IO a
readMVar MVar a
mv
    readMVarTimeout (Just IO a
timer) MVar a
mv = do
      let connectTimedout :: TransportError ConnectErrorCode
connectTimedout = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectTimeout HostName
"Timed out"
      ThreadId
tid <- IO ThreadId
myThreadId
      IO ThreadId -> (ThreadId -> IO ()) -> (ThreadId -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO a
timer IO a -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ThreadId -> TransportError ConnectErrorCode -> IO ()
forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
tid TransportError ConnectErrorCode
connectTimedout) ThreadId -> IO ()
killThread ((ThreadId -> IO a) -> IO a) -> (ThreadId -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$
        IO a -> ThreadId -> IO a
forall a b. a -> b -> a
const (IO a -> ThreadId -> IO a) -> IO a -> ThreadId -> IO a
forall a b. (a -> b) -> a -> b
$ MVar a -> IO a
forall a. MVar a -> IO a
readMVar MVar a
mv

-- | Send a payload over a heavyweight connection (thread safe)
--
-- The socket cannot be used for sending after the non-atomic 'sendMany'
-- is interrupted - otherwise, the other side may get the msg corrupted.
--
-- There are two types of possible exceptions here:
-- 1) Outer asynchronous exceptions (like 'ProcessLinkException').
-- 2) Synchronous exceptions (inner or outer).
-- On a synchronous exception the remote endpoint is failed (see 'runScheduledAction',
-- for example) and its socket is not supposed to be used again.
--
-- With 'async' the code is run in a new thread which is not-targeted (and
-- thus, not interrupted) by the 1st type of exceptions. With 'remoteSendLock'
-- we protect the socket usage by the concurrent threads, as well as prevent
-- that usage after SomeException.
sendOn :: ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn :: ValidRemoteEndPointState -> [ByteString] -> IO ()
sendOn ValidRemoteEndPointState
vst [ByteString]
bs = (Async () -> IO ()
forall a. Async a -> IO a
wait (Async () -> IO ()) -> IO (Async ()) -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (IO (Async ()) -> IO ()) -> IO (Async ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO (Async ())
forall a. IO a -> IO (Async a)
async (IO () -> IO (Async ())) -> IO () -> IO (Async ())
forall a b. (a -> b) -> a -> b
$
  ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
    let lock :: MVar (Maybe SomeException)
lock = ValidRemoteEndPointState -> MVar (Maybe SomeException)
remoteSendLock ValidRemoteEndPointState
vst
    Maybe SomeException
maybeException <- MVar (Maybe SomeException) -> IO (Maybe SomeException)
forall a. MVar a -> IO a
takeMVar MVar (Maybe SomeException)
lock
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Maybe SomeException -> Bool
forall a. Maybe a -> Bool
isNothing Maybe SomeException
maybeException) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
      IO () -> IO ()
forall a. IO a -> IO a
restore (Socket -> [ByteString] -> IO ()
sendMany (ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst) [ByteString]
bs) IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \SomeException
ex -> do
        MVar (Maybe SomeException) -> Maybe SomeException -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe SomeException)
lock (SomeException -> Maybe SomeException
forall a. a -> Maybe a
Just SomeException
ex)
        SomeException -> IO ()
forall e a. Exception e => e -> IO a
throwIO SomeException
ex
    MVar (Maybe SomeException) -> Maybe SomeException -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Maybe SomeException)
lock Maybe SomeException
maybeException
    Maybe SomeException -> (SomeException -> IO Any) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe SomeException
maybeException ((SomeException -> IO Any) -> IO ())
-> (SomeException -> IO Any) -> IO ()
forall a b. (a -> b) -> a -> b
$ \SomeException
e ->
      IOError -> IO Any
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO Any) -> IOError -> IO Any
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError (HostName -> IOError) -> HostName -> IOError
forall a b. (a -> b) -> a -> b
$ HostName
"sendOn failed earlier with: " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ SomeException -> HostName
forall a. Show a => a -> HostName
show SomeException
e

--------------------------------------------------------------------------------
-- Scheduling actions                                                         --
--------------------------------------------------------------------------------

-- | See 'schedule'/'runScheduledAction'
type Action a = MVar (Either SomeException a)

-- | Schedule an action to be executed (see also 'runScheduledAction')
schedule :: RemoteEndPoint -> IO a -> IO (Action a)
schedule :: forall a. RemoteEndPoint -> IO a -> IO (Action a)
schedule RemoteEndPoint
theirEndPoint IO a
act = do
  Action a
mvar <- IO (Action a)
forall a. IO (MVar a)
newEmptyMVar
  Chan (IO ()) -> IO () -> IO ()
forall a. Chan a -> a -> IO ()
writeChan (RemoteEndPoint -> Chan (IO ())
remoteScheduled RemoteEndPoint
theirEndPoint) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (IO a
act IO a -> (a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Action a -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar Action a
mvar (Either SomeException a -> IO ())
-> (a -> Either SomeException a) -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Either SomeException a
forall a b. b -> Either a b
Right) (Action a -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar Action a
mvar (Either SomeException a -> IO ())
-> (SomeException -> Either SomeException a)
-> SomeException
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> Either SomeException a
forall a b. a -> Either a b
Left)
  Action a -> IO (Action a)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Action a
mvar

-- | Run a scheduled action. Every call to 'schedule' should be paired with a
-- call to 'runScheduledAction' so that every scheduled action is run. Note
-- however that the there is no guarantee that in
--
-- > do act <- schedule p
-- >    runScheduledAction
--
-- 'runScheduledAction' will run @p@ (it might run some other scheduled action).
-- However, it will then wait until @p@ is executed (by this call to
-- 'runScheduledAction' or by another).
runScheduledAction :: EndPointPair -> Action a -> IO a
runScheduledAction :: forall a. EndPointPair -> Action a -> IO a
runScheduledAction (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) Action a
mvar = do
    IO (IO ()) -> IO ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (IO (IO ()) -> IO ()) -> IO (IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ Chan (IO ()) -> IO (IO ())
forall a. Chan a -> IO a
readChan (RemoteEndPoint -> Chan (IO ())
remoteScheduled RemoteEndPoint
theirEndPoint)
    Either SomeException a
ma <- Action a -> IO (Either SomeException a)
forall a. MVar a -> IO a
readMVar Action a
mvar
    case Either SomeException a
ma of
      Right a
a -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
      Left SomeException
e -> do
        Maybe IOError -> (IOError -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (SomeException -> Maybe IOError
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e) ((IOError -> IO ()) -> IO ()) -> (IOError -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \IOError
ioe ->
          MVar RemoteState -> (RemoteState -> IO RemoteState) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO RemoteState) -> IO ())
-> (RemoteState -> IO RemoteState) -> IO ()
forall a b. (a -> b) -> a -> b
$ \RemoteState
st ->
            case RemoteState
st of
              RemoteEndPointValid ValidRemoteEndPointState
vst -> IOError -> ValidRemoteEndPointState -> IO RemoteState
handleIOException IOError
ioe ValidRemoteEndPointState
vst
              RemoteState
_ -> RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> RemoteState
RemoteEndPointFailed IOError
ioe)
        SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO SomeException
e
  where
    handleIOException :: IOException
                      -> ValidRemoteEndPointState
                      -> IO RemoteState
    handleIOException :: IOError -> ValidRemoteEndPointState -> IO RemoteState
handleIOException IOError
ex ValidRemoteEndPointState
vst = do
      -- Release probing resources if probing.
      Maybe (IO ()) -> (IO () -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (ValidRemoteEndPointState -> Maybe (IO ())
remoteProbing ValidRemoteEndPointState
vst) IO () -> IO ()
forall a. a -> a
id
      -- Must shut down the socket here, so that the other end will realize
      -- we lost the connection
      Socket -> IO ()
tryShutdownSocketBoth (ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst)
      -- Eventually, handleIncomingMessages will fail while trying to
      -- receive, and ultimately enqueue the 'EventConnectionLost'.
      RemoteState -> IO RemoteState
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IOError -> RemoteState
RemoteEndPointFailed IOError
ex)

-- | Use 'schedule' action 'runScheduled' action in a safe way, it's assumed that
-- callback is used only once, otherwise guarantees of runScheduledAction are not
-- respected.
withScheduledAction :: LocalEndPoint -> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
withScheduledAction :: forall a.
LocalEndPoint
-> ((RemoteEndPoint -> IO a -> IO ()) -> IO ()) -> IO ()
withScheduledAction LocalEndPoint
ourEndPoint (RemoteEndPoint -> IO a -> IO ()) -> IO ()
f =
  IO (IORef (Maybe (RemoteEndPoint, Action a)))
-> (IORef (Maybe (RemoteEndPoint, Action a)) -> IO (Maybe a))
-> (IORef (Maybe (RemoteEndPoint, Action a)) -> IO ())
-> IO ()
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Maybe (RemoteEndPoint, Action a)
-> IO (IORef (Maybe (RemoteEndPoint, Action a)))
forall a. a -> IO (IORef a)
newIORef Maybe (RemoteEndPoint, Action a)
forall a. Maybe a
Nothing)
          (((RemoteEndPoint, Action a) -> IO a)
-> Maybe (RemoteEndPoint, Action a) -> IO (Maybe a)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Maybe a -> f (Maybe b)
traverse (\(RemoteEndPoint
tp, Action a
a) -> EndPointPair -> Action a -> IO a
forall a. EndPointPair -> Action a -> IO a
runScheduledAction (LocalEndPoint
ourEndPoint, RemoteEndPoint
tp) Action a
a) (Maybe (RemoteEndPoint, Action a) -> IO (Maybe a))
-> (IORef (Maybe (RemoteEndPoint, Action a))
    -> IO (Maybe (RemoteEndPoint, Action a)))
-> IORef (Maybe (RemoteEndPoint, Action a))
-> IO (Maybe a)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< IORef (Maybe (RemoteEndPoint, Action a))
-> IO (Maybe (RemoteEndPoint, Action a))
forall a. IORef a -> IO a
readIORef)
          (\IORef (Maybe (RemoteEndPoint, Action a))
ref -> (RemoteEndPoint -> IO a -> IO ()) -> IO ()
f (\RemoteEndPoint
rp IO a
g -> IO () -> IO ()
forall a. IO a -> IO a
mask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ RemoteEndPoint -> IO a -> IO (Action a)
forall a. RemoteEndPoint -> IO a -> IO (Action a)
schedule RemoteEndPoint
rp IO a
g IO (Action a) -> (Action a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Action a
x -> IORef (Maybe (RemoteEndPoint, Action a))
-> Maybe (RemoteEndPoint, Action a) -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Maybe (RemoteEndPoint, Action a))
ref ((RemoteEndPoint, Action a) -> Maybe (RemoteEndPoint, Action a)
forall a. a -> Maybe a
Just (RemoteEndPoint
rp,Action a
x)) ))

--------------------------------------------------------------------------------
-- "Stateless" (MVar free) functions                                          --
--------------------------------------------------------------------------------

-- | Establish a connection to a remote endpoint
--
-- Maybe throw a TransportError
--
-- If a socket is created and returned (Right is given) then the caller is
-- responsible for eventually closing the socket and filling the MVar (which
-- is empty). The MVar must be filled immediately after, and never before,
-- the socket is closed.
socketToEndPoint :: Maybe EndPointAddress -- ^ Our address
                 -> EndPointAddress       -- ^ Their address
                 -> Bool                  -- ^ Use SO_REUSEADDR?
                 -> Bool                  -- ^ Use TCP_NODELAY
                 -> Bool                  -- ^ Use TCP_KEEPALIVE
                 -> Maybe Int             -- ^ Maybe TCP_USER_TIMEOUT
                 -> Maybe Int             -- ^ Timeout for connect
                 -> IO (Either (TransportError ConnectErrorCode)
                               (MVar (), N.Socket, ConnectionRequestResponse))
socketToEndPoint :: Maybe EndPointAddress
-> EndPointAddress
-> Bool
-> Bool
-> Bool
-> Maybe Int
-> Maybe Int
-> IO
     (Either
        (TransportError ConnectErrorCode)
        (MVar (), Socket, ConnectionRequestResponse))
socketToEndPoint Maybe EndPointAddress
mOurAddress EndPointAddress
theirAddress Bool
reuseAddr Bool
noDelay Bool
keepAlive
                 Maybe Int
mUserTimeout Maybe Int
timeout =
  IO (MVar (), Socket, ConnectionRequestResponse)
-> IO
     (Either
        (TransportError ConnectErrorCode)
        (MVar (), Socket, ConnectionRequestResponse))
forall e a. Exception e => IO a -> IO (Either e a)
try (IO (MVar (), Socket, ConnectionRequestResponse)
 -> IO
      (Either
         (TransportError ConnectErrorCode)
         (MVar (), Socket, ConnectionRequestResponse)))
-> IO (MVar (), Socket, ConnectionRequestResponse)
-> IO
     (Either
        (TransportError ConnectErrorCode)
        (MVar (), Socket, ConnectionRequestResponse))
forall a b. (a -> b) -> a -> b
$ do
    (HostName
host, HostName
port, HeavyweightConnectionId
theirEndPointId) <- case EndPointAddress
-> Maybe (HostName, HostName, HeavyweightConnectionId)
decodeEndPointAddress EndPointAddress
theirAddress of
      Maybe (HostName, HostName, HeavyweightConnectionId)
Nothing  -> TransportError ConnectErrorCode
-> IO (HostName, HostName, HeavyweightConnectionId)
forall e a. Exception e => e -> IO a
throwIO (IOError -> TransportError ConnectErrorCode
failed (IOError -> TransportError ConnectErrorCode)
-> (HostName -> IOError)
-> HostName
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> IOError
userError (HostName -> TransportError ConnectErrorCode)
-> HostName -> TransportError ConnectErrorCode
forall a b. (a -> b) -> a -> b
$ HostName
"Could not parse")
      Just (HostName, HostName, HeavyweightConnectionId)
dec -> (HostName, HostName, HeavyweightConnectionId)
-> IO (HostName, HostName, HeavyweightConnectionId)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (HostName, HostName, HeavyweightConnectionId)
dec
    AddrInfo
addr:[AddrInfo]
_ <- (IOError -> TransportError ConnectErrorCode)
-> IO [AddrInfo] -> IO [AddrInfo]
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
invalidAddress (IO [AddrInfo] -> IO [AddrInfo]) -> IO [AddrInfo] -> IO [AddrInfo]
forall a b. (a -> b) -> a -> b
$
      Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo Maybe AddrInfo
forall a. Maybe a
Nothing (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
port)
    IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (MVar (), Socket, ConnectionRequestResponse))
-> IO (MVar (), Socket, ConnectionRequestResponse)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (AddrInfo -> IO Socket
createSocket AddrInfo
addr) Socket -> IO ()
tryCloseSocket ((Socket -> IO (MVar (), Socket, ConnectionRequestResponse))
 -> IO (MVar (), Socket, ConnectionRequestResponse))
-> (Socket -> IO (MVar (), Socket, ConnectionRequestResponse))
-> IO (MVar (), Socket, ConnectionRequestResponse)
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
reuseAddr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        (IOError -> TransportError ConnectErrorCode) -> IO () -> IO ()
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
failed (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.ReuseAddr Int
1
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
noDelay (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        (IOError -> TransportError ConnectErrorCode) -> IO () -> IO ()
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
failed (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.NoDelay Int
1
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
keepAlive (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
        (IOError -> TransportError ConnectErrorCode) -> IO () -> IO ()
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
failed (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.KeepAlive Int
1
      Maybe Int -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe Int
mUserTimeout ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
        (IOError -> TransportError ConnectErrorCode) -> IO () -> IO ()
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
failed (IO () -> IO ()) -> (Int -> IO ()) -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.UserTimeout
      HeavyweightConnectionId
response <- Maybe Int
-> TransportError ConnectErrorCode
-> IO HeavyweightConnectionId
-> IO HeavyweightConnectionId
forall e a. Exception e => Maybe Int -> e -> IO a -> IO a
timeoutMaybe Maybe Int
timeout TransportError ConnectErrorCode
timeoutError (IO HeavyweightConnectionId -> IO HeavyweightConnectionId)
-> IO HeavyweightConnectionId -> IO HeavyweightConnectionId
forall a b. (a -> b) -> a -> b
$ do
        (IOError -> TransportError ConnectErrorCode) -> IO () -> IO ()
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
invalidAddress (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
          Socket -> SockAddr -> IO ()
N.connect Socket
sock (AddrInfo -> SockAddr
N.addrAddress AddrInfo
addr)
        (IOError -> TransportError ConnectErrorCode)
-> IO HeavyweightConnectionId -> IO HeavyweightConnectionId
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
failed (IO HeavyweightConnectionId -> IO HeavyweightConnectionId)
-> IO HeavyweightConnectionId -> IO HeavyweightConnectionId
forall a b. (a -> b) -> a -> b
$ do
          case Maybe EndPointAddress
mOurAddress of
            Just (EndPointAddress ByteString
ourAddress) ->
              Socket -> [ByteString] -> IO ()
sendMany Socket
sock ([ByteString] -> IO ()) -> [ByteString] -> IO ()
forall a b. (a -> b) -> a -> b
$
                  HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
currentProtocolVersion
                ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString]
prependLength (HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
theirEndPointId ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString]
prependLength [ByteString
ourAddress])
            Maybe EndPointAddress
Nothing ->
              Socket -> [ByteString] -> IO ()
sendMany Socket
sock ([ByteString] -> IO ()) -> [ByteString] -> IO ()
forall a b. (a -> b) -> a -> b
$
                  HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
currentProtocolVersion
                ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString] -> [ByteString]
prependLength ([HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
theirEndPointId, HeavyweightConnectionId -> ByteString
encodeWord32 HeavyweightConnectionId
0])
          Socket -> IO HeavyweightConnectionId
recvWord32 Socket
sock
      case HeavyweightConnectionId -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse HeavyweightConnectionId
response of
        Maybe ConnectionRequestResponse
Nothing -> TransportError ConnectErrorCode
-> IO (MVar (), Socket, ConnectionRequestResponse)
forall e a. Exception e => e -> IO a
throwIO (IOError -> TransportError ConnectErrorCode
failed (IOError -> TransportError ConnectErrorCode)
-> (HostName -> IOError)
-> HostName
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> IOError
userError (HostName -> TransportError ConnectErrorCode)
-> HostName -> TransportError ConnectErrorCode
forall a b. (a -> b) -> a -> b
$ HostName
"Unexpected response")
        Just ConnectionRequestResponse
r  -> do
          MVar ()
socketClosedVar <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
          (MVar (), Socket, ConnectionRequestResponse)
-> IO (MVar (), Socket, ConnectionRequestResponse)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (MVar ()
socketClosedVar, Socket
sock, ConnectionRequestResponse
r)
  where
    createSocket :: N.AddrInfo -> IO N.Socket
    createSocket :: AddrInfo -> IO Socket
createSocket AddrInfo
addr = (IOError -> TransportError ConnectErrorCode)
-> IO Socket -> IO Socket
forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> TransportError ConnectErrorCode
insufficientResources (IO Socket -> IO Socket) -> IO Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$
      Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket (AddrInfo -> Family
N.addrFamily AddrInfo
addr) SocketType
N.Stream ProtocolNumber
N.defaultProtocol

    invalidAddress :: IOError -> TransportError ConnectErrorCode
invalidAddress        = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectNotFound (HostName -> TransportError ConnectErrorCode)
-> (IOError -> HostName)
-> IOError
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> HostName
forall a. Show a => a -> HostName
show
    insufficientResources :: IOError -> TransportError ConnectErrorCode
insufficientResources = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectInsufficientResources (HostName -> TransportError ConnectErrorCode)
-> (IOError -> HostName)
-> IOError
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> HostName
forall a. Show a => a -> HostName
show
    failed :: IOError -> TransportError ConnectErrorCode
failed                = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectFailed (HostName -> TransportError ConnectErrorCode)
-> (IOError -> HostName)
-> IOError
-> TransportError ConnectErrorCode
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> HostName
forall a. Show a => a -> HostName
show
    timeoutError :: TransportError ConnectErrorCode
timeoutError          = ConnectErrorCode -> HostName -> TransportError ConnectErrorCode
forall error. error -> HostName -> TransportError error
TransportError ConnectErrorCode
ConnectTimeout HostName
"Timed out"

-- | Construct a ConnectionId
createConnectionId :: HeavyweightConnectionId
                   -> LightweightConnectionId
                   -> ConnectionId
createConnectionId :: HeavyweightConnectionId -> HeavyweightConnectionId -> ConnectionId
createConnectionId HeavyweightConnectionId
hcid HeavyweightConnectionId
lcid =
  (HeavyweightConnectionId -> ConnectionId
forall a b. (Integral a, Num b) => a -> b
fromIntegral HeavyweightConnectionId
hcid ConnectionId -> Int -> ConnectionId
forall a. Bits a => a -> Int -> a
`shiftL` Int
32) ConnectionId -> ConnectionId -> ConnectionId
forall a. Bits a => a -> a -> a
.|. HeavyweightConnectionId -> ConnectionId
forall a b. (Integral a, Num b) => a -> b
fromIntegral HeavyweightConnectionId
lcid

--------------------------------------------------------------------------------
-- Functions from TransportInternals                                          --
--------------------------------------------------------------------------------

-- Find a socket between two endpoints
--
-- Throws an IO exception if the socket could not be found.
internalSocketBetween :: TCPTransport    -- ^ Transport
                      -> EndPointAddress -- ^ Local endpoint
                      -> EndPointAddress -- ^ Remote endpoint
                      -> IO N.Socket
internalSocketBetween :: TCPTransport -> EndPointAddress -> EndPointAddress -> IO Socket
internalSocketBetween TCPTransport
transport EndPointAddress
ourAddress EndPointAddress
theirAddress = do
  HeavyweightConnectionId
ourEndPointId <- case EndPointAddress
-> Maybe (HostName, HostName, HeavyweightConnectionId)
decodeEndPointAddress EndPointAddress
ourAddress of
    Just (HostName
_, HostName
_, HeavyweightConnectionId
eid) -> HeavyweightConnectionId -> IO HeavyweightConnectionId
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return HeavyweightConnectionId
eid
    Maybe (HostName, HostName, HeavyweightConnectionId)
_ -> IOError -> IO HeavyweightConnectionId
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO HeavyweightConnectionId)
-> IOError -> IO HeavyweightConnectionId
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Malformed local EndPointAddress"
  LocalEndPoint
ourEndPoint <- MVar TransportState
-> (TransportState -> IO LocalEndPoint) -> IO LocalEndPoint
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (TCPTransport -> MVar TransportState
transportState TCPTransport
transport) ((TransportState -> IO LocalEndPoint) -> IO LocalEndPoint)
-> (TransportState -> IO LocalEndPoint) -> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ \TransportState
st -> case TransportState
st of
      TransportState
TransportClosed ->
        IOError -> IO LocalEndPoint
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO LocalEndPoint) -> IOError -> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Transport closed"
      TransportValid ValidTransportState
vst ->
        case ValidTransportState
vst ValidTransportState
-> T ValidTransportState (Maybe LocalEndPoint)
-> Maybe LocalEndPoint
forall r a. r -> T r a -> a
^. HeavyweightConnectionId
-> T ValidTransportState (Maybe LocalEndPoint)
localEndPointAt HeavyweightConnectionId
ourEndPointId of
          Maybe LocalEndPoint
Nothing -> IOError -> IO LocalEndPoint
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO LocalEndPoint) -> IOError -> IO LocalEndPoint
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Local endpoint not found"
          Just LocalEndPoint
ep -> LocalEndPoint -> IO LocalEndPoint
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return LocalEndPoint
ep
  RemoteEndPoint
theirEndPoint <- MVar LocalEndPointState
-> (LocalEndPointState -> IO RemoteEndPoint) -> IO RemoteEndPoint
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (LocalEndPoint -> MVar LocalEndPointState
localState LocalEndPoint
ourEndPoint) ((LocalEndPointState -> IO RemoteEndPoint) -> IO RemoteEndPoint)
-> (LocalEndPointState -> IO RemoteEndPoint) -> IO RemoteEndPoint
forall a b. (a -> b) -> a -> b
$ \LocalEndPointState
st -> case LocalEndPointState
st of
      LocalEndPointState
LocalEndPointClosed ->
        IOError -> IO RemoteEndPoint
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO RemoteEndPoint) -> IOError -> IO RemoteEndPoint
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Local endpoint closed"
      LocalEndPointValid ValidLocalEndPointState
vst ->
        case ValidLocalEndPointState
vst ValidLocalEndPointState
-> T ValidLocalEndPointState (Maybe RemoteEndPoint)
-> Maybe RemoteEndPoint
forall r a. r -> T r a -> a
^. EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo EndPointAddress
theirAddress of
          Maybe RemoteEndPoint
Nothing -> IOError -> IO RemoteEndPoint
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO RemoteEndPoint) -> IOError -> IO RemoteEndPoint
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Remote endpoint not found"
          Just RemoteEndPoint
ep -> RemoteEndPoint -> IO RemoteEndPoint
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return RemoteEndPoint
ep
  MVar RemoteState -> (RemoteState -> IO Socket) -> IO Socket
forall a b. MVar a -> (a -> IO b) -> IO b
withMVar (RemoteEndPoint -> MVar RemoteState
remoteState RemoteEndPoint
theirEndPoint) ((RemoteState -> IO Socket) -> IO Socket)
-> (RemoteState -> IO Socket) -> IO Socket
forall a b. (a -> b) -> a -> b
$ \RemoteState
st -> case RemoteState
st of
    RemoteEndPointInit MVar ()
_ MVar ()
_ RequestedBy
_ ->
      IOError -> IO Socket
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO Socket) -> IOError -> IO Socket
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Remote endpoint not yet initialized"
    RemoteEndPointValid ValidRemoteEndPointState
vst ->
      Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> IO Socket) -> Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst
    RemoteEndPointClosing MVar ()
_ ValidRemoteEndPointState
vst ->
      Socket -> IO Socket
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket -> IO Socket) -> Socket -> IO Socket
forall a b. (a -> b) -> a -> b
$ ValidRemoteEndPointState -> Socket
remoteSocket ValidRemoteEndPointState
vst
    RemoteState
RemoteEndPointClosed ->
      IOError -> IO Socket
forall e a. Exception e => e -> IO a
throwIO (IOError -> IO Socket) -> IOError -> IO Socket
forall a b. (a -> b) -> a -> b
$ HostName -> IOError
userError HostName
"Remote endpoint closed"
    RemoteEndPointInvalid TransportError ConnectErrorCode
err ->
      TransportError ConnectErrorCode -> IO Socket
forall e a. Exception e => e -> IO a
throwIO TransportError ConnectErrorCode
err
    RemoteEndPointFailed IOError
err ->
      IOError -> IO Socket
forall e a. Exception e => e -> IO a
throwIO IOError
err
  where

--------------------------------------------------------------------------------
-- Constants                                                                  --
--------------------------------------------------------------------------------

-- | We reserve a bunch of connection IDs for control messages
firstNonReservedLightweightConnectionId :: LightweightConnectionId
firstNonReservedLightweightConnectionId :: HeavyweightConnectionId
firstNonReservedLightweightConnectionId = HeavyweightConnectionId
1024

-- | Self-connection
heavyweightSelfConnectionId :: HeavyweightConnectionId
heavyweightSelfConnectionId :: HeavyweightConnectionId
heavyweightSelfConnectionId = HeavyweightConnectionId
0

-- | We reserve some connection IDs for special heavyweight connections
firstNonReservedHeavyweightConnectionId :: HeavyweightConnectionId
firstNonReservedHeavyweightConnectionId :: HeavyweightConnectionId
firstNonReservedHeavyweightConnectionId = HeavyweightConnectionId
1

--------------------------------------------------------------------------------
-- Accessor definitions                                                       --
--------------------------------------------------------------------------------

localEndPoints :: Accessor ValidTransportState (Map EndPointId LocalEndPoint)
localEndPoints :: T ValidTransportState (Map HeavyweightConnectionId LocalEndPoint)
localEndPoints = (ValidTransportState -> Map HeavyweightConnectionId LocalEndPoint)
-> (Map HeavyweightConnectionId LocalEndPoint
    -> ValidTransportState -> ValidTransportState)
-> T ValidTransportState
     (Map HeavyweightConnectionId LocalEndPoint)
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidTransportState -> Map HeavyweightConnectionId LocalEndPoint
_localEndPoints (\Map HeavyweightConnectionId LocalEndPoint
es ValidTransportState
st -> ValidTransportState
st { _localEndPoints = es })

nextEndPointId :: Accessor ValidTransportState EndPointId
nextEndPointId :: T ValidTransportState HeavyweightConnectionId
nextEndPointId = (ValidTransportState -> HeavyweightConnectionId)
-> (HeavyweightConnectionId
    -> ValidTransportState -> ValidTransportState)
-> T ValidTransportState HeavyweightConnectionId
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidTransportState -> HeavyweightConnectionId
_nextEndPointId (\HeavyweightConnectionId
eid ValidTransportState
st -> ValidTransportState
st { _nextEndPointId = eid })

localNextConnOutId :: Accessor ValidLocalEndPointState LightweightConnectionId
localNextConnOutId :: T ValidLocalEndPointState HeavyweightConnectionId
localNextConnOutId = (ValidLocalEndPointState -> HeavyweightConnectionId)
-> (HeavyweightConnectionId
    -> ValidLocalEndPointState -> ValidLocalEndPointState)
-> T ValidLocalEndPointState HeavyweightConnectionId
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidLocalEndPointState -> HeavyweightConnectionId
_localNextConnOutId (\HeavyweightConnectionId
cix ValidLocalEndPointState
st -> ValidLocalEndPointState
st { _localNextConnOutId = cix })

localConnections :: Accessor ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections :: T ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections = (ValidLocalEndPointState -> Map EndPointAddress RemoteEndPoint)
-> (Map EndPointAddress RemoteEndPoint
    -> ValidLocalEndPointState -> ValidLocalEndPointState)
-> T ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidLocalEndPointState -> Map EndPointAddress RemoteEndPoint
_localConnections (\Map EndPointAddress RemoteEndPoint
es ValidLocalEndPointState
st -> ValidLocalEndPointState
st { _localConnections = es })

nextConnInId :: Accessor ValidLocalEndPointState HeavyweightConnectionId
nextConnInId :: T ValidLocalEndPointState HeavyweightConnectionId
nextConnInId = (ValidLocalEndPointState -> HeavyweightConnectionId)
-> (HeavyweightConnectionId
    -> ValidLocalEndPointState -> ValidLocalEndPointState)
-> T ValidLocalEndPointState HeavyweightConnectionId
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidLocalEndPointState -> HeavyweightConnectionId
_nextConnInId (\HeavyweightConnectionId
rid ValidLocalEndPointState
st -> ValidLocalEndPointState
st { _nextConnInId = rid })

remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing :: Accessor ValidRemoteEndPointState Int
remoteOutgoing = (ValidRemoteEndPointState -> Int)
-> (Int -> ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> Accessor ValidRemoteEndPointState Int
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidRemoteEndPointState -> Int
_remoteOutgoing (\Int
cs ValidRemoteEndPointState
conn -> ValidRemoteEndPointState
conn { _remoteOutgoing = cs })

remoteIncoming :: Accessor ValidRemoteEndPointState (Set LightweightConnectionId)
remoteIncoming :: T ValidRemoteEndPointState (Set HeavyweightConnectionId)
remoteIncoming = (ValidRemoteEndPointState -> Set HeavyweightConnectionId)
-> (Set HeavyweightConnectionId
    -> ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> T ValidRemoteEndPointState (Set HeavyweightConnectionId)
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidRemoteEndPointState -> Set HeavyweightConnectionId
_remoteIncoming (\Set HeavyweightConnectionId
cs ValidRemoteEndPointState
conn -> ValidRemoteEndPointState
conn { _remoteIncoming = cs })

remoteLastIncoming :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteLastIncoming :: Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteLastIncoming = (ValidRemoteEndPointState -> HeavyweightConnectionId)
-> (HeavyweightConnectionId
    -> ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidRemoteEndPointState -> HeavyweightConnectionId
_remoteLastIncoming (\HeavyweightConnectionId
lcid ValidRemoteEndPointState
st -> ValidRemoteEndPointState
st { _remoteLastIncoming = lcid })

remoteNextConnOutId :: Accessor ValidRemoteEndPointState LightweightConnectionId
remoteNextConnOutId :: Accessor ValidRemoteEndPointState HeavyweightConnectionId
remoteNextConnOutId = (ValidRemoteEndPointState -> HeavyweightConnectionId)
-> (HeavyweightConnectionId
    -> ValidRemoteEndPointState -> ValidRemoteEndPointState)
-> Accessor ValidRemoteEndPointState HeavyweightConnectionId
forall r a. (r -> a) -> (a -> r -> r) -> Accessor r a
accessor ValidRemoteEndPointState -> HeavyweightConnectionId
_remoteNextConnOutId (\HeavyweightConnectionId
cix ValidRemoteEndPointState
st -> ValidRemoteEndPointState
st { _remoteNextConnOutId = cix })

localEndPointAt :: EndPointId -> Accessor ValidTransportState (Maybe LocalEndPoint)
localEndPointAt :: HeavyweightConnectionId
-> T ValidTransportState (Maybe LocalEndPoint)
localEndPointAt HeavyweightConnectionId
addr = T ValidTransportState (Map HeavyweightConnectionId LocalEndPoint)
localEndPoints T ValidTransportState (Map HeavyweightConnectionId LocalEndPoint)
-> T (Map HeavyweightConnectionId LocalEndPoint)
     (Maybe LocalEndPoint)
-> T ValidTransportState (Maybe LocalEndPoint)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> HeavyweightConnectionId
-> T (Map HeavyweightConnectionId LocalEndPoint)
     (Maybe LocalEndPoint)
forall key elem. Ord key => key -> T (Map key elem) (Maybe elem)
DAC.mapMaybe HeavyweightConnectionId
addr

localConnectionTo :: EndPointAddress -> Accessor ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo :: EndPointAddress -> T ValidLocalEndPointState (Maybe RemoteEndPoint)
localConnectionTo EndPointAddress
addr = T ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
localConnections T ValidLocalEndPointState (Map EndPointAddress RemoteEndPoint)
-> T (Map EndPointAddress RemoteEndPoint) (Maybe RemoteEndPoint)
-> T ValidLocalEndPointState (Maybe RemoteEndPoint)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> EndPointAddress
-> T (Map EndPointAddress RemoteEndPoint) (Maybe RemoteEndPoint)
forall key elem. Ord key => key -> T (Map key elem) (Maybe elem)
DAC.mapMaybe EndPointAddress
addr

-------------------------------------------------------------------------------
-- Debugging                                                                 --
-------------------------------------------------------------------------------

relyViolation :: EndPointPair -> String -> IO a
relyViolation :: forall a. EndPointPair -> HostName -> IO a
relyViolation (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
str = do
  EndPointPair -> HostName -> IO ()
elog (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) (HostName
str HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
" RELY violation")
  HostName -> IO a
forall a. HostName -> IO a
forall (m :: * -> *) a. MonadFail m => HostName -> m a
fail (HostName
str HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
" RELY violation")

elog :: EndPointPair -> String -> IO ()
elog :: EndPointPair -> HostName -> IO ()
elog (LocalEndPoint
ourEndPoint, RemoteEndPoint
theirEndPoint) HostName
msg = do
  ThreadId
tid <- IO ThreadId
myThreadId
  HostName -> IO ()
putStrLn  (HostName -> IO ()) -> HostName -> IO ()
forall a b. (a -> b) -> a -> b
$  EndPointAddress -> HostName
forall a. Show a => a -> HostName
show (LocalEndPoint -> EndPointAddress
localAddress LocalEndPoint
ourEndPoint)
    HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
"/"  HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ EndPointAddress -> HostName
forall a. Show a => a -> HostName
show (RemoteEndPoint -> EndPointAddress
remoteAddress RemoteEndPoint
theirEndPoint)
    HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
"("  HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HeavyweightConnectionId -> HostName
forall a. Show a => a -> HostName
show (RemoteEndPoint -> HeavyweightConnectionId
remoteId RemoteEndPoint
theirEndPoint) HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
")"
    HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
"/"  HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ ThreadId -> HostName
forall a. Show a => a -> HostName
show ThreadId
tid
    HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
": " HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
msg