{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
module FlatBuffers.Internal.Compiler.SemanticAnalysis where
import Control.Monad ( forM_, join, when )
import Control.Monad.Except ( MonadError, throwError )
import Control.Monad.Reader ( MonadReader(..), asks, runReaderT )
import Control.Monad.State ( MonadState, State, StateT, evalState, evalStateT, get, modify, put )
import Data.Coerce ( coerce )
import Data.Foldable ( asum, find, foldlM, traverse_ )
import qualified Data.Foldable as Foldable
import Data.Functor ( ($>), (<&>) )
import Data.Int
import Data.Ix ( inRange )
import qualified Data.List as List
import Data.List.NonEmpty ( NonEmpty )
import qualified Data.List.NonEmpty as NE
import Data.Map.Strict ( Map )
import qualified Data.Map.Strict as Map
import Data.Maybe ( catMaybes, fromMaybe, isJust )
import Data.Monoid ( Sum(..) )
import Data.Scientific ( Scientific )
import qualified Data.Scientific as Scientific
import Data.Set ( Set )
import qualified Data.Set as Set
import Data.Text ( Text )
import qualified Data.Text as T
import Data.Traversable ( for )
import Data.Word
import FlatBuffers.Internal.Compiler.Display ( Display(..) )
import FlatBuffers.Internal.Compiler.SyntaxTree ( FileTree(..), HasIdent(..), HasMetadata(..), Ident, Namespace, Schema, TypeRef(..), qualify )
import qualified FlatBuffers.Internal.Compiler.SyntaxTree as ST
import FlatBuffers.Internal.Compiler.ValidSyntaxTree
import FlatBuffers.Internal.Constants
import FlatBuffers.Internal.Types
import FlatBuffers.Internal.Util ( isPowerOfTwo, roundUpToNearestMultipleOf )
import Text.Read ( readMaybe )
type ValidationCtx m = (MonadError Text m, MonadReader ValidationState m)
data ValidationState = ValidationState
{ validationStateCurrentContext :: !Ident
, validationStateAllAttributes :: !(Set ST.AttributeDecl)
}
modifyContext :: ValidationCtx m => (Ident -> Ident) -> m a -> m a
modifyContext f =
local $ \s ->
s { validationStateCurrentContext = f (validationStateCurrentContext s) }
data SymbolTable enum struct table union = SymbolTable
{ allEnums :: ![(Namespace, enum)]
, allStructs :: ![(Namespace, struct)]
, allTables :: ![(Namespace, table)]
, allUnions :: ![(Namespace, union)]
}
deriving (Eq, Show)
instance Semigroup (SymbolTable e s t u) where
SymbolTable e1 s1 t1 u1 <> SymbolTable e2 s2 t2 u2 =
SymbolTable (e1 <> e2) (s1 <> s2) (t1 <> t2) (u1 <> u2)
instance Monoid (SymbolTable e s t u) where
mempty = SymbolTable [] [] [] []
type Stage1 = SymbolTable ST.EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage2 = SymbolTable EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage3 = SymbolTable EnumDecl StructDecl ST.TableDecl ST.UnionDecl
type Stage4 = SymbolTable EnumDecl StructDecl TableDecl ST.UnionDecl
type ValidDecls = SymbolTable EnumDecl StructDecl TableDecl UnionDecl
createSymbolTables :: FileTree Schema -> FileTree Stage1
createSymbolTables = fmap (pairDeclsWithNamespaces . ST.decls)
where
pairDeclsWithNamespaces :: [ST.Decl] -> Stage1
pairDeclsWithNamespaces = snd . foldl go ("", mempty)
go :: (Namespace, Stage1) -> ST.Decl -> (Namespace, Stage1)
go (currentNamespace, decls) decl =
case decl of
ST.DeclN (ST.NamespaceDecl newNamespace) -> (newNamespace, decls)
ST.DeclE enum -> (currentNamespace, decls <> SymbolTable [(currentNamespace, enum)] [] [] [])
ST.DeclS struct -> (currentNamespace, decls <> SymbolTable [] [(currentNamespace, struct)] [] [])
ST.DeclT table -> (currentNamespace, decls <> SymbolTable [] [] [(currentNamespace, table)] [])
ST.DeclU union -> (currentNamespace, decls <> SymbolTable [] [] [] [(currentNamespace, union)])
_ -> (currentNamespace, decls)
validateSchemas :: MonadError Text m => FileTree Schema -> m (FileTree ValidDecls)
validateSchemas schemas =
flip runReaderT (ValidationState "" allAttributes) $ do
checkDuplicateIdentifiers allQualifiedTopLevelIdentifiers
validateEnums symbolTables
>>= validateStructs
>>= validateTables
>>= validateUnions
>>= updateRootTable (fileTreeRoot schemas)
where
symbolTables = createSymbolTables schemas
allQualifiedTopLevelIdentifiers =
flip concatMap symbolTables $ \symbolTable ->
join
[ uncurry qualify <$> allEnums symbolTable
, uncurry qualify <$> allStructs symbolTable
, uncurry qualify <$> allTables symbolTable
, uncurry qualify <$> allUnions symbolTable
]
declaredAttributes =
flip concatMap schemas $ \schema ->
[ attr | ST.DeclA attr <- ST.decls schema ]
allAttributes = Set.fromList $ declaredAttributes <> knownAttributes
data RootInfo = RootInfo
{ rootTableNamespace :: !Namespace
, rootTable :: !TableDecl
, rootFileIdent :: !(Maybe Text)
}
updateRootTable :: forall m. ValidationCtx m => Schema -> FileTree ValidDecls -> m (FileTree ValidDecls)
updateRootTable schema symbolTables =
getRootInfo schema symbolTables <&> \case
Just rootInfo -> updateSymbolTable rootInfo <$> symbolTables
Nothing -> symbolTables
where
updateSymbolTable :: RootInfo -> ValidDecls -> ValidDecls
updateSymbolTable rootInfo st = st { allTables = updateTable rootInfo <$> allTables st}
updateTable :: RootInfo -> (Namespace, TableDecl) -> (Namespace, TableDecl)
updateTable (RootInfo rootTableNamespace rootTable fileIdent) pair@(namespace, table) =
if namespace == rootTableNamespace && table == rootTable
then (namespace, table { tableIsRoot = IsRoot fileIdent })
else pair
getRootInfo :: forall m. ValidationCtx m => Schema -> FileTree ValidDecls -> m (Maybe RootInfo)
getRootInfo schema symbolTables =
foldlM go ("", Nothing, Nothing) (ST.decls schema) <&> \case
(_, Just (rootTableNamespace, rootTable), fileIdent) -> Just $ RootInfo rootTableNamespace rootTable fileIdent
_ -> Nothing
where
go :: (Namespace, Maybe (Namespace, TableDecl), Maybe Text) -> ST.Decl -> m (Namespace, Maybe (Namespace, TableDecl), Maybe Text)
go state@(currentNamespace, rootInfo, fileIdent) decl =
case decl of
ST.DeclN (ST.NamespaceDecl newNamespace) -> pure (newNamespace, rootInfo, fileIdent)
ST.DeclFI (ST.FileIdentifierDecl newFileIdent) -> pure (currentNamespace, rootInfo, Just (coerce newFileIdent))
ST.DeclR (ST.RootDecl typeRef) ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchT (rootTableNamespace, rootTable) -> pure (currentNamespace, Just (rootTableNamespace, rootTable), fileIdent)
_ -> throwErrorMsg "root type must be a table"
_ -> pure state
knownAttributes :: [ST.AttributeDecl]
knownAttributes =
coerce
[ idAttr
, deprecatedAttr
, requiredAttr
, forceAlignAttr
, bitFlagsAttr
]
<> otherKnownAttributes
idAttr, deprecatedAttr, requiredAttr, forceAlignAttr, bitFlagsAttr :: Text
idAttr = "id"
deprecatedAttr = "deprecated"
requiredAttr = "required"
forceAlignAttr = "force_align"
bitFlagsAttr = "bit_flags"
otherKnownAttributes :: [ST.AttributeDecl]
otherKnownAttributes =
[ "nested_flatbuffer"
, "flexbuffer"
, "key"
, "hash"
, "original_order"
, "native_inline"
, "native_default"
, "native_custom_alloc"
, "native_type"
, "cpp_type"
, "cpp_ptr_type"
, "cpp_str_type"
, "cpp_str_flex_ctor"
, "shared"
]
data Match enum struct table union
= MatchE !(Namespace, enum)
| MatchS !(Namespace, struct)
| MatchT !(Namespace, table)
| MatchU !(Namespace, union)
findDecl ::
ValidationCtx m
=> (HasIdent e, HasIdent s, HasIdent t, HasIdent u)
=> Namespace
-> FileTree (SymbolTable e s t u)
-> TypeRef
-> m (Match e s t u)
findDecl currentNamespace symbolTables typeRef@(TypeRef refNamespace refIdent) =
let parentNamespaces' = parentNamespaces currentNamespace
results = do
parentNamespace <- parentNamespaces'
let candidateNamespace = parentNamespace <> refNamespace
let searchSymbolTable symbolTable =
asum
[ MatchE <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allEnums symbolTable)
, MatchS <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allStructs symbolTable)
, MatchT <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allTables symbolTable)
, MatchU <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allUnions symbolTable)
]
pure $ asum $ fmap searchSymbolTable symbolTables
in
case asum results of
Just match -> pure match
Nothing ->
throwErrorMsg $
"type '"
<> display typeRef
<> "' does not exist (checked in these namespaces: "
<> display parentNamespaces'
<> ")"
parentNamespaces :: ST.Namespace -> NonEmpty ST.Namespace
parentNamespaces (ST.Namespace ns) =
coerce $ NE.reverse $ NE.inits ns
validateEnums :: forall m. ValidationCtx m => FileTree Stage1 -> m (FileTree Stage2)
validateEnums symbolTables =
for symbolTables $ \symbolTable -> do
let enums = allEnums symbolTable
let validate (namespace, enum) = do
validEnum <- validateEnum (namespace, enum)
pure (namespace, validEnum)
validEnums <- traverse validate enums
pure symbolTable { allEnums = validEnums }
validateEnum :: forall m. ValidationCtx m => (Namespace, ST.EnumDecl) -> m EnumDecl
validateEnum (currentNamespace, enum) =
modifyContext (\_ -> qualify currentNamespace enum) $ do
checkBitFlags
checkDuplicateFields
checkUndeclaredAttributes enum
validEnum
where
validEnum = do
enumType <- validateEnumType (ST.enumType enum)
let enumVals = flip evalState Nothing . traverse mapEnumVal $ ST.enumVals enum
validateOrder enumVals
traverse_ (validateBounds enumType) enumVals
pure EnumDecl
{ enumIdent = getIdent enum
, enumType = enumType
, enumVals = enumVals
}
mapEnumVal :: ST.EnumVal -> State (Maybe Integer) EnumVal
mapEnumVal enumVal = do
thisInt <-
case ST.enumValLiteral enumVal of
Just (ST.IntLiteral thisInt) ->
pure thisInt
Nothing ->
get <&> \case
Just lastInt -> lastInt + 1
Nothing -> 0
put (Just thisInt)
pure (EnumVal (getIdent enumVal) thisInt)
validateOrder :: NonEmpty EnumVal -> m ()
validateOrder xs =
if all (\(x, y) -> enumValInt x < enumValInt y) (NE.toList xs `zip` NE.tail xs)
then pure ()
else throwErrorMsg "enum values must be specified in ascending order"
validateBounds :: EnumType -> EnumVal -> m ()
validateBounds enumType enumVal =
modifyContext (\context -> context <> "." <> getIdent enumVal) $
case enumType of
EInt8 -> validateBounds' @Int8 enumVal
EInt16 -> validateBounds' @Int16 enumVal
EInt32 -> validateBounds' @Int32 enumVal
EInt64 -> validateBounds' @Int64 enumVal
EWord8 -> validateBounds' @Word8 enumVal
EWord16 -> validateBounds' @Word16 enumVal
EWord32 -> validateBounds' @Word32 enumVal
EWord64 -> validateBounds' @Word64 enumVal
validateBounds' :: forall a. (Integral a, Bounded a, Show a) => EnumVal -> m ()
validateBounds' e =
if inRange (toInteger (minBound @a), toInteger (maxBound @a)) (enumValInt e)
then pure ()
else throwErrorMsg $
"enum value does not fit ["
<> T.pack (show (minBound @a))
<> "; "
<> T.pack (show (maxBound @a))
<> "]"
validateEnumType :: ST.Type -> m EnumType
validateEnumType t =
case t of
ST.TInt8 -> pure EInt8
ST.TInt16 -> pure EInt16
ST.TInt32 -> pure EInt32
ST.TInt64 -> pure EInt64
ST.TWord8 -> pure EWord8
ST.TWord16 -> pure EWord16
ST.TWord32 -> pure EWord32
ST.TWord64 -> pure EWord64
_ -> throwErrorMsg "underlying enum type must be integral"
checkDuplicateFields :: m ()
checkDuplicateFields =
checkDuplicateIdentifiers
(ST.enumVals enum)
checkBitFlags :: m ()
checkBitFlags =
when (hasAttribute bitFlagsAttr (ST.enumMetadata enum)) $
throwErrorMsg "`bit_flags` are not supported yet"
data TableFieldWithoutId = TableFieldWithoutId !Ident !TableFieldType !Bool
validateTables :: ValidationCtx m => FileTree Stage3 -> m (FileTree Stage4)
validateTables symbolTables =
for symbolTables $ \symbolTable -> do
let tables = allTables symbolTable
let validate (namespace, table) = do
validTable <- validateTable symbolTables (namespace, table)
pure (namespace, validTable)
validTables <- traverse validate tables
pure symbolTable { allTables = validTables }
validateTable :: forall m. ValidationCtx m => FileTree Stage3 -> (Namespace, ST.TableDecl) -> m TableDecl
validateTable symbolTables (currentNamespace, table) =
modifyContext (\_ -> qualify currentNamespace table) $ do
let fields = ST.tableFields table
let fieldsMetadata = ST.tableFieldMetadata <$> fields
checkDuplicateFields fields
checkUndeclaredAttributes table
validFieldsWithoutIds <- traverse validateTableField fields
validFields <- assignFieldIds fieldsMetadata validFieldsWithoutIds
pure TableDecl
{ tableIdent = getIdent table
, tableIsRoot = NotRoot
, tableFields = validFields
}
where
checkDuplicateFields :: [ST.TableField] -> m ()
checkDuplicateFields = checkDuplicateIdentifiers
assignFieldIds :: [ST.Metadata] -> [TableFieldWithoutId] -> m [TableField]
assignFieldIds metadata fieldsWithoutIds = do
ids <- catMaybes <$> traverse (findIntAttr idAttr) metadata
if null ids
then pure $ evalState (traverse assignFieldId fieldsWithoutIds) (-1)
else if length ids == length fieldsWithoutIds
then do
let fields = zipWith (\(TableFieldWithoutId ident typ depr) id -> TableField id ident typ depr) fieldsWithoutIds ids
let sorted = List.sortOn tableFieldId fields
evalStateT (traverse_ checkFieldId sorted) (-1)
pure sorted
else
throwErrorMsg "either all fields or no fields must have an 'id' attribute"
assignFieldId :: TableFieldWithoutId -> State Integer TableField
assignFieldId (TableFieldWithoutId ident typ depr) = do
lastId <- get
let fieldId =
case typ of
TUnion _ _ -> lastId + 2
TVector _ (VUnion _) -> lastId + 2
_ -> lastId + 1
put fieldId
pure (TableField fieldId ident typ depr)
checkFieldId :: TableField -> StateT Integer m ()
checkFieldId field = do
lastId <- get
modifyContext (\context -> context <> "." <> getIdent field) $ do
case tableFieldType field of
TUnion _ _ ->
when (tableFieldId field /= lastId + 2) $
throwErrorMsg "the id of a union field must be the last field's id + 2"
TVector _ (VUnion _) ->
when (tableFieldId field /= lastId + 2) $
throwErrorMsg "the id of a vector of unions field must be the last field's id + 2"
_ ->
when (tableFieldId field /= lastId + 1) $
throwErrorMsg $ "field ids must be consecutive from 0; id " <> display (lastId + 1) <> " is missing"
put (tableFieldId field)
validateTableField :: ST.TableField -> m TableFieldWithoutId
validateTableField tf =
modifyContext (\context -> context <> "." <> getIdent tf) $ do
checkUndeclaredAttributes tf
validFieldType <- validateTableFieldType (ST.tableFieldMetadata tf) (ST.tableFieldDefault tf) (ST.tableFieldType tf)
pure $ TableFieldWithoutId
(getIdent tf)
validFieldType
(hasAttribute deprecatedAttr (ST.tableFieldMetadata tf))
validateTableFieldType :: ST.Metadata -> Maybe ST.DefaultVal -> ST.Type -> m TableFieldType
validateTableFieldType md dflt tableFieldType =
case tableFieldType of
ST.TInt8 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt8
ST.TInt16 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt16
ST.TInt32 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt32
ST.TInt64 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt64
ST.TWord8 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord8
ST.TWord16 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord16
ST.TWord32 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord32
ST.TWord64 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord64
ST.TFloat -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TFloat
ST.TDouble -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TDouble
ST.TBool -> checkNoRequired md >> validateDefaultValAsBool dflt <&> TBool
ST.TString -> checkNoDefault dflt $> TString (isRequired md)
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchE (ns, enum) -> do
checkNoRequired md
validDefault <- validateDefaultAsEnum dflt enum
pure $ TEnum (TypeRef ns (getIdent enum)) (enumType enum) validDefault
MatchS (ns, struct) -> checkNoDefault dflt $> TStruct (TypeRef ns (getIdent struct)) (isRequired md)
MatchT (ns, table) -> checkNoDefault dflt $> TTable (TypeRef ns (getIdent table)) (isRequired md)
MatchU (ns, union) -> checkNoDefault dflt $> TUnion (TypeRef ns (getIdent union)) (isRequired md)
ST.TVector vecType ->
checkNoDefault dflt >> TVector (isRequired md) <$>
case vecType of
ST.TInt8 -> pure VInt8
ST.TInt16 -> pure VInt16
ST.TInt32 -> pure VInt32
ST.TInt64 -> pure VInt64
ST.TWord8 -> pure VWord8
ST.TWord16 -> pure VWord16
ST.TWord32 -> pure VWord32
ST.TWord64 -> pure VWord64
ST.TFloat -> pure VFloat
ST.TDouble -> pure VDouble
ST.TBool -> pure VBool
ST.TString -> pure VString
ST.TVector _ -> throwErrorMsg "nested vector types not supported"
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef <&> \case
MatchE (ns, enum) ->
VEnum (TypeRef ns (getIdent enum))
(enumType enum)
MatchS (ns, struct) ->
VStruct (TypeRef ns (getIdent struct))
MatchT (ns, table) -> VTable (TypeRef ns (getIdent table))
MatchU (ns, union) -> VUnion (TypeRef ns (getIdent union))
checkNoRequired :: ValidationCtx m => ST.Metadata -> m ()
checkNoRequired md =
when (hasAttribute requiredAttr md) $
throwErrorMsg "only non-scalar fields (strings, vectors, unions, structs, tables) may be 'required'"
checkNoDefault :: ValidationCtx m => Maybe ST.DefaultVal -> m ()
checkNoDefault dflt =
when (isJust dflt) $
throwErrorMsg
"default values currently only supported for scalar fields (integers, floating point, bool, enums)"
isRequired :: ST.Metadata -> Required
isRequired md = if hasAttribute requiredAttr md then Req else Opt
validateDefaultValAsInt :: forall m a. (ValidationCtx m, Integral a, Bounded a, Show a) => Maybe ST.DefaultVal -> m (DefaultVal a)
validateDefaultValAsInt dflt =
case dflt of
Nothing -> pure (DefaultVal 0)
Just (ST.DefaultNum n) ->
if not (Scientific.isInteger n)
then throwErrorMsg "default value must be integral"
else case Scientific.toBoundedInteger @a n of
Nothing ->
throwErrorMsg $
"default value does not fit ["
<> T.pack (show (minBound @a))
<> "; "
<> T.pack (show (maxBound @a))
<> "]"
Just i -> pure (DefaultVal i)
Just _ -> throwErrorMsg "default value must be integral"
validateDefaultValAsScientific :: ValidationCtx m => Maybe ST.DefaultVal -> m (DefaultVal Scientific)
validateDefaultValAsScientific dflt =
case dflt of
Nothing -> pure (DefaultVal 0)
Just (ST.DefaultNum n) -> pure (DefaultVal n)
Just _ -> throwErrorMsg "default value must be a number"
validateDefaultValAsBool :: ValidationCtx m => Maybe ST.DefaultVal -> m (DefaultVal Bool)
validateDefaultValAsBool dflt =
case dflt of
Nothing -> pure (DefaultVal False)
Just (ST.DefaultBool b) -> pure (DefaultVal b)
Just _ -> throwErrorMsg "default value must be a boolean"
validateDefaultAsEnum :: ValidationCtx m => Maybe ST.DefaultVal -> EnumDecl -> m (DefaultVal Integer)
validateDefaultAsEnum dflt enum =
DefaultVal <$>
case dflt of
Nothing ->
case find (\val -> enumValInt val == 0) (enumVals enum) of
Just zeroVal -> pure (enumValInt zeroVal)
Nothing -> throwErrorMsg "enum does not have a 0 value; please manually specify a default for this field"
Just (ST.DefaultNum n) ->
case Scientific.floatingOrInteger @Float n of
Left _float -> throwErrorMsg $ "default value must be integral or one of: " <> display (getIdent <$> enumVals enum)
Right i ->
case find (\val -> enumValInt val == i) (enumVals enum) of
Just matchingVal -> pure (enumValInt matchingVal)
Nothing -> throwErrorMsg $ "default value of " <> display i <> " is not part of enum " <> display (getIdent enum)
Just (ST.DefaultRef ref) ->
case find (\val -> getIdent val == ref) (enumVals enum) of
Just matchingVal -> pure (enumValInt matchingVal)
Nothing -> throwErrorMsg $ "default value of " <> display ref <> " is not part of enum " <> display (getIdent enum)
Just (ST.DefaultBool _) -> throwErrorMsg $ "default value must be integral or one of: " <> display (getIdent <$> enumVals enum)
validateUnions :: ValidationCtx m => FileTree Stage4 -> m (FileTree ValidDecls)
validateUnions symbolTables =
for symbolTables $ \symbolTable -> do
let unions = allUnions symbolTable
let validate (namespace, union) = do
validUnion <- validateUnion symbolTables (namespace, union)
pure (namespace, validUnion)
validUnions <- traverse validate unions
pure symbolTable { allUnions = validUnions }
validateUnion :: forall m. ValidationCtx m => FileTree Stage4 -> (Namespace, ST.UnionDecl) -> m UnionDecl
validateUnion symbolTables (currentNamespace, union) =
modifyContext (\_ -> qualify currentNamespace union) $ do
validUnionVals <- traverse validateUnionVal (ST.unionVals union)
checkDuplicateVals validUnionVals
checkUndeclaredAttributes union
pure $ UnionDecl
{ unionIdent = getIdent union
, unionVals = validUnionVals
}
where
validateUnionVal :: ST.UnionVal -> m UnionVal
validateUnionVal uv = do
let tref = ST.unionValTypeRef uv
let partiallyQualifiedTypeRef = qualify (typeRefNamespace tref) (typeRefIdent tref)
let ident = fromMaybe partiallyQualifiedTypeRef (ST.unionValIdent uv)
let identFormatted = coerce $ T.replace "." "_" $ coerce ident
modifyContext (\context -> context <> "." <> identFormatted) $ do
tableRef <- validateUnionValType tref
pure $ UnionVal
{ unionValIdent = identFormatted
, unionValTableRef = tableRef
}
validateUnionValType :: TypeRef -> m TypeRef
validateUnionValType typeRef =
findDecl currentNamespace symbolTables typeRef >>= \case
MatchT (ns, table) -> pure $ TypeRef ns (getIdent table)
_ -> throwErrorMsg "union members may only be tables"
checkDuplicateVals :: NonEmpty UnionVal -> m ()
checkDuplicateVals vals = checkDuplicateIdentifiers (NE.cons "NONE" (fmap getIdent vals))
validateStructs :: ValidationCtx m => FileTree Stage2 -> m (FileTree Stage3)
validateStructs symbolTables =
flip evalStateT [] $ traverse validateFile symbolTables
where
validateFile :: (MonadState [(Namespace, StructDecl)] m, ValidationCtx m) => Stage2 -> m Stage3
validateFile symbolTable = do
let structs = allStructs symbolTable
traverse_ (checkStructCycles symbolTables) structs
validStructs <- traverse (validateStruct symbolTables) structs
pure symbolTable { allStructs = validStructs }
checkStructCycles :: forall m. ValidationCtx m => FileTree Stage2 -> (Namespace, ST.StructDecl) -> m ()
checkStructCycles symbolTables = go []
where
go :: [Ident] -> (Namespace, ST.StructDecl) -> m ()
go visited (currentNamespace, struct) =
let qualifiedName = qualify currentNamespace struct
in modifyContext (const qualifiedName) $
if qualifiedName `elem` visited
then
throwErrorMsg $
"cyclic dependency detected ["
<> display (T.intercalate " -> " . coerce $ List.dropWhile (/= qualifiedName) $ List.reverse (qualifiedName : visited))
<>"] - structs cannot contain themselves, directly or indirectly"
else
forM_ (ST.structFields struct) $ \field ->
modifyContext (\context -> context <> "." <> getIdent field) $
case ST.structFieldType field of
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchS struct -> go (qualifiedName : visited) struct
_ -> pure ()
_ -> pure ()
data UnpaddedStructField = UnpaddedStructField
{ unpaddedStructFieldIdent :: !Ident
, unpaddedStructFieldType :: !StructFieldType
} deriving (Show, Eq)
validateStruct ::
forall m. (MonadState [(Namespace, StructDecl)] m, ValidationCtx m)
=> FileTree Stage2
-> (Namespace, ST.StructDecl)
-> m (Namespace, StructDecl)
validateStruct symbolTables (currentNamespace, struct) =
modifyContext (\_ -> qualify currentNamespace struct) $ do
validStructs <- get
case find (\(ns, s) -> ns == currentNamespace && getIdent s == getIdent struct) validStructs of
Just match -> pure match
Nothing -> do
checkDuplicateFields
checkUndeclaredAttributes struct
fields <- traverse validateStructField (ST.structFields struct)
let naturalAlignment = maximum (structFieldAlignment <$> fields)
forceAlignAttrVal <- getForceAlignAttr
forceAlign <- traverse (validateForceAlign naturalAlignment) forceAlignAttrVal
let alignment = fromMaybe naturalAlignment forceAlign
let (size, paddedFields) = addFieldPadding alignment fields
let validStruct = StructDecl
{ structIdent = getIdent struct
, structAlignment = alignment
, structSize = size
, structFields = paddedFields
}
modify ((currentNamespace, validStruct) :)
pure (currentNamespace, validStruct)
where
invalidStructFieldType = "struct fields may only be integers, floating point, bool, enums, or other structs"
addFieldPadding :: Alignment -> NonEmpty UnpaddedStructField -> (InlineSize, NonEmpty StructField)
addFieldPadding structAlignment unpaddedFields =
(size, NE.fromList (reverse paddedFields))
where
(size, paddedFields) = go 0 [] (NE.toList unpaddedFields)
go :: InlineSize -> [StructField] -> [UnpaddedStructField] -> (InlineSize, [StructField])
go size paddedFields [] = (size, paddedFields)
go size paddedFields (x : y : tail) =
let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
nextFieldsAlignment = fromIntegral @Alignment @InlineSize (structFieldAlignment y)
paddingNeeded = (size' `roundUpToNearestMultipleOf` nextFieldsAlignment) - size'
size'' = size' + paddingNeeded
paddedField = StructField
{ structFieldIdent = unpaddedStructFieldIdent x
, structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
, structFieldOffset = coerce size
, structFieldType = unpaddedStructFieldType x
}
in go size'' (paddedField : paddedFields) (y : tail)
go size paddedFields [x] =
let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
structAlignment' = fromIntegral @Alignment @InlineSize structAlignment
paddingNeeded = (size' `roundUpToNearestMultipleOf` structAlignment') - size'
size'' = size' + paddingNeeded
paddedField = StructField
{ structFieldIdent = unpaddedStructFieldIdent x
, structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
, structFieldOffset = coerce size
, structFieldType = unpaddedStructFieldType x
}
in (size'', paddedField : paddedFields)
validateStructField :: ST.StructField -> m UnpaddedStructField
validateStructField sf =
modifyContext (\context -> context <> "." <> getIdent sf) $ do
checkUnsupportedAttributes sf
checkUndeclaredAttributes sf
structFieldType <- validateStructFieldType (ST.structFieldType sf)
pure $ UnpaddedStructField
{ unpaddedStructFieldIdent = getIdent sf
, unpaddedStructFieldType = structFieldType
}
validateStructFieldType :: ST.Type -> m StructFieldType
validateStructFieldType structFieldType =
case structFieldType of
ST.TInt8 -> pure SInt8
ST.TInt16 -> pure SInt16
ST.TInt32 -> pure SInt32
ST.TInt64 -> pure SInt64
ST.TWord8 -> pure SWord8
ST.TWord16 -> pure SWord16
ST.TWord32 -> pure SWord32
ST.TWord64 -> pure SWord64
ST.TFloat -> pure SFloat
ST.TDouble -> pure SDouble
ST.TBool -> pure SBool
ST.TString -> throwErrorMsg invalidStructFieldType
ST.TVector _ -> throwErrorMsg invalidStructFieldType
ST.TRef typeRef ->
findDecl currentNamespace symbolTables typeRef >>= \case
MatchE (enumNamespace, enum) ->
pure (SEnum (TypeRef enumNamespace (getIdent enum)) (enumType enum))
MatchS (nestedNamespace, nestedStruct) ->
SStruct <$> validateStruct symbolTables (nestedNamespace, nestedStruct)
_ -> throwErrorMsg invalidStructFieldType
checkUnsupportedAttributes :: ST.StructField -> m ()
checkUnsupportedAttributes structField = do
when (hasAttribute deprecatedAttr (ST.structFieldMetadata structField)) $
throwErrorMsg "can't deprecate fields in a struct"
when (hasAttribute requiredAttr (ST.structFieldMetadata structField)) $
throwErrorMsg "struct fields are already required, the 'required' attribute is redundant"
when (hasAttribute idAttr (ST.structFieldMetadata structField)) $
throwErrorMsg "struct fields cannot be reordered using the 'id' attribute"
getForceAlignAttr :: m (Maybe Integer)
getForceAlignAttr = findIntAttr forceAlignAttr (ST.structMetadata struct)
validateForceAlign :: Alignment -> Integer -> m Alignment
validateForceAlign naturalAlignment forceAlign =
if isPowerOfTwo forceAlign
&& inRange (fromIntegral @Alignment @Integer naturalAlignment, 16) forceAlign
then pure (fromIntegral @Integer @Alignment forceAlign)
else throwErrorMsg $
"force_align must be a power of two integer ranging from the struct's natural alignment (in this case, "
<> T.pack (show naturalAlignment)
<> ") to 16"
checkDuplicateFields :: m ()
checkDuplicateFields =
checkDuplicateIdentifiers
(ST.structFields struct)
structFieldAlignment :: UnpaddedStructField -> Alignment
structFieldAlignment usf =
case unpaddedStructFieldType usf of
SInt8 -> int8Size
SInt16 -> int16Size
SInt32 -> int32Size
SInt64 -> int64Size
SWord8 -> word8Size
SWord16 -> word16Size
SWord32 -> word32Size
SWord64 -> word64Size
SFloat -> floatSize
SDouble -> doubleSize
SBool -> boolSize
SEnum _ enumType -> enumAlignment enumType
SStruct (_, nestedStruct) -> structAlignment nestedStruct
enumAlignment :: EnumType -> Alignment
enumAlignment = Alignment . enumSize
enumSize :: EnumType -> Word8
enumSize e =
case e of
EInt8 -> int8Size
EInt16 -> int16Size
EInt32 -> int32Size
EInt64 -> int64Size
EWord8 -> word8Size
EWord16 -> word16Size
EWord32 -> word32Size
EWord64 -> word64Size
structFieldTypeSize :: StructFieldType -> InlineSize
structFieldTypeSize sft =
case sft of
SInt8 -> int8Size
SInt16 -> int16Size
SInt32 -> int32Size
SInt64 -> int64Size
SWord8 -> word8Size
SWord16 -> word16Size
SWord32 -> word32Size
SWord64 -> word64Size
SFloat -> floatSize
SDouble -> doubleSize
SBool -> boolSize
SEnum _ enumType -> fromIntegral @Word8 @InlineSize (enumSize enumType)
SStruct (_, nestedStruct) -> structSize nestedStruct
checkDuplicateIdentifiers :: (ValidationCtx m, Foldable f, Functor f, HasIdent a) => f a -> m ()
checkDuplicateIdentifiers xs =
case findDups (getIdent <$> xs) of
[] -> pure ()
dups ->
throwErrorMsg $
display dups <> " declared more than once"
where
findDups :: (Foldable f, Functor f, Ord a) => f a -> [a]
findDups xs = Map.keys $ Map.filter (>1) $ occurrences xs
occurrences :: (Foldable f, Functor f, Ord a) => f a -> Map a (Sum Int)
occurrences xs =
Map.unionsWith (<>) $ Foldable.toList $ fmap (\x -> Map.singleton x (Sum 1)) xs
checkUndeclaredAttributes :: (ValidationCtx m, HasMetadata a) => a -> m ()
checkUndeclaredAttributes a = do
allAttributes <- asks validationStateAllAttributes
forM_ (Map.keys . ST.unMetadata . getMetadata $ a) $ \attr ->
when (coerce attr `Set.notMember` allAttributes) $
throwErrorMsg $ "user defined attributes must be declared before use: " <> attr
hasAttribute :: Text -> ST.Metadata -> Bool
hasAttribute name (ST.Metadata attrs) = Map.member name attrs
findIntAttr :: ValidationCtx m => Text -> ST.Metadata -> m (Maybe Integer)
findIntAttr name (ST.Metadata attrs) =
case Map.lookup name attrs of
Nothing -> pure Nothing
Just Nothing -> err
Just (Just (ST.AttrI i)) -> pure (Just i)
Just (Just (ST.AttrS t)) ->
case readMaybe @Integer (T.unpack t) of
Just i -> pure (Just i)
Nothing -> err
where
err =
throwErrorMsg $
"expected attribute '"
<> name
<> "' to have an integer value, e.g. '"
<> name
<> ": 123'"
findStringAttr :: ValidationCtx m => Text -> ST.Metadata -> m (Maybe Text)
findStringAttr name (ST.Metadata attrs) =
case Map.lookup name attrs of
Nothing -> pure Nothing
Just (Just (ST.AttrS s)) -> pure (Just s)
Just _ ->
throwErrorMsg $
"expected attribute '"
<> name
<> "' to have a string value, e.g. '"
<> name
<> ": \"abc\"'"
throwErrorMsg :: ValidationCtx m => Text -> m a
throwErrorMsg msg = do
context <- asks validationStateCurrentContext
if context == ""
then throwError msg
else throwError $ "[" <> display context <> "]: " <> msg