{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Data.ProtoLens.Encoding.Wire
( Tag(..)
, TaggedValue(..)
, WireValue(..)
, FieldSet
, splitTypeAndTag
, joinTypeAndTag
, parseFieldSet
, buildFieldSet
, buildMessageSet
, parseTaggedValueFromWire
, parseMessageSetTaggedValueFromWire
) where
import Control.DeepSeq (NFData(..))
import Data.Bits ((.&.), (.|.), shiftL, shiftR)
import qualified Data.ByteString as B
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup ((<>))
#endif
import Data.Word (Word8, Word32, Word64)
import Data.ProtoLens.Encoding.Bytes
newtype Tag = Tag { unTag :: Int }
deriving (Show, Eq, Ord, Num, NFData)
data WireValue
= VarInt !Word64
| Fixed64 !Word64
| Lengthy !B.ByteString
| StartGroup
| EndGroup
| Fixed32 !Word32
deriving (Eq, Ord)
data TaggedValue = TaggedValue !Tag !WireValue
deriving (Eq, Ord)
type FieldSet = [TaggedValue]
instance NFData TaggedValue where
rnf = (`seq` ())
instance NFData WireValue where
rnf = (`seq` ())
buildTaggedValue :: TaggedValue -> Builder
buildTaggedValue (TaggedValue tag wv) =
putVarInt (joinTypeAndTag tag (wireValueToInt wv))
<> buildWireValue wv
buildTaggedValueAsMessageSet :: TaggedValue -> Builder
buildTaggedValueAsMessageSet (TaggedValue (Tag t) wv) =
buildTaggedValue ( TaggedValue 1 StartGroup)
<> buildTaggedValue (TaggedValue 2 (VarInt $ fromIntegral t))
<> buildTaggedValue (TaggedValue 3 wv)
<> buildTaggedValue (TaggedValue 1 EndGroup)
buildWireValue :: WireValue -> Builder
buildWireValue (VarInt w) = putVarInt w
buildWireValue (Fixed64 w) = putFixed64 w
buildWireValue (Fixed32 w) = putFixed32 w
buildWireValue (Lengthy b) =
putVarInt (fromIntegral $ B.length b)
<> putBytes b
buildWireValue StartGroup = mempty
buildWireValue EndGroup = mempty
wireValueToInt :: WireValue -> Word8
wireValueToInt VarInt{} = 0
wireValueToInt Fixed64{} = 1
wireValueToInt Lengthy{} = 2
wireValueToInt StartGroup{} = 3
wireValueToInt EndGroup{} = 4
wireValueToInt Fixed32{} = 5
parseTaggedValue :: Parser TaggedValue
parseTaggedValue = getVarInt >>= parseTaggedValueFromWire
parseTaggedValueFromWire :: Word64 -> Parser TaggedValue
parseTaggedValueFromWire t =
let (tag, w) = splitTypeAndTag t
in TaggedValue tag <$> case w of
0 -> VarInt <$> getVarInt
1 -> Fixed64 <$> getFixed64
2 -> Lengthy <$> do
len <- getVarInt
getBytes $ fromIntegral len
3 -> return StartGroup
4 -> return EndGroup
5 -> Fixed32 <$> getFixed32
_ -> fail $ "Unknown wire type " ++ show w
parseMessageSetTaggedValueFromWire :: Word64 -> Parser TaggedValue
parseMessageSetTaggedValueFromWire t =
parseTaggedValueFromWire t >>= \v -> case v of
TaggedValue 1 StartGroup -> parseTaggedValue >>= \ft -> case ft of
TaggedValue 2 (VarInt f) -> parseTaggedValue >>= \dt -> case dt of
TaggedValue 3 (Lengthy b) -> parseTaggedValue >>= \et -> case et of
TaggedValue 1 EndGroup -> return $ TaggedValue (Tag $ fromIntegral f) (Lengthy b)
_ -> fail "missing end_group"
_ -> fail "missing message"
_ -> fail "missing field tag"
_ -> return v
splitTypeAndTag :: Word64 -> (Tag, Word8)
splitTypeAndTag w = (fromIntegral $ w `shiftR` 3, fromIntegral (w .&. 7))
joinTypeAndTag :: Tag -> Word8 -> Word64
joinTypeAndTag (Tag t) w = fromIntegral t `shiftL` 3 .|. fromIntegral w
parseFieldSet :: Parser FieldSet
parseFieldSet = loop []
where
loop ws = do
end <- atEnd
if end
then return $! reverse ws
else do
!w <- parseTaggedValue
loop (w:ws)
buildFieldSet :: FieldSet -> Builder
buildFieldSet = mconcat . map buildTaggedValue
buildMessageSet :: FieldSet -> Builder
buildMessageSet = mconcat . map buildTaggedValueAsMessageSet