{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE TemplateHaskell #-} module FlatBuffers.Internal.Compiler.TH where import Control.Monad ( join ) import Control.Monad.Except ( runExceptT ) import Data.Bits ( (.&.) ) import Data.Foldable ( traverse_ ) import Data.Functor ( (<&>) ) import Data.Int import qualified Data.List as List import Data.List.NonEmpty ( NonEmpty(..) ) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as Map import Data.Text ( Text ) import qualified Data.Text as T import Data.Word import FlatBuffers.Internal.Build import qualified FlatBuffers.Internal.Compiler.NamingConventions as NC import qualified FlatBuffers.Internal.Compiler.ParserIO as ParserIO import FlatBuffers.Internal.Compiler.SemanticAnalysis ( SymbolTable(..) ) import qualified FlatBuffers.Internal.Compiler.SemanticAnalysis as SemanticAnalysis import qualified FlatBuffers.Internal.Compiler.SyntaxTree as SyntaxTree import FlatBuffers.Internal.Compiler.ValidSyntaxTree import FlatBuffers.Internal.FileIdentifier ( HasFileIdentifier(..), unsafeFileIdentifier ) import FlatBuffers.Internal.Read import FlatBuffers.Internal.Types import FlatBuffers.Internal.Write import Language.Haskell.TH import Language.Haskell.TH.Syntax ( lift ) import qualified Language.Haskell.TH.Syntax as TH -- | Helper method to create function types. -- @ConT ''Int ~> ConT ''String === Int -> String@ (~>) :: Type -> Type -> Type a ~> b = ArrowT `AppT` a `AppT` b infixr 1 ~> -- | Options to control how\/which flatbuffers constructors\/accessor should be generated. -- -- Options can be set using record syntax on `defaultOptions` with the fields below. -- -- > defaultOptions { compileAllSchemas = True } data Options = Options { -- | Directories to search for @include@s (same as flatc @-I@ option). includeDirectories :: [FilePath] -- | Generate code not just for the root schema, -- but for all schemas it includes as well -- (same as flatc @--gen-all@ option). , compileAllSchemas :: Bool } deriving (Show, Eq) -- | Default flatbuffers options: -- -- > Options -- > { includeDirectories = [] -- > , compileAllSchemas = False -- > } defaultOptions :: Options defaultOptions = Options { includeDirectories = [] , compileAllSchemas = False } -- | Generates constructors and accessors for all data types declared in the given flatbuffers -- schema whose namespace matches the current module. -- -- > namespace Data.Game; -- > -- > table Monster {} -- -- > {-# LANGUAGE TemplateHaskell #-} -- > -- > module Data.Game where -- > import FlatBuffers -- > -- > $(mkFlatBuffers "schemas/game.fbs" defaultOptions) mkFlatBuffers :: FilePath -> Options -> Q [Dec] mkFlatBuffers rootFilePath opts = do currentModule <- T.pack . loc_module <$> location parseResult <- runIO $ runExceptT $ ParserIO.parseSchemas rootFilePath (includeDirectories opts) schemaFileTree <- either (fail . fixMsg) pure parseResult registerFiles schemaFileTree symbolTables <- either (fail . fixMsg) pure $ SemanticAnalysis.validateSchemas schemaFileTree let symbolTable = if compileAllSchemas opts then SyntaxTree.fileTreeRoot symbolTables <> mconcat (Map.elems $ SyntaxTree.fileTreeForest symbolTables) else SyntaxTree.fileTreeRoot symbolTables let symbolTable' = filterByCurrentModule currentModule symbolTable compileSymbolTable symbolTable' where registerFiles (SyntaxTree.FileTree rootFilePath _ includedFiles) = do TH.addDependentFile rootFilePath traverse_ TH.addDependentFile $ Map.keys includedFiles filterByCurrentModule currentModule (SymbolTable enums structs tables unions) = SymbolTable { allEnums = Map.filterWithKey (isCurrentModule currentModule) enums , allStructs = Map.filterWithKey (isCurrentModule currentModule) structs , allTables = Map.filterWithKey (isCurrentModule currentModule) tables , allUnions = Map.filterWithKey (isCurrentModule currentModule) unions } isCurrentModule currentModule (ns, _) _ = NC.namespace ns == currentModule -- | This does two things: -- -- 1. ghcid stops parsing an error when it finds a line that start with alphabetical characters or an empty lines, -- so we prepend each line with an empty space to avoid this. -- 2. we also remove any trailing \n, otherwise ghcid would stop parsing here and not show the source code location. fixMsg :: String -> String fixMsg = List.intercalate "\n" . fmap fixLine . lines where fixLine line = " " <> line compileSymbolTable :: SemanticAnalysis.ValidDecls -> Q [Dec] compileSymbolTable symbolTable = do enumDecs <- join <$> traverse mkEnum (Map.elems (allEnums symbolTable)) structDecs <- join <$> traverse mkStruct (Map.elems (allStructs symbolTable)) tableDecs <- join <$> traverse mkTable (Map.elems (allTables symbolTable)) unionDecs <- join <$> traverse mkUnion (Map.elems (allUnions symbolTable)) pure $ enumDecs <> structDecs <> tableDecs <> unionDecs mkEnum :: EnumDecl -> Q [Dec] mkEnum enum = if enumBitFlags enum then mkEnumBitFlags enum else mkEnumNormal enum mkEnumBitFlags :: EnumDecl -> Q [Dec] mkEnumBitFlags enum = do nameFun <- mkEnumBitFlagsNames enum enumValNames pure $ mkEnumBitFlagsConstants enum enumValNames <> mkEnumBitFlagsAllValls enum enumValNames <> nameFun where enumValNames = mkName . T.unpack . NC.enumBitFlagsConstant enum <$> NE.toList (enumVals enum) mkEnumBitFlagsConstants :: EnumDecl -> [Name] -> [Dec] mkEnumBitFlagsConstants enum enumValNames = NE.toList (enumVals enum) `zip` enumValNames >>= \(enumVal, enumValName) -> let sig = SigD enumValName (enumTypeToType (enumType enum)) fun = FunD enumValName [Clause [] (NormalB (intLitE (enumValInt enumVal))) []] in [sig, fun] -- | Generates a list with all the enum values, e.g. -- -- > allColors = [colorsRed, colorsGreen, colorsBlue] mkEnumBitFlagsAllValls :: EnumDecl -> [Name] -> [Dec] mkEnumBitFlagsAllValls enum enumValNames = let name = mkName $ T.unpack $ NC.enumBitFlagsAllFun enum sig = SigD name (ListT `AppT` enumTypeToType (enumType enum)) fun = FunD name [ Clause [] (NormalB body) []] body = ListE (VarE <$> enumValNames) in [sig, fun, inlinePragma name] -- | Generates @colorsNames@. mkEnumBitFlagsNames :: EnumDecl -> [Name] -> Q [Dec] mkEnumBitFlagsNames enum enumValNames = do inputName <- newName "c" firstRes <- newName "res0" firstClause <- [d| $(varP firstRes) = [] |] (clauses, lastRes) <- mkClauses namesAndIdentifiers 1 inputName firstRes firstClause let fun = FunD funName [ Clause [VarP inputName] (NormalB (VarE lastRes)) (List.reverse clauses) ] pure [ sig , fun , inlinePragma funName ] where funName = mkName $ T.unpack $ NC.enumBitFlagsNamesFun enum sig = SigD funName (enumTypeToType (enumType enum) ~> ListT `AppT` ConT ''Text) namesAndIdentifiers :: [(Name, Ident)] namesAndIdentifiers = List.reverse (enumValNames `zip` fmap enumValIdent (NE.toList (enumVals enum))) mkClauses :: [(Name, Ident)] -> Int -> Name -> Name -> [Dec] -> Q ([Dec], Name) mkClauses [] _ _ previousRes clauses = pure (clauses, previousRes) mkClauses ((name, Ident ident) : rest) ix inputName previousRes clauses = do res <- newName ("res" <> show ix) clause <- [d| $(varP res) = if $(varE name) .&. $(varE inputName) /= 0 then $(pure (textLitE ident)) : $(varE previousRes) else $(varE previousRes) |] mkClauses rest (ix + 1) inputName res (clause <> clauses) -- | Generated declarations for a non-bit-flags enum. mkEnumNormal :: EnumDecl -> Q [Dec] mkEnumNormal enum = do let enumName = mkName' $ NC.dataTypeName enum let enumValNames = enumVals enum <&> \enumVal -> mkName $ T.unpack $ NC.enumUnionMember enum enumVal let enumDec = mkEnumDataDec enumName enumValNames let enumValsAndNames = enumVals enum `NE.zip` enumValNames toEnumDecs <- mkToEnum enumName enum enumValsAndNames fromEnumDecs <- mkFromEnum enumName enum enumValsAndNames enumNameDecs <- mkEnumNameFun enumName enum enumValsAndNames pure $ enumDec : toEnumDecs <> fromEnumDecs <> enumNameDecs mkEnumDataDec :: Name -> NonEmpty Name -> Dec mkEnumDataDec enumName enumValNames = DataD [] enumName [] Nothing (fmap (\n -> NormalC n []) (NE.toList enumValNames)) [ DerivClause Nothing [ ConT ''Eq , ConT ''Show , ConT ''Read , ConT ''Ord , ConT ''Bounded ] ] mkToEnum :: Name -> EnumDecl -> NonEmpty (EnumVal, Name) -> Q [Dec] mkToEnum enumName enum enumValsAndNames = do let funName = mkName' $ NC.toEnumFun enum argName <- newName "n" pure [ SigD funName (enumTypeToType (enumType enum) ~> ConT ''Maybe `AppT` ConT enumName) , FunD funName [ Clause [VarP argName] (NormalB (CaseE (VarE argName) matches)) [] ] , inlinePragma funName ] where matches = (mkMatch <$> NE.toList enumValsAndNames) <> [matchWildcard] mkMatch (enumVal, enumName) = Match (intLitP (enumValInt enumVal)) (NormalB (ConE 'Just `AppE` ConE enumName)) [] matchWildcard = Match WildP (NormalB (ConE 'Nothing)) [] mkFromEnum :: Name -> EnumDecl -> NonEmpty (EnumVal, Name) -> Q [Dec] mkFromEnum enumName enum enumValsAndNames = do let funName = mkName' $ NC.fromEnumFun enum argName <- newName "n" pure [ SigD funName (ConT enumName ~> enumTypeToType (enumType enum)) , FunD funName [ Clause [VarP argName] (NormalB (CaseE (VarE argName) (mkMatch <$> NE.toList enumValsAndNames))) [] ] , inlinePragma funName ] where mkMatch (enumVal, enumName) = Match (ConP enumName []) (NormalB (intLitE (enumValInt enumVal))) [] -- | Generates @colorsName@. mkEnumNameFun :: Name -> EnumDecl -> NonEmpty (EnumVal, Name) -> Q [Dec] mkEnumNameFun enumName enum enumValsAndNames = do let funName = mkName' $ NC.enumNameFun enum argName <- newName "c" pure [ SigD funName (ConT enumName ~> ConT ''Text) , FunD funName [ Clause [VarP argName] (NormalB (CaseE (VarE argName) (mkMatch <$> NE.toList enumValsAndNames))) [] ] , inlinePragma funName ] where mkMatch (enumVal, enumName) = Match (ConP enumName []) (NormalB (textLitE (unIdent (getIdent enumVal)))) [] mkStruct :: StructDecl -> Q [Dec] mkStruct struct = do let structName = mkName' $ NC.dataTypeName struct isStructInstance <- mkIsStructInstance structName struct let dataDec = DataD [] structName [] Nothing [] [] (consSig, cons) <- mkStructConstructor structName struct let getters = foldMap (mkStructFieldGetter structName struct) (structFields struct) pure $ dataDec : isStructInstance <> [ consSig, cons ] <> getters mkIsStructInstance :: Name -> StructDecl -> Q [Dec] mkIsStructInstance structName struct = [d| instance IsStruct $(conT structName) where structAlignmentOf = $(lift . unAlignment . structAlignment $ struct) structSizeOf = $(lift . unInlineSize . structSize $ struct) |] mkStructConstructor :: Name -> StructDecl -> Q (Dec, Dec) mkStructConstructor structName struct = do argsInfo <- traverse mkStructConstructorArg (structFields struct) let (argTypes, pats, exps) = nonEmptyUnzip3 argsInfo let retType = AppT (ConT ''WriteStruct) (ConT structName) let sigType = foldr (~>) retType argTypes let consName = mkName' $ NC.dataTypeConstructor struct let consSig = SigD consName sigType let exp = foldr1 (\e acc -> InfixE (Just e) (VarE '(<>)) (Just acc)) (join exps) let body = NormalB $ ConE 'WriteStruct `AppE` exp let cons = FunD consName [ Clause (NE.toList pats) body [] ] pure (consSig, cons) mkStructConstructorArg :: StructField -> Q (Type, Pat, NonEmpty Exp) mkStructConstructorArg sf = do argName <- newName' $ NC.arg sf let argPat = VarP argName let argRef = VarE argName let argType = structFieldTypeToWriteType (structFieldType sf) let mkWriteExp sft = case sft of SInt8 -> VarE 'buildInt8 SInt16 -> VarE 'buildInt16 SInt32 -> VarE 'buildInt32 SInt64 -> VarE 'buildInt64 SWord8 -> VarE 'buildWord8 SWord16 -> VarE 'buildWord16 SWord32 -> VarE 'buildWord32 SWord64 -> VarE 'buildWord64 SFloat -> VarE 'buildFloat SDouble -> VarE 'buildDouble SBool -> VarE 'buildBool SEnum _ enumType -> mkWriteExp (enumTypeToStructFieldType enumType) SStruct _ -> VarE 'buildStruct let exp = mkWriteExp (structFieldType sf) `AppE` argRef let exps = if structFieldPadding sf == 0 then [ exp ] else [ exp , VarE 'buildPadding `AppE` intLitE (structFieldPadding sf) ] pure (argType, argPat, exps) mkStructFieldGetter :: Name -> StructDecl -> StructField -> [Dec] mkStructFieldGetter structName struct sf = [sig, fun] where funName = mkName (T.unpack (NC.getter struct sf)) fieldOffsetExp = intLitE (structFieldOffset sf) retType = structFieldTypeToReadType (structFieldType sf) sig = SigD funName $ case structFieldType sf of SStruct _ -> ConT ''Struct `AppT` ConT structName ~> retType _ -> ConT ''Struct `AppT` ConT structName ~> ConT ''Either `AppT` ConT ''ReadError `AppT` retType fun = FunD funName [ Clause [] (NormalB body) [] ] body = app [ VarE 'readStructField , mkReadExp (structFieldType sf) , fieldOffsetExp ] mkReadExp sft = case sft of SInt8 -> VarE 'readInt8 SInt16 -> VarE 'readInt16 SInt32 -> VarE 'readInt32 SInt64 -> VarE 'readInt64 SWord8 -> VarE 'readWord8 SWord16 -> VarE 'readWord16 SWord32 -> VarE 'readWord32 SWord64 -> VarE 'readWord64 SFloat -> VarE 'readFloat SDouble -> VarE 'readDouble SBool -> VarE 'readBool SEnum _ enumType -> mkReadExp $ enumTypeToStructFieldType enumType SStruct _ -> VarE 'readStruct mkTable :: TableDecl -> Q [Dec] mkTable table = do let tableName = mkName' $ NC.dataTypeName table (consSig, cons) <- mkTableConstructor tableName table let fileIdentifierDec = mkTableFileIdentifier tableName (tableIsRoot table) let getters = foldMap (mkTableFieldGetter tableName table) (tableFields table) pure $ [ DataD [] tableName [] Nothing [] [] , consSig , cons ] <> fileIdentifierDec <> getters mkTableFileIdentifier :: Name -> IsRoot -> [Dec] mkTableFileIdentifier tableName isRoot = case isRoot of NotRoot -> [] IsRoot Nothing -> [] IsRoot (Just fileIdentifier) -> [ InstanceD Nothing [] (ConT ''HasFileIdentifier `AppT` ConT tableName) [ FunD 'getFileIdentifier [ Clause [] (NormalB $ VarE 'unsafeFileIdentifier `AppE` textLitE fileIdentifier) [] ] ] ] mkTableConstructor :: Name -> TableDecl -> Q (Dec, Dec) mkTableConstructor tableName table = do (argTypes, pats, exps) <- mconcat <$> traverse mkTableContructorArg (tableFields table) let retType = AppT (ConT ''WriteTable) (ConT tableName) let sigType = foldr (~>) retType argTypes let consName = mkName' $ NC.dataTypeConstructor table let consSig = SigD consName sigType let body = NormalB $ AppE (VarE 'writeTable) (ListE exps) let cons = FunD consName [ Clause pats body [] ] pure (consSig, cons) mkTableContructorArg :: TableField -> Q ([Type], [Pat], [Exp]) mkTableContructorArg tf = if tableFieldDeprecated tf then case tableFieldType tf of TUnion _ _ -> pure ([], [], [VarE 'deprecated, VarE 'deprecated]) TVector _ (VUnion _) -> pure ([], [], [VarE 'deprecated, VarE 'deprecated]) _ -> pure ([], [], [VarE 'deprecated]) else do argName <- newName' $ NC.arg tf let argPat = VarP argName let argRef = VarE argName let argType = tableFieldTypeToWriteType (tableFieldType tf) let exps = mkExps argRef (tableFieldType tf) pure ([argType], [argPat], exps) where expForScalar :: Exp -> Exp -> Exp -> Exp expForScalar defaultValExp writeExp varExp = VarE 'optionalDef `AppE` defaultValExp `AppE` writeExp `AppE` varExp expForNonScalar :: Required -> Exp -> Exp -> Exp expForNonScalar Req exp argRef = exp `AppE` argRef expForNonScalar Opt exp argRef = VarE 'optional `AppE` exp `AppE` argRef mkExps :: Exp -> TableFieldType -> [Exp] mkExps argRef tfType = case tfType of TInt8 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt8TableField ) argRef TInt16 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt16TableField ) argRef TInt32 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt32TableField ) argRef TInt64 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeInt64TableField ) argRef TWord8 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord8TableField ) argRef TWord16 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord16TableField ) argRef TWord32 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord32TableField ) argRef TWord64 (DefaultVal n) -> pure $ expForScalar (intLitE n) (VarE 'writeWord64TableField ) argRef TFloat (DefaultVal n) -> pure $ expForScalar (realLitE n) (VarE 'writeFloatTableField ) argRef TDouble (DefaultVal n) -> pure $ expForScalar (realLitE n) (VarE 'writeDoubleTableField ) argRef TBool (DefaultVal b) -> pure $ expForScalar (if b then ConE 'True else ConE 'False) (VarE 'writeBoolTableField) argRef TString req     -> pure $ expForNonScalar req (VarE 'writeTextTableField) argRef TEnum _ enumType dflt -> mkExps argRef (enumTypeToTableFieldType enumType dflt) TStruct _ req -> pure $ expForNonScalar req (VarE 'writeStructTableField) argRef TTable _ req -> pure $ expForNonScalar req (VarE 'writeTableTableField) argRef TUnion _ _ -> [ VarE 'writeUnionTypeTableField `AppE` argRef , VarE 'writeUnionValueTableField `AppE` argRef ] TVector req vecElemType -> mkExpForVector argRef req vecElemType mkExpForVector :: Exp -> Required -> VectorElementType -> [Exp] mkExpForVector argRef req vecElemType = case vecElemType of VInt8 -> [ expForNonScalar req (VarE 'writeVectorInt8TableField) argRef ] VInt16 -> [ expForNonScalar req (VarE 'writeVectorInt16TableField) argRef ] VInt32 -> [ expForNonScalar req (VarE 'writeVectorInt32TableField) argRef ] VInt64 -> [ expForNonScalar req (VarE 'writeVectorInt64TableField) argRef ] VWord8 -> [ expForNonScalar req (VarE 'writeVectorWord8TableField) argRef ] VWord16 -> [ expForNonScalar req (VarE 'writeVectorWord16TableField) argRef ] VWord32 -> [ expForNonScalar req (VarE 'writeVectorWord32TableField) argRef ] VWord64 -> [ expForNonScalar req (VarE 'writeVectorWord64TableField) argRef ] VFloat -> [ expForNonScalar req (VarE 'writeVectorFloatTableField) argRef ] VDouble -> [ expForNonScalar req (VarE 'writeVectorDoubleTableField) argRef ] VBool -> [ expForNonScalar req (VarE 'writeVectorBoolTableField) argRef ] VString -> [ expForNonScalar req (VarE 'writeVectorTextTableField) argRef ] VEnum _ enumType -> mkExpForVector argRef req (enumTypeToVectorElementType enumType) VStruct _ -> [ expForNonScalar req (VarE 'writeVectorStructTableField) argRef ] VTable _ -> [ expForNonScalar req (VarE 'writeVectorTableTableField) argRef ] VUnion _ -> [ expForNonScalar req (VarE 'writeUnionTypesVectorTableField) argRef , expForNonScalar req (VarE 'writeUnionValuesVectorTableField) argRef ] mkTableFieldGetter :: Name -> TableDecl -> TableField -> [Dec] mkTableFieldGetter tableName table tf = if tableFieldDeprecated tf then [] else [sig, mkFun (tableFieldType tf)] where funName = mkName (T.unpack (NC.getter table tf)) fieldIndex = intLitE (tableFieldId tf) sig = SigD funName $ ConT ''Table `AppT` ConT tableName ~> ConT ''Either `AppT` ConT ''ReadError `AppT` tableFieldTypeToReadType (tableFieldType tf) mkFun :: TableFieldType -> Dec mkFun tft = case tft of TWord8 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord8)) TWord16 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord16)) TWord32 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord32)) TWord64 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readWord64)) TInt8 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt8)) TInt16 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt16)) TInt32 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt32)) TInt64 (DefaultVal n) -> mkFunWithBody (bodyForScalar (intLitE n) (VarE 'readInt64)) TFloat (DefaultVal n) -> mkFunWithBody (bodyForScalar (realLitE n) (VarE 'readFloat)) TDouble (DefaultVal n) -> mkFunWithBody (bodyForScalar (realLitE n) (VarE 'readDouble)) TBool (DefaultVal b) -> mkFunWithBody (bodyForScalar (if b then ConE 'True else ConE 'False) (VarE 'readBool)) TString req -> mkFunWithBody (bodyForNonScalar req (VarE 'readText)) TEnum _ enumType dflt -> mkFun $ enumTypeToTableFieldType enumType dflt TStruct _ req -> mkFunWithBody (bodyForNonScalar req (compose [ConE 'Right, VarE 'readStruct])) TTable _ req -> mkFunWithBody (bodyForNonScalar req (VarE 'readTable)) TUnion (TypeRef ns ident) _req -> mkFunWithBody $ app [ VarE 'readTableFieldUnion , VarE . mkName . T.unpack . NC.withModulePrefix ns $ NC.readUnionFun ident , fieldIndex ] TVector req vecElemType -> mkFunForVector req vecElemType mkFunForVector :: Required -> VectorElementType -> Dec mkFunForVector req vecElemType = case vecElemType of VInt8 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt8 VInt16 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt16 VInt32 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt32 VInt64 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorInt64 VWord8 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord8 VWord16 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord16 VWord32 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord32 VWord64 -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorWord64 VFloat -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorFloat VDouble -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorDouble VBool -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorBool VString -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorText VEnum _ enumType -> mkFunForVector req (enumTypeToVectorElementType enumType) VStruct _ -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readPrimVector `AppE` ConE 'VectorStruct VTable _ -> mkFunWithBody $ bodyForNonScalar req $ VarE 'readTableVector VUnion (TypeRef ns ident) -> mkFunWithBody $ case req of Opt -> app [ VarE 'readTableFieldUnionVectorOpt , VarE . mkName . T.unpack . NC.withModulePrefix ns $ NC.readUnionFun ident , fieldIndex ] Req -> app [ VarE 'readTableFieldUnionVectorReq , VarE . mkName . T.unpack . NC.withModulePrefix ns $ NC.readUnionFun ident , fieldIndex , stringLitE . unIdent . getIdent $ tf ] mkFunWithBody :: Exp -> Dec mkFunWithBody body = FunD funName [ Clause [] (NormalB body) [] ] bodyForNonScalar req readExp = case req of Req -> app [ VarE 'readTableFieldReq , readExp , fieldIndex , stringLitE . unIdent . getIdent $ tf ] Opt -> app [ VarE 'readTableFieldOpt , readExp , fieldIndex ] bodyForScalar defaultValExp readExp = app [ VarE 'readTableFieldWithDef , readExp , fieldIndex , defaultValExp ] mkUnion :: UnionDecl -> Q [Dec] mkUnion union = do let unionName = mkName' $ NC.dataTypeName union let unionValNames = unionVals union <&> \unionVal -> mkName $ T.unpack $ NC.enumUnionMember union unionVal unionConstructors <- mkUnionConstructors unionName union readFun <- mkReadUnionFun unionName unionValNames union pure $ mkUnionDataDec unionName (unionVals union `NE.zip` unionValNames) : unionConstructors <> readFun mkUnionDataDec :: Name -> NonEmpty (UnionVal, Name) -> Dec mkUnionDataDec unionName unionValsAndNames = DataD [] unionName [] Nothing (NE.toList $ fmap mkCons unionValsAndNames) [] where mkCons (unionVal, unionValName) = NormalC unionValName [(bang, ConT ''Table `AppT` typeRefToType (unionValTableRef unionVal))] bang = Bang NoSourceUnpackedness SourceStrict mkUnionConstructors :: Name -> UnionDecl -> Q [Dec] mkUnionConstructors unionName union = fmap join . traverse mkUnionConstructor $ NE.toList (unionVals union) `zip` [1..] where mkUnionConstructor :: (UnionVal, Integer) -> Q [Dec] mkUnionConstructor (unionVal, ix) = do let constructorName = mkName' $ NC.unionConstructor union unionVal pure [ SigD constructorName $ ConT ''WriteTable `AppT` typeRefToType (unionValTableRef unionVal) ~> ConT ''WriteUnion `AppT` ConT unionName , FunD constructorName [ Clause [] (NormalB $ VarE 'writeUnion `AppE` intLitE ix) [] ] ] mkReadUnionFun :: Name -> NonEmpty Name -> UnionDecl -> Q [Dec] mkReadUnionFun unionName unionValNames union = do nArg <- newName "n" posArg <- newName "pos" wildcard <- newName "n'" let funName = mkName $ T.unpack $ NC.readUnionFun union let sig = SigD funName $ ConT ''Positive `AppT` ConT ''Word8 ~> ConT ''PositionInfo ~> ConT ''Either `AppT` ConT ''ReadError `AppT` (ConT ''Union `AppT` ConT unionName) let mkMatch :: Name -> Integer -> Match mkMatch unionValName ix = Match (intLitP ix) (NormalB $ InfixE (Just (compose [ConE 'Union, ConE unionValName])) (VarE '(<$>)) (Just (VarE 'readTable' `AppE` VarE posArg)) ) [] let matchWildcard = Match (VarP wildcard) (NormalB $ InfixE (Just (VarE 'pure)) (VarE '($!)) (Just (ConE 'UnionUnknown `AppE` VarE wildcard)) ) [] let matches = (uncurry mkMatch <$> NE.toList unionValNames `zip` [1..]) <> [matchWildcard] let funBody = NormalB $ CaseE (VarE 'getPositive `AppE` VarE nArg) matches let fun = FunD funName [ Clause [VarP nArg, VarP posArg] funBody [] ] pure [sig, fun] enumTypeToType :: EnumType -> Type enumTypeToType et = case et of EInt8 -> ConT ''Int8 EInt16 -> ConT ''Int16 EInt32 -> ConT ''Int32 EInt64 -> ConT ''Int64 EWord8 -> ConT ''Word8 EWord16 -> ConT ''Word16 EWord32 -> ConT ''Word32 EWord64 -> ConT ''Word64 enumTypeToTableFieldType :: Integral a => EnumType -> DefaultVal a -> TableFieldType enumTypeToTableFieldType et dflt = case et of EInt8 -> TInt8 (fromIntegral dflt) EInt16 -> TInt16 (fromIntegral dflt) EInt32 -> TInt32 (fromIntegral dflt) EInt64 -> TInt64 (fromIntegral dflt) EWord8 -> TWord8 (fromIntegral dflt) EWord16 -> TWord16 (fromIntegral dflt) EWord32 -> TWord32 (fromIntegral dflt) EWord64 -> TWord64 (fromIntegral dflt) enumTypeToStructFieldType :: EnumType -> StructFieldType enumTypeToStructFieldType et = case et of EInt8 -> SInt8 EInt16 -> SInt16 EInt32 -> SInt32 EInt64 -> SInt64 EWord8 -> SWord8 EWord16 -> SWord16 EWord32 -> SWord32 EWord64 -> SWord64 enumTypeToVectorElementType :: EnumType -> VectorElementType enumTypeToVectorElementType et = case et of EInt8 -> VInt8 EInt16 -> VInt16 EInt32 -> VInt32 EInt64 -> VInt64 EWord8 -> VWord8 EWord16 -> VWord16 EWord32 -> VWord32 EWord64 -> VWord64 structFieldTypeToWriteType :: StructFieldType -> Type structFieldTypeToWriteType sft = case sft of SInt8 -> ConT ''Int8 SInt16 -> ConT ''Int16 SInt32 -> ConT ''Int32 SInt64 -> ConT ''Int64 SWord8 -> ConT ''Word8 SWord16 -> ConT ''Word16 SWord32 -> ConT ''Word32 SWord64 -> ConT ''Word64 SFloat -> ConT ''Float SDouble -> ConT ''Double SBool -> ConT ''Bool SEnum _ enumType -> enumTypeToType enumType SStruct (namespace, structDecl) -> ConT ''WriteStruct `AppT` typeRefToType (TypeRef namespace (getIdent structDecl)) structFieldTypeToReadType :: StructFieldType -> Type structFieldTypeToReadType sft = case sft of SInt8 -> ConT ''Int8 SInt16 -> ConT ''Int16 SInt32 -> ConT ''Int32 SInt64 -> ConT ''Int64 SWord8 -> ConT ''Word8 SWord16 -> ConT ''Word16 SWord32 -> ConT ''Word32 SWord64 -> ConT ''Word64 SFloat -> ConT ''Float SDouble -> ConT ''Double SBool -> ConT ''Bool SEnum _ enumType -> enumTypeToType enumType SStruct (namespace, structDecl) -> ConT ''Struct `AppT` typeRefToType (TypeRef namespace (getIdent structDecl)) tableFieldTypeToWriteType :: TableFieldType -> Type tableFieldTypeToWriteType tft = case tft of TInt8 _ -> ConT ''Maybe `AppT` ConT ''Int8 TInt16 _ -> ConT ''Maybe `AppT` ConT ''Int16 TInt32 _ -> ConT ''Maybe `AppT` ConT ''Int32 TInt64 _ -> ConT ''Maybe `AppT` ConT ''Int64 TWord8 _ -> ConT ''Maybe `AppT` ConT ''Word8 TWord16 _ -> ConT ''Maybe `AppT` ConT ''Word16 TWord32 _ -> ConT ''Maybe `AppT` ConT ''Word32 TWord64 _ -> ConT ''Maybe `AppT` ConT ''Word64 TFloat _ -> ConT ''Maybe `AppT` ConT ''Float TDouble _ -> ConT ''Maybe `AppT` ConT ''Double TBool _ -> ConT ''Maybe `AppT` ConT ''Bool TString req -> requiredType req (ConT ''Text) TEnum _ enumType _ -> ConT ''Maybe `AppT` enumTypeToType enumType TStruct typeRef req -> requiredType req (ConT ''WriteStruct `AppT` typeRefToType typeRef) TTable typeRef req -> requiredType req (ConT ''WriteTable `AppT` typeRefToType typeRef) TUnion typeRef _ -> ConT ''WriteUnion `AppT` typeRefToType typeRef TVector req vecElemType -> requiredType req (vectorElementTypeToWriteType vecElemType) tableFieldTypeToReadType :: TableFieldType -> Type tableFieldTypeToReadType tft = case tft of TInt8 _ -> ConT ''Int8 TInt16 _ -> ConT ''Int16 TInt32 _ -> ConT ''Int32 TInt64 _ -> ConT ''Int64 TWord8 _ -> ConT ''Word8 TWord16 _ -> ConT ''Word16 TWord32 _ -> ConT ''Word32 TWord64 _ -> ConT ''Word64 TFloat _ -> ConT ''Float TDouble _ -> ConT ''Double TBool _ -> ConT ''Bool TString req -> requiredType req (ConT ''Text) TEnum _ enumType _ -> enumTypeToType enumType TStruct typeRef req -> requiredType req (ConT ''Struct `AppT` typeRefToType typeRef) TTable typeRef req -> requiredType req (ConT ''Table `AppT` typeRefToType typeRef) TUnion typeRef _ -> ConT ''Union `AppT` typeRefToType typeRef TVector req vecElemType -> requiredType req (vectorElementTypeToReadType vecElemType) vectorElementTypeToWriteType :: VectorElementType -> Type vectorElementTypeToWriteType vet = case vet of VInt8 -> ConT ''WriteVector `AppT` ConT ''Int8 VInt16 -> ConT ''WriteVector `AppT` ConT ''Int16 VInt32 -> ConT ''WriteVector `AppT` ConT ''Int32 VInt64 -> ConT ''WriteVector `AppT` ConT ''Int64 VWord8 -> ConT ''WriteVector `AppT` ConT ''Word8 VWord16 -> ConT ''WriteVector `AppT` ConT ''Word16 VWord32 -> ConT ''WriteVector `AppT` ConT ''Word32 VWord64 -> ConT ''WriteVector `AppT` ConT ''Word64 VFloat -> ConT ''WriteVector `AppT` ConT ''Float VDouble -> ConT ''WriteVector `AppT` ConT ''Double VBool -> ConT ''WriteVector `AppT` ConT ''Bool VString -> ConT ''WriteVector `AppT` ConT ''Text VEnum _ enumType -> ConT ''WriteVector `AppT` enumTypeToType enumType VStruct typeRef -> ConT ''WriteVector `AppT` (ConT ''WriteStruct `AppT` typeRefToType typeRef) VTable typeRef -> ConT ''WriteVector `AppT` (ConT ''WriteTable `AppT` typeRefToType typeRef) VUnion typeRef -> ConT ''WriteVector `AppT` (ConT ''WriteUnion `AppT` typeRefToType typeRef) vectorElementTypeToReadType :: VectorElementType -> Type vectorElementTypeToReadType vet = case vet of VInt8 -> ConT ''Vector `AppT` ConT ''Int8 VInt16 -> ConT ''Vector `AppT` ConT ''Int16 VInt32 -> ConT ''Vector `AppT` ConT ''Int32 VInt64 -> ConT ''Vector `AppT` ConT ''Int64 VWord8 -> ConT ''Vector `AppT` ConT ''Word8 VWord16 -> ConT ''Vector `AppT` ConT ''Word16 VWord32 -> ConT ''Vector `AppT` ConT ''Word32 VWord64 -> ConT ''Vector `AppT` ConT ''Word64 VFloat -> ConT ''Vector `AppT` ConT ''Float VDouble -> ConT ''Vector `AppT` ConT ''Double VBool -> ConT ''Vector `AppT` ConT ''Bool VString -> ConT ''Vector `AppT` ConT ''Text VEnum _ enumType -> ConT ''Vector `AppT` enumTypeToType enumType VStruct typeRef -> ConT ''Vector `AppT` (ConT ''Struct `AppT` typeRefToType typeRef) VTable typeRef -> ConT ''Vector `AppT` (ConT ''Table `AppT` typeRefToType typeRef) VUnion typeRef -> ConT ''Vector `AppT` (ConT ''Union `AppT` typeRefToType typeRef) typeRefToType :: TypeRef -> Type typeRefToType (TypeRef ns ident) = ConT . mkName' . NC.withModulePrefix ns . NC.dataTypeName $ ident requiredType :: Required -> Type -> Type requiredType Req t = t requiredType Opt t = AppT (ConT ''Maybe) t mkName' :: Text -> Name mkName' = mkName . T.unpack newName' :: Text -> Q Name newName' = newName . T.unpack intLitP :: Integral i => i -> Pat intLitP = LitP . IntegerL . toInteger intLitE :: Integral i => i -> Exp intLitE = LitE . IntegerL . toInteger realLitE :: Real i => i -> Exp realLitE = LitE . RationalL . toRational textLitE :: Text -> Exp textLitE t = VarE 'T.pack `AppE` LitE (StringL (T.unpack t)) stringLitE :: Text -> Exp stringLitE t = LitE (StringL (T.unpack t)) inlinePragma :: Name -> Dec inlinePragma funName = PragmaD $ InlineP funName Inline FunLike AllPhases -- | Applies a function to multiple arguments. Assumes the list is not empty. app :: [Exp] -> Exp app = foldl1 AppE compose :: [Exp] -> Exp compose = foldr1 (\e1 e2 -> InfixE (Just e1) (VarE '(.)) (Just e2)) nonEmptyUnzip3 :: NonEmpty (a,b,c) -> (NonEmpty a, NonEmpty b, NonEmpty c) nonEmptyUnzip3 xs = ( (\(x, _, _) -> x) <$> xs , (\(_, x, _) -> x) <$> xs , (\(_, _, x) -> x) <$> xs )