module Database.Redis.PubSub (
publish,
pubSub,
Message(..),
PubSub(),
subscribe, unsubscribe, psubscribe, punsubscribe
) where
import Control.Applicative
import Control.Monad
import Control.Monad.State
import Data.ByteString.Char8 (ByteString)
import Data.Monoid
import qualified Database.Redis.Core as Core
import Database.Redis.Protocol (Reply(..))
import Database.Redis.Types
data PubSubState = PubSubState { subCnt, pending :: Int }
modifyPending :: (MonadState PubSubState m) => (Int -> Int) -> m ()
modifyPending f = modify $ \s -> s{ pending = f (pending s) }
putSubCnt :: (MonadState PubSubState m) => Int -> m ()
putSubCnt n = modify $ \s -> s{ subCnt = n }
data Subscribe
data Unsubscribe
data Channel
data Pattern
data PubSub = PubSub
{ subs :: Cmd Subscribe Channel
, unsubs :: Cmd Unsubscribe Channel
, psubs :: Cmd Subscribe Pattern
, punsubs :: Cmd Unsubscribe Pattern
} deriving (Eq)
instance Monoid PubSub where
mempty = PubSub mempty mempty mempty mempty
mappend p1 p2 = PubSub { subs = subs p1 `mappend` subs p2
, unsubs = unsubs p1 `mappend` unsubs p2
, psubs = psubs p1 `mappend` psubs p2
, punsubs = punsubs p1 `mappend` punsubs p2
}
data Cmd a b = DoNothing | Cmd { changes :: [ByteString] } deriving (Eq)
instance Monoid (Cmd Subscribe a) where
mempty = DoNothing
mappend DoNothing x = x
mappend x DoNothing = x
mappend (Cmd xs) (Cmd ys) = Cmd (xs ++ ys)
instance Monoid (Cmd Unsubscribe a) where
mempty = DoNothing
mappend DoNothing x = x
mappend x DoNothing = x
mappend (Cmd []) _ = Cmd []
mappend _ (Cmd []) = Cmd []
mappend (Cmd xs) (Cmd ys) = Cmd (xs ++ ys)
class Command a where
redisCmd :: a -> ByteString
updatePending :: a -> Int -> Int
sendCmd :: (Command (Cmd a b)) => Cmd a b -> StateT PubSubState Core.Redis ()
sendCmd DoNothing = return ()
sendCmd cmd = do
lift $ Core.send (redisCmd cmd : changes cmd)
modifyPending (updatePending cmd)
plusChangeCnt :: Cmd a b -> Int -> Int
plusChangeCnt DoNothing = id
plusChangeCnt (Cmd cs) = (+ length cs)
instance Command (Cmd Subscribe Channel) where
redisCmd = const "SUBSCRIBE"
updatePending = plusChangeCnt
instance Command (Cmd Subscribe Pattern) where
redisCmd = const "PSUBSCRIBE"
updatePending = plusChangeCnt
instance Command (Cmd Unsubscribe Channel) where
redisCmd = const "UNSUBSCRIBE"
updatePending = const id
instance Command (Cmd Unsubscribe Pattern) where
redisCmd = const "PUNSUBSCRIBE"
updatePending = const id
data Message = Message { msgChannel, msgMessage :: ByteString}
| PMessage { msgPattern, msgChannel, msgMessage :: ByteString}
deriving (Show)
data PubSubReply = Subscribed | Unsubscribed Int | Msg Message
publish
:: (Core.RedisCtx m f)
=> ByteString
-> ByteString
-> m (f Integer)
publish channel message =
Core.sendRequest ["PUBLISH", channel, message]
subscribe
:: [ByteString]
-> PubSub
subscribe [] = mempty
subscribe cs = mempty{ subs = Cmd cs }
unsubscribe
:: [ByteString]
-> PubSub
unsubscribe cs = mempty{ unsubs = Cmd cs }
psubscribe
:: [ByteString]
-> PubSub
psubscribe [] = mempty
psubscribe ps = mempty{ psubs = Cmd ps }
punsubscribe
:: [ByteString]
-> PubSub
punsubscribe ps = mempty{ punsubs = Cmd ps }
pubSub
:: PubSub
-> (Message -> IO PubSub)
-> Core.Redis ()
pubSub initial callback
| initial == mempty = return ()
| otherwise = evalStateT (send initial) (PubSubState 0 0)
where
send :: PubSub -> StateT PubSubState Core.Redis ()
send PubSub{..} = do
sendCmd subs
sendCmd unsubs
sendCmd psubs
sendCmd punsubs
recv
recv :: StateT PubSubState Core.Redis ()
recv = do
reply <- lift Core.recv
case decodeMsg reply of
Msg msg -> liftIO (callback msg) >>= send
Subscribed -> modifyPending (subtract 1) >> recv
Unsubscribed n -> do
putSubCnt n
PubSubState{..} <- get
unless (subCnt == 0 && pending == 0) recv
decodeMsg :: Reply -> PubSubReply
decodeMsg r@(MultiBulk (Just (r0:r1:r2:rs))) = either (errMsg r) id $ do
kind <- decode r0
case kind :: ByteString of
"message" -> Msg <$> decodeMessage
"pmessage" -> Msg <$> decodePMessage
"subscribe" -> return Subscribed
"psubscribe" -> return Subscribed
"unsubscribe" -> Unsubscribed <$> decodeCnt
"punsubscribe" -> Unsubscribed <$> decodeCnt
_ -> errMsg r
where
decodeMessage = Message <$> decode r1 <*> decode r2
decodePMessage = PMessage <$> decode r1 <*> decode r2 <*> decode (head rs)
decodeCnt = fromInteger <$> decode r2
decodeMsg r = errMsg r
errMsg :: Reply -> a
errMsg r = error $ "Hedis: expected pub/sub-message but got: " ++ show r