{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}

module Data.ProtocolBuffers.Decode
  ( Decode(..)
  , decodeMessage
  , decodeLengthPrefixedMessage
  , GDecode(..)
  , fieldDecode
  ) where

import Control.Applicative
import Control.Monad
import qualified Data.ByteString as B
import Data.Foldable
import Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import Data.Int (Int32, Int64)
import Data.Maybe (fromMaybe)
import Data.Monoid
import Data.Proxy
import Data.Serialize.Get
import Data.Traversable (traverse)

import GHC.Generics
import GHC.TypeLits

import Data.ProtocolBuffers.Types
import Data.ProtocolBuffers.Wire

-- |
-- Decode a Protocol Buffers message.
decodeMessage :: Decode a => Get a
{-# INLINE decodeMessage #-}
decodeMessage = decode =<< HashMap.map reverse <$> go HashMap.empty where
  go :: HashMap Tag [WireField] -> Get (HashMap Tag [WireField])
  go msg = do
    mfield <- Just <$> getWireField <|> return Nothing
    case mfield of
      Just v  -> go $! HashMap.insertWith (\(x:[]) xs -> x:xs) (wireFieldTag v) [v] msg
      Nothing -> return msg

-- |
-- Decode a Protocol Buffers message prefixed with a varint encoded 32-bit integer describing its length.
decodeLengthPrefixedMessage :: Decode a => Get a
{-# INLINE decodeLengthPrefixedMessage #-}
decodeLengthPrefixedMessage = do
  len :: Int64 <- getVarInt
  bs <- getBytes $ fromIntegral len
  case runGetState decodeMessage bs 0 of
    Right (val, bs')
      | B.null bs' -> return val
      | otherwise  -> fail $ "Unparsed bytes leftover in decodeLengthPrefixedMessage: " ++ show (B.length bs')
    Left err  -> fail err

class Decode (a :: *) where
  decode :: HashMap Tag [WireField] -> Get a
  default decode :: (Generic a, GDecode (Rep a)) => HashMap Tag [WireField] -> Get a
  decode = fmap to . gdecode

-- | Untyped message decoding, @ 'decode' = 'id' @
instance Decode (HashMap Tag [WireField]) where
  decode = pure

class GDecode (f :: * -> *) where
  gdecode :: HashMap Tag [WireField] -> Get (f a)

instance GDecode a => GDecode (M1 i c a) where
  gdecode = fmap M1 . gdecode

instance (GDecode a, GDecode b) => GDecode (a :*: b) where
  gdecode msg = liftA2 (:*:) (gdecode msg) (gdecode msg)

instance (GDecode x, GDecode y) => GDecode (x :+: y) where
  gdecode msg = L1 <$> gdecode msg <|> R1 <$> gdecode msg

fieldDecode
  :: forall a b i n p . (DecodeWire a, Monoid a, KnownNat n)
  => (a -> b)
  -> HashMap Tag [WireField]
  -> Get (K1 i (Field n b) p)
{-# INLINE fieldDecode #-}
fieldDecode c msg =
  let tag = fromIntegral $ natVal (Proxy :: Proxy n)
  in case HashMap.lookup tag msg of
    Just val -> K1 . Field . c <$> foldMapM decodeWire val
    Nothing  -> empty

instance (DecodeWire a, KnownNat n) => GDecode (K1 i (Field n (OptionalField (Last (Value a))))) where
  gdecode msg = fieldDecode Optional msg <|> pure (K1 mempty)

instance (Enum a, KnownNat n) => GDecode (K1 i (Field n (RequiredField (Always (Enumeration a))))) where
  gdecode msg = do
    K1 mx <- fieldDecode Required msg
    case mx :: Field n (RequiredField (Always (Value Int32))) of
      Field (Required (Always (Value x))) ->
        return . K1 . Field . Required . Always . Enumeration . toEnum $ fromIntegral x

instance (Enum a, KnownNat n) => GDecode (K1 i (Field n (OptionalField (Last (Enumeration a))))) where
  gdecode msg = do
    K1 mx <- fieldDecode Optional msg <|> pure (K1 mempty)
    case mx :: Field n (OptionalField (Last (Value Int32))) of
      Field (Optional (Last (Just (Value x)))) ->
        return . K1 . Field . Optional . Last . Just . Enumeration . toEnum $ fromIntegral x
      _ -> pure (K1 mempty)

instance (DecodeWire a, KnownNat n) => GDecode (K1 i (Repeated n a)) where
  gdecode msg =
    let tag = fromIntegral $ natVal (Proxy :: Proxy n)
    in case HashMap.lookup tag msg of
      Just val -> K1 . Field . Repeated <$> traverse decodeWire val
      Nothing  -> pure $ K1 mempty

instance (DecodeWire a, KnownNat n) => GDecode (K1 i (Field n (RequiredField (Always (Value a))))) where
  gdecode msg = fieldDecode Required msg

instance (DecodeWire (PackedList a), KnownNat n) => GDecode (K1 i (Packed n a)) where
  gdecode msg = fieldDecode PackedField msg <|> pure (K1 mempty)

instance GDecode U1 where
  gdecode _ = return U1

-- |
-- foldMapM implemented in a way that defers using (mempty :: b) unless the
-- Foldable is empty, this allows the gross hack of pretending Always is
-- a Monoid while strictly evaluating the accumulator
foldMapM :: (Monad m, Foldable t, Monoid b) => (a -> m b) -> t a -> m b
foldMapM f = liftM (fromMaybe mempty) . foldlM go Nothing where
  go (Just !acc) = liftM (Just . mappend acc) . f
  go Nothing     = liftM Just . f