{-# LANGUAGE CPP #-}
module Data.ProtoLens.Encoding (
    encodeMessage,
    buildMessage,
    decodeMessage,
    parseMessage,
    decodeMessageOrDie,
    -- ** Delimited messages
    buildMessageDelimited,
    parseMessageDelimited,
    decodeMessageDelimitedH,
    ) where

import System.IO (Handle)

import Data.ProtoLens.Message (Message(..))
import Data.ProtoLens.Encoding.Bytes (Parser, Builder)
import qualified Data.ProtoLens.Encoding.Bytes as Bytes

import Control.Monad.IO.Class (liftIO)
import Control.Monad.Trans.Except (runExceptT, ExceptT(..))
import qualified Data.ByteString as B
#if !MIN_VERSION_base(4,11,0)
import Data.Semigroup ((<>))
#endif

-- | Decode a message from its wire format.  Returns 'Left' if the decoding
-- fails.
decodeMessage :: Message msg => B.ByteString -> Either String msg
decodeMessage :: forall msg. Message msg => ByteString -> Either String msg
decodeMessage = forall a. Parser a -> ByteString -> Either String a
Bytes.runParser forall msg. Message msg => Parser msg
parseMessage

-- | Decode a message from its wire format.  Throws an error if the decoding
-- fails.
decodeMessageOrDie :: Message msg => B.ByteString -> msg
decodeMessageOrDie :: forall msg. Message msg => ByteString -> msg
decodeMessageOrDie ByteString
bs = case forall msg. Message msg => ByteString -> Either String msg
decodeMessage ByteString
bs of
    Left String
e -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"decodeMessageOrDie: " forall a. [a] -> [a] -> [a]
++ String
e
    Right msg
x -> msg
x

-- | Encode a message to the wire format as a strict 'ByteString'.
encodeMessage :: Message msg => msg -> B.ByteString
encodeMessage :: forall msg. Message msg => msg -> ByteString
encodeMessage = Builder -> ByteString
Bytes.runBuilder forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall msg. Message msg => msg -> Builder
buildMessage

-- | Encode a message to the wire format, prefixed by its size as a VarInt,
-- as part of a 'Builder'.
--
-- This can be used to build up streams of messages in the size-delimited
-- format expected by some protocols.
buildMessageDelimited :: Message msg => msg -> Builder
buildMessageDelimited :: forall msg. Message msg => msg -> Builder
buildMessageDelimited msg
msg =
    let b :: ByteString
b = forall msg. Message msg => msg -> ByteString
encodeMessage msg
msg
    in Word64 -> Builder
Bytes.putVarInt (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
Bytes.putBytes ByteString
b

parseMessageDelimited :: Message msg => Parser msg
parseMessageDelimited :: forall msg. Message msg => Parser msg
parseMessageDelimited = do
    Word64
len <- Parser Word64
Bytes.getVarInt
    ByteString
bytes <- Int -> Parser ByteString
Bytes.getBytes forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
len
    forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall msg. Message msg => ByteString -> Either String msg
decodeMessage ByteString
bytes

-- | Same as @decodeMessage@ but for delimited messages read through a Handle
decodeMessageDelimitedH :: Message msg => Handle -> IO (Either String msg)
decodeMessageDelimitedH :: forall msg. Message msg => Handle -> IO (Either String msg)
decodeMessageDelimitedH Handle
h = forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$
    Handle -> ExceptT String IO Word64
Bytes.getVarIntH Handle
h forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
    forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Handle -> Int -> IO ByteString
B.hGet Handle
h forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>=
    forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall msg. Message msg => ByteString -> Either String msg
decodeMessage