module Pinch.Protocol.Binary (binaryProtocol) where
#if __GLASGOW_HASKELL__ < 709
import Control.Applicative
#endif
import Control.Monad
import Data.Bits (shiftR, (.&.))
import Data.ByteString (ByteString)
import Data.HashMap.Strict (HashMap)
import Data.HashSet (HashSet)
import Data.Int (Int16, Int8)
import Data.Vector (Vector)
import qualified Data.ByteString as B
import qualified Data.HashMap.Strict as M
import qualified Data.HashSet as S
import qualified Data.Text.Encoding as TE
import qualified Data.Vector as V
import Pinch.Internal.Builder (Build)
import Pinch.Internal.Message
import Pinch.Internal.Parser (Parser, runParser)
import Pinch.Internal.TType
import Pinch.Internal.Value
import Pinch.Protocol (Protocol (..))
import qualified Pinch.Internal.Builder as BB
import qualified Pinch.Internal.Parser as P
binaryProtocol :: Protocol
binaryProtocol = Protocol
{ serializeValue = BB.run . binarySerialize
, deserializeValue = binaryDeserialize ttype
, serializeMessage = BB.run . binarySerializeMessage
, deserializeMessage = binaryDeserializeMessage
}
binarySerializeMessage :: Message -> Build
binarySerializeMessage msg = do
binarySerialize . VBinary . TE.encodeUtf8 $ messageName msg
BB.int8 $ messageCode (messageType msg)
BB.int32 $ messageId msg
binarySerialize (messagePayload msg)
binaryDeserializeMessage :: ByteString -> Either String Message
binaryDeserializeMessage = runParser binaryMessageParser
binaryMessageParser :: Parser Message
binaryMessageParser = do
size <- P.int32
if size < 0
then parseStrict size
else parseNonStrict size
where
parseStrict versionAndType = do
unless (version == 1) $
fail $ "Unsupported version: " ++ show version
Message
<$> TE.decodeUtf8 <$> (P.int32 >>= P.take . fromIntegral)
<*> typ
<*> P.int32
<*> binaryParser ttype
where
version = (0x7fff0000 .&. versionAndType) `shiftR` 16
code = fromIntegral $ 0x00ff .&. versionAndType
typ = case fromMessageCode code of
Nothing -> fail $ "Unknown message type: " ++ show code
Just t -> return t
parseNonStrict nameLength =
Message
<$> TE.decodeUtf8 <$> P.take (fromIntegral nameLength)
<*> parseMessageType
<*> P.int32
<*> binaryParser ttype
parseMessageType :: Parser MessageType
parseMessageType = P.int8 >>= \code -> case fromMessageCode code of
Nothing -> fail $ "Unknown message type: " ++ show code
Just t -> return t
binaryDeserialize :: TType a -> ByteString -> Either String (Value a)
binaryDeserialize t = runParser (binaryParser t)
binaryParser :: TType a -> Parser (Value a)
binaryParser typ = case typ of
TBool -> parseBool
TByte -> parseByte
TDouble -> parseDouble
TInt16 -> parseInt16
TInt32 -> parseInt32
TInt64 -> parseInt64
TBinary -> parseBinary
TStruct -> parseStruct
TMap -> parseMap
TSet -> parseSet
TList -> parseList
getTType :: Int8 -> Parser SomeTType
getTType code =
maybe (fail $ "Unknown TType: " ++ show code) return $ fromTypeCode code
parseTType :: Parser SomeTType
parseTType = P.int8 >>= getTType
parseBool :: Parser (Value TBool)
parseBool = VBool . (== 1) <$> P.int8
parseByte :: Parser (Value TByte)
parseByte = VByte <$> P.int8
parseDouble :: Parser (Value TDouble)
parseDouble = VDouble <$> P.double
parseInt16 :: Parser (Value TInt16)
parseInt16 = VInt16 <$> P.int16
parseInt32 :: Parser (Value TInt32)
parseInt32 = VInt32 <$> P.int32
parseInt64 :: Parser (Value TInt64)
parseInt64 = VInt64 <$> P.int64
parseBinary :: Parser (Value TBinary)
parseBinary = VBinary <$> (P.int32 >>= P.take . fromIntegral)
parseMap :: Parser (Value TMap)
parseMap = do
ktype' <- parseTType
vtype' <- parseTType
count <- P.int32
case (ktype', vtype') of
(SomeTType ktype, SomeTType vtype) -> do
pairs <- replicateM (fromIntegral count) $
(,) <$> binaryParser ktype
<*> binaryParser vtype
return $ VMap (M.fromList pairs)
parseSet :: Parser (Value TSet)
parseSet = do
vtype' <- parseTType
count <- P.int32
case vtype' of
SomeTType vtype -> do
items <- replicateM (fromIntegral count) (binaryParser vtype)
return $ VSet (S.fromList items)
parseList :: Parser (Value TList)
parseList = do
vtype' <- parseTType
count <- P.int32
case vtype' of
SomeTType vtype ->
VList <$> V.replicateM (fromIntegral count) (binaryParser vtype)
parseStruct :: Parser (Value TStruct)
parseStruct = P.int8 >>= loop M.empty
where
loop :: HashMap Int16 SomeValue -> Int8 -> Parser (Value TStruct)
loop fields 0 = return $ VStruct fields
loop fields code = do
vtype' <- getTType code
fieldId <- P.int16
case vtype' of
SomeTType vtype -> do
value <- SomeValue <$> binaryParser vtype
loop (M.insert fieldId value fields) =<< P.int8
binarySerialize :: Value a -> Build
binarySerialize v0 = case v0 of
VBinary x -> do
BB.int32 . fromIntegral . B.length $ x
BB.byteString x
VBool x -> BB.int8 $ if x then 1 else 0
VByte x -> BB.int8 x
VDouble x -> BB.double x
VInt16 x -> BB.int16 x
VInt32 x -> BB.int32 x
VInt64 x -> BB.int64 x
VStruct xs -> serializeStruct xs
VList xs -> serializeList ttype xs
VMap xs -> serializeMap ttype ttype xs
VSet xs -> serializeSet ttype xs
serializeStruct :: HashMap Int16 SomeValue -> Build
serializeStruct fields = do
forM_ (M.toList fields) $ \(fieldId, SomeValue fieldValue) ->
writeField fieldId ttype fieldValue
BB.int8 0
where
writeField :: Int16 -> TType a -> Value a -> Build
writeField fieldId fieldType fieldValue = do
BB.int8 (toTypeCode fieldType)
BB.int16 fieldId
binarySerialize fieldValue
serializeList :: TType a -> Vector (Value a) -> Build
serializeList vtype xs = do
BB.int8 $ toTypeCode vtype
BB.int32 $ fromIntegral (V.length xs)
mapM_ binarySerialize (V.toList xs)
serializeMap :: TType k -> TType v -> HashMap (Value k) (Value v) -> Build
serializeMap kt vt xs = do
BB.int8 $ toTypeCode kt
BB.int8 $ toTypeCode vt
BB.int32 $ fromIntegral (M.size xs)
forM_ (M.toList xs) $ \(k, v) -> do
binarySerialize k
binarySerialize v
serializeSet :: TType a -> HashSet (Value a) -> Build
serializeSet vtype xs = do
BB.int8 $ toTypeCode vtype
BB.int32 $ fromIntegral (S.size xs)
mapM_ binarySerialize (S.toList xs)
messageCode :: MessageType -> Int8
messageCode Call = 1
messageCode Reply = 2
messageCode Exception = 3
messageCode Oneway = 4
fromMessageCode :: Int8 -> Maybe MessageType
fromMessageCode 1 = Just Call
fromMessageCode 2 = Just Reply
fromMessageCode 3 = Just Exception
fromMessageCode 4 = Just Oneway
fromMessageCode _ = Nothing
toTypeCode :: TType a -> Int8
toTypeCode TBool = 2
toTypeCode TByte = 3
toTypeCode TDouble = 4
toTypeCode TInt16 = 6
toTypeCode TInt32 = 8
toTypeCode TInt64 = 10
toTypeCode TBinary = 11
toTypeCode TStruct = 12
toTypeCode TMap = 13
toTypeCode TSet = 14
toTypeCode TList = 15
fromTypeCode :: Int8 -> Maybe SomeTType
fromTypeCode 2 = Just $ SomeTType TBool
fromTypeCode 3 = Just $ SomeTType TByte
fromTypeCode 4 = Just $ SomeTType TDouble
fromTypeCode 6 = Just $ SomeTType TInt16
fromTypeCode 8 = Just $ SomeTType TInt32
fromTypeCode 10 = Just $ SomeTType TInt64
fromTypeCode 11 = Just $ SomeTType TBinary
fromTypeCode 12 = Just $ SomeTType TStruct
fromTypeCode 13 = Just $ SomeTType TMap
fromTypeCode 14 = Just $ SomeTType TSet
fromTypeCode 15 = Just $ SomeTType TList
fromTypeCode _ = Nothing