{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies     #-}
{-|
Module: Capnp.Convert
Description: Convert between messages, typed capnproto values, and (lazy)bytestring(builders).

This module provides various helper functions to convert between messages, types defined
in capnproto schema (called "values" in the rest of this module's documentation),
bytestrings (both lazy and strict), and bytestring builders.

Note that most of the functions which decode messages or raw bytes do *not* need to be
run inside of an instance of 'MonadLimit'; they choose an appropriate limit based on the
size of the input.

Note that not all conversions exist or necessarily make sense.
-}
module Capnp.Convert
    ( msgToBuilder
    , msgToLBS
    , msgToBS
    , msgToValue
    , bsToMsg
    , bsToValue
    , lbsToMsg
    , lbsToValue
    , valueToBuilder
    , valueToBS
    , valueToLBS
    , valueToMsg
    ) where

import Control.Monad         ((>=>))
import Control.Monad.Catch   (MonadThrow)
import Data.Foldable         (foldlM)
import Data.Functor.Identity (runIdentity)

import qualified Data.ByteString         as BS
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy    as LBS

import Capnp.Classes

import Capnp.Bits           (WordCount)
import Capnp.TraversalLimit (LimitT, MonadLimit, evalLimitT)
import Codec.Capnp          (getRoot, setRoot)
import Data.Mutable         (freeze)

import qualified Capnp.Message as M

-- | Compute a reasonable limit based on the size of a message. The limit
-- is the total number of words in all of the message's segments, multiplied
-- by 10 to provide some slack for decoding default values.
limitFromMsg :: (MonadThrow m, M.Message m msg) => msg -> m WordCount
limitFromMsg :: msg -> m WordCount
limitFromMsg msg
msg = do
    WordCount
messageWords <- m WordCount
countMessageWords
    WordCount -> m WordCount
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WordCount
messageWords WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
* WordCount
10)
  where
    countMessageWords :: m WordCount
countMessageWords = do
        Int
segCount <- msg -> m Int
forall (m :: * -> *) msg. Message m msg => msg -> m Int
M.numSegs msg
msg
        (WordCount -> Int -> m WordCount)
-> WordCount -> [Int] -> m WordCount
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM
            (\WordCount
total Int
i -> do
                WordCount
words <- msg -> Int -> m (Segment msg)
forall (m :: * -> *) msg.
(MonadThrow m, Message m msg) =>
msg -> Int -> m (Segment msg)
M.getSegment msg
msg Int
i m (Segment msg) -> (Segment msg -> m WordCount) -> m WordCount
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Segment msg -> m WordCount
forall (m :: * -> *) msg.
Message m msg =>
Segment msg -> m WordCount
M.numWords
                WordCount -> m WordCount
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WordCount
words WordCount -> WordCount -> WordCount
forall a. Num a => a -> a -> a
+ WordCount
total)
            )
            WordCount
0
            [Int
0..Int
segCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

-- | Convert an immutable message to a bytestring 'BB.Builder'.
-- To convert a mutable message, 'freeze' it first.
msgToBuilder :: M.ConstMsg -> BB.Builder
msgToBuilder :: ConstMsg -> Builder
msgToBuilder = Identity Builder -> Builder
forall a. Identity a -> a
runIdentity (Identity Builder -> Builder)
-> (ConstMsg -> Identity Builder) -> ConstMsg -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstMsg -> Identity Builder
forall (m :: * -> *). Monad m => ConstMsg -> m Builder
M.encode

-- | Convert an immutable message to a lazy 'LBS.ByteString'.
-- To convert a mutable message, 'freeze' it first.
msgToLBS :: M.ConstMsg -> LBS.ByteString
msgToLBS :: ConstMsg -> ByteString
msgToLBS = Builder -> ByteString
BB.toLazyByteString (Builder -> ByteString)
-> (ConstMsg -> Builder) -> ConstMsg -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstMsg -> Builder
msgToBuilder

-- | Convert an immutable message to a strict 'BS.ByteString'.
-- To convert a mutable message, 'freeze' it first.
msgToBS :: M.ConstMsg -> BS.ByteString
msgToBS :: ConstMsg -> ByteString
msgToBS = ByteString -> ByteString
LBS.toStrict (ByteString -> ByteString)
-> (ConstMsg -> ByteString) -> ConstMsg -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConstMsg -> ByteString
msgToLBS

-- | Convert a message to a value.
msgToValue :: (MonadThrow m, M.Message (LimitT m) msg, M.Message m msg, FromStruct msg a) => msg -> m a
msgToValue :: msg -> m a
msgToValue msg
msg = do
    WordCount
limit <- msg -> m WordCount
forall (m :: * -> *) msg.
(MonadThrow m, Message m msg) =>
msg -> m WordCount
limitFromMsg msg
msg
    WordCount -> LimitT m a -> m a
forall (m :: * -> *) a.
MonadThrow m =>
WordCount -> LimitT m a -> m a
evalLimitT WordCount
limit (msg -> LimitT m a
forall msg a (m :: * -> *).
(FromStruct msg a, ReadCtx m msg) =>
msg -> m a
getRoot msg
msg)

-- | Convert a strict 'BS.ByteString' to a message.
bsToMsg :: MonadThrow m => BS.ByteString -> m M.ConstMsg
bsToMsg :: ByteString -> m ConstMsg
bsToMsg = ByteString -> m ConstMsg
forall (m :: * -> *). MonadThrow m => ByteString -> m ConstMsg
M.decode

-- | Convert a strict 'BS.ByteString' to a value.
bsToValue :: (MonadThrow m, FromStruct M.ConstMsg a) => BS.ByteString -> m a
bsToValue :: ByteString -> m a
bsToValue = ByteString -> m ConstMsg
forall (m :: * -> *). MonadThrow m => ByteString -> m ConstMsg
bsToMsg (ByteString -> m ConstMsg)
-> (ConstMsg -> m a) -> ByteString -> m a
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> ConstMsg -> m a
forall (m :: * -> *) msg a.
(MonadThrow m, Message (LimitT m) msg, Message m msg,
 FromStruct msg a) =>
msg -> m a
msgToValue

-- | Convert a lazy 'LBS.ByteString' to a message.
lbsToMsg :: MonadThrow m => LBS.ByteString -> m M.ConstMsg
lbsToMsg :: ByteString -> m ConstMsg
lbsToMsg = ByteString -> m ConstMsg
forall (m :: * -> *). MonadThrow m => ByteString -> m ConstMsg
bsToMsg (ByteString -> m ConstMsg)
-> (ByteString -> ByteString) -> ByteString -> m ConstMsg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.toStrict

-- | Convert a lazy 'LBS.ByteString' to a value.
lbsToValue :: (MonadThrow m, FromStruct M.ConstMsg a) => LBS.ByteString -> m a
lbsToValue :: ByteString -> m a
lbsToValue = ByteString -> m a
forall (m :: * -> *) a.
(MonadThrow m, FromStruct ConstMsg a) =>
ByteString -> m a
bsToValue (ByteString -> m a)
-> (ByteString -> ByteString) -> ByteString -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
LBS.toStrict

-- | Convert a value to a 'BS.Builder'.
valueToBuilder :: (MonadLimit m, M.WriteCtx m s, Cerialize s a, ToStruct (M.MutMsg s) (Cerial (M.MutMsg s) a)) => a -> m BB.Builder
valueToBuilder :: a -> m Builder
valueToBuilder a
val = ConstMsg -> Builder
msgToBuilder (ConstMsg -> Builder) -> m ConstMsg -> m Builder
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> m (MutMsg s)
forall (m :: * -> *) s a.
(MonadLimit m, WriteCtx m s, Cerialize s a,
 ToStruct (MutMsg s) (Cerial (MutMsg s) a)) =>
a -> m (MutMsg s)
valueToMsg a
val m (MutMsg s) -> (MutMsg s -> m ConstMsg) -> m ConstMsg
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MutMsg s -> m ConstMsg
forall a (m :: * -> *) s.
(Thaw a, PrimMonad m, PrimState m ~ s) =>
Mutable s a -> m a
freeze)

-- | Convert a value to a strict 'BS.ByteString'.
valueToBS :: (MonadLimit m, M.WriteCtx m s, Cerialize s a, ToStruct (M.MutMsg s) (Cerial (M.MutMsg s) a)) => a -> m BS.ByteString
valueToBS :: a -> m ByteString
valueToBS = (ByteString -> ByteString) -> m ByteString -> m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ByteString -> ByteString
LBS.toStrict (m ByteString -> m ByteString)
-> (a -> m ByteString) -> a -> m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m ByteString
forall (m :: * -> *) s a.
(MonadLimit m, WriteCtx m s, Cerialize s a,
 ToStruct (MutMsg s) (Cerial (MutMsg s) a)) =>
a -> m ByteString
valueToLBS

-- | Convert a value to a lazy 'LBS.ByteString'.
valueToLBS :: (MonadLimit m, M.WriteCtx m s, Cerialize s a, ToStruct (M.MutMsg s) (Cerial (M.MutMsg s) a)) => a -> m LBS.ByteString
valueToLBS :: a -> m ByteString
valueToLBS = (Builder -> ByteString) -> m Builder -> m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Builder -> ByteString
BB.toLazyByteString (m Builder -> m ByteString)
-> (a -> m Builder) -> a -> m ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> m Builder
forall (m :: * -> *) s a.
(MonadLimit m, WriteCtx m s, Cerialize s a,
 ToStruct (MutMsg s) (Cerial (MutMsg s) a)) =>
a -> m Builder
valueToBuilder

-- | Convert a value to a message.
valueToMsg :: (MonadLimit m, M.WriteCtx m s, Cerialize s a, ToStruct (M.MutMsg s) (Cerial (M.MutMsg s) a)) => a -> m (M.MutMsg s)
valueToMsg :: a -> m (MutMsg s)
valueToMsg a
val = do
    MutMsg s
msg <- Maybe WordCount -> m (MutMsg s)
forall (m :: * -> *) s.
WriteCtx m s =>
Maybe WordCount -> m (MutMsg s)
M.newMessage Maybe WordCount
forall a. Maybe a
Nothing
    Cerial (MutMsg s) a
ret <- MutMsg s -> a -> m (Cerial (MutMsg s) a)
forall s a (m :: * -> *).
(Cerialize s a, RWCtx m s) =>
MutMsg s -> a -> m (Cerial (MutMsg s) a)
cerialize MutMsg s
msg a
val
    Cerial (MutMsg s) a -> m ()
forall s a (m :: * -> *).
(ToStruct (MutMsg s) a, WriteCtx m s) =>
a -> m ()
setRoot Cerial (MutMsg s) a
ret
    MutMsg s -> m (MutMsg s)
forall (f :: * -> *) a. Applicative f => a -> f a
pure MutMsg s
msg