{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
module Data.ProtocolBuffers.Decode
( Decode(..)
, decodeMessage
, decodeLengthPrefixedMessage
, GDecode(..)
, fieldDecode
) where
import Control.Applicative
import Control.Monad
import qualified Data.ByteString as B
import Data.Foldable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.Int (Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.Monoid
import Data.Proxy
import Data.Serialize.Get
import Data.Traversable (traverse)
import GHC.Generics
import GHC.TypeLits
import Data.ProtocolBuffers.Types
import Data.ProtocolBuffers.Wire
decodeMessage :: Decode a => Get a
{-# INLINE decodeMessage #-}
decodeMessage = decode =<< HashMap.map reverse <$> go HashMap.empty where
go :: HashMap Tag [WireField] -> Get (HashMap Tag [WireField])
go msg = do
mfield <- Just <$> getWireField <|> return Nothing
case mfield of
Just v -> go $! HashMap.insertWith (\(x:[]) xs -> x:xs) (wireFieldTag v) [v] msg
Nothing -> return msg
decodeLengthPrefixedMessage :: Decode a => Get a
{-# INLINE decodeLengthPrefixedMessage #-}
decodeLengthPrefixedMessage = do
len :: Int64 <- getVarInt
bs <- getBytes $ fromIntegral len
case runGetState decodeMessage bs 0 of
Right (val, bs')
| B.null bs' -> return val
| otherwise -> fail $ "Unparsed bytes leftover in decodeLengthPrefixedMessage: " ++ show (B.length bs')
Left err -> fail err
class Decode (a :: *) where
decode :: HashMap Tag [WireField] -> Get a
default decode :: (Generic a, GDecode (Rep a)) => HashMap Tag [WireField] -> Get a
decode = fmap to . gdecode
instance Decode (HashMap Tag [WireField]) where
decode = pure
class GDecode (f :: * -> *) where
gdecode :: HashMap Tag [WireField] -> Get (f a)
instance GDecode a => GDecode (M1 i c a) where
gdecode = fmap M1 . gdecode
instance (GDecode a, GDecode b) => GDecode (a :*: b) where
gdecode msg = liftA2 (:*:) (gdecode msg) (gdecode msg)
instance (GDecode x, GDecode y) => GDecode (x :+: y) where
gdecode msg = L1 <$> gdecode msg <|> R1 <$> gdecode msg
fieldDecode
:: forall a b i n p . (DecodeWire a, Monoid a, KnownNat n)
=> (a -> b)
-> HashMap Tag [WireField]
-> Get (K1 i (Field n b) p)
{-# INLINE fieldDecode #-}
fieldDecode c msg =
let tag = fromIntegral $ natVal (Proxy :: Proxy n)
in case HashMap.lookup tag msg of
Just val -> K1 . Field . c <$> foldMapM decodeWire val
Nothing -> empty
instance (DecodeWire a, KnownNat n) => GDecode (K1 i (Field n (OptionalField (Last (Value a))))) where
gdecode msg = fieldDecode Optional msg <|> pure (K1 mempty)
instance (Enum a, KnownNat n) => GDecode (K1 i (Field n (RequiredField (Always (Enumeration a))))) where
gdecode msg = do
K1 mx <- fieldDecode Required msg
case mx :: Field n (RequiredField (Always (Value Int32))) of
Field (Required (Always (Value x))) ->
return . K1 . Field . Required . Always . Enumeration . toEnum $ fromIntegral x
instance (Enum a, KnownNat n) => GDecode (K1 i (Field n (OptionalField (Last (Enumeration a))))) where
gdecode msg = do
K1 mx <- fieldDecode Optional msg <|> pure (K1 mempty)
case mx :: Field n (OptionalField (Last (Value Int32))) of
Field (Optional (Last (Just (Value x)))) ->
return . K1 . Field . Optional . Last . Just . Enumeration . toEnum $ fromIntegral x
_ -> pure (K1 mempty)
instance (DecodeWire a, KnownNat n) => GDecode (K1 i (Repeated n a)) where
gdecode msg =
let tag = fromIntegral $ natVal (Proxy :: Proxy n)
in case HashMap.lookup tag msg of
Just val -> K1 . Field . Repeated <$> traverse decodeWire val
Nothing -> pure $ K1 mempty
instance (DecodeWire a, KnownNat n) => GDecode (K1 i (Field n (RequiredField (Always (Value a))))) where
gdecode msg = fieldDecode Required msg
instance (DecodeWire (PackedList a), KnownNat n) => GDecode (K1 i (Packed n a)) where
gdecode msg = fieldDecode PackedField msg <|> pure (K1 mempty)
instance GDecode U1 where
gdecode _ = return U1
foldMapM :: (Monad m, Foldable t, Monoid b) => (a -> m b) -> t a -> m b
foldMapM f = liftM (fromMaybe mempty) . foldlM go Nothing where
go (Just !acc) = liftM (Just . mappend acc) . f
go Nothing = liftM Just . f