{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module FlatBuffers.Internal.Compiler.SemanticAnalysis where

import           Control.Monad                                 ( forM_, join, when )
import           Control.Monad.Except                          ( throwError )
import           Control.Monad.Reader                          ( ReaderT, asks, local, runReaderT )
import           Control.Monad.State                           ( MonadState, State, StateT, evalState, evalStateT, get, mapStateT, modify, put )
import           Control.Monad.Trans                           ( lift )

import           Data.Bits                                     ( (.&.), (.|.), Bits, FiniteBits, bit, finiteBitSize )
import           Data.Coerce                                   ( coerce )
import           Data.Foldable                                 ( asum, find, foldlM, traverse_ )
import qualified Data.Foldable                                 as Foldable
import           Data.Functor                                  ( ($>), (<&>) )
import           Data.Int
import           Data.Ix                                       ( inRange )
import qualified Data.List                                     as List
import           Data.List.NonEmpty                            ( NonEmpty((:|)) )
import qualified Data.List.NonEmpty                            as NE
import           Data.Map.Strict                               ( Map )
import qualified Data.Map.Strict                               as Map
import           Data.Maybe                                    ( catMaybes, fromMaybe, isJust )
import           Data.Monoid                                   ( Sum(..) )
import           Data.Scientific                               ( Scientific )
import qualified Data.Scientific                               as Scientific
import           Data.Set                                      ( Set )
import qualified Data.Set                                      as Set
import           Data.Text                                     ( Text )
import qualified Data.Text                                     as T
import           Data.Traversable                              ( for )
import           Data.Word

import           FlatBuffers.Internal.Compiler.Display         ( Display(..) )
import           FlatBuffers.Internal.Compiler.SyntaxTree      ( FileTree(..), HasIdent(..), HasMetadata(..), Ident, Namespace, Schema, TypeRef(..), qualify )
import qualified FlatBuffers.Internal.Compiler.SyntaxTree      as ST
import           FlatBuffers.Internal.Compiler.ValidSyntaxTree
import           FlatBuffers.Internal.Constants
import           FlatBuffers.Internal.Types

import           Text.Read                                     ( readMaybe )


----------------------------------
------- MonadValidation ----------
----------------------------------
newtype Validation a = Validation
  { runValidation :: ReaderT ValidationState (Either String) a
  }
  deriving newtype (Functor, Applicative, Monad)

data ValidationState = ValidationState
  { validationStateCurrentContext :: ![Ident]
    -- ^ The thing being validated (e.g. a fully-qualified struct name, or a table field name).
  , validationStateAllAttributes  :: !(Set ST.AttributeDecl)
    -- ^ All the attributes declared in all the schemas (including imported ones).
  }

class Monad m => MonadValidation m where
  -- | Start validating an item @a@
  validating :: HasIdent a => a -> m b -> m b
  -- | Clear validation context, i.e. forget which item is currently being validated, if any.
  resetContext :: m a -> m a
  -- | Get the path to the item currently being validated
  getContext :: m [Ident]
  -- | Get a list of all the attributes declared in every loaded schema
  getDeclaredAttributes :: m (Set ST.AttributeDecl)
  -- | Fail validation with a message
  throwErrorMsg :: String -> m a

instance MonadValidation Validation where
  validating a (Validation v) = Validation (local addIdent v)
    where
      addIdent (ValidationState ctx attrs) = ValidationState (getIdent a : ctx) attrs
  resetContext (Validation v) = Validation (local reset v)
    where
      reset (ValidationState _ attrs) = ValidationState [] attrs
  getContext            = Validation (asks (List.reverse . validationStateCurrentContext))
  getDeclaredAttributes = Validation (asks validationStateAllAttributes)
  throwErrorMsg msg = do
    idents <- getContext
    if null idents
      then Validation (throwError msg)
      else Validation . throwError $ "[" <> List.intercalate "." (T.unpack . unIdent <$> idents) <> "]: " <> msg

instance MonadValidation m => MonadValidation (StateT s m)  where
  validating            = mapStateT . validating
  resetContext          = mapStateT resetContext
  getContext            = lift getContext
  getDeclaredAttributes = lift getDeclaredAttributes
  throwErrorMsg         = lift . throwErrorMsg

----------------------------------
------- Validation stages --------
----------------------------------
{-
During validation, we translate `SyntaxTree.XDecl` declarations into
`ValidSyntaxTree.XDecl` declarations.

This is done in stages: we first translate enums, then structs, then tables,
and lastly unions.
-}

data SymbolTable enum struct table union = SymbolTable
  { allEnums   :: !(Map (Namespace, Ident) enum)
  , allStructs :: !(Map (Namespace, Ident) struct)
  , allTables  :: !(Map (Namespace, Ident) table)
  , allUnions  :: !(Map (Namespace, Ident) union)
  }
  deriving (Eq, Show)

instance Semigroup (SymbolTable e s t u)  where
  SymbolTable e1 s1 t1 u1 <> SymbolTable e2 s2 t2 u2 =
    SymbolTable (e1 <> e2) (s1 <> s2) (t1 <> t2) (u1 <> u2)

instance Monoid (SymbolTable e s t u) where
  mempty = SymbolTable mempty mempty mempty mempty

type Stage1     = SymbolTable ST.EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage2     = SymbolTable    EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage3     = SymbolTable    EnumDecl    StructDecl ST.TableDecl ST.UnionDecl
type Stage4     = SymbolTable    EnumDecl    StructDecl    TableDecl ST.UnionDecl
type ValidDecls = SymbolTable    EnumDecl    StructDecl    TableDecl    UnionDecl

validateSchemas :: FileTree Schema -> Either String (FileTree ValidDecls)
validateSchemas schemas =
  flip runReaderT (ValidationState [] allAttributes) $ runValidation $ do
    symbolTables <- createSymbolTables schemas
    checkDuplicateIdentifiers (allQualifiedTopLevelIdentifiers symbolTables)
    validateEnums symbolTables
      >>= validateStructs
      >>= validateTables
      >>= validateUnions
      >>= updateRootTable (fileTreeRoot schemas)
  where
    allQualifiedTopLevelIdentifiers symbolTables =
      flip concatMap symbolTables $ \symbolTable ->
        join
          [ uncurry qualify <$> Map.keys (allEnums symbolTable)
          , uncurry qualify <$> Map.keys (allStructs symbolTable)
          , uncurry qualify <$> Map.keys (allTables symbolTable)
          , uncurry qualify <$> Map.keys (allUnions symbolTable)
          ]

    declaredAttributes =
      flip concatMap schemas $ \schema ->
        [ attr | ST.DeclA attr <- ST.decls schema ]

    allAttributes = Set.fromList $ declaredAttributes <> knownAttributes

-- | Takes a collection of schemas, and pairs each type declaration with its corresponding namespace
createSymbolTables :: FileTree Schema -> Validation (FileTree Stage1)
createSymbolTables = traverse (createSymbolTable . ST.decls)
  where
    createSymbolTable :: [ST.Decl] -> Validation Stage1
    createSymbolTable decls = snd <$> foldlM go ("", mempty) decls

    go :: (Namespace, Stage1) -> ST.Decl -> Validation (Namespace, Stage1)
    go (currentNamespace, symbolTable) decl =
      case decl of
        ST.DeclE enum   -> addEnum symbolTable currentNamespace enum     <&> \symbolTable' -> (currentNamespace, symbolTable')
        ST.DeclS struct -> addStruct symbolTable currentNamespace struct <&> \symbolTable' -> (currentNamespace, symbolTable')
        ST.DeclT table  -> addTable symbolTable currentNamespace table   <&> \symbolTable' -> (currentNamespace, symbolTable')
        ST.DeclU union  -> addUnion symbolTable currentNamespace union   <&> \symbolTable' -> (currentNamespace, symbolTable')
        ST.DeclN (ST.NamespaceDecl newNamespace) -> pure (newNamespace, symbolTable)
        _               -> pure (currentNamespace, symbolTable)

    addEnum (SymbolTable es ss ts us) namespace enum     = insertSymbol namespace enum es   <&> \es' -> SymbolTable es' ss ts us
    addStruct (SymbolTable es ss ts us) namespace struct = insertSymbol namespace struct ss <&> \ss' -> SymbolTable es ss' ts us
    addTable (SymbolTable es ss ts us) namespace table   = insertSymbol namespace table ts  <&> \ts' -> SymbolTable es ss ts' us
    addUnion (SymbolTable es ss ts us) namespace union   = insertSymbol namespace union us  <&> \us' -> SymbolTable es ss ts us'

-- | Fails if the key is already present in the map.
insertSymbol :: HasIdent a => Namespace -> a -> Map (Namespace, Ident) a -> Validation (Map (Namespace, Ident) a)
insertSymbol namespace symbol map =
  if Map.member key map
    then throwErrorMsg $ display (qualify namespace symbol) <> " declared more than once"
    else pure $ Map.insert key symbol map
  where
    key = (namespace, getIdent symbol)


----------------------------------
------------ Root Type -----------
----------------------------------
data RootInfo = RootInfo
  { rootTableNamespace :: !Namespace
  , rootTable          :: !TableDecl
  , rootFileIdent      :: !(Maybe Text)
  }

-- | Finds the root table (if any) and sets the `tableIsRoot` flag accordingly.
-- We only care about @root_type@ declarations in the root schema. Imported schemas are not scanned for @root_type@s.
-- The root type declaration can point to a table in any schema (root or imported).
updateRootTable :: Schema -> FileTree ValidDecls -> Validation (FileTree ValidDecls)
updateRootTable schema symbolTables =
  getRootInfo schema symbolTables <&> \case
    Just rootInfo -> updateSymbolTable rootInfo <$> symbolTables
    Nothing       -> symbolTables

  where
    updateSymbolTable :: RootInfo -> ValidDecls -> ValidDecls
    updateSymbolTable rootInfo st = st { allTables = Map.mapWithKey (updateTable rootInfo) (allTables st) }

    updateTable :: RootInfo -> (Namespace, Ident) -> TableDecl -> TableDecl
    updateTable (RootInfo rootTableNamespace rootTable fileIdent) (namespace, _) table =
      if namespace == rootTableNamespace && table == rootTable
        then table { tableIsRoot = IsRoot fileIdent }
        else table

getRootInfo :: Schema -> FileTree ValidDecls -> Validation (Maybe RootInfo)
getRootInfo schema symbolTables =
  foldlM go ("", Nothing, Nothing) (ST.decls schema) <&> \case
    (_, Just (rootTableNamespace, rootTable), fileIdent) -> Just $ RootInfo rootTableNamespace rootTable fileIdent
    _ -> Nothing
  where
    go :: (Namespace, Maybe (Namespace, TableDecl), Maybe Text) -> ST.Decl -> Validation (Namespace, Maybe (Namespace, TableDecl), Maybe Text)
    go state@(currentNamespace, rootInfo, fileIdent) decl =
      case decl of
        ST.DeclN (ST.NamespaceDecl newNamespace)       -> pure (newNamespace, rootInfo, fileIdent)
        ST.DeclFI (ST.FileIdentifierDecl newFileIdent) -> pure (currentNamespace, rootInfo, Just (coerce newFileIdent))
        ST.DeclR (ST.RootDecl typeRef)                 ->
          findDecl currentNamespace symbolTables typeRef >>= \case
            MatchT rootTableNamespace rootTable  -> pure (currentNamespace, Just (rootTableNamespace, rootTable), fileIdent)
            _                                    -> throwErrorMsg "root type must be a table"
        _ -> pure state

----------------------------------
----------- Attributes -----------
----------------------------------
knownAttributes :: [ST.AttributeDecl]
knownAttributes =
  coerce
    [ idAttr
    , deprecatedAttr
    , requiredAttr
    , forceAlignAttr
    , bitFlagsAttr
    ]
  <> otherKnownAttributes

idAttr, deprecatedAttr, requiredAttr, forceAlignAttr, bitFlagsAttr :: Text
idAttr          = "id"
deprecatedAttr  = "deprecated"
requiredAttr    = "required"
forceAlignAttr  = "force_align"
bitFlagsAttr    = "bit_flags"

otherKnownAttributes :: [ST.AttributeDecl]
otherKnownAttributes =
  -- https://google.github.io/flatbuffers/flatbuffers_guide_writing_schema.html
  [ "nested_flatbuffer"
  , "flexbuffer"
  , "key"
  , "hash"
  , "original_order"
  -- https://google.github.io/flatbuffers/flatbuffers_guide_use_cpp.html#flatbuffers_cpp_object_based_api
  , "native_inline"
  , "native_default"
  , "native_custom_alloc"
  , "native_type"
  , "cpp_type"
  , "cpp_ptr_type"
  , "cpp_str_type"
  , "cpp_str_flex_ctor"
  , "shared"
  ]

----------------------------------
--------- Symbol search ----------
----------------------------------
data Match enum struct table union
  = MatchE !Namespace !enum
  | MatchS !Namespace !struct
  | MatchT !Namespace !table
  | MatchU !Namespace !union

-- | Looks for a type reference in a set of type declarations.
findDecl ::
     MonadValidation m
  => Namespace
  -> FileTree (SymbolTable e s t u)
  -> TypeRef
  -> m (Match e s t u)
findDecl currentNamespace symbolTables typeRef@(TypeRef refNamespace refIdent) =
  let parentNamespaces' = parentNamespaces currentNamespace
      results = do
        parentNamespace <- parentNamespaces'
        let candidateNamespace = parentNamespace <> refNamespace
        let searchSymbolTable symbolTable =
              asum
                [ MatchE candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allEnums symbolTable)
                , MatchS candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allStructs symbolTable)
                , MatchT candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allTables symbolTable)
                , MatchU candidateNamespace <$> Map.lookup (candidateNamespace, refIdent) (allUnions symbolTable)
                ]
        pure $ asum $ fmap searchSymbolTable symbolTables
  in
    case asum results of
      Just match -> pure match
      Nothing    ->
        throwErrorMsg $
          "type "
          <> display typeRef
          <> " does not exist (checked in these namespaces: "
          <> display parentNamespaces'
          <> ")"

-- | Returns a list of all the namespaces "between" the current namespace
-- and the root namespace, in that order.
-- See: https://github.com/google/flatbuffers/issues/5234#issuecomment-471680403
--
-- > parentNamespaces "A.B.C" == ["A.B.C", "A.B", "A", ""]
parentNamespaces :: ST.Namespace -> NonEmpty ST.Namespace
parentNamespaces (ST.Namespace ns) =
  coerce $ NE.reverse $ NE.inits ns

----------------------------------
------------- Enums --------------
----------------------------------
validateEnums :: FileTree Stage1 -> Validation (FileTree Stage2)
validateEnums symbolTables =
  for symbolTables $ \symbolTable -> do
    validEnums <- Map.traverseWithKey validateEnum (allEnums symbolTable)
    pure symbolTable { allEnums = validEnums }

validateEnum :: (Namespace, Ident) -> ST.EnumDecl -> Validation EnumDecl
validateEnum (currentNamespace, _) enum =
  validating (qualify currentNamespace enum) $ do
    checkDuplicateFields
    checkUndeclaredAttributes enum
    validEnum
  where
    isBitFlags = hasAttribute bitFlagsAttr (ST.enumMetadata enum)

    validEnum = do
      enumType <- validateEnumType (ST.enumType enum)
      let enumVals = flip evalState Nothing . traverse mapEnumVal $ ST.enumVals enum
      validateOrder enumVals
      traverse_ (validateBounds enumType) enumVals
      pure EnumDecl
        { enumIdent = getIdent enum
        , enumType = enumType
        , enumBitFlags = isBitFlags
        , enumVals = shiftBitFlags <$> enumVals
        }

    mapEnumVal :: ST.EnumVal -> State (Maybe Integer) EnumVal
    mapEnumVal enumVal = do
      thisInt <-
        case ST.enumValLiteral enumVal of
          Just (ST.IntLiteral thisInt) ->
            pure thisInt
          Nothing ->
            get <&> \case
              Just lastInt -> lastInt + 1
              Nothing      -> 0
      put (Just thisInt)
      pure (EnumVal (getIdent enumVal) thisInt)

    validateOrder :: NonEmpty EnumVal -> Validation ()
    validateOrder xs =
      let consecutivePairs = NE.toList xs `zip` NE.tail xs
          outOfOrderPais = filter (\(x, y) -> enumValInt x >= enumValInt y) consecutivePairs
      in
          case outOfOrderPais of
            []         -> pure ()
            (x, y) : _ -> throwErrorMsg $
              "enum values must be specified in ascending order. "
              <> display (enumValIdent y)
              <> " ("
              <> display (enumValInt y)
              <> ") should be greater than "
              <> display (enumValIdent x)
              <> " ("
              <> display (enumValInt x)
              <> ")"

    validateBounds :: EnumType -> EnumVal -> Validation ()
    validateBounds enumType enumVal =
      validating enumVal $
        case enumType of
          EInt8 -> validateBounds' @Int8 enumVal
          EInt16 -> validateBounds' @Int16 enumVal
          EInt32 -> validateBounds' @Int32 enumVal
          EInt64 -> validateBounds' @Int64 enumVal
          EWord8 -> validateBounds' @Word8 enumVal
          EWord16 -> validateBounds' @Word16 enumVal
          EWord32 -> validateBounds' @Word32 enumVal
          EWord64 -> validateBounds' @Word64 enumVal

    validateBounds' :: forall a. (FiniteBits a, Integral a, Bounded a) => EnumVal -> Validation ()
    validateBounds' e =
      if inRange (lower, upper) (enumValInt e)
        then pure ()
        else throwErrorMsg $
              "enum value of "
              <> display (enumValInt e)
              <> " does not fit ["
              <> display lower
              <> "; "
              <> display upper
              <> "]"
      where
        lower = if isBitFlags
                  then 0
                  else toInteger (minBound @a)
        upper = if isBitFlags
                  then toInteger (finiteBitSize @a (undefined :: a) - 1)
                  else toInteger (maxBound @a)

    validateEnumType :: ST.Type -> Validation EnumType
    validateEnumType t =
      case t of
        ST.TInt8   -> unlessIsBitFlags EInt8
        ST.TInt16  -> unlessIsBitFlags EInt16
        ST.TInt32  -> unlessIsBitFlags EInt32
        ST.TInt64  -> unlessIsBitFlags EInt64
        ST.TWord8  -> pure EWord8
        ST.TWord16 -> pure EWord16
        ST.TWord32 -> pure EWord32
        ST.TWord64 -> pure EWord64
        _          -> throwErrorMsg "underlying enum type must be integral"
      where
        unlessIsBitFlags x =
          if isBitFlags
            then throwErrorMsg "underlying type of bit_flags enum must be unsigned"
            else pure x

    -- If this enum has the `bit_flags` attribute, convert its int value to the corresponding bitmask.
    -- E.g., 2 -> 00000100
    shiftBitFlags :: EnumVal -> EnumVal
    shiftBitFlags e =
      if isBitFlags
        then e { enumValInt = bit (fromIntegral @Integer @Int (enumValInt e)) }
        else e

    checkDuplicateFields :: Validation ()
    checkDuplicateFields =
      checkDuplicateIdentifiers
        (ST.enumVals enum)


----------------------------------
------------ Tables --------------
----------------------------------
data TableFieldWithoutId = TableFieldWithoutId !Ident !TableFieldType !Bool

validateTables :: FileTree Stage3 -> Validation (FileTree Stage4)
validateTables symbolTables =
  for symbolTables $ \symbolTable -> do
    validTables <- Map.traverseWithKey (validateTable symbolTables) (allTables symbolTable)
    pure symbolTable { allTables = validTables }

validateTable :: FileTree Stage3 -> (Namespace, Ident) -> ST.TableDecl -> Validation TableDecl
validateTable symbolTables (currentNamespace, _) table =
  validating (qualify currentNamespace table) $ do

    let fields = ST.tableFields table
    let fieldsMetadata = ST.tableFieldMetadata <$> fields

    checkDuplicateFields fields
    checkUndeclaredAttributes table

    validFieldsWithoutIds <- traverse validateTableField fields
    validFields <- assignFieldIds fieldsMetadata validFieldsWithoutIds

    pure TableDecl
      { tableIdent = getIdent table
      , tableIsRoot = NotRoot
      , tableFields = validFields
      }

  where
    checkDuplicateFields :: [ST.TableField] -> Validation ()
    checkDuplicateFields = checkDuplicateIdentifiers

    assignFieldIds :: [ST.Metadata] -> [TableFieldWithoutId] -> Validation [TableField]
    assignFieldIds metadata fieldsWithoutIds = do
      ids <- catMaybes <$> traverse (findIntAttr idAttr) metadata
      if null ids
        then pure $ evalState (traverse assignFieldId fieldsWithoutIds) (-1)
        else if length ids == length fieldsWithoutIds
          then do
            let fields = zipWith (\(TableFieldWithoutId ident typ depr) id -> TableField id ident typ depr) fieldsWithoutIds ids
            let sorted = List.sortOn tableFieldId fields
            evalStateT (traverse_ checkFieldId sorted) (-1)
            pure sorted
          else
            throwErrorMsg "either all fields or no fields must have an 'id' attribute"

    assignFieldId :: TableFieldWithoutId -> State Integer TableField
    assignFieldId (TableFieldWithoutId ident typ depr) = do
      lastId <- get
      let fieldId =
            case typ of
              TUnion _ _           -> lastId + 2
              TVector _ (VUnion _) -> lastId + 2
              _                    -> lastId + 1
      put fieldId
      pure (TableField fieldId ident typ depr)

    checkFieldId :: TableField -> StateT Integer Validation ()
    checkFieldId field = do
      lastId <- get
      validating field $ do
        case tableFieldType field of
          TUnion _ _ ->
            when (tableFieldId field /= lastId + 2) $
              throwErrorMsg "the id of a union field must be the last field's id + 2"
          TVector _ (VUnion _) ->
            when (tableFieldId field /= lastId + 2) $
              throwErrorMsg "the id of a vector of unions field must be the last field's id + 2"
          _ ->
            when (tableFieldId field /= lastId + 1) $
              throwErrorMsg $ "field ids must be consecutive from 0; id " <> display (lastId + 1) <> " is missing"
        put (tableFieldId field)

    validateTableField :: ST.TableField -> Validation TableFieldWithoutId
    validateTableField tf =
      validating tf $ do
        checkUndeclaredAttributes tf
        validFieldType <- validateTableFieldType (ST.tableFieldMetadata tf) (ST.tableFieldDefault tf) (ST.tableFieldType tf)

        pure $ TableFieldWithoutId
          (getIdent tf)
          validFieldType
          (hasAttribute deprecatedAttr (ST.tableFieldMetadata tf))

    validateTableFieldType :: ST.Metadata -> Maybe ST.DefaultVal -> ST.Type -> Validation TableFieldType
    validateTableFieldType md dflt tableFieldType =
      case tableFieldType of
        ST.TInt8 -> checkNoRequired md >> validateDefaultValAsInt @Int8 dflt <&> TInt8
        ST.TInt16 -> checkNoRequired md >> validateDefaultValAsInt @Int16 dflt <&> TInt16
        ST.TInt32 -> checkNoRequired md >> validateDefaultValAsInt @Int32 dflt <&> TInt32
        ST.TInt64 -> checkNoRequired md >> validateDefaultValAsInt @Int64 dflt <&> TInt64
        ST.TWord8 -> checkNoRequired md >> validateDefaultValAsInt @Word8 dflt <&> TWord8
        ST.TWord16 -> checkNoRequired md >> validateDefaultValAsInt @Word16 dflt <&> TWord16
        ST.TWord32 -> checkNoRequired md >> validateDefaultValAsInt @Word32 dflt <&> TWord32
        ST.TWord64 -> checkNoRequired md >> validateDefaultValAsInt @Word64 dflt <&> TWord64
        ST.TFloat -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TFloat
        ST.TDouble -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TDouble
        ST.TBool -> checkNoRequired md >> validateDefaultValAsBool dflt <&> TBool
        ST.TString -> checkNoDefault dflt $> TString (isRequired md)
        ST.TRef typeRef ->
          findDecl currentNamespace symbolTables typeRef >>= \case
            MatchE ns enum -> do
              checkNoRequired md
              validDefault <- validateDefaultAsEnum dflt enum
              pure $ TEnum (TypeRef ns (getIdent enum)) (enumType enum) validDefault
            MatchS ns struct -> checkNoDefault dflt $> TStruct (TypeRef ns (getIdent struct))  (isRequired md)
            MatchT ns table  -> checkNoDefault dflt $> TTable  (TypeRef ns (getIdent table)) (isRequired md)
            MatchU ns union  -> checkNoDefault dflt $> TUnion  (TypeRef ns (getIdent union)) (isRequired md)
        ST.TVector vecType ->
          checkNoDefault dflt >> TVector (isRequired md) <$>
            case vecType of
              ST.TInt8 -> pure VInt8
              ST.TInt16 -> pure VInt16
              ST.TInt32 -> pure VInt32
              ST.TInt64 -> pure VInt64
              ST.TWord8 -> pure VWord8
              ST.TWord16 -> pure VWord16
              ST.TWord32 -> pure VWord32
              ST.TWord64 -> pure VWord64
              ST.TFloat -> pure VFloat
              ST.TDouble -> pure VDouble
              ST.TBool -> pure VBool
              ST.TString -> pure VString
              ST.TVector _ -> throwErrorMsg "nested vector types not supported"
              ST.TRef typeRef ->
                findDecl currentNamespace symbolTables typeRef <&> \case
                  MatchE ns enum ->
                    VEnum (TypeRef ns (getIdent enum))
                          (enumType enum)
                  MatchS ns struct ->
                    VStruct (TypeRef ns (getIdent struct))
                  MatchT ns table -> VTable (TypeRef ns (getIdent table))
                  MatchU ns union -> VUnion (TypeRef ns (getIdent union))

checkNoRequired :: ST.Metadata -> Validation ()
checkNoRequired md =
  when (hasAttribute requiredAttr md) $
    throwErrorMsg "only non-scalar fields (strings, vectors, unions, structs, tables) may be 'required'"

checkNoDefault :: Maybe ST.DefaultVal -> Validation ()
checkNoDefault dflt =
  when (isJust dflt) $
    throwErrorMsg
      "default values currently only supported for scalar fields (integers, floating point, bool, enums)"

isRequired :: ST.Metadata -> Required
isRequired md = if hasAttribute requiredAttr md then Req else Opt

validateDefaultValAsInt :: forall a. (Integral a, Bounded a, Display a) => Maybe ST.DefaultVal -> Validation (DefaultVal Integer)
validateDefaultValAsInt dflt =
  case dflt of
    Nothing                -> pure (DefaultVal 0)
    Just (ST.DefaultNum n) -> scientificToInteger @a n "default value must be integral"
    Just _                 -> throwErrorMsg "default value must be integral"

validateDefaultValAsScientific :: Maybe ST.DefaultVal -> Validation (DefaultVal Scientific)
validateDefaultValAsScientific dflt =
  case dflt of
    Nothing                 -> pure (DefaultVal 0)
    Just (ST.DefaultNum n)  -> pure (DefaultVal n)
    Just _                  -> throwErrorMsg "default value must be a number"

validateDefaultValAsBool :: Maybe ST.DefaultVal -> Validation (DefaultVal Bool)
validateDefaultValAsBool dflt =
  case dflt of
    Nothing                 -> pure (DefaultVal False)
    Just (ST.DefaultBool b) -> pure (DefaultVal b)
    Just _                  -> throwErrorMsg "default value must be a boolean"

validateDefaultAsEnum :: Maybe ST.DefaultVal -> EnumDecl -> Validation (DefaultVal Integer)
validateDefaultAsEnum dflt enum =
  case dflt of
    Nothing ->
      if enumBitFlags enum
        then pure 0
        else
          case find (\val -> enumValInt val == 0) (enumVals enum) of
            Just _  -> pure 0
            Nothing -> throwErrorMsg "enum does not have a 0 value; please manually specify a default for this field"
    Just (ST.DefaultNum n) ->
      if enumBitFlags enum
        then
          case enumType enum of
            EWord8  -> scientificToInteger @Word8  n defaultErrorMsg
            EWord16 -> scientificToInteger @Word16 n defaultErrorMsg
            EWord32 -> scientificToInteger @Word32 n defaultErrorMsg
            EWord64 -> scientificToInteger @Word64 n defaultErrorMsg
            _       -> throwErrorMsg "The 'impossible' has happened: bit_flags enum with signed integer"
        else
          case Scientific.floatingOrInteger @Float n of
            Left _float -> throwErrorMsg defaultErrorMsg
            Right i ->
              case find (\val -> enumValInt val == i) (enumVals enum) of
                Just matchingVal -> pure (DefaultVal (enumValInt matchingVal))
                Nothing -> throwErrorMsg $ "default value of " <> display i <> " is not part of enum " <> display (getIdent enum)
    Just (ST.DefaultRef refs) ->
      if enumBitFlags enum
        then
          foldr1 (.|.) <$> traverse findEnumByRef refs
        else
          case refs of
            ref :| [] -> findEnumByRef ref
            _         -> throwErrorMsg $ "default value must be a single identifier, found "
                          <> display (NE.length refs)
                          <> ": "
                          <> display (fmap (\ref -> "'" <> ref <> "'") refs)
    Just (ST.DefaultBool _) ->
      throwErrorMsg defaultErrorMsg
  where
    defaultErrorMsg =
      if enumBitFlags enum
        then case enumVals enum of
          x :| y : _ ->
            "default value must be integral, one of ["
            <> display (getIdent <$> enumVals enum)
            <> "], or a combination of the latter in double quotes (e.g. \""
            <> T.unpack (unIdent (getIdent x))
            <> " "
            <> T.unpack (unIdent (getIdent y))
            <> "\")"
          _ ->
            "default value must be integral or one of: " <> display (getIdent <$> enumVals enum)
        else
          "default value must be integral or one of: " <> display (getIdent <$> enumVals enum)

    findEnumByRef :: Text -> Validation (DefaultVal Integer)
    findEnumByRef ref =
      case find (\val -> unIdent (getIdent val) == ref) (enumVals enum) of
        Just matchingVal -> pure (DefaultVal (enumValInt matchingVal))
        Nothing          -> throwErrorMsg $ "default value of " <> display ref <> " is not part of enum " <> display (getIdent enum)

scientificToInteger ::
  forall a. (Integral a, Bounded a, Display a)
  => Scientific -> String -> Validation (DefaultVal Integer)
scientificToInteger n notIntegerErrorMsg =
  if not (Scientific.isInteger n)
    then throwErrorMsg notIntegerErrorMsg
    else
      case Scientific.toBoundedInteger @a n of
        Nothing ->
          throwErrorMsg $
            "default value does not fit ["
            <> display (minBound @a)
            <> "; "
            <> display (maxBound @a)
            <> "]"
        Just i -> pure (DefaultVal (toInteger i))

----------------------------------
------------ Unions --------------
----------------------------------
validateUnions :: FileTree Stage4 -> Validation (FileTree ValidDecls)
validateUnions symbolTables =
  for symbolTables $ \symbolTable -> do
    validUnions <- Map.traverseWithKey (validateUnion symbolTables) (allUnions symbolTable)
    pure symbolTable { allUnions = validUnions }

validateUnion :: FileTree Stage4 -> (Namespace, Ident) -> ST.UnionDecl -> Validation UnionDecl
validateUnion symbolTables (currentNamespace, _) union =
  validating (qualify currentNamespace union) $ do
    validUnionVals <- traverse validateUnionVal (ST.unionVals union)
    checkDuplicateVals validUnionVals
    checkUndeclaredAttributes union
    pure $ UnionDecl
      { unionIdent = getIdent union
      , unionVals = validUnionVals
      }
  where
    validateUnionVal :: ST.UnionVal -> Validation UnionVal
    validateUnionVal uv = do
      let tref = ST.unionValTypeRef uv
      let partiallyQualifiedTypeRef = qualify (typeRefNamespace tref) (typeRefIdent tref)
      let ident = fromMaybe partiallyQualifiedTypeRef (ST.unionValIdent uv)
      let identFormatted = coerce $ T.replace "." "_" $ coerce ident
      validating identFormatted $ do
        tableRef <- validateUnionValType tref
        pure $ UnionVal
          { unionValIdent = identFormatted
          , unionValTableRef = tableRef
          }

    validateUnionValType :: TypeRef -> Validation TypeRef
    validateUnionValType typeRef =
      findDecl currentNamespace symbolTables typeRef >>= \case
        MatchT ns table -> pure $ TypeRef ns (getIdent table)
        _               -> throwErrorMsg "union members may only be tables"

    checkDuplicateVals :: NonEmpty UnionVal -> Validation ()
    checkDuplicateVals vals = checkDuplicateIdentifiers (NE.cons "NONE" (fmap getIdent vals))


----------------------------------
------------ Structs -------------
----------------------------------

-- | Cache of already validated structs.
--
-- When we're validating a struct @A@, it may contain an inner struct @B@ which also needs validating.
-- @B@ needs to be fully validated before we can consider @A@ valid.
--
-- If we've validated @B@ in a previous iteration, we will find it in this Map
-- and therefore avoid re-validating it.
type ValidatedStructs = Map (Namespace, Ident) StructDecl


validateStructs :: FileTree Stage2 -> Validation (FileTree Stage3)
validateStructs symbolTables =
  flip evalStateT Map.empty $ traverse validateFile symbolTables
  where
  validateFile :: Stage2 -> StateT ValidatedStructs Validation Stage3
  validateFile symbolTable = do
    let structs = allStructs symbolTable

    traverse_ (\((ns, _), struct) -> checkStructCycles symbolTables (ns, struct)) (Map.toList structs)
    validStructs <- Map.traverseWithKey (\(ns, _) struct -> validateStruct symbolTables ns struct) structs

    pure symbolTable { allStructs = validStructs }

checkStructCycles :: forall m. MonadValidation m => FileTree Stage2 -> (Namespace, ST.StructDecl) -> m ()
checkStructCycles symbolTables = go []
  where
    go :: [Ident] -> (Namespace, ST.StructDecl) -> m ()
    go visited (currentNamespace, struct) = do
      let qualifiedName = qualify currentNamespace struct
      resetContext $
        validating qualifiedName $
          if qualifiedName `elem` visited
            then
              throwErrorMsg $
                "cyclic dependency detected ["
                <> display (T.intercalate " -> " . coerce $ List.dropWhile (/= qualifiedName) $ List.reverse (qualifiedName : visited))
                <>"] - structs cannot contain themselves, directly or indirectly"
            else
              forM_ (ST.structFields struct) $ \field ->
                validating field $
                  case ST.structFieldType field of
                    ST.TRef typeRef ->
                      findDecl currentNamespace symbolTables typeRef >>= \case
                        MatchS ns struct -> go (qualifiedName : visited) (ns, struct)
                        _                -> pure () -- The TypeRef points to an enum (or is invalid), so no further validation is needed at this point
                    _                    -> pure () -- Field is not a TypeRef, no validation needed

data UnpaddedStructField = UnpaddedStructField
  { unpaddedStructFieldIdent :: !Ident
  , unpaddedStructFieldType  :: !StructFieldType
  } deriving (Show, Eq)

validateStruct ::
     forall m. (MonadState ValidatedStructs m, MonadValidation m)
  => FileTree Stage2
  -> Namespace
  -> ST.StructDecl
  -> m StructDecl
validateStruct symbolTables currentNamespace struct =
  resetContext $
  validating (qualify currentNamespace struct) $ do
    validStructs <- get
    -- Check if this struct has already been validated in a previous iteration
    case Map.lookup (currentNamespace, getIdent struct) validStructs of
      Just match -> pure match
      Nothing -> do
        checkDuplicateFields
        checkUndeclaredAttributes struct

        fields <- traverse validateStructField (ST.structFields struct)
        let naturalAlignment = maximum (structFieldAlignment <$> fields)
        forceAlignAttrVal <- getForceAlignAttr
        forceAlign <- traverse (validateForceAlign naturalAlignment) forceAlignAttrVal
        let alignment = fromMaybe naturalAlignment forceAlign

        -- In order to calculate the padding between fields, we must first know the fields' and the struct's
        -- alignment. Which means we must first validate all the struct's fields, and then do a second
        -- pass to calculate the padding.
        let (size, paddedFields) = addFieldPadding alignment fields

        let validStruct = StructDecl
              { structIdent      = getIdent struct
              , structAlignment  = alignment
              , structSize       = size
              , structFields     = paddedFields
              }
        modify (Map.insert (currentNamespace, getIdent validStruct) validStruct)
        pure validStruct

  where
    invalidStructFieldType = "struct fields may only be integers, floating point, bool, enums, or other structs"

    -- | Calculates how much padding each field needs, and returns the struct's total size
    -- and a list of fields with padding information.
    addFieldPadding :: Alignment -> NonEmpty UnpaddedStructField -> (InlineSize, NonEmpty StructField)
    addFieldPadding structAlignment unpaddedFields =
      (size, NE.fromList (reverse paddedFields))
      where

        (size, paddedFields) = go 0 [] (NE.toList unpaddedFields)

        go :: InlineSize -> [StructField] -> [UnpaddedStructField] -> (InlineSize, [StructField])
        go size paddedFields [] = (size, paddedFields)
        go size paddedFields (x : y : tail) =
          let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
              nextFieldsAlignment = fromIntegral @Alignment @InlineSize (structFieldAlignment y)
              paddingNeeded = (size' `roundUpToNearestMultipleOf` nextFieldsAlignment) - size'
              size'' = size' + paddingNeeded
              paddedField = StructField
                { structFieldIdent = unpaddedStructFieldIdent x
                -- NOTE: it is safe to narrow `paddingNeeded` to a word8 here because it's always smaller than `nextFieldsAlignment`
                , structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
                , structFieldOffset = coerce size
                , structFieldType = unpaddedStructFieldType x
                }
          in  go size'' (paddedField : paddedFields) (y : tail)
        go size paddedFields [x] =
          let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
              structAlignment' = fromIntegral @Alignment @InlineSize structAlignment
              paddingNeeded = (size' `roundUpToNearestMultipleOf` structAlignment') - size'
              size'' = size' + paddingNeeded
              paddedField = StructField
                { structFieldIdent = unpaddedStructFieldIdent x
                -- NOTE: it is safe to narrow `paddingNeeded` to a word8 here because it's always smaller than `nextFieldsAlignment`
                , structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
                , structFieldOffset = coerce size
                , structFieldType = unpaddedStructFieldType x
                }
          in  (size'', paddedField : paddedFields)

    validateStructField :: ST.StructField -> m UnpaddedStructField
    validateStructField sf =
      validating sf $ do
        checkUnsupportedAttributes sf
        checkUndeclaredAttributes sf
        structFieldType <- validateStructFieldType (ST.structFieldType sf)
        pure $ UnpaddedStructField
          { unpaddedStructFieldIdent = getIdent sf
          , unpaddedStructFieldType = structFieldType
          }

    validateStructFieldType :: ST.Type -> m StructFieldType
    validateStructFieldType structFieldType =
      case structFieldType of
        ST.TInt8 -> pure SInt8
        ST.TInt16 -> pure SInt16
        ST.TInt32 -> pure SInt32
        ST.TInt64 -> pure SInt64
        ST.TWord8 -> pure SWord8
        ST.TWord16 -> pure SWord16
        ST.TWord32 -> pure SWord32
        ST.TWord64 -> pure SWord64
        ST.TFloat -> pure SFloat
        ST.TDouble -> pure SDouble
        ST.TBool -> pure SBool
        ST.TString -> throwErrorMsg invalidStructFieldType
        ST.TVector _ -> throwErrorMsg invalidStructFieldType
        ST.TRef typeRef ->
          findDecl currentNamespace symbolTables typeRef >>= \case
            MatchE enumNamespace enum ->
              pure (SEnum (TypeRef enumNamespace (getIdent enum)) (enumType enum))
            MatchS nestedNamespace nestedStruct -> do
              -- if this is a reference to a struct, we need to validate it first
              validNestedStruct <- validateStruct symbolTables nestedNamespace nestedStruct
              pure $ SStruct (nestedNamespace, validNestedStruct)
            _ -> throwErrorMsg invalidStructFieldType

    checkUnsupportedAttributes :: ST.StructField -> m ()
    checkUnsupportedAttributes structField = do
      when (hasAttribute deprecatedAttr (ST.structFieldMetadata structField)) $
        throwErrorMsg "can't deprecate fields in a struct"
      when (hasAttribute requiredAttr (ST.structFieldMetadata structField)) $
        throwErrorMsg "struct fields are already required, the 'required' attribute is redundant"
      when (hasAttribute idAttr (ST.structFieldMetadata structField)) $
        throwErrorMsg "struct fields cannot be reordered using the 'id' attribute"

    getForceAlignAttr :: m (Maybe Integer)
    getForceAlignAttr = findIntAttr forceAlignAttr (ST.structMetadata struct)

    validateForceAlign :: Alignment -> Integer -> m Alignment
    validateForceAlign naturalAlignment forceAlign =
      if isPowerOfTwo forceAlign
        && inRange (fromIntegral @Alignment @Integer naturalAlignment, 16) forceAlign
        then pure (fromIntegral @Integer @Alignment forceAlign)
        else throwErrorMsg $
              "force_align must be a power of two integer ranging from the struct's natural alignment (in this case, "
              <> display naturalAlignment
              <> ") to 16"

    checkDuplicateFields :: m ()
    checkDuplicateFields =
      checkDuplicateIdentifiers
        (ST.structFields struct)

----------------------------------
------------ Helpers -------------
----------------------------------
structFieldAlignment :: UnpaddedStructField -> Alignment
structFieldAlignment usf =
  case unpaddedStructFieldType usf of
    SInt8 -> int8Size
    SInt16 -> int16Size
    SInt32 -> int32Size
    SInt64 -> int64Size
    SWord8 -> word8Size
    SWord16 -> word16Size
    SWord32 -> word32Size
    SWord64 -> word64Size
    SFloat -> floatSize
    SDouble -> doubleSize
    SBool -> boolSize
    SEnum _ enumType -> enumAlignment enumType
    SStruct (_, nestedStruct) -> structAlignment nestedStruct

enumAlignment :: EnumType -> Alignment
enumAlignment = Alignment . enumSize

-- | The size of an enum is either 1, 2, 4 or 8 bytes, so its size fits in a Word8
enumSize :: EnumType -> Word8
enumSize e =
  case e of
    EInt8 -> int8Size
    EInt16 -> int16Size
    EInt32 -> int32Size
    EInt64 -> int64Size
    EWord8 -> word8Size
    EWord16 -> word16Size
    EWord32 -> word32Size
    EWord64 -> word64Size

structFieldTypeSize :: StructFieldType -> InlineSize
structFieldTypeSize sft =
  case sft of
    SInt8 -> int8Size
    SInt16 -> int16Size
    SInt32 -> int32Size
    SInt64 -> int64Size
    SWord8 -> word8Size
    SWord16 -> word16Size
    SWord32 -> word32Size
    SWord64 -> word64Size
    SFloat -> floatSize
    SDouble -> doubleSize
    SBool -> boolSize
    SEnum _ enumType -> fromIntegral @Word8 @InlineSize (enumSize enumType)
    SStruct (_, nestedStruct) -> structSize nestedStruct

checkDuplicateIdentifiers :: (MonadValidation m, Foldable f, Functor f, HasIdent a) => f a -> m ()
checkDuplicateIdentifiers xs =
  case findDups (getIdent <$> xs) of
    [] -> pure ()
    dups ->
      throwErrorMsg $
        display dups <> " declared more than once"
  where
    findDups :: (Foldable f, Functor f, Ord a) => f a -> [a]
    findDups xs = Map.keys $ Map.filter (>1) $ occurrences xs

    occurrences :: (Foldable f, Functor f, Ord a) => f a -> Map a (Sum Int)
    occurrences xs =
      Map.unionsWith (<>) $ Foldable.toList $ fmap (\x -> Map.singleton x (Sum 1)) xs

checkUndeclaredAttributes :: (MonadValidation m, HasMetadata a) => a -> m ()
checkUndeclaredAttributes a = do
  allAttributes <- getDeclaredAttributes
  forM_ (Map.keys . ST.unMetadata . getMetadata $ a) $ \attr ->
    when (coerce attr `Set.notMember` allAttributes) $
      throwErrorMsg $ "user defined attributes must be declared before use: " <> display attr

hasAttribute :: Text -> ST.Metadata -> Bool
hasAttribute name (ST.Metadata attrs) = Map.member name attrs

findIntAttr :: MonadValidation m => Text -> ST.Metadata -> m (Maybe Integer)
findIntAttr name (ST.Metadata attrs) =
  case Map.lookup name attrs of
    Nothing                  -> pure Nothing
    Just Nothing             -> err
    Just (Just (ST.AttrI i)) -> pure (Just i)
    Just (Just (ST.AttrS t)) ->
      case readMaybe @Integer (T.unpack t) of
        Just i  -> pure (Just i)
        Nothing -> err
  where
    err =
      throwErrorMsg $
        "expected attribute '"
        <> display name
        <> "' to have an integer value, e.g. '"
        <> display name
        <> ": 123'"

findStringAttr :: Text -> ST.Metadata -> Validation (Maybe Text)
findStringAttr name (ST.Metadata attrs) =
  case Map.lookup name attrs of
    Nothing                  -> pure Nothing
    Just (Just (ST.AttrS s)) -> pure (Just s)
    Just _ ->
      throwErrorMsg $
        "expected attribute '"
        <> display name
        <> "' to have a string value, e.g. '"
        <> display name
        <> ": \"abc\"'"

isPowerOfTwo :: (Num a, Bits a) => a -> Bool
isPowerOfTwo 0 = False
isPowerOfTwo n = (n .&. (n - 1)) == 0

roundUpToNearestMultipleOf :: Integral n => n -> n -> n
roundUpToNearestMultipleOf x y =
  case x `rem` y of
    0         -> x
    remainder -> (y - remainder) + x