-- | Multicast utilities
{-# LANGUAGE RankNTypes #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
module Control.Distributed.Process.Backend.SimpleLocalnet.Internal.Multicast (initMulticast) where

import Data.Map (Map)
import qualified Data.Map as Map (empty)
import Data.Binary (Binary, decode, encode)
import Data.IORef (IORef, newIORef, readIORef, modifyIORef)
import qualified Data.ByteString as BSS (ByteString, concat)
import qualified Data.ByteString.Lazy as BSL
  ( ByteString
  , empty
  , append
  , fromChunks
  , toChunks
  , length
  , splitAt
  )
import Data.Accessor (Accessor, (^:), (^.), (^=))
import qualified Data.Accessor.Container as DAC (mapDefault)
import Network.Socket (HostName, PortNumber, Socket, SockAddr)
import qualified Network.Socket.ByteString as NBS (recvFrom, sendManyTo)
import Network.Transport.Internal (decodeNum32, encodeEnum32)
import Network.Multicast (multicastSender, multicastReceiver)

--------------------------------------------------------------------------------
-- Top-level API                                                              --
--------------------------------------------------------------------------------

-- | Given a hostname and a port number, initialize the multicast system.
--
-- Note: it is important that you never send messages larger than the maximum
-- message size; if you do, all subsequent communication will probably fail.
--
-- Returns a reader and a writer.
--
-- NOTE: By rights the two functions should be "locally" polymorphic in 'a',
-- but this requires impredicative types.
initMulticast :: forall a. Binary a
              => HostName    -- ^ Multicast IP
              -> PortNumber  -- ^ Port number
              -> Int         -- ^ Maximum message size
              -> IO (IO (a, SockAddr), a -> IO ())
initMulticast :: forall a.
Binary a =>
HostName -> PortNumber -> Int -> IO (IO (a, SockAddr), a -> IO ())
initMulticast HostName
host PortNumber
port Int
bufferSize = do
    (Socket
sendSock, SockAddr
sendAddr) <- HostName -> PortNumber -> IO (Socket, SockAddr)
multicastSender HostName
host PortNumber
port
    Socket
readSock <- HostName -> PortNumber -> IO Socket
multicastReceiver HostName
host PortNumber
port
    IORef (Map SockAddr ByteString)
st <- Map SockAddr ByteString -> IO (IORef (Map SockAddr ByteString))
forall a. a -> IO (IORef a)
newIORef Map SockAddr ByteString
forall k a. Map k a
Map.empty
    (IO (a, SockAddr), a -> IO ()) -> IO (IO (a, SockAddr), a -> IO ())
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Socket
-> IORef (Map SockAddr ByteString) -> Int -> IO (a, SockAddr)
forall a.
Binary a =>
Socket
-> IORef (Map SockAddr ByteString) -> Int -> IO (a, SockAddr)
recvBinary Socket
readSock IORef (Map SockAddr ByteString)
st Int
bufferSize, Socket -> SockAddr -> a -> IO ()
forall a. Binary a => Socket -> SockAddr -> a -> IO ()
writer Socket
sendSock SockAddr
sendAddr)
  where
    writer :: forall a. Binary a => Socket -> SockAddr -> a -> IO ()
    writer :: forall a. Binary a => Socket -> SockAddr -> a -> IO ()
writer Socket
sock SockAddr
addr a
val = do
      let bytes :: ByteString
bytes = a -> ByteString
forall a. Binary a => a -> ByteString
encode a
val
          len :: ByteString
len   = Int64 -> ByteString
forall a. Enum a => a -> ByteString
encodeEnum32 (ByteString -> Int64
BSL.length ByteString
bytes)
      Socket -> [ByteString] -> SockAddr -> IO ()
NBS.sendManyTo Socket
sock (ByteString
len ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: ByteString -> [ByteString]
BSL.toChunks ByteString
bytes) SockAddr
addr

--------------------------------------------------------------------------------
-- UDP multicast read, dealing with multiple senders                          --
--------------------------------------------------------------------------------

type UDPState = Map SockAddr BSL.ByteString

bufferFor :: SockAddr -> Accessor UDPState BSL.ByteString
bufferFor :: SockAddr -> Accessor (Map SockAddr ByteString) ByteString
bufferFor = ByteString
-> SockAddr -> Accessor (Map SockAddr ByteString) ByteString
forall key elem. Ord key => elem -> key -> T (Map key elem) elem
DAC.mapDefault ByteString
BSL.empty

bufferAppend :: SockAddr -> BSS.ByteString -> UDPState -> UDPState
bufferAppend :: SockAddr
-> ByteString -> Map SockAddr ByteString -> Map SockAddr ByteString
bufferAppend SockAddr
addr ByteString
bytes =
  SockAddr -> Accessor (Map SockAddr ByteString) ByteString
bufferFor SockAddr
addr Accessor (Map SockAddr ByteString) ByteString
-> (ByteString -> ByteString)
-> Map SockAddr ByteString
-> Map SockAddr ByteString
forall r a. T r a -> (a -> a) -> r -> r
^: (ByteString -> ByteString -> ByteString)
-> ByteString -> ByteString -> ByteString
forall a b c. (a -> b -> c) -> b -> a -> c
flip ByteString -> ByteString -> ByteString
BSL.append ([ByteString] -> ByteString
BSL.fromChunks [ByteString
bytes])

recvBinary :: Binary a => Socket -> IORef UDPState -> Int -> IO (a, SockAddr)
recvBinary :: forall a.
Binary a =>
Socket
-> IORef (Map SockAddr ByteString) -> Int -> IO (a, SockAddr)
recvBinary Socket
sock IORef (Map SockAddr ByteString)
st Int
bufferSize = do
  (ByteString
bytes, SockAddr
addr) <- Socket
-> IORef (Map SockAddr ByteString)
-> Int
-> IO (ByteString, SockAddr)
recvWithLength Socket
sock IORef (Map SockAddr ByteString)
st Int
bufferSize
  (a, SockAddr) -> IO (a, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString -> a
forall a. Binary a => ByteString -> a
decode ByteString
bytes, SockAddr
addr)

recvWithLength :: Socket
               -> IORef UDPState
               -> Int
               -> IO (BSL.ByteString, SockAddr)
recvWithLength :: Socket
-> IORef (Map SockAddr ByteString)
-> Int
-> IO (ByteString, SockAddr)
recvWithLength Socket
sock IORef (Map SockAddr ByteString)
st Int
bufferSize = do
  (ByteString
len, SockAddr
addr) <- Socket
-> Int
-> IORef (Map SockAddr ByteString)
-> Int
-> IO (ByteString, SockAddr)
recvExact Socket
sock Int
4 IORef (Map SockAddr ByteString)
st Int
bufferSize
  let n :: Int
n = ByteString -> Int
forall a. Num a => ByteString -> a
decodeNum32 (ByteString -> Int)
-> (ByteString -> ByteString) -> ByteString -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ByteString] -> ByteString
BSS.concat ([ByteString] -> ByteString)
-> (ByteString -> [ByteString]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
BSL.toChunks (ByteString -> Int) -> ByteString -> Int
forall a b. (a -> b) -> a -> b
$ ByteString
len
  ByteString
bytes <- SockAddr
-> Socket
-> Int
-> IORef (Map SockAddr ByteString)
-> Int
-> IO ByteString
recvExactFrom SockAddr
addr Socket
sock Int
n IORef (Map SockAddr ByteString)
st Int
bufferSize
  (ByteString, SockAddr) -> IO (ByteString, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bytes, SockAddr
addr)

-- Receive all bytes currently in the buffer
recvAll :: Socket -> IORef UDPState -> Int -> IO SockAddr
recvAll :: Socket -> IORef (Map SockAddr ByteString) -> Int -> IO SockAddr
recvAll Socket
sock IORef (Map SockAddr ByteString)
st Int
bufferSize = do
  (ByteString
bytes, SockAddr
addr) <- Socket -> Int -> IO (ByteString, SockAddr)
NBS.recvFrom Socket
sock Int
bufferSize
  IORef (Map SockAddr ByteString)
-> (Map SockAddr ByteString -> Map SockAddr ByteString) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Map SockAddr ByteString)
st ((Map SockAddr ByteString -> Map SockAddr ByteString) -> IO ())
-> (Map SockAddr ByteString -> Map SockAddr ByteString) -> IO ()
forall a b. (a -> b) -> a -> b
$ SockAddr
-> ByteString -> Map SockAddr ByteString -> Map SockAddr ByteString
bufferAppend SockAddr
addr ByteString
bytes
  SockAddr -> IO SockAddr
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return SockAddr
addr

recvExact :: Socket
          -> Int
          -> IORef UDPState
          -> Int
          -> IO (BSL.ByteString, SockAddr)
recvExact :: Socket
-> Int
-> IORef (Map SockAddr ByteString)
-> Int
-> IO (ByteString, SockAddr)
recvExact Socket
sock Int
n IORef (Map SockAddr ByteString)
st Int
bufferSize = do
  SockAddr
addr  <- Socket -> IORef (Map SockAddr ByteString) -> Int -> IO SockAddr
recvAll Socket
sock IORef (Map SockAddr ByteString)
st Int
bufferSize
  ByteString
bytes <- SockAddr
-> Socket
-> Int
-> IORef (Map SockAddr ByteString)
-> Int
-> IO ByteString
recvExactFrom SockAddr
addr Socket
sock Int
n IORef (Map SockAddr ByteString)
st Int
bufferSize
  (ByteString, SockAddr) -> IO (ByteString, SockAddr)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
bytes, SockAddr
addr)

recvExactFrom :: SockAddr
              -> Socket
              -> Int
              -> IORef UDPState
              -> Int
              -> IO BSL.ByteString
recvExactFrom :: SockAddr
-> Socket
-> Int
-> IORef (Map SockAddr ByteString)
-> Int
-> IO ByteString
recvExactFrom SockAddr
addr Socket
sock Int
n IORef (Map SockAddr ByteString)
st Int
bufferSize = IO ByteString
go
  where
    go :: IO BSL.ByteString
    go :: IO ByteString
go = do
      ByteString
accAddr <- (Map SockAddr ByteString
-> Accessor (Map SockAddr ByteString) ByteString -> ByteString
forall r a. r -> T r a -> a
^. SockAddr -> Accessor (Map SockAddr ByteString) ByteString
bufferFor SockAddr
addr) (Map SockAddr ByteString -> ByteString)
-> IO (Map SockAddr ByteString) -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IORef (Map SockAddr ByteString) -> IO (Map SockAddr ByteString)
forall a. IORef a -> IO a
readIORef IORef (Map SockAddr ByteString)
st
      if ByteString -> Int64
BSL.length ByteString
accAddr Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n
        then do
          let (ByteString
bytes, ByteString
accAddr') = Int64 -> ByteString -> (ByteString, ByteString)
BSL.splitAt (Int -> Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) ByteString
accAddr
          IORef (Map SockAddr ByteString)
-> (Map SockAddr ByteString -> Map SockAddr ByteString) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Map SockAddr ByteString)
st ((Map SockAddr ByteString -> Map SockAddr ByteString) -> IO ())
-> (Map SockAddr ByteString -> Map SockAddr ByteString) -> IO ()
forall a b. (a -> b) -> a -> b
$ SockAddr -> Accessor (Map SockAddr ByteString) ByteString
bufferFor SockAddr
addr Accessor (Map SockAddr ByteString) ByteString
-> ByteString -> Map SockAddr ByteString -> Map SockAddr ByteString
forall r a. T r a -> a -> r -> r
^= ByteString
accAddr'
          ByteString -> IO ByteString
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
bytes
        else do
          SockAddr
_ <- Socket -> IORef (Map SockAddr ByteString) -> Int -> IO SockAddr
recvAll Socket
sock IORef (Map SockAddr ByteString)
st Int
bufferSize
          IO ByteString
go