-- | Utility functions for TCP sockets
module Network.Transport.TCP.Internal
  ( ControlHeader(..)
  , encodeControlHeader
  , decodeControlHeader
  , ConnectionRequestResponse(..)
  , encodeConnectionRequestResponse
  , decodeConnectionRequestResponse
  , forkServer
  , recvWithLength
  , recvExact
  , recvWord32
  , encodeWord32
  , tryCloseSocket
  , tryShutdownSocketBoth
  , resolveSockAddr
  , EndPointId
  , encodeEndPointAddress
  , decodeEndPointAddress
  , randomEndPointAddress
  , ProtocolVersion
  , currentProtocolVersion
  ) where

#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif

import Network.Transport.Internal
  ( decodeWord32
  , encodeWord32
  , void
  , tryIO
  , forkIOWithUnmask
  )

import Network.Transport ( EndPointAddress(..) )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket as N
#else
import qualified Network.Socket as N
#endif
  ( HostName
  , NameInfoFlag(NI_NUMERICHOST)
  , ServiceName
  , Socket
  , SocketType(Stream)
  , SocketOption(ReuseAddr)
  , getAddrInfo
  , defaultHints
  , socket
  , bind
  , listen
  , addrFamily
  , addrAddress
  , defaultProtocol
  , setSocketOption
  , accept
  , close
  , socketPort
  , shutdown
  , ShutdownCmd(ShutdownBoth)
  , SockAddr(..)
  , getNameInfo
  )

#ifdef USE_MOCK_NETWORK
import qualified Network.Transport.TCP.Mock.Socket.ByteString as NBS (recv)
#else
import qualified Network.Socket.ByteString as NBS (recv)
#endif

import Data.Word (Word32)

import Control.Monad (forever, when)
import Control.Exception (SomeException, catch, bracketOnError, throwIO, mask_)
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar
  ( MVar
  , newEmptyMVar
  , putMVar
  , readMVar
  )
import Control.Exception
  ( mask
  , finally
  )

import Data.ByteString (ByteString)
import qualified Data.ByteString as BS (length, concat, null)
import Data.ByteString.Lazy.Internal (smallChunkSize)
import qualified Data.ByteString.Char8 as BSC (unpack, pack)
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID

-- | Local identifier for an endpoint within this transport
type EndPointId = Word32

-- | Identifies the version of the network-transport-tcp protocol.
-- It's the first piece of data sent when a new heavyweight connection is
-- established.
type ProtocolVersion = Word32

currentProtocolVersion :: ProtocolVersion
currentProtocolVersion :: EndPointId
currentProtocolVersion = EndPointId
0x00000000

-- | Control headers
data ControlHeader =
    -- | Tell the remote endpoint that we created a new connection
    CreatedNewConnection
    -- | Tell the remote endpoint we will no longer be using a connection
  | CloseConnection
    -- | Request to close the connection (see module description)
  | CloseSocket
    -- | Sent by an endpoint when it is closed.
  | CloseEndPoint
    -- | Message sent to probe a socket
  | ProbeSocket
    -- | Acknowledgement of the ProbeSocket message
  | ProbeSocketAck
  deriving (Int -> ControlHeader -> ShowS
[ControlHeader] -> ShowS
ControlHeader -> HostName
(Int -> ControlHeader -> ShowS)
-> (ControlHeader -> HostName)
-> ([ControlHeader] -> ShowS)
-> Show ControlHeader
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ControlHeader -> ShowS
showsPrec :: Int -> ControlHeader -> ShowS
$cshow :: ControlHeader -> HostName
show :: ControlHeader -> HostName
$cshowList :: [ControlHeader] -> ShowS
showList :: [ControlHeader] -> ShowS
Show)

decodeControlHeader :: Word32 -> Maybe ControlHeader
decodeControlHeader :: EndPointId -> Maybe ControlHeader
decodeControlHeader EndPointId
w32 = case EndPointId
w32 of
  EndPointId
0 -> ControlHeader -> Maybe ControlHeader
forall a. a -> Maybe a
Just ControlHeader
CreatedNewConnection
  EndPointId
1 -> ControlHeader -> Maybe ControlHeader
forall a. a -> Maybe a
Just ControlHeader
CloseConnection
  EndPointId
2 -> ControlHeader -> Maybe ControlHeader
forall a. a -> Maybe a
Just ControlHeader
CloseSocket
  EndPointId
3 -> ControlHeader -> Maybe ControlHeader
forall a. a -> Maybe a
Just ControlHeader
CloseEndPoint
  EndPointId
4 -> ControlHeader -> Maybe ControlHeader
forall a. a -> Maybe a
Just ControlHeader
ProbeSocket
  EndPointId
5 -> ControlHeader -> Maybe ControlHeader
forall a. a -> Maybe a
Just ControlHeader
ProbeSocketAck
  EndPointId
_ -> Maybe ControlHeader
forall a. Maybe a
Nothing

encodeControlHeader :: ControlHeader -> Word32
encodeControlHeader :: ControlHeader -> EndPointId
encodeControlHeader ControlHeader
ch = case ControlHeader
ch of
  ControlHeader
CreatedNewConnection -> EndPointId
0
  ControlHeader
CloseConnection      -> EndPointId
1
  ControlHeader
CloseSocket          -> EndPointId
2
  ControlHeader
CloseEndPoint        -> EndPointId
3
  ControlHeader
ProbeSocket          -> EndPointId
4
  ControlHeader
ProbeSocketAck       -> EndPointId
5

-- | Response sent by /B/ to /A/ when /A/ tries to connect
data ConnectionRequestResponse =
    -- | /B/ does not support the protocol version requested by /A/.
    ConnectionRequestUnsupportedVersion
    -- | /B/ accepts the connection
  | ConnectionRequestAccepted
    -- | /A/ requested an invalid endpoint
  | ConnectionRequestInvalid
    -- | /A/s request crossed with a request from /B/ (see protocols)
  | ConnectionRequestCrossed
    -- | /A/ gave an incorrect host (did not match the host that /B/ observed).
  | ConnectionRequestHostMismatch
  deriving (Int -> ConnectionRequestResponse -> ShowS
[ConnectionRequestResponse] -> ShowS
ConnectionRequestResponse -> HostName
(Int -> ConnectionRequestResponse -> ShowS)
-> (ConnectionRequestResponse -> HostName)
-> ([ConnectionRequestResponse] -> ShowS)
-> Show ConnectionRequestResponse
forall a.
(Int -> a -> ShowS) -> (a -> HostName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ConnectionRequestResponse -> ShowS
showsPrec :: Int -> ConnectionRequestResponse -> ShowS
$cshow :: ConnectionRequestResponse -> HostName
show :: ConnectionRequestResponse -> HostName
$cshowList :: [ConnectionRequestResponse] -> ShowS
showList :: [ConnectionRequestResponse] -> ShowS
Show)

decodeConnectionRequestResponse :: Word32 -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse :: EndPointId -> Maybe ConnectionRequestResponse
decodeConnectionRequestResponse EndPointId
w32 = case EndPointId
w32 of
  EndPointId
0xFFFFFFFF -> ConnectionRequestResponse -> Maybe ConnectionRequestResponse
forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestUnsupportedVersion
  EndPointId
0x00000000 -> ConnectionRequestResponse -> Maybe ConnectionRequestResponse
forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestAccepted
  EndPointId
0x00000001 -> ConnectionRequestResponse -> Maybe ConnectionRequestResponse
forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestInvalid
  EndPointId
0x00000002 -> ConnectionRequestResponse -> Maybe ConnectionRequestResponse
forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestCrossed
  EndPointId
0x00000003 -> ConnectionRequestResponse -> Maybe ConnectionRequestResponse
forall a. a -> Maybe a
Just ConnectionRequestResponse
ConnectionRequestHostMismatch
  EndPointId
_          -> Maybe ConnectionRequestResponse
forall a. Maybe a
Nothing

encodeConnectionRequestResponse :: ConnectionRequestResponse -> Word32
encodeConnectionRequestResponse :: ConnectionRequestResponse -> EndPointId
encodeConnectionRequestResponse ConnectionRequestResponse
crr = case ConnectionRequestResponse
crr of
  ConnectionRequestResponse
ConnectionRequestUnsupportedVersion -> EndPointId
0xFFFFFFFF
  ConnectionRequestResponse
ConnectionRequestAccepted           -> EndPointId
0x00000000
  ConnectionRequestResponse
ConnectionRequestInvalid            -> EndPointId
0x00000001
  ConnectionRequestResponse
ConnectionRequestCrossed            -> EndPointId
0x00000002
  ConnectionRequestResponse
ConnectionRequestHostMismatch       -> EndPointId
0x00000003

-- | Generate an EndPointAddress which does not encode a host/port/endpointid.
-- Such addresses are used for unreachable endpoints, and for ephemeral
-- addresses when such endpoints establish new heavyweight connections.
randomEndPointAddress :: IO EndPointAddress
randomEndPointAddress :: IO EndPointAddress
randomEndPointAddress = do
  UUID
uuid <- IO UUID
UUID.nextRandom
  EndPointAddress -> IO EndPointAddress
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (EndPointAddress -> IO EndPointAddress)
-> EndPointAddress -> IO EndPointAddress
forall a b. (a -> b) -> a -> b
$ ByteString -> EndPointAddress
EndPointAddress (UUID -> ByteString
UUID.toASCIIBytes UUID
uuid)

-- | Start a server at the specified address.
--
-- This sets up a server socket for the specified host and port. Exceptions
-- thrown during setup are not caught.
--
-- Once the socket is created we spawn a new thread which repeatedly accepts
-- incoming connections and executes the given request handler in another
-- thread. If any exception occurs the accepting thread terminates and calls
-- the terminationHandler. Threads spawned for previous accepted connections
-- are not killed.
-- This exception may occur because of a call to 'N.accept', or because the
-- thread was explicitly killed.
--
-- The request handler is not responsible for closing the socket. It will be
-- closed once that handler returns. Take care to ensure that the socket is not
-- used after the handler returns, or you will get undefined behavior
-- (the file descriptor may be re-used).
--
-- The return value includes the port was bound to. This is not always the same
-- port as that given in the argument. For example, binding to port 0 actually
-- binds to a random port, selected by the OS.
forkServer :: N.HostName                     -- ^ Host
           -> N.ServiceName                  -- ^ Port
           -> Int                            -- ^ Backlog (maximum number of queued connections)
           -> Bool                           -- ^ Set ReuseAddr option?
           -> (SomeException -> IO ())       -- ^ Error handler. Called with an
                                             --   exception raised when
                                             --   accepting a connection.
           -> (SomeException -> IO ())       -- ^ Termination handler. Called
                                             --   when the error handler throws
                                             --   an exception.
           -> (IO () -> (N.Socket, N.SockAddr) -> IO ())
                                             -- ^ Request handler. Gets an
                                             --   action which completes when
                                             --   the socket is closed.
           -> IO (N.ServiceName, ThreadId)
forkServer :: HostName
-> HostName
-> Int
-> Bool
-> (SomeException -> IO ())
-> (SomeException -> IO ())
-> (IO () -> (Socket, SockAddr) -> IO ())
-> IO (HostName, ThreadId)
forkServer HostName
host HostName
port Int
backlog Bool
reuseAddr SomeException -> IO ()
errorHandler SomeException -> IO ()
terminationHandler IO () -> (Socket, SockAddr) -> IO ()
requestHandler = do
    -- Resolve the specified address. By specification, getAddrInfo will never
    -- return an empty list (but will throw an exception instead) and will return
    -- the "best" address first, whatever that means
    AddrInfo
addr:[AddrInfo]
_ <- Maybe AddrInfo -> Maybe HostName -> Maybe HostName -> IO [AddrInfo]
N.getAddrInfo (AddrInfo -> Maybe AddrInfo
forall a. a -> Maybe a
Just AddrInfo
N.defaultHints) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
host) (HostName -> Maybe HostName
forall a. a -> Maybe a
Just HostName
port)
    IO Socket
-> (Socket -> IO ())
-> (Socket -> IO (HostName, ThreadId))
-> IO (HostName, ThreadId)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError (Family -> SocketType -> ProtocolNumber -> IO Socket
N.socket (AddrInfo -> Family
N.addrFamily AddrInfo
addr) SocketType
N.Stream ProtocolNumber
N.defaultProtocol)
                   Socket -> IO ()
tryCloseSocket ((Socket -> IO (HostName, ThreadId)) -> IO (HostName, ThreadId))
-> (Socket -> IO (HostName, ThreadId)) -> IO (HostName, ThreadId)
forall a b. (a -> b) -> a -> b
$ \Socket
sock -> do
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
reuseAddr (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> SocketOption -> Int -> IO ()
N.setSocketOption Socket
sock SocketOption
N.ReuseAddr Int
1
      Socket -> SockAddr -> IO ()
N.bind Socket
sock (AddrInfo -> SockAddr
N.addrAddress AddrInfo
addr)
      Socket -> Int -> IO ()
N.listen Socket
sock Int
backlog

      -- Close up and fill the synchonizing MVar.
      let release :: ((N.Socket, N.SockAddr), MVar ()) -> IO ()
          release :: ((Socket, SockAddr), MVar ()) -> IO ()
release ((Socket
sock, SockAddr
_), MVar ()
socketClosed) =
            Socket -> IO ()
N.close Socket
sock IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` MVar () -> () -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ()
socketClosed ()

      -- Run the request handler.
      let act :: (IO () -> IO ()) -> (Socket, SockAddr) -> IO ()
act IO () -> IO ()
restore (Socket
sock, SockAddr
sockAddr) = do
            MVar ()
socketClosed <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
            IO ThreadId -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO ThreadId -> IO ()) -> IO ThreadId -> IO ()
forall a b. (a -> b) -> a -> b
$ IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
restore (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
              IO () -> (Socket, SockAddr) -> IO ()
requestHandler (MVar () -> IO ()
forall a. MVar a -> IO a
readMVar MVar ()
socketClosed) (Socket
sock, SockAddr
sockAddr)
              IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally`
              ((Socket, SockAddr), MVar ()) -> IO ()
release ((Socket
sock, SockAddr
sockAddr), MVar ()
socketClosed)

      let acceptRequest :: IO ()
          acceptRequest :: IO ()
acceptRequest = ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore -> do
            -- Async exceptions are masked so that, if accept does give a
            -- socket, we'll always deliver it to the handler before the
            -- exception is raised.
            -- If it's a Right handler then it will eventually be closed.
            -- If it's a Left handler then we assume the handler itself will
            -- close it.
            (Socket
sock, SockAddr
sockAddr) <- Socket -> IO (Socket, SockAddr)
N.accept Socket
sock
            -- Looks like 'act' will never throw an exception, but to be
            -- safe we'll close the socket if it does.
            let handler :: SomeException -> IO ()
                handler :: SomeException -> IO ()
handler SomeException
_ = Socket -> IO ()
N.close Socket
sock
            IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch ((IO () -> IO ()) -> (Socket, SockAddr) -> IO ()
act IO () -> IO ()
forall a. IO a -> IO a
restore (Socket
sock, SockAddr
sockAddr)) SomeException -> IO ()
handler

      -- We start listening for incoming requests in a separate thread. When
      -- that thread is killed, we close the server socket and the termination
      -- handler is run. We have to make sure that the exception handler is
      -- installed /before/ any asynchronous exception occurs. So we mask_, then
      -- fork (the child thread inherits the masked state from the parent), then
      -- unmask only inside the catch.
      (,) (HostName -> ThreadId -> (HostName, ThreadId))
-> IO HostName -> IO (ThreadId -> (HostName, ThreadId))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PortNumber -> HostName) -> IO PortNumber -> IO HostName
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PortNumber -> HostName
forall a. Show a => a -> HostName
show (Socket -> IO PortNumber
N.socketPort Socket
sock) IO (ThreadId -> (HostName, ThreadId))
-> IO ThreadId -> IO (HostName, ThreadId)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
        (IO ThreadId -> IO ThreadId
forall a. IO a -> IO a
mask_ (IO ThreadId -> IO ThreadId) -> IO ThreadId -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
          IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (IO () -> IO ()
forall a. IO a -> IO a
unmask (IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO ()
acceptRequest SomeException -> IO ()
errorHandler))) ((SomeException -> IO ()) -> IO ())
-> (SomeException -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \SomeException
ex -> do
            Socket -> IO ()
tryCloseSocket Socket
sock
            SomeException -> IO ()
terminationHandler SomeException
ex)

-- | Read a length and then a payload of that length, subject to a limit
--   on the length.
--   If the length (first 'Word32' received) is greater than the limit then
--   an exception is thrown.
recvWithLength :: Word32 -> N.Socket -> IO [ByteString]
recvWithLength :: EndPointId -> Socket -> IO [ByteString]
recvWithLength EndPointId
limit Socket
sock = do
  EndPointId
len <- Socket -> IO EndPointId
recvWord32 Socket
sock
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (EndPointId
len EndPointId -> EndPointId -> Bool
forall a. Ord a => a -> a -> Bool
> EndPointId
limit) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
    IOError -> IO ()
forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"recvWithLength: limit exceeded")
  Socket -> EndPointId -> IO [ByteString]
recvExact Socket
sock EndPointId
len

-- | Receive a 32-bit unsigned integer
recvWord32 :: N.Socket -> IO Word32
recvWord32 :: Socket -> IO EndPointId
recvWord32 = ([ByteString] -> EndPointId) -> IO [ByteString] -> IO EndPointId
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ByteString -> EndPointId
decodeWord32 (ByteString -> EndPointId)
-> ([ByteString] -> ByteString) -> [ByteString] -> EndPointId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
BS.concat) (IO [ByteString] -> IO EndPointId)
-> (Socket -> IO [ByteString]) -> Socket -> IO EndPointId
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Socket -> EndPointId -> IO [ByteString])
-> EndPointId -> Socket -> IO [ByteString]
forall a b c. (a -> b -> c) -> b -> a -> c
flip Socket -> EndPointId -> IO [ByteString]
recvExact EndPointId
4

-- | Close a socket, ignoring I/O exceptions.
tryCloseSocket :: N.Socket -> IO ()
tryCloseSocket :: Socket -> IO ()
tryCloseSocket Socket
sock = IO (Either IOError ()) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError ()) -> IO ())
-> (IO () -> IO (Either IOError ())) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO (Either IOError ())
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
  Socket -> IO ()
N.close Socket
sock

-- | Shutdown socket sends and receives, ignoring I/O exceptions.
tryShutdownSocketBoth :: N.Socket -> IO ()
tryShutdownSocketBoth :: Socket -> IO ()
tryShutdownSocketBoth Socket
sock = IO (Either IOError ()) -> IO ()
forall (m :: * -> *) a. Monad m => m a -> m ()
void (IO (Either IOError ()) -> IO ())
-> (IO () -> IO (Either IOError ())) -> IO () -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO (Either IOError ())
forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
  Socket -> ShutdownCmd -> IO ()
N.shutdown Socket
sock ShutdownCmd
N.ShutdownBoth

-- | Read an exact number of bytes from a socket
--
-- Throws an I/O exception if the socket closes before the specified
-- number of bytes could be read
recvExact :: N.Socket        -- ^ Socket to read from
          -> Word32          -- ^ Number of bytes to read
          -> IO [ByteString] -- ^ Data read
recvExact :: Socket -> EndPointId -> IO [ByteString]
recvExact Socket
sock EndPointId
len = [ByteString] -> EndPointId -> IO [ByteString]
go [] EndPointId
len
  where
    go :: [ByteString] -> Word32 -> IO [ByteString]
    go :: [ByteString] -> EndPointId -> IO [ByteString]
go [ByteString]
acc EndPointId
0 = [ByteString] -> IO [ByteString]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([ByteString] -> [ByteString]
forall a. [a] -> [a]
reverse [ByteString]
acc)
    go [ByteString]
acc EndPointId
l = do
      ByteString
bs <- Socket -> Int -> IO ByteString
NBS.recv Socket
sock (EndPointId -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral EndPointId
l Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` Int
smallChunkSize)
      if ByteString -> Bool
BS.null ByteString
bs
        then IOError -> IO [ByteString]
forall e a. Exception e => e -> IO a
throwIO (HostName -> IOError
userError HostName
"recvExact: Socket closed")
        else [ByteString] -> EndPointId -> IO [ByteString]
go (ByteString
bs ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
acc) (EndPointId
l EndPointId -> EndPointId -> EndPointId
forall a. Num a => a -> a -> a
- Int -> EndPointId
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
BS.length ByteString
bs))

-- | Get the numeric host, resolved host (via getNameInfo), and port from a
-- SockAddr. The numeric host is first, then resolved host (which may be the
-- same as the numeric host).
-- Will only give 'Just' for IPv4 addresses.
resolveSockAddr :: N.SockAddr -> IO (Maybe (N.HostName, N.HostName, N.ServiceName))
resolveSockAddr :: SockAddr -> IO (Maybe (HostName, HostName, HostName))
resolveSockAddr SockAddr
sockAddr = case SockAddr
sockAddr of
  N.SockAddrInet PortNumber
port EndPointId
_ -> do
    (Maybe HostName
mResolvedHost, Maybe HostName
mResolvedPort) <- [NameInfoFlag]
-> Bool -> Bool -> SockAddr -> IO (Maybe HostName, Maybe HostName)
N.getNameInfo [] Bool
True Bool
False SockAddr
sockAddr
    case (Maybe HostName
mResolvedHost, Maybe HostName
mResolvedPort) of
      (Just HostName
resolvedHost, Maybe HostName
Nothing) -> do
        (Just HostName
numericHost, Maybe HostName
_) <- [NameInfoFlag]
-> Bool -> Bool -> SockAddr -> IO (Maybe HostName, Maybe HostName)
N.getNameInfo [NameInfoFlag
N.NI_NUMERICHOST] Bool
True Bool
False SockAddr
sockAddr
        Maybe (HostName, HostName, HostName)
-> IO (Maybe (HostName, HostName, HostName))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (HostName, HostName, HostName)
 -> IO (Maybe (HostName, HostName, HostName)))
-> Maybe (HostName, HostName, HostName)
-> IO (Maybe (HostName, HostName, HostName))
forall a b. (a -> b) -> a -> b
$ (HostName, HostName, HostName)
-> Maybe (HostName, HostName, HostName)
forall a. a -> Maybe a
Just (HostName
numericHost, HostName
resolvedHost, PortNumber -> HostName
forall a. Show a => a -> HostName
show PortNumber
port)
      (Maybe HostName, Maybe HostName)
_ -> HostName -> IO (Maybe (HostName, HostName, HostName))
forall a. HasCallStack => HostName -> a
error (HostName -> IO (Maybe (HostName, HostName, HostName)))
-> HostName -> IO (Maybe (HostName, HostName, HostName))
forall a b. (a -> b) -> a -> b
$ [HostName] -> HostName
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [
          HostName
"decodeSockAddr: unexpected resolution "
        , SockAddr -> HostName
forall a. Show a => a -> HostName
show SockAddr
sockAddr
        , HostName
" -> "
        , Maybe HostName -> HostName
forall a. Show a => a -> HostName
show Maybe HostName
mResolvedHost
        , HostName
", "
        , Maybe HostName -> HostName
forall a. Show a => a -> HostName
show Maybe HostName
mResolvedPort
        ]
  SockAddr
_ -> Maybe (HostName, HostName, HostName)
-> IO (Maybe (HostName, HostName, HostName))
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (HostName, HostName, HostName)
forall a. Maybe a
Nothing

-- | Encode end point address
encodeEndPointAddress :: N.HostName
                      -> N.ServiceName
                      -> EndPointId
                      -> EndPointAddress
encodeEndPointAddress :: HostName -> HostName -> EndPointId -> EndPointAddress
encodeEndPointAddress HostName
host HostName
port EndPointId
ix = ByteString -> EndPointAddress
EndPointAddress (ByteString -> EndPointAddress)
-> (HostName -> ByteString) -> HostName -> EndPointAddress
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostName -> ByteString
BSC.pack (HostName -> EndPointAddress) -> HostName -> EndPointAddress
forall a b. (a -> b) -> a -> b
$
  HostName
host HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
":" HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
port HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ HostName
":" HostName -> ShowS
forall a. [a] -> [a] -> [a]
++ EndPointId -> HostName
forall a. Show a => a -> HostName
show EndPointId
ix

-- | Decode end point address
decodeEndPointAddress :: EndPointAddress
                      -> Maybe (N.HostName, N.ServiceName, EndPointId)
decodeEndPointAddress :: EndPointAddress -> Maybe (HostName, HostName, EndPointId)
decodeEndPointAddress (EndPointAddress ByteString
bs) =
  case (Char -> Bool) -> Int -> HostName -> [HostName]
forall a. (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
':') Int
2 (HostName -> [HostName]) -> HostName -> [HostName]
forall a b. (a -> b) -> a -> b
$ ByteString -> HostName
BSC.unpack ByteString
bs of
    [HostName
host, HostName
port, HostName
endPointIdStr] ->
      case ReadS EndPointId
forall a. Read a => ReadS a
reads HostName
endPointIdStr of
        [(EndPointId
endPointId, HostName
"")] -> (HostName, HostName, EndPointId)
-> Maybe (HostName, HostName, EndPointId)
forall a. a -> Maybe a
Just (HostName
host, HostName
port, EndPointId
endPointId)
        [(EndPointId, HostName)]
_                  -> Maybe (HostName, HostName, EndPointId)
forall a. Maybe a
Nothing
    [HostName]
_ ->
      Maybe (HostName, HostName, EndPointId)
forall a. Maybe a
Nothing

-- | @spltiMaxFromEnd p n xs@ splits list @xs@ at elements matching @p@,
-- returning at most @p@ segments -- counting from the /end/
--
-- > splitMaxFromEnd (== ':') 2 "ab:cd:ef:gh" == ["ab:cd", "ef", "gh"]
splitMaxFromEnd :: (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd :: forall a. (a -> Bool) -> Int -> [a] -> [[a]]
splitMaxFromEnd a -> Bool
p = \Int
n -> [[a]] -> Int -> [a] -> [[a]]
forall {t}. (Eq t, Num t) => [[a]] -> t -> [a] -> [[a]]
go [[]] Int
n ([a] -> [[a]]) -> ([a] -> [a]) -> [a] -> [[a]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> [a]
forall a. [a] -> [a]
reverse
  where
    -- go :: [[a]] -> Int -> [a] -> [[a]]
    go :: [[a]] -> t -> [a] -> [[a]]
go [[a]]
accs         t
_ []     = [[a]]
accs
    go ([]  : [[a]]
accs) t
0 [a]
xs     = [a] -> [a]
forall a. [a] -> [a]
reverse [a]
xs [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [[a]]
accs
    go ([a]
acc : [[a]]
accs) t
n (a
x:[a]
xs) =
      if a -> Bool
p a
x then [[a]] -> t -> [a] -> [[a]]
go ([] [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [a]
acc [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [[a]]
accs) (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) [a]
xs
             else [[a]] -> t -> [a] -> [[a]]
go ((a
x a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
acc) [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: [[a]]
accs) t
n [a]
xs
    go [[a]]
_ t
_ [a]
_ = HostName -> [[a]]
forall a. HasCallStack => HostName -> a
error HostName
"Bug in splitMaxFromEnd"