{-# Language ScopedTypeVariables, MultiWayIf, TypeFamilies, FlexibleContexts #-}
module Data.Bond.Internal.TaggedProtocol where
import Data.Bond.Schema.BondDataType
import Data.Bond.Struct
import Data.Bond.TypedSchema
import Data.Bond.Types
import Data.Bond.Internal.BinaryUtils
import Data.Bond.Internal.Default
import Data.Bond.Internal.OrdinalSet
import Data.Bond.Internal.Protocol
import Data.Bond.Internal.SchemaOps
import Data.Bond.Internal.SchemaUtils
import Control.Applicative
import Control.Monad
import Control.Monad.Error
import Data.Bits
import Data.Proxy
import Prelude
import qualified Data.Binary.Get as B
import qualified Data.Binary.Put as B
import qualified Data.ByteString.Lazy as BL
import qualified Data.Map as M
import qualified Data.Map.Strict as MS
data StructLevel = TopLevelStruct | BaseStruct
deriving (Show, Eq)
class Protocol t => TaggedProtocol t where
getFieldHeader :: BondGet t (BondDataType, Ordinal)
getListHeader :: BondGet t (BondDataType, Int)
getTaggedStruct :: BondGet t Struct
putFieldHeader :: BondDataType -> Ordinal -> BondPut t
putListHeader :: (Integral a, FiniteBits a) => BondDataType -> a -> BondPut t
putTaggedStruct :: Struct -> BondPut t
skipStruct :: BondGet t ()
skipRestOfStruct :: BondGet t ()
skipType :: BondDataType -> BondGet t ()
getStruct :: forall a t. (Functor (ReaderM t), Monad (ReaderM t), TaggedProtocol t, BondStruct a) => StructLevel -> BondGet t a
getStruct level = do
let schema = getSchema (Proxy :: Proxy a)
let fieldsMap = structFields schema
b <- bondStructGetBase defaultValue
let readField wiretype ordinal s =
if M.member ordinal fieldsMap
then bondStructGetField ordinal s
else do
skipType wiretype
return s
let loop (s, ords) = do
(wiretype, ordinal) <- getFieldHeader
if | wiretype == bT_STOP && level == BaseStruct -> fail "BT_STOP found where BT_STOP_BASE expected"
| wiretype == bT_STOP && level == TopLevelStruct -> return (s, ords)
| wiretype == bT_STOP_BASE && level == BaseStruct -> return (s, ords)
| wiretype == bT_STOP_BASE && level == TopLevelStruct -> skipRestOfStruct >> return (s, ords)
| otherwise -> do
s' <- readField wiretype ordinal s
loop (s', deleteOrdinal ordinal ords)
(value, notRead) <- loop (b, structRequiredOrdinals schema)
unless (isEmptySet notRead) $ fail $ "required fields not read: " ++ show (map (getFieldName schema) $ toOrdinalList notRead)
return value
putStruct :: (WriterM t ~ ErrorT String B.PutM, TaggedProtocol t, BondStruct a) => StructLevel -> a -> BondPut t
putStruct level a = do
bondStructPut a
case level of
TopLevelStruct -> putTag bT_STOP
BaseStruct -> putTag bT_STOP_BASE
putBaseStruct :: (WriterM t ~ ErrorT String B.PutM, TaggedProtocol t, BondStruct a) => a -> BondPut t
putBaseStruct = putStruct BaseStruct
putField :: forall a b t. (Monad (BondPutM t), TaggedProtocol t, BondType a, BondStruct b) => Proxy b -> Ordinal -> a -> BondPut t
putField p ordinal value = do
let tag = getWireType (Proxy :: Proxy a)
let info = M.findWithDefault (error "internal error: unknown field ordinal") ordinal (structFields $ getSchema p)
let needToSave = not (equalToDefault (fieldType info) value) || fieldModifier info /= FieldOptional
when needToSave $ do
putFieldHeader tag ordinal
bondPut value
putTag :: WriterM t ~ ErrorT String B.PutM => BondDataType -> BondPut t
putTag = putWord8 . fromIntegral . fromEnum
binaryDecode :: forall a t. (ReaderM t ~ B.Get, BondStruct a, Protocol t) => t -> BL.ByteString -> Either String a
binaryDecode _ s =
let BondGet g = bondGetStruct :: BondGet t a
in case B.runGetOrFail g s of
Left (_, used, msg) -> Left $ "parse error at " ++ show used ++ ": " ++ msg
Right (rest, used, _) | not (BL.null rest) -> Left $ "incomplete parse, used " ++ show used ++ ", left " ++ show (BL.length rest)
Right (_, _, a) -> Right a
binaryEncode :: forall a t. (WriterM t ~ ErrorT String B.PutM, BondStruct a, Protocol t) => t -> a -> Either String BL.ByteString
binaryEncode _ a =
let BondPut g = bondPutStruct a :: BondPut t
in tryPut g
getTaggedData :: forall t. (ReaderM t ~ B.Get, TaggedProtocol t) => BondGet t Struct
getTaggedData = fieldLoop $ Struct Nothing M.empty
where
getValue :: BondDataType -> BondGet t Value
getValue t =
if | t == bT_BOOL -> BOOL <$> bondGetBool
| t == bT_UINT8 -> UINT8 <$> bondGetUInt8
| t == bT_UINT16 -> UINT16 <$> bondGetUInt16
| t == bT_UINT32 -> UINT32 <$> bondGetUInt32
| t == bT_UINT64 -> UINT64 <$> bondGetUInt64
| t == bT_FLOAT -> FLOAT <$> bondGetFloat
| t == bT_DOUBLE -> DOUBLE <$> bondGetDouble
| t == bT_STRING -> STRING <$> bondGetString
| t == bT_STRUCT -> STRUCT <$> getTaggedStruct
| t == bT_LIST -> do
(td, n) <- getListHeader
LIST td <$> replicateM n (getValue td)
| t == bT_SET -> do
(td, n) <- getListHeader
SET td <$> replicateM n (getValue td)
| t == bT_MAP -> do
tk <- BondDataType . fromIntegral <$> getWord8
tv <- BondDataType . fromIntegral <$> getWord8
n <- getVarInt
MAP tk tv <$> replicateM n (do
k <- getValue tk
v <- getValue tv
return (k, v))
| t == bT_INT8 -> INT8 <$> bondGetInt8
| t == bT_INT16 -> INT16 <$> bondGetInt16
| t == bT_INT32 -> INT32 <$> bondGetInt32
| t == bT_INT64 -> INT64 <$> bondGetInt64
| t == bT_WSTRING -> WSTRING <$> bondGetWString
| otherwise -> fail $ "invalid field type " ++ bondTypeName t
setField s o v = return $ s { fields = MS.insert o v (fields s) }
fieldLoop s = do
(t, o) <- getFieldHeader
if | t == bT_STOP -> return s
| t == bT_STOP_BASE -> fieldLoop $ Struct (Just s) M.empty
| otherwise -> getValue t >>= setField s o >>= fieldLoop
readTagged :: forall t. (ReaderM t ~ B.Get, TaggedProtocol t) => t -> BL.ByteString -> Either String Struct
readTagged _ s =
let BondGet g = getTaggedStruct :: BondGet t Struct
in case B.runGetOrFail g s of
Left (_, used, msg) -> Left $ "parse error at " ++ show used ++ ": " ++ msg
Right (rest, used, _) | not (BL.null rest) -> Left $ "incomplete parse, used " ++ show used ++ ", left " ++ show (BL.length rest)
Right (_, _, a) -> Right a
readTaggedWithSchema :: forall t. (ReaderM t ~ B.Get, TaggedProtocol t) => t -> StructSchema -> BL.ByteString -> Either String Struct
readTaggedWithSchema _ schema s =
let BondGet g = getTaggedStruct :: BondGet t Struct
in case B.runGetOrFail g s of
Left (_, used, msg) -> Left $ "parse error at " ++ show used ++ ": " ++ msg
Right (rest, used, _) | not (BL.null rest) -> Left $ "incomplete parse, used " ++ show used ++ ", left " ++ show (BL.length rest)
Right (_, _, a) -> checkStructSchema schema a
putTaggedData :: forall t. (MonadError String (BondPutM t), WriterM t ~ ErrorT String B.PutM, TaggedProtocol t) => Struct -> BondPut t
putTaggedData s = do
case base s of
Just b -> putTaggedData b >> putTag bT_STOP_BASE
Nothing -> return ()
forM_ (M.toList $ fields s) $ \ (o, v) -> do
let (typ, writer) = saveValue v
putFieldHeader typ o
writer
where
saveValue :: Value -> (BondDataType, BondPut t)
saveValue (BOOL v) = (bT_BOOL, bondPutBool v)
saveValue (INT8 v) = (bT_INT8, bondPutInt8 v)
saveValue (INT16 v) = (bT_INT16, bondPutInt16 v)
saveValue (INT32 v) = (bT_INT32, bondPutInt32 v)
saveValue (INT64 v) = (bT_INT64, bondPutInt64 v)
saveValue (UINT8 v) = (bT_UINT8, bondPutUInt8 v)
saveValue (UINT16 v) = (bT_UINT16, bondPutUInt16 v)
saveValue (UINT32 v) = (bT_UINT32, bondPutUInt32 v)
saveValue (UINT64 v) = (bT_UINT64, bondPutUInt64 v)
saveValue (FLOAT v) = (bT_FLOAT, bondPutFloat v)
saveValue (DOUBLE v) = (bT_DOUBLE, bondPutDouble v)
saveValue (STRING v) = (bT_STRING, bondPutString v)
saveValue (WSTRING v) = (bT_WSTRING, bondPutWString v)
saveValue (STRUCT v) = (bT_STRUCT, putTaggedStruct v)
saveValue (LIST td xs) = (bT_LIST, putListHeader td (length xs) >> mapM_ (saveTypedValue td) xs)
saveValue (SET td xs) = (bT_SET, putListHeader td (length xs) >> mapM_ (saveTypedValue td) xs)
saveValue (MAP tk tv xs) = (bT_MAP, do
putTag tk
putTag tv
putVarInt $ length xs
forM_ xs $ \ (k, v) -> do
saveTypedValue tk k
saveTypedValue tv v
)
saveValue (BONDED (BondedObject v)) = (bT_STRUCT, putTaggedStruct v)
saveValue (BONDED _) = (bT_STRUCT, throwError "not implemented: should decode bonded values before tagged writes")
saveTypedValue td v
= let (realtd, writer) = saveValue v
in if td == realtd
then writer
else throwError $ "element type do not match container type: " ++ bondTypeName td ++ " expected, " ++ bondTypeName realtd ++ " found"
writeTagged :: forall t. (WriterM t ~ ErrorT String B.PutM, TaggedProtocol t) => t -> Struct -> Either String BL.ByteString
writeTagged _ a = let BondPut g = putTaggedStruct a :: BondPut t
in tryPut g
writeTaggedWithSchema :: (WriterM t ~ ErrorT String B.PutM, TaggedProtocol t) => t -> StructSchema -> Struct -> Either String BL.ByteString
writeTaggedWithSchema t schema struct = checkStructSchema schema struct >>= writeTagged t