{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE Rank2Types #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE AllowAmbiguousTypes #-}

module Data.ZRE (
    zreVer
  , newZRE
  , parseZRE
  , encodeZRE
  , zreBeacon
  , parseBeacon
  , Name
  , Headers
  , Content
  , Group
  , mkGroup
  , unGroup
  , Groups
  , Seq
  , GroupSeq
  , ZREMsg(..)
  , ZRECmd(..)
  , SymbolicGroup
  , KnownGroup
  , knownToGroup
  ) where

import Prelude hiding (putStrLn, take)
import Data.ByteString (ByteString)

import qualified Data.ByteString.Char8 as B
import qualified Data.ByteString.Lazy as BL

import GHC.Word

import Data.Map (Map)
import Data.Set (Set)

import qualified Data.Set

import Data.UUID
import Data.Time.Clock

import System.ZMQ4.Endpoint
import Data.ZMQParse

import GHC.TypeLits
import Data.Proxy

zreVer :: Int
zreVer :: GroupSeq
zreVer = GroupSeq
2
zreSig :: Word16
zreSig :: Word16
zreSig = Word16
0xAAA1

type Seq = Int
type GroupSeq = Int

type SymbolicGroup = Symbol
type KnownGroup = KnownSymbol

-- | Convert from symbolic "KnownGroup" to "Group".
knownToGroup :: forall n. KnownGroup n => Group
knownToGroup :: forall (n :: Symbol). KnownGroup n => Group
knownToGroup  = ByteString -> Group
Group forall a b. (a -> b) -> a -> b
$ [Char] -> ByteString
B.pack forall a b. (a -> b) -> a -> b
$ forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> [Char]
symbolVal @n forall {k} (t :: k). Proxy t
Proxy

newtype Group = Group ByteString
  deriving (GroupSeq -> Group -> ShowS
[Group] -> ShowS
Group -> [Char]
forall a.
(GroupSeq -> a -> ShowS)
-> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [Group] -> ShowS
$cshowList :: [Group] -> ShowS
show :: Group -> [Char]
$cshow :: Group -> [Char]
showsPrec :: GroupSeq -> Group -> ShowS
$cshowsPrec :: GroupSeq -> Group -> ShowS
Show, Group -> Group -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Group -> Group -> Bool
$c/= :: Group -> Group -> Bool
== :: Group -> Group -> Bool
$c== :: Group -> Group -> Bool
Eq, Eq Group
Group -> Group -> Bool
Group -> Group -> Ordering
Group -> Group -> Group
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Group -> Group -> Group
$cmin :: Group -> Group -> Group
max :: Group -> Group -> Group
$cmax :: Group -> Group -> Group
>= :: Group -> Group -> Bool
$c>= :: Group -> Group -> Bool
> :: Group -> Group -> Bool
$c> :: Group -> Group -> Bool
<= :: Group -> Group -> Bool
$c<= :: Group -> Group -> Bool
< :: Group -> Group -> Bool
$c< :: Group -> Group -> Bool
compare :: Group -> Group -> Ordering
$ccompare :: Group -> Group -> Ordering
Ord)

-- | Constructor for "Group"
mkGroup :: ByteString -> Group
mkGroup :: ByteString -> Group
mkGroup = ByteString -> Group
Group

unGroup :: Group -> ByteString
unGroup :: Group -> ByteString
unGroup (Group ByteString
a) = ByteString
a

type Groups = Set Group

type Name = ByteString
type Headers = Map ByteString ByteString
type Content = [ByteString]

data ZREMsg = ZREMsg {
    ZREMsg -> Maybe UUID
msgFrom :: Maybe UUID
  , ZREMsg -> GroupSeq
msgSeq :: Seq
  , ZREMsg -> Maybe UTCTime
msgTime :: Maybe UTCTime
  , ZREMsg -> ZRECmd
msgCmd :: ZRECmd
  } deriving (GroupSeq -> ZREMsg -> ShowS
[ZREMsg] -> ShowS
ZREMsg -> [Char]
forall a.
(GroupSeq -> a -> ShowS)
-> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ZREMsg] -> ShowS
$cshowList :: [ZREMsg] -> ShowS
show :: ZREMsg -> [Char]
$cshow :: ZREMsg -> [Char]
showsPrec :: GroupSeq -> ZREMsg -> ShowS
$cshowsPrec :: GroupSeq -> ZREMsg -> ShowS
Show, ZREMsg -> ZREMsg -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ZREMsg -> ZREMsg -> Bool
$c/= :: ZREMsg -> ZREMsg -> Bool
== :: ZREMsg -> ZREMsg -> Bool
$c== :: ZREMsg -> ZREMsg -> Bool
Eq, Eq ZREMsg
ZREMsg -> ZREMsg -> Bool
ZREMsg -> ZREMsg -> Ordering
ZREMsg -> ZREMsg -> ZREMsg
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ZREMsg -> ZREMsg -> ZREMsg
$cmin :: ZREMsg -> ZREMsg -> ZREMsg
max :: ZREMsg -> ZREMsg -> ZREMsg
$cmax :: ZREMsg -> ZREMsg -> ZREMsg
>= :: ZREMsg -> ZREMsg -> Bool
$c>= :: ZREMsg -> ZREMsg -> Bool
> :: ZREMsg -> ZREMsg -> Bool
$c> :: ZREMsg -> ZREMsg -> Bool
<= :: ZREMsg -> ZREMsg -> Bool
$c<= :: ZREMsg -> ZREMsg -> Bool
< :: ZREMsg -> ZREMsg -> Bool
$c< :: ZREMsg -> ZREMsg -> Bool
compare :: ZREMsg -> ZREMsg -> Ordering
$ccompare :: ZREMsg -> ZREMsg -> Ordering
Ord)

data ZRECmd =
    Hello Endpoint Groups GroupSeq Name Headers
  | Whisper Content
  | Shout Group Content
  | Join Group GroupSeq
  | Leave Group GroupSeq
  | Ping
  | PingOk
  deriving (GroupSeq -> ZRECmd -> ShowS
[ZRECmd] -> ShowS
ZRECmd -> [Char]
forall a.
(GroupSeq -> a -> ShowS)
-> (a -> [Char]) -> ([a] -> ShowS) -> Show a
showList :: [ZRECmd] -> ShowS
$cshowList :: [ZRECmd] -> ShowS
show :: ZRECmd -> [Char]
$cshow :: ZRECmd -> [Char]
showsPrec :: GroupSeq -> ZRECmd -> ShowS
$cshowsPrec :: GroupSeq -> ZRECmd -> ShowS
Show, ZRECmd -> ZRECmd -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ZRECmd -> ZRECmd -> Bool
$c/= :: ZRECmd -> ZRECmd -> Bool
== :: ZRECmd -> ZRECmd -> Bool
$c== :: ZRECmd -> ZRECmd -> Bool
Eq, Eq ZRECmd
ZRECmd -> ZRECmd -> Bool
ZRECmd -> ZRECmd -> Ordering
ZRECmd -> ZRECmd -> ZRECmd
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ZRECmd -> ZRECmd -> ZRECmd
$cmin :: ZRECmd -> ZRECmd -> ZRECmd
max :: ZRECmd -> ZRECmd -> ZRECmd
$cmax :: ZRECmd -> ZRECmd -> ZRECmd
>= :: ZRECmd -> ZRECmd -> Bool
$c>= :: ZRECmd -> ZRECmd -> Bool
> :: ZRECmd -> ZRECmd -> Bool
$c> :: ZRECmd -> ZRECmd -> Bool
<= :: ZRECmd -> ZRECmd -> Bool
$c<= :: ZRECmd -> ZRECmd -> Bool
< :: ZRECmd -> ZRECmd -> Bool
$c< :: ZRECmd -> ZRECmd -> Bool
compare :: ZRECmd -> ZRECmd -> Ordering
$ccompare :: ZRECmd -> ZRECmd -> Ordering
Ord)

zreBeacon :: ByteString -> Port -> ByteString
zreBeacon :: ByteString -> GroupSeq -> ByteString
zreBeacon ByteString
uuid GroupSeq
port = Put -> ByteString
runPut forall a b. (a -> b) -> a -> b
$ do
  ByteString -> Put
putByteString ByteString
"ZRE"
  -- XXX: for compatibility with zyre implementation
  -- this should use 0x01 instead, but why when
  -- we can stick zre version there and use it for filtering?
  -- for now leave in compat mode as we don't
  -- assert this but zyre does
  Int8 -> Put
putInt8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral (GroupSeq
0x01 :: Int) -- compat
  --putInt8 $ fromIntegral zreVer -- non-compat
  ByteString -> Put
putByteString ByteString
uuid
  Int16 -> Put
putInt16be forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral GroupSeq
port

parseUUID :: Get UUID
parseUUID :: Get UUID
parseUUID =  do
  Maybe UUID
muuid <- ByteString -> Maybe UUID
fromByteString forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
BL.fromStrict forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> GroupSeq -> Get ByteString
getByteString GroupSeq
16
  case Maybe UUID
muuid of
    Just UUID
uuid -> forall (m :: * -> *) a. Monad m => a -> m a
return UUID
uuid
    Maybe UUID
Nothing -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Unable to parse UUID"

parseBeacon :: ByteString
            -> (Either String (ByteString, Integer, UUID, Integer))
parseBeacon :: ByteString -> Either [Char] (ByteString, Integer, UUID, Integer)
parseBeacon = forall a. Get a -> ByteString -> Either [Char] a
runGet forall a b. (a -> b) -> a -> b
$ do
  ByteString
lead <- GroupSeq -> Get ByteString
getByteString GroupSeq
3
  Integer
ver <- forall a. Integral a => Get a
getInt8
  UUID
uuid <- Get UUID
parseUUID
  Integer
port <- forall a. Integral a => Get a
getInt16
  forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
lead, Integer
ver, UUID
uuid, Integer
port)

cmdCode :: ZRECmd -> Word8
cmdCode :: ZRECmd -> Word8
cmdCode (Hello Endpoint
_ Groups
_ GroupSeq
_ ByteString
_ Headers
_) = Word8
0x01
cmdCode (Whisper Content
_)       = Word8
0x02
cmdCode (Shout Group
_ Content
_)       = Word8
0x03
cmdCode (Join Group
_ GroupSeq
_)        = Word8
0x04
cmdCode (Leave Group
_ GroupSeq
_)       = Word8
0x05
cmdCode ZRECmd
Ping              = Word8
0x06
cmdCode ZRECmd
PingOk            = Word8
0x07

getContent :: ZRECmd -> Content
getContent :: ZRECmd -> Content
getContent (Whisper Content
c) = Content
c
getContent (Shout Group
_ Content
c) = Content
c
getContent ZRECmd
_ = []

newZRE :: Seq -> ZRECmd -> ZREMsg
newZRE :: GroupSeq -> ZRECmd -> ZREMsg
newZRE GroupSeq
seqNum ZRECmd
cmd = Maybe UUID -> GroupSeq -> Maybe UTCTime -> ZRECmd -> ZREMsg
ZREMsg forall a. Maybe a
Nothing GroupSeq
seqNum forall a. Maybe a
Nothing ZRECmd
cmd

encodeZRE :: ZREMsg -> [ByteString]
encodeZRE :: ZREMsg -> Content
encodeZRE ZREMsg{GroupSeq
Maybe UTCTime
Maybe UUID
ZRECmd
msgCmd :: ZRECmd
msgTime :: Maybe UTCTime
msgSeq :: GroupSeq
msgFrom :: Maybe UUID
msgCmd :: ZREMsg -> ZRECmd
msgTime :: ZREMsg -> Maybe UTCTime
msgSeq :: ZREMsg -> GroupSeq
msgFrom :: ZREMsg -> Maybe UUID
..} = ByteString
msgforall a. a -> [a] -> [a]
:(ZRECmd -> Content
getContent ZRECmd
msgCmd)
  where
    msg :: ByteString
msg = Put -> ByteString
runPut forall a b. (a -> b) -> a -> b
$ do
      Word16 -> Put
putWord16be Word16
zreSig
      Word8 -> Put
putWord8 forall a b. (a -> b) -> a -> b
$ ZRECmd -> Word8
cmdCode ZRECmd
msgCmd
      Int8 -> Put
putInt8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral GroupSeq
zreVer
      Int16 -> Put
putInt16be forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral GroupSeq
msgSeq
      ZRECmd -> Put
encodeCmd ZRECmd
msgCmd

encodeCmd :: ZRECmd -> PutM ()
encodeCmd :: ZRECmd -> Put
encodeCmd (Hello Endpoint
endpoint Groups
groups GroupSeq
statusSeq ByteString
name Headers
headers) = do
  ByteString -> Put
putByteStringLen (Endpoint -> ByteString
pEndpoint Endpoint
endpoint)
  forall (t :: * -> *). Foldable t => t ByteString -> Put
putByteStrings forall a b. (a -> b) -> a -> b
$ (forall b a. Ord b => (a -> b) -> Set a -> Set b
Data.Set.map (\(Group ByteString
g) -> ByteString
g)) Groups
groups
  Int8 -> Put
putInt8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral GroupSeq
statusSeq
  ByteString -> Put
putByteStringLen ByteString
name
  Headers -> Put
putMap Headers
headers
encodeCmd (Shout Group
group Content
_content) = Group -> Put
putGroup Group
group
encodeCmd (Join Group
group GroupSeq
statusSeq) = do
  Group -> Put
putGroup Group
group
  Int8 -> Put
putInt8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral GroupSeq
statusSeq
encodeCmd (Leave Group
group GroupSeq
statusSeq) = do
  Group -> Put
putGroup Group
group
  Int8 -> Put
putInt8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral GroupSeq
statusSeq
encodeCmd ZRECmd
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()

putGroup :: Group -> PutM ()
putGroup :: Group -> Put
putGroup (Group ByteString
g) = ByteString -> Put
putByteStringLen ByteString
g

parseHello :: Get ZRECmd
parseHello :: Get ZRECmd
parseHello = Endpoint -> Groups -> GroupSeq -> ByteString -> Headers -> ZRECmd
Hello
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Endpoint
parseEndpoint'
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a. Ord a => [a] -> Set a
Data.Set.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Group
Group forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Content
parseStrings)
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Integral a => Get a
getInt8
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get ByteString
parseString
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Get Headers
parseMap
  where
    parseEndpoint' :: Get Endpoint
parseEndpoint' = do
      ByteString
s <- Get ByteString
parseString
      case ByteString -> Either [Char] Endpoint
parseAttoEndpoint ByteString
s of
        (Left [Char]
err) -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"Unable to parse endpoint: " forall a. [a] -> [a] -> [a]
++ [Char]
err
        (Right Endpoint
endpoint) -> forall (m :: * -> *) a. Monad m => a -> m a
return Endpoint
endpoint

parseGroup :: Get Group
parseGroup :: Get Group
parseGroup = ByteString -> Group
Group forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
parseString

parseShout :: Content -> Get ZRECmd
parseShout :: Content -> Get ZRECmd
parseShout Content
frames = Group -> Content -> ZRECmd
Shout forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Group
parseGroup forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Content
frames

parseJoin :: Get ZRECmd
parseJoin :: Get ZRECmd
parseJoin = Group -> GroupSeq -> ZRECmd
Join forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Group
parseGroup forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Integral a => Get a
getInt8

parseLeave :: Get ZRECmd
parseLeave :: Get ZRECmd
parseLeave = Group -> GroupSeq -> ZRECmd
Leave forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Group
parseGroup forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Integral a => Get a
getInt8

parseCmd :: ByteString -> Content -> Get ZREMsg
parseCmd :: ByteString -> Content -> Get ZREMsg
parseCmd ByteString
from Content
frames = do
    GroupSeq
cmd <- (forall a. Integral a => Get a
getInt8 :: Get Int)
    GroupSeq
ver <- forall a. Integral a => Get a
getInt8
    GroupSeq
sqn <- forall a. Integral a => Get a
getInt16

    case forall a. Get a -> ByteString -> Either [Char] a
runGet Get UUID
parseUUID ByteString
from of
      (Left [Char]
err) -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail forall a b. (a -> b) -> a -> b
$ [Char]
"No UUID: " forall a. [a] -> [a] -> [a]
++ [Char]
err
      (Right UUID
uuid)-> do
        if GroupSeq
ver forall a. Eq a => a -> a -> Bool
/= GroupSeq
zreVer
          then forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Protocol version mismatch"
          else do

            ZRECmd
zcmd <- case GroupSeq
cmd of
              GroupSeq
0x01 -> Get ZRECmd
parseHello
              GroupSeq
0x02 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Content -> ZRECmd
Whisper Content
frames -- parseWhisper
              GroupSeq
0x03 -> Content -> Get ZRECmd
parseShout Content
frames
              GroupSeq
0x04 -> Get ZRECmd
parseJoin
              GroupSeq
0x05 -> Get ZRECmd
parseLeave
              GroupSeq
0x06 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ZRECmd
Ping
              GroupSeq
0x07 -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ZRECmd
PingOk
              GroupSeq
_    -> forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Unknown command"

            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Maybe UUID -> GroupSeq -> Maybe UTCTime -> ZRECmd -> ZREMsg
ZREMsg (forall a. a -> Maybe a
Just UUID
uuid) GroupSeq
sqn forall a. Maybe a
Nothing ZRECmd
zcmd

parseZRE :: [ByteString] -> Either String ZREMsg
parseZRE :: Content -> Either [Char] ZREMsg
parseZRE (ByteString
from:ByteString
msg:Content
rest) = ByteString -> ByteString -> Content -> Either [Char] ZREMsg
parseZre ByteString
from ByteString
msg Content
rest
parseZRE Content
_ = forall a b. a -> Either a b
Left [Char]
"empty message"

parseZre :: ByteString -> ByteString -> Content -> Either String ZREMsg
parseZre :: ByteString -> ByteString -> Content -> Either [Char] ZREMsg
parseZre ByteString
from ByteString
msg Content
frames = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. Get a -> ByteString -> Either [Char] a
runGet ByteString
msg forall a b. (a -> b) -> a -> b
$ do
  Word16
sig <- forall a. Integral a => Get a
getInt16
  if Word16
sig forall a. Eq a => a -> a -> Bool
/= Word16
zreSig
    then forall (m :: * -> *) a. MonadFail m => [Char] -> m a
fail [Char]
"Signature mismatch"
    else do
      -- we need to drop 1st byte of from string which is '1':UUID (17 bytes)
      ZREMsg
res <- ByteString -> Content -> Get ZREMsg
parseCmd (HasCallStack => ByteString -> ByteString
B.tail ByteString
from) Content
frames
      forall (m :: * -> *) a. Monad m => a -> m a
return ZREMsg
res