{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NoMonomorphismRestriction #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.ProtoLens.Compiler.Generate.Encoding
( generatedParser
, generatedBuilder
) where
import Data.Int (Int32)
import qualified Data.Map as Map
import Data.Semigroup ((<>))
import qualified Data.Text as Text
import Lens.Family2 (view, (^.))
import Data.ProtoLens.Compiler.Combinators
import Data.ProtoLens.Compiler.Definitions
import Data.ProtoLens.Compiler.Generate.Field
import Data.ProtoLens.Encoding.Wire (joinTypeAndTag)
import Proto.Google.Protobuf.Descriptor_Fields
( name
, number
, type'
)
generatedParser :: Env QName -> MessageInfo Name -> Exp
generatedParser env m =
let' [typeSig [loop] loopSig
, funBind [match loop (fmap pVar $ loopArgs names) loopExpr]
]
$ "Data.ProtoLens.Encoding.Bytes.<?>"
@@ do' (startStmts ++ [stmt $ continue startExp])
@@ stringExp msgName
where
ty = tyCon (unQual $ messageName m)
msgName = Text.unpack (messageDescriptor m ^. name)
loopSig = foldr tyFun
("Data.ProtoLens.Encoding.Bytes.Parser" @@ ty)
(loopArgs $ parseStateTypes env m)
names = parseStateNames m
exprs = fmap (var . unQual) names
tag = "tag"
end = "end"
loop = "loop"
(startStmts, startExp) = startParse names
continue :: ParseState Exp -> Exp
continue s = foldl (@@) loop (loopArgs s)
loopExpr
| Just g <- groupFieldNumber m = do'
[ tag <-- getVarInt'
, stmt $ case' tag $
(pLitInt (groupEndTag g) --> finish m exprs)
: parseTagCases continue exprs m
]
| otherwise = do'
[ end <-- "Data.ProtoLens.Encoding.Bytes.atEnd"
, stmt $
if' end (finish m exprs)
$ do'
[ tag <-- getVarInt'
, stmt $ case' tag $ parseTagCases continue exprs m
]
]
finish :: MessageInfo Name -> ParseState Exp -> Exp
finish m s = do' $
[ pVar frozen <-- unsafeLiftIO' @@
("Data.ProtoLens.Encoding.Growing.unsafeFreeze"
@@ mutable)
| (frozen, mutable) <- Map.elems $ Map.intersectionWith (,)
frozenNames (repeatedFieldMVectors s)
]
++
[ stmt $ checkMissingFields s
, stmt $ "Prelude.return" @@
(over' unknownFields' "Prelude.reverse"
@@(foldr (@@)
(partialMessage s)
(Map.intersectionWith
(\finfo frozen ->
"Lens.Family2.set"
@@ fieldOfVector finfo
@@ var (unQual frozen))
repeatedInfos frozenNames)))
]
where
repeatedInfos = repeatedFields m
frozenNames = (\f -> nameFromSymbol $ "frozen'" <> overloadedFieldName f)
<$> repeatedInfos
data ParseState v = ParseState
{ partialMessage :: v
, requiredFieldsUnset :: Map.Map FieldId v
, repeatedFieldMVectors :: Map.Map FieldId v
} deriving Functor
loopArgs :: ParseState v -> [v]
loopArgs s = partialMessage s : Map.elems (requiredFieldsUnset s)
++ Map.elems (repeatedFieldMVectors s)
newtype FieldId = FieldId Text.Text
deriving (Eq, Ord)
fieldId :: PlainFieldInfo -> FieldId
fieldId f = FieldId $ fieldDescriptor (plainFieldInfo f) ^. name
parseStateNames :: MessageInfo Name -> ParseState Name
parseStateNames m = ParseState
{ partialMessage = "x"
, requiredFieldsUnset = Map.fromList
[ (fieldId f, nameFromSymbol $ "required'" <> n)
| f <- messageFields m
, let info = plainFieldInfo f
, let n = overloadedFieldName info
, RequiredField <- [plainFieldKind f]
]
, repeatedFieldMVectors =
(\f -> nameFromSymbol $ "mutable'" <> overloadedFieldName f)
<$> repeatedFields m
}
repeatedFields :: MessageInfo Name -> Map.Map FieldId FieldInfo
repeatedFields m = Map.fromList
[ (fieldId f, plainFieldInfo f)
| f <- messageFields m
, RepeatedField{} <- [plainFieldKind f]
]
startParse :: ParseState Name -> ([Stmt], ParseState Exp)
startParse names =
([ pVar n <-- unsafeLiftIO' @@ "Data.ProtoLens.Encoding.Growing.new"
| n <- Map.elems mvectorNames
]
, ParseState
{ partialMessage = "Data.ProtoLens.defMessage"
, requiredFieldsUnset = const "Prelude.True"
<$> requiredFieldsUnset names
, repeatedFieldMVectors = var . unQual <$> mvectorNames
}
)
where
mvectorNames = repeatedFieldMVectors names
parseStateTypes :: Env QName -> MessageInfo Name -> ParseState Type
parseStateTypes env m = ParseState
{ partialMessage = tyCon (unQual $ messageName m)
, requiredFieldsUnset = fmap (const "Prelude.Bool")
$ requiredFieldsUnset
$ parseStateNames m
, repeatedFieldMVectors = growingType env <$> repeatedFields m
}
updateParseState ::
Exp
-> ParseState Exp
-> ParseState Exp
updateParseState f s = s { partialMessage = f @@ (partialMessage s) }
markRequiredField :: FieldId -> ParseState Exp -> ParseState Exp
markRequiredField f s =
s { requiredFieldsUnset = Map.insert f "Prelude.False"
$ requiredFieldsUnset s }
appendToRepeated :: FieldId -> Exp -> ParseState Exp -> (Stmt, ParseState Exp)
appendToRepeated f x s =
( v <-- unsafeLiftIO'
@@ ("Data.ProtoLens.Encoding.Growing.append"
@@ (repeatedFieldMVectors s Map.! f)
@@ x)
, s { repeatedFieldMVectors =
Map.insert f (var $ unQual v)
$ repeatedFieldMVectors s
}
)
where
v = "v"
checkMissingFields :: ParseState Exp -> Exp
checkMissingFields s =
let' [patBind missing allMissingFields]
$ if' ("Prelude.null" @@ missing) ("Prelude.return" @@ unit)
$ "Prelude.fail"
@@ ("Prelude.++"
@@ stringExp "Missing required fields: "
@@ ("Prelude.show" @@ (missing @::@ "[Prelude.String]")))
where
missing = "missing"
allMissingFields = Map.foldrWithKey consIfMissing emptyList (requiredFieldsUnset s)
consIfMissing (FieldId f) e rest =
(if' e (cons @@ stringExp (Text.unpack f)) "Prelude.id") @@ rest
parseTagCases ::
(ParseState Exp -> Exp)
-> ParseState Exp
-> MessageInfo Name
-> [Alt]
parseTagCases loop x info =
concatMap (parseFieldCase loop x) allFields
++ [unknownFieldCase info loop x]
where
allFields = messageFields info
++ [ PlainFieldInfo OptionalMaybeField (caseField c)
| o <- messageOneofFields info
, c <- oneofCases o
]
parseFieldCase ::
(ParseState Exp -> Exp) -> ParseState Exp -> PlainFieldInfo -> [Alt]
parseFieldCase loop x f = case plainFieldKind f of
MapField entryInfo -> [mapCase entryInfo]
RepeatedField p
| p == NotPackable -> [unpackedCase]
| otherwise -> [unpackedCase, packedCase]
RequiredField -> [requiredCase]
_ -> [valueCase]
where
y = "y"
entry = "entry"
info = plainFieldInfo f
valueCase = pLitInt (fieldTag info) --> do'
[ y <-- parseField info
, stmt . loop . updateParseState (setField info @@ y)
$ x
]
requiredCase = pLitInt (fieldTag info) --> do'
[ y <-- parseField info
, stmt . loop
. updateParseState (setField info @@ y)
. markRequiredField (fieldId f)
$ x
]
unpackedCase = pLitInt (fieldTag info) -->
let (appendStmt, x') = appendToRepeated (fieldId f) y x
in do'
[ bangPat y <-- parseField info
, appendStmt
, stmt . loop $ x'
]
packedCase = pLitInt (packedFieldTag info) --> do'
[ y <-- isolatedLengthy (parsePackedField info
@@ repeatedFieldMVectors x Map.! fieldId f)
, stmt $ loop x { repeatedFieldMVectors =
Map.insert (fieldId f) (var $ unQual y)
$ repeatedFieldMVectors x }
]
mapCase entryInfo = pLitInt (fieldTag info) --> do'
[ bangPat (entry `patTypeSig` tyCon (unQual $ mapEntryTypeName entryInfo))
<-- parseField info
, stmt . let' [ patBind "key"
$ view' @@ fieldOf (keyField entryInfo)
@@ entry
, patBind "value"
$ view' @@ fieldOf (valueField entryInfo)
@@ entry
]
. loop
. updateParseState
(overField info
("Data.Map.insert" @@ "key" @@ "value"))
$ x
]
unknownFieldCase ::
MessageInfo Name -> (ParseState Exp -> Exp) -> ParseState Exp -> Alt
unknownFieldCase info loop x = wire --> (do' $
[ bangPat y <-- "Data.ProtoLens.Encoding.Wire.parseTaggedValueFromWire" @@ wire
]
++
[ stmt $ case' y
[ pApp "Data.ProtoLens.Encoding.Wire.TaggedValue"
[utag, "Data.ProtoLens.Encoding.Wire.EndGroup"]
--> "Prelude.fail" @@
("Prelude.++"
@@ stringExp "Mismatched group-end tag number "
@@ ("Prelude.show" @@ utag))
, pWildCard --> "Prelude.return" @@ unit
]
| Just _ <- [groupFieldNumber info]
]
++
[ stmt . loop . updateParseState (over' unknownFields' (cons @@ y))
$ x
])
where
wire = "wire"
y = "y"
utag = "utag"
setField :: FieldInfo -> Exp
setField f = "Lens.Family2.set" @@ fieldOf f
overField :: FieldInfo -> Exp -> Exp
overField f = over' (fieldOf f)
over' :: Exp -> Exp -> Exp
over' f g = "Lens.Family2.over"
@@ f
@@ lambda [bangPat t] (g @@ t)
where
t = "t"
parsePackedField :: FieldInfo -> Exp
parsePackedField info = let' [funBind [match ploop [qs] ploopExp]]
ploop
where
ploop = "ploop"
q = "q"
qs = "qs"
qs' = "qs'"
packedEnd = "packedEnd"
ploopExp = do'
[ packedEnd <-- "Data.ProtoLens.Encoding.Bytes.atEnd"
, stmt $
if' packedEnd
("Prelude.return" @@ qs)
$ do'
[ bangPat q <-- parseField info
, qs' <-- unsafeLiftIO' @@
("Data.ProtoLens.Encoding.Growing.append"
@@ qs @@ q)
, stmt $ ploop @@ qs'
]
]
generatedBuilder :: MessageInfo Name -> Exp
generatedBuilder m =
lambda [x] $ foldMapExp $ map (buildPlainField x) (messageFields m)
++ map (buildOneofField x) (messageOneofFields m)
++ [buildUnknown x]
++ buildGroupEnd
where
x = "_x"
buildGroupEnd = [ putVarInt' @@ litInt (groupEndTag g)
| Just g <- [groupFieldNumber m]
]
buildUnknown :: Exp -> Exp
buildUnknown x
= "Data.ProtoLens.Encoding.Wire.buildFieldSet"
@@ (view' @@ unknownFields' @@ x)
foldMapExp :: [Exp] -> Exp
foldMapExp [] = mempty'
foldMapExp [x] = x
foldMapExp (x:xs) = "Data.Monoid.<>" @@ x @@ foldMapExp xs
buildPlainField :: Exp -> PlainFieldInfo -> Exp
buildPlainField x f = case plainFieldKind f of
RequiredField -> buildTaggedField info fieldValue
OptionalMaybeField -> case' maybeFieldValue
["Prelude.Nothing" --> mempty'
, "Prelude.Just" `pApp` [v]
--> buildTaggedField info v
]
OptionalValueField -> let' [patBind v fieldValue]
$ if' ("Prelude.==" @@ v @@ "Data.ProtoLens.fieldDefault")
mempty'
(buildTaggedField info v)
MapField entryInfo
-> "Data.Monoid.mconcat"
@@ ("Prelude.map"
@@ lambda [v] (buildEntry entryInfo v)
@@ ("Data.Map.toList" @@ fieldValue))
RepeatedField Packed -> buildPackedField info vectorFieldValue
RepeatedField _ -> "Data.ProtoLens.Encoding.Bytes.foldMapBuilder"
@@ lambda [v]
(buildTaggedField info v)
@@ vectorFieldValue
where
info = plainFieldInfo f
v = "_v"
fieldValue = view'
@@ fieldOf info
@@ x
maybeFieldValue = view'
@@ fieldOfMaybe info
@@ x
vectorFieldValue = view'
@@ fieldOfVector info
@@ x
buildEntry entry kv
= buildTaggedField info
$ set'
@@ fieldOf (keyField entry)
@@ ("Prelude.fst" @@ kv)
@@ (set' @@ fieldOf (valueField entry)
@@ ("Prelude.snd" @@ kv)
@@ ("Data.ProtoLens.defMessage"
@::@ tyCon (unQual $ mapEntryTypeName entry)))
fieldOf :: FieldInfo -> Exp
fieldOf = fieldOfExp . overloadedFieldName
fieldOfMaybe :: FieldInfo -> Exp
fieldOfMaybe = fieldOfExp . ("maybe'" <>) . overloadedFieldName
fieldOfOneof :: OneofInfo -> Exp
fieldOfOneof =
fieldOfExp . ("maybe'" <>) . overloadedName . oneofFieldName
fieldOfVector :: FieldInfo -> Exp
fieldOfVector = fieldOfExp . ("vec'" <>) . overloadedFieldName
buildTaggedField :: FieldInfo -> Exp -> Exp
buildTaggedField f x = foldMapExp
[ putVarInt' @@ litInt (fieldTag f)
, buildField f @@ x
]
buildPackedField :: FieldInfo -> Exp -> Exp
buildPackedField f x = let' [patBind p x]
$ if' ("Data.Vector.Generic.null" @@ p) mempty'
$ "Data.Monoid.<>"
@@ (putVarInt' @@ litInt (packedFieldTag f))
@@ (buildFieldType lengthy
@@ ("Data.ProtoLens.Encoding.Bytes.runBuilder"
@@ ("Data.ProtoLens.Encoding.Bytes.foldMapBuilder"
@@ buildField f
@@ p)))
where
p = "p"
buildOneofField :: Exp -> OneofInfo -> Exp
buildOneofField x info = case' (view' @@ fieldOfOneof info @@ x) $
("Prelude.Nothing" --> mempty')
: [ pApp "Prelude.Just" [pApp (unQual $ caseConstructorName c)
[v]]
--> buildTaggedField (caseField c) v
| c <- oneofCases info
]
where
v = "v"
makeTag :: Int32 -> FieldEncoding -> Integer
makeTag num enc = fromIntegral $ joinTypeAndTag (fromIntegral num) (wireType enc)
fieldTag :: FieldInfo -> Integer
fieldTag f = makeTag (fieldDescriptor f ^. number) $ fieldInfoEncoding f
packedFieldTag :: FieldInfo -> Integer
packedFieldTag f = makeTag (fieldDescriptor f ^. number) lengthy
groupEndTag :: Int32 -> Integer
groupEndTag num = makeTag num groupEnd
fieldOfExp :: Symbol -> Exp
fieldOfExp sym = "Data.ProtoLens.Field.field" @@ typeApp (promoteSymbol sym)
getVarInt', putVarInt', mempty', view', set', unknownFields', unsafeLiftIO'
:: Exp
getVarInt' = "Data.ProtoLens.Encoding.Bytes.getVarInt"
putVarInt' = "Data.ProtoLens.Encoding.Bytes.putVarInt"
mempty' = "Data.Monoid.mempty"
view' = "Lens.Family2.view"
set' = "Lens.Family2.set"
unknownFields' = "Data.ProtoLens.unknownFields"
unsafeLiftIO' = "Data.ProtoLens.Encoding.Parser.Unsafe.unsafeLiftIO"
parseField :: FieldInfo -> Exp
parseField f = "Data.ProtoLens.Encoding.Bytes.<?>"
@@ (parseFieldType $ fieldInfoEncoding f)
@@ stringExp n
where
n = Text.unpack (fieldDescriptor f ^. name)
buildField :: FieldInfo -> Exp
buildField = buildFieldType . fieldInfoEncoding
fieldInfoEncoding :: FieldInfo -> FieldEncoding
fieldInfoEncoding = fieldEncoding . view type' . fieldDescriptor
growingType :: Env QName -> FieldInfo -> Type
growingType env f
= "Data.ProtoLens.Encoding.Growing.Growing"
@@ hsFieldVectorType f
@@ "Data.ProtoLens.Encoding.Growing.RealWorld"
@@ hsFieldType env f