{-# LANGUAGE DeriveGeneric #-}

-- |
-- Module      : PostgresWebsockets.Broadcast
-- Description : Distribute messages from one producer to several consumers.
--
-- PostgresWebsockets functions to broadcast messages to several listening clients
-- This module provides a type called Multiplexer.
-- The multiplexer contains a map of channels and a producer thread.
--
-- This module avoids any database implementation details, it is used by HasqlBroadcast where
-- the database logic is combined.
module PostgresWebsockets.Broadcast
  ( Multiplexer,
    Message (..),
    newMultiplexer,
    onMessage,
    relayMessages,
    relayMessagesForever,
    superviseMultiplexer,

    -- * Re-exports
    readTQueue,
    writeTQueue,
    readTChan,
  )
where

import Control.Concurrent.STM.TChan
import Control.Concurrent.STM.TQueue
import qualified Data.Aeson as A
import Protolude hiding (toS)
import Protolude.Conv (toS)
import qualified StmContainers.Map as M

data Message = Message
  { Message -> Text
channel :: Text,
    Message -> Text
payload :: Text
  }
  deriving (Message -> Message -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Message -> Message -> Bool
$c/= :: Message -> Message -> Bool
== :: Message -> Message -> Bool
$c== :: Message -> Message -> Bool
Eq, Int -> Message -> ShowS
[Message] -> ShowS
Message -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Message] -> ShowS
$cshowList :: [Message] -> ShowS
show :: Message -> String
$cshow :: Message -> String
showsPrec :: Int -> Message -> ShowS
$cshowsPrec :: Int -> Message -> ShowS
Show)

data Multiplexer = Multiplexer
  { Multiplexer -> Map Text Channel
channels :: M.Map Text Channel,
    Multiplexer -> TQueue Message
messages :: TQueue Message,
    Multiplexer -> MVar ThreadId
producerThreadId :: MVar ThreadId,
    Multiplexer -> IO ThreadId
reopenProducer :: IO ThreadId
  }

data MultiplexerSnapshot = MultiplexerSnapshot
  { MultiplexerSnapshot -> Int
channelsSize :: Int,
    MultiplexerSnapshot -> Bool
messageQueueEmpty :: Bool,
    MultiplexerSnapshot -> Text
producerId :: Text
  }
  deriving (forall x. Rep MultiplexerSnapshot x -> MultiplexerSnapshot
forall x. MultiplexerSnapshot -> Rep MultiplexerSnapshot x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep MultiplexerSnapshot x -> MultiplexerSnapshot
$cfrom :: forall x. MultiplexerSnapshot -> Rep MultiplexerSnapshot x
Generic)

data Channel = Channel
  { Channel -> TChan Message
broadcast :: TChan Message,
    Channel -> Integer
listeners :: Integer
  }

instance A.ToJSON MultiplexerSnapshot

-- | Given a multiplexer derive a type that can be printed for debugging or logging purposes
takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot
takeSnapshot :: Multiplexer -> IO MultiplexerSnapshot
takeSnapshot Multiplexer
multi =
  Int -> Bool -> Text -> MultiplexerSnapshot
MultiplexerSnapshot forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Int
size forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Bool
e forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IO Text
thread
  where
    size :: IO Int
size = forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall key value. Map key value -> STM Int
M.size forall a b. (a -> b) -> a -> b
$ Multiplexer -> Map Text Channel
channels Multiplexer
multi
    thread :: IO Text
thread = forall a b. (Show a, StringConv String b) => a -> b
show forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. MVar a -> IO a
readMVar (Multiplexer -> MVar ThreadId
producerThreadId Multiplexer
multi)
    e :: IO Bool
e = forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> STM Bool
isEmptyTQueue forall a b. (a -> b) -> a -> b
$ Multiplexer -> TQueue Message
messages Multiplexer
multi

-- | Opens a thread that relays messages from the producer thread to the channels forever
relayMessagesForever :: Multiplexer -> IO ThreadId
relayMessagesForever :: Multiplexer -> IO ThreadId
relayMessagesForever = IO () -> IO ThreadId
forkIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall b c a. (b -> c) -> (a -> b) -> a -> c
. Multiplexer -> IO ()
relayMessages

-- | Reads the messages from the producer and relays them to the active listeners in their respective channels.
relayMessages :: Multiplexer -> IO ()
relayMessages :: Multiplexer -> IO ()
relayMessages Multiplexer
multi =
  forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
    Message
m <- forall a. TQueue a -> STM a
readTQueue (Multiplexer -> TQueue Message
messages Multiplexer
multi)
    Maybe Channel
mChannel <- forall key value.
Hashable key =>
key -> Map key value -> STM (Maybe value)
M.lookup (Message -> Text
channel Message
m) (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
    case Maybe Channel
mChannel of
      Maybe Channel
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
      Just Channel
c -> forall a. TChan a -> a -> STM ()
writeTChan (Channel -> TChan Message
broadcast Channel
c) Message
m

newMultiplexer ::
  (TQueue Message -> IO a) ->
  (Either SomeException a -> IO ()) ->
  IO Multiplexer
newMultiplexer :: forall a.
(TQueue Message -> IO a)
-> (Either SomeException a -> IO ()) -> IO Multiplexer
newMultiplexer TQueue Message -> IO a
openProducer Either SomeException a -> IO ()
closeProducer = do
  TQueue Message
msgs <- forall a. IO (TQueue a)
newTQueueIO
  let forkNewProducer :: IO ThreadId
forkNewProducer = forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (TQueue Message -> IO a
openProducer TQueue Message
msgs) Either SomeException a -> IO ()
closeProducer
  ThreadId
tid <- IO ThreadId
forkNewProducer
  Map Text Channel
multiplexerMap <- forall key value. IO (Map key value)
M.newIO
  MVar ThreadId
producerThreadId <- forall a. a -> IO (MVar a)
newMVar ThreadId
tid
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Map Text Channel
-> TQueue Message -> MVar ThreadId -> IO ThreadId -> Multiplexer
Multiplexer Map Text Channel
multiplexerMap TQueue Message
msgs MVar ThreadId
producerThreadId IO ThreadId
forkNewProducer

-- |  Given a multiplexer, a number of milliseconds and an IO computation that returns a boolean
--      Runs the IO computation at every interval of milliseconds interval and reopens the multiplexer producer
--      if the resulting boolean is true
--      When interval is 0 this is NOOP, so the minimum interval is 1ms
--      Call this in case you want to ensure the producer thread is killed and restarted under a certain condition
superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO ()
superviseMultiplexer :: Multiplexer -> Int -> IO Bool -> IO ()
superviseMultiplexer Multiplexer
multi Int
msInterval IO Bool
shouldRestart = do
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
    IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$
      forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
        Int -> IO ()
threadDelay forall a b. (a -> b) -> a -> b
$ Int
msInterval forall a. Num a => a -> a -> a
* Int
1000
        Bool
sr <- IO Bool
shouldRestart
        forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
sr forall a b. (a -> b) -> a -> b
$ do
          MultiplexerSnapshot
snapBefore <- Multiplexer -> IO MultiplexerSnapshot
takeSnapshot Multiplexer
multi
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ ThreadId -> IO ()
killThread forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. MVar a -> IO a
readMVar (Multiplexer -> MVar ThreadId
producerThreadId Multiplexer
multi)
          ThreadId
new <- Multiplexer -> IO ThreadId
reopenProducer Multiplexer
multi
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO a
swapMVar (Multiplexer -> MVar ThreadId
producerThreadId Multiplexer
multi) ThreadId
new
          MultiplexerSnapshot
snapAfter <- Multiplexer -> IO MultiplexerSnapshot
takeSnapshot Multiplexer
multi
          forall a (m :: * -> *). (Print a, MonadIO m) => a -> m ()
putStrLn forall a b. (a -> b) -> a -> b
$
            ByteString
"Restarting producer. Multiplexer updated: "
              forall a. Semigroup a => a -> a -> a
<> forall a. ToJSON a => a -> ByteString
A.encode MultiplexerSnapshot
snapBefore
              forall a. Semigroup a => a -> a -> a
<> ByteString
" -> "
              forall a. Semigroup a => a -> a -> a
<> forall a. ToJSON a => a -> ByteString
A.encode MultiplexerSnapshot
snapAfter

openChannel :: Multiplexer -> Text -> STM Channel
openChannel :: Multiplexer -> Text -> STM Channel
openChannel Multiplexer
multi Text
chan = do
  TChan Message
c <- forall a. STM (TChan a)
newBroadcastTChan
  let newChannel :: Channel
newChannel =
        Channel
          { broadcast :: TChan Message
broadcast = TChan Message
c,
            listeners :: Integer
listeners = Integer
0
          }
  forall key value.
Hashable key =>
value -> key -> Map key value -> STM ()
M.insert Channel
newChannel Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
  forall (m :: * -> *) a. Monad m => a -> m a
return Channel
newChannel

-- |  Adds a listener to a certain multiplexer's channel.
--      The listener must be a function that takes a 'TChan Message' and perform any IO action.
--      All listeners run in their own thread.
--      The first listener will open the channel, when a listener dies it will check if there acquire
--      any others and close the channel when that's the case.
onMessage :: Multiplexer -> Text -> (Message -> IO ()) -> IO ()
onMessage :: Multiplexer -> Text -> (Message -> IO ()) -> IO ()
onMessage Multiplexer
multi Text
chan Message -> IO ()
action = do
  TChan Message
listener <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ STM Channel
openChannelWhenNotFound forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Channel -> STM (TChan Message)
addListener
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkFinally (forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (forall a. STM a -> IO a
atomically (forall a. TChan a -> STM a
readTChan TChan Message
listener) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Message -> IO ()
action)) forall {p}. p -> IO ()
disposeListener
  where
    disposeListener :: p -> IO ()
disposeListener p
_ = forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
      Maybe Channel
mC <- forall key value.
Hashable key =>
key -> Map key value -> STM (Maybe value)
M.lookup Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      let c :: Channel
c = forall a. a -> Maybe a -> a
fromMaybe (forall a. HasCallStack => Text -> a
panic forall a b. (a -> b) -> a -> b
$ Text
"trying to remove listener from non existing channel: " forall a. Semigroup a => a -> a -> a
<> forall a b. StringConv a b => a -> b
toS Text
chan) Maybe Channel
mC
      forall key value. Hashable key => key -> Map key value -> STM ()
M.delete Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Channel -> Integer
listeners Channel
c forall a. Num a => a -> a -> a
- Integer
1 forall a. Ord a => a -> a -> Bool
> Integer
0) forall a b. (a -> b) -> a -> b
$
        forall key value.
Hashable key =>
value -> key -> Map key value -> STM ()
M.insert Channel {broadcast :: TChan Message
broadcast = Channel -> TChan Message
broadcast Channel
c, listeners :: Integer
listeners = Channel -> Integer
listeners Channel
c forall a. Num a => a -> a -> a
- Integer
1} Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
    openChannelWhenNotFound :: STM Channel
openChannelWhenNotFound =
      forall key value.
Hashable key =>
key -> Map key value -> STM (Maybe value)
M.lookup Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
        Maybe Channel
Nothing -> Multiplexer -> Text -> STM Channel
openChannel Multiplexer
multi Text
chan
        Just Channel
ch -> forall (m :: * -> *) a. Monad m => a -> m a
return Channel
ch
    addListener :: Channel -> STM (TChan Message)
addListener Channel
ch = do
      forall key value. Hashable key => key -> Map key value -> STM ()
M.delete Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      let newChannel :: Channel
newChannel = Channel {broadcast :: TChan Message
broadcast = Channel -> TChan Message
broadcast Channel
ch, listeners :: Integer
listeners = Channel -> Integer
listeners Channel
ch forall a. Num a => a -> a -> a
+ Integer
1}
      forall key value.
Hashable key =>
value -> key -> Map key value -> STM ()
M.insert Channel
newChannel Text
chan (Multiplexer -> Map Text Channel
channels Multiplexer
multi)
      forall a. TChan a -> STM (TChan a)
dupTChan forall a b. (a -> b) -> a -> b
$ Channel -> TChan Message
broadcast Channel
newChannel