{-# Language BangPatterns #-}
module Text.ProtocolBuffers.WireMessage
(
messageSize,messagePut,messageGet,messagePutM,messageGetM
, messageWithLengthSize,messageWithLengthPut,messageWithLengthGet,messageWithLengthPutM,messageWithLengthGetM
, messageAsFieldSize,messageAsFieldPutM,messageAsFieldGetM
, Put,PutM,Get,runPut,runPutM,runGet,runGetOnLazy,getFromBS
, Wire(..)
, size'WireTag,size'WireSize,toWireType,toWireTag,toPackedWireTag,mkWireTag
, prependMessageSize,putSize,putVarUInt,getVarInt,putLazyByteString,splitWireTag,fieldIdOf
, wireSizeReq,wireSizeOpt,wireSizeRep,wireSizePacked
, wirePutReq,wirePutOpt,wirePutRep,wirePutPacked
, wirePutReqWithSize,wirePutOptWithSize,wirePutRepWithSize,wirePutPackedWithSize
, sequencePutWithSize
, wireSizeErr,wirePutErr,wireGetErr
, getMessageWith,getBareMessageWith,wireGetEnum,wireGetPackedEnum
, unknownField,unknown,wireGetFromWire
, castWord64ToDouble,castWord32ToFloat,castDoubleToWord64,castFloatToWord32
, zzEncode64,zzEncode32,zzDecode64,zzDecode32
) where
import Control.Monad(when,foldM)
import Control.Monad.Error.Class(throwError)
import Control.Monad.ST
import Data.Array.ST(newArray,readArray)
import Data.Array.Unsafe(castSTUArray)
import Data.Bits (Bits(..))
import qualified Data.ByteString.Lazy as BS (length)
import qualified Data.Foldable as F(foldl', Foldable)
import Data.Maybe(fromMaybe)
import Data.Sequence ((|>))
import qualified Data.Sequence as Seq(length,empty)
import qualified Data.Set as Set(delete,null)
import Data.Typeable (Typeable,typeOf)
import Data.Binary.Put (Put,PutM,runPutM,runPut,putWord8,putWord32le,putWord64le,putLazyByteString)
import Text.ProtocolBuffers.Basic
import Text.ProtocolBuffers.Get as Get (Result(..),Get,runGet,runGetAll,bytesRead,isReallyEmpty,decode7unrolled
,spanOf,skip,lookAhead,highBitRun
,getWord32le,getWord64le,getLazyByteString)
import Text.ProtocolBuffers.Reflections(ReflectDescriptor(reflectDescriptorInfo,getMessageInfo)
,DescriptorInfo(..),GetMessageInfo(..))
trace :: a -> b -> b
trace _ = id
messageSize :: (ReflectDescriptor msg,Wire msg) => msg -> WireSize
messageSize msg = wireSize 10 msg
messageWithLengthSize :: (ReflectDescriptor msg,Wire msg) => msg -> WireSize
messageWithLengthSize msg = wireSize 11 msg
messageAsFieldSize :: (ReflectDescriptor msg,Wire msg) => FieldId -> msg -> WireSize
messageAsFieldSize fi msg = let headerSize = size'WireTag (toWireTag fi 11)
in headerSize + messageWithLengthSize msg
messagePut :: (ReflectDescriptor msg, Wire msg) => msg -> ByteString
messagePut msg = runPut (messagePutM msg)
messageWithLengthPut :: (ReflectDescriptor msg, Wire msg) => msg -> ByteString
messageWithLengthPut msg = runPut (messageWithLengthPutM msg)
messagePutM :: (ReflectDescriptor msg, Wire msg) => msg -> Put
messagePutM msg = wirePut 10 msg
messageWithLengthPutM :: (ReflectDescriptor msg, Wire msg) => msg -> Put
messageWithLengthPutM msg = wirePut 11 msg
messageAsFieldPutM :: (ReflectDescriptor msg, Wire msg) => FieldId -> msg -> Put
messageAsFieldPutM fi msg = let wireTag = toWireTag fi 11
in wirePutReq wireTag 11 msg
messageGet :: (ReflectDescriptor msg, Wire msg) => ByteString -> Either String (msg,ByteString)
messageGet bs = runGetOnLazy messageGetM bs
messageWithLengthGet :: (ReflectDescriptor msg, Wire msg) => ByteString -> Either String (msg,ByteString)
messageWithLengthGet bs = runGetOnLazy messageWithLengthGetM bs
messageGetM :: (ReflectDescriptor msg, Wire msg) => Get msg
messageGetM = wireGet 10
messageWithLengthGetM :: (ReflectDescriptor msg, Wire msg) => Get msg
messageWithLengthGetM = wireGet 11
messageAsFieldGetM :: (ReflectDescriptor msg, Wire msg) => Get (FieldId,msg)
messageAsFieldGetM = do
wireTag <- fmap WireTag getVarInt
let (fieldId,wireType) = splitWireTag wireTag
when (wireType /= 2) (throwError $ "messageAsFieldGetM: wireType was not 2 "++show (fieldId,wireType))
msg <- wireGet 11
return (fieldId,msg)
getFromBS :: Get r -> ByteString -> r
getFromBS parser bs = case runGetOnLazy parser bs of
Left msg -> error msg
Right (r,_) -> r
runGetOnLazy :: Get r -> ByteString -> Either String (r,ByteString)
runGetOnLazy parser bs = resolve (runGetAll parser bs)
where resolve :: Result r -> Either String (r,ByteString)
resolve (Failed i s) = Left ("Failed at "++show i++" : "++s)
resolve (Finished bsOut _i r) = Right (r,bsOut)
resolve (Partial op) = resolve (op Nothing)
prependMessageSize :: WireSize -> WireSize
prependMessageSize n = n + size'WireSize n
{-# INLINE sequencePutWithSize #-}
sequencePutWithSize :: F.Foldable f => f (PutM WireSize) -> PutM WireSize
sequencePutWithSize =
let combine size act =
do size2 <- act
return $! size + size2
in foldM combine 0
{-# INLINE wirePutReqWithSize #-}
wirePutReqWithSize :: Wire v => WireTag -> FieldType -> v -> PutM WireSize
wirePutReqWithSize wireTag fieldType v =
let startTag = getWireTag wireTag
endTag = succ startTag
putTag tag = putVarUInt tag >> return (size'Word32 tag)
putAct = wirePutWithSize fieldType v
in case fieldType of
10 -> sequencePutWithSize [putTag startTag, putAct, putTag endTag]
_ -> sequencePutWithSize [putTag startTag, putAct]
{-# INLINE wirePutOptWithSize #-}
wirePutOptWithSize :: Wire v => WireTag -> FieldType -> Maybe v -> PutM WireSize
wirePutOptWithSize _wireTag _fieldType Nothing = return 0
wirePutOptWithSize wireTag fieldType (Just v) = wirePutReqWithSize wireTag fieldType v
{-# INLINE wirePutRepWithSize #-}
wirePutRepWithSize :: Wire v => WireTag -> FieldType -> Seq v -> PutM WireSize
wirePutRepWithSize wireTag fieldType vs =
sequencePutWithSize $ fmap (wirePutReqWithSize wireTag fieldType) vs
{-# INLINE wirePutPackedWithSize #-}
wirePutPackedWithSize :: Wire v => WireTag -> FieldType -> Seq v -> PutM WireSize
wirePutPackedWithSize wireTag fieldType vs =
let actInner = sequencePutWithSize $ fmap (wirePutWithSize fieldType) vs
(size, _) = runPutM actInner
tagSize = size'WireTag wireTag
putTag tag = putVarUInt (getWireTag tag) >> return tagSize
in sequencePutWithSize [putTag wireTag, putSize size>>return (size'WireSize size), actInner]
{-# INLINE wirePutReq #-}
wirePutReq :: Wire v => WireTag -> FieldType -> v -> Put
wirePutReq wireTag fieldType v = wirePutReqWithSize wireTag fieldType v >> return ()
{-# INLINE wirePutOpt #-}
wirePutOpt :: Wire v => WireTag -> FieldType -> Maybe v -> Put
wirePutOpt wireTag fieldType v = wirePutOptWithSize wireTag fieldType v >> return ()
{-# INLINE wirePutRep #-}
wirePutRep :: Wire v => WireTag -> FieldType -> Seq v -> Put
wirePutRep wireTag fieldType vs = wirePutRepWithSize wireTag fieldType vs >> return ()
{-# INLINE wirePutPacked #-}
wirePutPacked :: Wire v => WireTag -> FieldType -> Seq v -> Put
wirePutPacked wireTag fieldType vs = wirePutPackedWithSize wireTag fieldType vs >> return ()
{-# INLINE wireSizeReq #-}
wireSizeReq :: Wire v => Int64 -> FieldType -> v -> Int64
wireSizeReq tagSize 10 v = tagSize + wireSize 10 v + tagSize
wireSizeReq tagSize fieldType v = tagSize + wireSize fieldType v
{-# INLINE wireSizeOpt #-}
wireSizeOpt :: Wire v => Int64 -> FieldType -> Maybe v -> Int64
wireSizeOpt _tagSize _i Nothing = 0
wireSizeOpt tagSize i (Just v) = wireSizeReq tagSize i v
{-# INLINE wireSizeRep #-}
wireSizeRep :: Wire v => Int64 -> FieldType -> Seq v -> Int64
wireSizeRep tagSize i vs = F.foldl' (\n v -> n + wireSizeReq tagSize i v) 0 vs
{-# INLINE wireSizePacked #-}
wireSizePacked :: Wire v => Int64 -> FieldType -> Seq v -> Int64
wireSizePacked tagSize i vs = tagSize + prependMessageSize (F.foldl' (\n v -> n + wireSize i v) 0 vs)
{-# INLINE putSize #-}
putSize :: WireSize -> Put
putSize = putVarUInt
toPackedWireTag :: FieldId -> WireTag
toPackedWireTag fieldId = mkWireTag fieldId 2
toWireTag :: FieldId -> FieldType -> WireTag
toWireTag fieldId fieldType
= mkWireTag fieldId (toWireType fieldType)
mkWireTag :: FieldId -> WireType -> WireTag
mkWireTag fieldId wireType
= ((fromIntegral . getFieldId $ fieldId) `shiftL` 3) .|. (fromIntegral . getWireType $ wireType)
splitWireTag :: WireTag -> (FieldId,WireType)
splitWireTag (WireTag wireTag) = ( FieldId . fromIntegral $ wireTag `shiftR` 3
, WireType . fromIntegral $ wireTag .&. 7 )
fieldIdOf :: WireTag -> FieldId
fieldIdOf = fst . splitWireTag
{-# INLINE wireGetPackedEnum #-}
wireGetPackedEnum :: (Typeable e,Enum e) => (Int -> Maybe e) -> Get (Seq e)
wireGetPackedEnum toMaybe'Enum = do
packedLength <- getVarInt
start <- bytesRead
let stop = packedLength+start
next !soFar = do
here <- bytesRead
case compare stop here of
EQ -> return soFar
LT -> tooMuchData packedLength soFar start here
GT -> do
value <- wireGetEnum toMaybe'Enum
seq value $ next (soFar |> value)
next Seq.empty
where
Just e = undefined `asTypeOf` (toMaybe'Enum undefined)
tooMuchData packedLength soFar start here =
throwError ("Text.ProtocolBuffers.WireMessage.wireGetPackedEnum: overran expected length."
++ "\n The type and count of values so far is " ++ show (typeOf (undefined `asTypeOf` e),Seq.length soFar)
++ "\n at (packedLength,start,here) == " ++ show (packedLength,start,here))
{-# INLINE genericPacked #-}
genericPacked :: Wire a => FieldType -> Get (Seq a)
genericPacked ft = do
packedLength <- getVarInt
start <- bytesRead
let stop = packedLength+start
next !soFar = do
here <- bytesRead
case compare stop here of
EQ -> return soFar
LT -> tooMuchData packedLength soFar start here
GT -> do
value <- wireGet ft
seq value $! next $! soFar |> value
next Seq.empty
where
tooMuchData packedLength soFar start here =
throwError ("Text.ProtocolBuffers.WireMessage.genericPacked: overran expected length."
++ "\n The FieldType and count of values so far are " ++ show (ft,Seq.length soFar)
++ "\n at (packedLength,start,here) == " ++ show (packedLength,start,here))
getMessageWith :: (Default message, ReflectDescriptor message)
=> (WireTag -> message -> Get message)
-> Get message
getMessageWith updater = do
messageLength <- getVarInt
start <- bytesRead
let stop = messageLength+start
go reqs !message | Set.null reqs = go' message
| otherwise = do
here <- bytesRead
case compare stop here of
EQ -> notEnoughData messageLength start
LT -> tooMuchData messageLength start here
GT -> do
wireTag <- fmap WireTag getVarInt
let
reqs' = Set.delete wireTag reqs
updater wireTag message >>= go reqs'
go' !message = do
here <- bytesRead
case compare stop here of
EQ -> return message
LT -> tooMuchData messageLength start here
GT -> do
wireTag <- fmap WireTag getVarInt
updater wireTag message >>= go'
go required initialMessage
where
initialMessage = defaultValue
(GetMessageInfo {requiredTags=required}) = getMessageInfo initialMessage
notEnoughData messageLength start =
throwError ("Text.ProtocolBuffers.WireMessage.getMessageWith: Required fields missing when processing "
++ (show . descName . reflectDescriptorInfo $ initialMessage)
++ "\n at (messageLength,start) == " ++ show (messageLength,start))
tooMuchData messageLength start here =
throwError ("Text.ProtocolBuffers.WireMessage.getMessageWith: overran expected length when processing"
++ (show . descName . reflectDescriptorInfo $ initialMessage)
++ "\n at (messageLength,start,here) == " ++ show (messageLength,start,here))
getBareMessageWith :: (Default message, ReflectDescriptor message)
=> (WireTag -> message -> Get message)
-> Get message
getBareMessageWith updater = go required initialMessage
where
go reqs !message | Set.null reqs = go' message
| otherwise = do
done <- isReallyEmpty
if done then notEnoughData
else do
wireTag <- fmap WireTag getVarInt
let (_fieldId,wireType) = splitWireTag wireTag
if wireType == 4 then notEnoughData
else let reqs' = Set.delete wireTag reqs
in updater wireTag message >>= go reqs'
go' !message = do
done <- isReallyEmpty
if done then return message
else do
wireTag <- fmap WireTag getVarInt
let (_fieldId,wireType) = splitWireTag wireTag
if wireType == 4 then return message
else updater wireTag message >>= go'
initialMessage = defaultValue
(GetMessageInfo {requiredTags=required}) = getMessageInfo initialMessage
notEnoughData = throwError ("Text.ProtocolBuffers.WireMessage.getBareMessageWith: Required fields missing when processing "
++ (show . descName . reflectDescriptorInfo $ initialMessage))
unknownField :: Typeable a => a -> FieldId -> Get a
unknownField msg fieldId = do
here <- bytesRead
throwError ("Impossible? Text.ProtocolBuffers.WireMessage.unknownField"
++"\n Updater for "++show (typeOf msg)++" claims there is an unknown field id on wire: "++show fieldId
++"\n at a position just before byte location "++show here)
unknown :: (Typeable a,ReflectDescriptor a) => FieldId -> WireType -> a -> Get a
unknown fieldId wireType initialMessage = do
here <- bytesRead
throwError ("Text.ProtocolBuffers.WireMessage.unknown: Unknown field found or failure parsing field (e.g. unexpected Enum value):"
++ "\n (message type name,field id number,wire type code,bytes read) == "
++ show (typeOf initialMessage,fieldId,wireType,here)
++ "\n when processing "
++ (show . descName . reflectDescriptorInfo $ initialMessage))
{-# INLINE castWord32ToFloat #-}
castWord32ToFloat :: Word32 -> Float
castWord32ToFloat x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
{-# INLINE castFloatToWord32 #-}
castFloatToWord32 :: Float -> Word32
castFloatToWord32 x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
{-# INLINE castWord64ToDouble #-}
castWord64ToDouble :: Word64 -> Double
castWord64ToDouble x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
{-# INLINE castDoubleToWord64 #-}
castDoubleToWord64 :: Double -> Word64
castDoubleToWord64 x = runST (newArray (0::Int,0) x >>= castSTUArray >>= flip readArray 0)
wireSizeErr :: Typeable a => FieldType -> a -> WireSize
wireSizeErr ft x = error $ concat [ "Impossible? wireSize field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf x) ]
wirePutErr :: Typeable a => FieldType -> a -> PutM b
wirePutErr ft x = fail $ concat [ "Impossible? wirePut field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf x) ]
wireGetErr :: Typeable a => FieldType -> Get a
wireGetErr ft = answer where
answer = throwError $ concat [ "Impossible? wireGet field type mismatch error: Field type number ", show ft
, " does not match internal type ", show (typeOf (undefined `asTypeOf` typeHack answer)) ]
typeHack :: Get a -> a
typeHack = undefined
class Wire b where
{-# MINIMAL wireGet, wireSize, (wirePut | wirePutWithSize) #-}
wireSize :: FieldType -> b -> WireSize
{-# INLINE wirePut #-}
wirePut :: FieldType -> b -> Put
wirePut ft x = wirePutWithSize ft x >> return ()
{-# INLINE wirePutWithSize #-}
wirePutWithSize :: FieldType -> b -> PutM WireSize
wirePutWithSize ft x = wirePut ft x >> return (wireSize ft x)
wireGet :: FieldType -> Get b
{-# INLINE wireGetPacked #-}
wireGetPacked :: FieldType -> Get (Seq b)
wireGetPacked ft = throwError ("Text.ProtocolBuffers.ProtoCompile.Basic: wireGetPacked default:"
++ "\n There is no way to get a packed FieldType of "++show ft
++ ".\n Either there is a bug in this library or the wire format is has been updated.")
instance Wire Double where
{-# INLINE wireSize #-}
wireSize 1 _ = 8
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 1 x = putWord64le (castDoubleToWord64 x)
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 1 = fmap castWord64ToDouble getWord64le
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 1 = genericPacked 1
wireGetPacked ft = wireGetErr ft
instance Wire Float where
{-# INLINE wireSize #-}
wireSize 2 _ = 4
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 2 x = putWord32le (castFloatToWord32 x)
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 2 = fmap castWord32ToFloat getWord32le
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 2 = genericPacked 2
wireGetPacked ft = wireGetErr ft
instance Wire Int64 where
{-# INLINE wireSize #-}
wireSize 3 x = size'Int64 x
wireSize 18 x = size'Word64 (zzEncode64 x)
wireSize 16 _ = 8
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 3 x = putVarSInt x
wirePut 18 x = putVarUInt (zzEncode64 x)
wirePut 16 x = putWord64le (fromIntegral x)
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 3 = getVarInt
wireGet 18 = fmap zzDecode64 getVarInt
wireGet 16 = fmap fromIntegral getWord64le
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 3 = genericPacked 3
wireGetPacked 18 = genericPacked 18
wireGetPacked 16 = genericPacked 16
wireGetPacked ft = wireGetErr ft
instance Wire Int32 where
{-# INLINE wireSize #-}
wireSize 5 x = size'Int32 x
wireSize 17 x = size'Word32 (zzEncode32 x)
wireSize 15 _ = 4
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 5 x = putVarSInt x
wirePut 17 x = putVarUInt (zzEncode32 x)
wirePut 15 x = putWord32le (fromIntegral x)
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 5 = getVarInt
wireGet 17 = fmap zzDecode32 getVarInt
wireGet 15 = fmap fromIntegral getWord32le
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 5 = genericPacked 5
wireGetPacked 17 = genericPacked 17
wireGetPacked 15 = genericPacked 15
wireGetPacked ft = wireGetErr ft
instance Wire Word64 where
{-# INLINE wireSize #-}
wireSize 4 x = size'Word64 x
wireSize 6 _ = 8
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 4 x = putVarUInt x
wirePut 6 x = putWord64le x
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 6 = getWord64le
wireGet 4 = getVarInt
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 6 = genericPacked 6
wireGetPacked 4 = genericPacked 4
wireGetPacked ft = wireGetErr ft
instance Wire Word32 where
{-# INLINE wireSize #-}
wireSize 13 x = size'Word32 x
wireSize 7 _ = 4
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 13 x = putVarUInt x
wirePut 7 x = putWord32le x
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 13 = getVarInt
wireGet 7 = getWord32le
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 13 = genericPacked 13
wireGetPacked 7 = genericPacked 7
wireGetPacked ft = wireGetErr ft
instance Wire Bool where
{-# INLINE wireSize #-}
wireSize 8 _ = 1
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 8 False = putWord8 0
wirePut 8 True = putWord8 1
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 8 = do
x <- getVarInt :: Get Int32
case x of
0 -> return False
_ -> return True
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 8 = genericPacked 8
wireGetPacked ft = wireGetErr ft
instance Wire Utf8 where
{-# INLINE wireSize #-}
wireSize 9 x = prependMessageSize $ BS.length (utf8 x)
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 9 x = putVarUInt (BS.length (utf8 x)) >> putLazyByteString (utf8 x)
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 9 = getVarInt >>= getLazyByteString >>= verifyUtf8
wireGet ft = wireGetErr ft
instance Wire ByteString where
{-# INLINE wireSize #-}
wireSize 12 x = prependMessageSize $ BS.length x
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 12 x = putVarUInt (BS.length x) >> putLazyByteString x
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 12 = getVarInt >>= getLazyByteString
wireGet ft = wireGetErr ft
instance Wire Int where
{-# INLINE wireSize #-}
wireSize 14 x = size'Int x
wireSize ft x = wireSizeErr ft x
{-# INLINE wirePut #-}
wirePut 14 x = putVarSInt x
wirePut ft x = wirePutErr ft x
{-# INLINE wireGet #-}
wireGet 14 = getVarInt
wireGet ft = wireGetErr ft
{-# INLINE wireGetPacked #-}
wireGetPacked 14 = genericPacked 14
wireGetPacked ft = wireGetErr ft
{-# INLINE verifyUtf8 #-}
verifyUtf8 :: ByteString -> Get Utf8
verifyUtf8 bs = case isValidUTF8 bs of
Nothing -> return (Utf8 bs)
Just i -> throwError $ "Text.ProtocolBuffers.WireMessage.verifyUtf8: ByteString is not valid utf8 at position "++show i
{-# INLINE wireGetEnum #-}
wireGetEnum :: (Typeable e, Enum e) => (Int -> Maybe e) -> Get e
wireGetEnum toMaybe'Enum = do
int <- wireGet 14
case toMaybe'Enum int of
Just !v -> return v
Nothing -> throwError (msg ++ show int)
where msg = "Bad wireGet of Enum "++show (typeOf (undefined `asTypeOf` typeHack toMaybe'Enum))++", unrecognized Int value is "
typeHack :: (Int -> Maybe e) -> e
typeHack f = fromMaybe undefined (f undefined)
size'WireTag :: WireTag -> Int64
size'WireTag = size'Word32 . getWireTag
size'Word32 :: Word32 -> Int64
size'Word32 b | b <= 0x7F = 1
| b <= 0x3FFF = 2
| b <= 0x1FFFFF = 3
| b <= 0xFFFFFFF = 4
| otherwise = 5
size'Int32 :: Int32 -> Int64
size'Int32 b | b < 0 = 10
| b <= 0x7F = 1
| b <= 0x3FFF = 2
| b <= 0x1FFFFF = 3
| b <= 0xFFFFFFF = 4
| otherwise = 5
size'Word64 :: Word64 -> Int64
size'Word64 b | b <= 0x7F = 1
| b <= 0x3FFF = 2
| b <= 0x1FFFFF = 3
| b <= 0xFFFFFFF = 4
| b <= 0X7FFFFFFFF = 5
| b <= 0x3FFFFFFFFFF = 6
| b <= 0x1FFFFFFFFFFFF = 7
| b <= 0xFFFFFFFFFFFFFF = 8
| b <= 0x7FFFFFFFFFFFFFFF = 9
| otherwise = 10
size'Int :: Int -> Int64
size'Int b | b < 0 = 10
| b <= 0x7F = 1
| b <= 0x3FFF = 2
| b <= 0x1FFFFF = 3
| b <= 0xFFFFFFF = 4
| b <= 0x7FFFFFFF = 5
| b <= 0x7FFFFFFFF = 5
| b <= 0x3FFFFFFFFFF = 6
| b <= 0x1FFFFFFFFFFFF = 7
| b <= 0xFFFFFFFFFFFFFF = 8
| otherwise = 9
size'Int64,size'WireSize :: Int64 -> Int64
size'WireSize = size'Int64
size'Int64 b | b < 0 = 10
| b <= 0x7F = 1
| b <= 0x3FFF = 2
| b <= 0x1FFFFF = 3
| b <= 0xFFFFFFF = 4
| b <= 0x7FFFFFFFF = 5
| b <= 0x3FFFFFFFFFF = 6
| b <= 0x1FFFFFFFFFFFF = 7
| b <= 0xFFFFFFFFFFFFFF = 8
| otherwise = 9
zzEncode32 :: Int32 -> Word32
zzEncode32 x = fromIntegral ((x `shiftL` 1) `xor` (x `shiftR` 31))
zzEncode64 :: Int64 -> Word64
zzEncode64 x = fromIntegral ((x `shiftL` 1) `xor` (x `shiftR` 63))
zzDecode32 :: Word32 -> Int32
zzDecode32 w = (fromIntegral (w `shiftR` 1)) `xor` (negate (fromIntegral (w .&. 1)))
zzDecode64 :: Word64 -> Int64
zzDecode64 w = (fromIntegral (w `shiftR` 1)) `xor` (negate (fromIntegral (w .&. 1)))
getVarInt :: (Show a, Integral a, Bits a) => Get a
{-# INLINE getVarInt #-}
getVarInt = do
a <- decode7unrolled
trace ("getVarInt: "++show a) $ return a
{-# INLINE putVarSInt #-}
putVarSInt :: (Integral a, Bits a) => a -> Put
putVarSInt bIn =
case compare bIn 0 of
LT -> let b :: Int64
b = fromIntegral bIn
len :: Int
len = 10
last'Mask = 1
go !i 1 = putWord8 (fromIntegral (i .&. last'Mask))
go !i n = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> go (i `shiftR` 7) (pred n)
in go b len
EQ -> putWord8 0
GT -> putVarUInt bIn
{-# INLINE putVarUInt #-}
putVarUInt :: (Integral a, Bits a) => a -> Put
putVarUInt i | i < 0x80 = putWord8 (fromIntegral i)
| otherwise = putWord8 (fromIntegral (i .&. 0x7F) .|. 0x80) >> putVarUInt (i `shiftR` 7)
wireGetFromWire :: FieldId -> WireType -> Get ByteString
wireGetFromWire fi wt = getLazyByteString =<< calcLen where
calcLen = case wt of
0 -> highBitRun
1 -> return 8
2 -> lookAhead $ do
here <- bytesRead
len <- getVarInt
there <- bytesRead
return ((there-here)+len)
3 -> lenOf (skipGroup fi)
4 -> throwError $ "Cannot wireGetFromWire with wireType of STOP_GROUP: "++show (fi,wt)
5 -> return 4
wtf -> throwError $ "Invalid wire type (expected 0,1,2,3,or 5) found: "++show (fi,wtf)
lenOf g = do here <- bytesRead
there <- lookAhead (g >> bytesRead)
trace (":wireGetFromWire.lenOf: "++show ((fi,wt),(here,there,there-here))) $ return (there-here)
skipGroup :: FieldId -> Get ()
skipGroup start_fi = go where
go = do
(fieldId,wireType) <- fmap (splitWireTag . WireTag) getVarInt
case wireType of
0 -> spanOf (>=128) >> skip 1 >> go
1 -> skip 8 >> go
2 -> getVarInt >>= skip >> go
3 -> skipGroup fieldId >> go
4 | start_fi /= fieldId -> throwError $ "skipGroup failed, fieldId mismatch bewteen START_GROUP and STOP_GROUP: "++show (start_fi,(fieldId,wireType))
| otherwise -> return ()
5 -> skip 4 >> go
wtf -> throwError $ "Invalid wire type (expected 0,1,2,3,4,or 5) found: "++show (fieldId,wtf)
toWireType :: FieldType -> WireType
toWireType 1 = 1
toWireType 2 = 5
toWireType 3 = 0
toWireType 4 = 0
toWireType 5 = 0
toWireType 6 = 1
toWireType 7 = 5
toWireType 8 = 0
toWireType 9 = 2
toWireType 10 = 3
toWireType 11 = 2
toWireType 12 = 2
toWireType 13 = 0
toWireType 14 = 0
toWireType 15 = 5
toWireType 16 = 1
toWireType 17 = 0
toWireType 18 = 0
toWireType x = error $ "Text.ProcolBuffers.Basic.toWireType: Bad FieldType: "++show x