{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}

module FlatBuffers.Internal.Compiler.SemanticAnalysis where

import           Control.Monad                                 ( forM_, join, when )
import           Control.Monad.Except                          ( MonadError, throwError )
import           Control.Monad.Reader                          ( MonadReader(..), asks, runReaderT )
import           Control.Monad.State                           ( MonadState, State, StateT, evalState, evalStateT, get, modify, put )

import           Data.Coerce                                   ( coerce )
import           Data.Foldable                                 ( asum, find, foldlM, traverse_ )
import qualified Data.Foldable                                 as Foldable
import           Data.Functor                                  ( ($>), (<&>) )
import           Data.Int
import           Data.Ix                                       ( inRange )
import qualified Data.List                                     as List
import           Data.List.NonEmpty                            ( NonEmpty )
import qualified Data.List.NonEmpty                            as NE
import           Data.Map.Strict                               ( Map )
import qualified Data.Map.Strict                               as Map
import           Data.Maybe                                    ( catMaybes, fromMaybe, isJust )
import           Data.Monoid                                   ( Sum(..) )
import           Data.Scientific                               ( Scientific )
import qualified Data.Scientific                               as Scientific
import           Data.Set                                      ( Set )
import qualified Data.Set                                      as Set
import           Data.Text                                     ( Text )
import qualified Data.Text                                     as T
import           Data.Traversable                              ( for )
import           Data.Word

import           FlatBuffers.Internal.Compiler.Display         ( Display(..) )
import           FlatBuffers.Internal.Compiler.SyntaxTree      ( FileTree(..), HasIdent(..), HasMetadata(..), Ident, Namespace, Schema, TypeRef(..), qualify )
import qualified FlatBuffers.Internal.Compiler.SyntaxTree      as ST
import           FlatBuffers.Internal.Compiler.ValidSyntaxTree
import           FlatBuffers.Internal.Constants
import           FlatBuffers.Internal.Types
import           FlatBuffers.Internal.Util                     ( isPowerOfTwo, roundUpToNearestMultipleOf )

import           Text.Read                                     ( readMaybe )


type ValidationCtx m = (MonadError Text m, MonadReader ValidationState m)

data ValidationState = ValidationState
  { validationStateCurrentContext :: !Ident
    -- ^ The thing being validated (e.g. a fully-qualified struct name, or a table field name).
  , validationStateAllAttributes  :: !(Set ST.AttributeDecl)
    -- ^ All the attributes declared in all the schemas (including imported ones).
  }


modifyContext :: ValidationCtx m => (Ident -> Ident) -> m a -> m a
modifyContext f =
  local $ \s ->
    s { validationStateCurrentContext = f (validationStateCurrentContext s) }

data SymbolTable enum struct table union = SymbolTable
  { allEnums   :: ![(Namespace, enum)]
  , allStructs :: ![(Namespace, struct)]
  , allTables  :: ![(Namespace, table)]
  , allUnions  :: ![(Namespace, union)]
  }
  deriving (Eq, Show)

instance Semigroup (SymbolTable e s t u)  where
  SymbolTable e1 s1 t1 u1 <> SymbolTable e2 s2 t2 u2 =
    SymbolTable (e1 <> e2) (s1 <> s2) (t1 <> t2) (u1 <> u2)

instance Monoid (SymbolTable e s t u) where
  mempty = SymbolTable [] [] [] []

type Stage1     = SymbolTable ST.EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage2     = SymbolTable    EnumDecl ST.StructDecl ST.TableDecl ST.UnionDecl
type Stage3     = SymbolTable    EnumDecl    StructDecl ST.TableDecl ST.UnionDecl
type Stage4     = SymbolTable    EnumDecl    StructDecl    TableDecl ST.UnionDecl
type ValidDecls = SymbolTable    EnumDecl    StructDecl    TableDecl    UnionDecl

-- | Takes a collection of schemas, and pairs each type declaration with its corresponding namespace
createSymbolTables :: FileTree Schema -> FileTree Stage1
createSymbolTables = fmap (pairDeclsWithNamespaces . ST.decls)
  where
    pairDeclsWithNamespaces :: [ST.Decl] -> Stage1
    pairDeclsWithNamespaces = snd . foldl go ("", mempty)

    go :: (Namespace, Stage1) -> ST.Decl -> (Namespace, Stage1)
    go (currentNamespace, decls) decl =
      case decl of
        ST.DeclN (ST.NamespaceDecl newNamespace) -> (newNamespace, decls)
        ST.DeclE enum   -> (currentNamespace, decls <> SymbolTable [(currentNamespace, enum)] [] [] [])
        ST.DeclS struct -> (currentNamespace, decls <> SymbolTable [] [(currentNamespace, struct)] [] [])
        ST.DeclT table  -> (currentNamespace, decls <> SymbolTable [] [] [(currentNamespace, table)] [])
        ST.DeclU union  -> (currentNamespace, decls <> SymbolTable [] [] [] [(currentNamespace, union)])
        _               -> (currentNamespace, decls)

validateSchemas :: MonadError Text m => FileTree Schema -> m (FileTree ValidDecls)
validateSchemas schemas =
  flip runReaderT (ValidationState "" allAttributes) $ do
    checkDuplicateIdentifiers allQualifiedTopLevelIdentifiers
    validateEnums symbolTables
      >>= validateStructs
      >>= validateTables
      >>= validateUnions
      >>= updateRootTable (fileTreeRoot schemas)
  where
    symbolTables = createSymbolTables schemas

    allQualifiedTopLevelIdentifiers =
      flip concatMap symbolTables $ \symbolTable ->
        join
          [ uncurry qualify <$> allEnums symbolTable
          , uncurry qualify <$> allStructs symbolTable
          , uncurry qualify <$> allTables symbolTable
          , uncurry qualify <$> allUnions symbolTable
          ]

    declaredAttributes =
      flip concatMap schemas $ \schema ->
        [ attr | ST.DeclA attr <- ST.decls schema ]

    allAttributes = Set.fromList $ declaredAttributes <> knownAttributes

----------------------------------
------------ Root Type -----------
----------------------------------
data RootInfo = RootInfo
  { rootTableNamespace :: !Namespace
  , rootTable          :: !TableDecl
  , rootFileIdent      :: !(Maybe Text)
  }

-- | Finds the root table (if any) and sets the `tableIsRoot` flag accordingly.
-- We only care about @root_type@ declarations in the root schema. Imported schemas are not scanned for @root_type@s.
-- The root type declaration can point to a table in any schema (root or imported).
updateRootTable :: forall m. ValidationCtx m => Schema -> FileTree ValidDecls -> m (FileTree ValidDecls)
updateRootTable schema symbolTables =
  getRootInfo schema symbolTables <&> \case
    Just rootInfo -> updateSymbolTable rootInfo <$> symbolTables
    Nothing       -> symbolTables

  where
    updateSymbolTable :: RootInfo -> ValidDecls -> ValidDecls
    updateSymbolTable rootInfo st = st { allTables = updateTable rootInfo <$> allTables st}

    updateTable :: RootInfo -> (Namespace, TableDecl) -> (Namespace, TableDecl)
    updateTable (RootInfo rootTableNamespace rootTable fileIdent) pair@(namespace, table) =
      if namespace == rootTableNamespace && table == rootTable
        then (namespace, table { tableIsRoot = IsRoot fileIdent })
        else pair

getRootInfo :: forall m. ValidationCtx m => Schema -> FileTree ValidDecls -> m (Maybe RootInfo)
getRootInfo schema symbolTables =
  foldlM go ("", Nothing, Nothing) (ST.decls schema) <&> \case
    (_, Just (rootTableNamespace, rootTable), fileIdent) -> Just $ RootInfo rootTableNamespace rootTable fileIdent
    _ -> Nothing
  where
    go :: (Namespace, Maybe (Namespace, TableDecl), Maybe Text) -> ST.Decl -> m (Namespace, Maybe (Namespace, TableDecl), Maybe Text)
    go state@(currentNamespace, rootInfo, fileIdent) decl =
      case decl of
        ST.DeclN (ST.NamespaceDecl newNamespace)       -> pure (newNamespace, rootInfo, fileIdent)
        ST.DeclFI (ST.FileIdentifierDecl newFileIdent) -> pure (currentNamespace, rootInfo, Just (coerce newFileIdent))
        ST.DeclR (ST.RootDecl typeRef)                 ->
          findDecl currentNamespace symbolTables typeRef >>= \case
            MatchT (rootTableNamespace, rootTable)  -> pure (currentNamespace, Just (rootTableNamespace, rootTable), fileIdent)
            _                                       -> throwErrorMsg "root type must be a table"
        _ -> pure state


----------------------------------
----------- Attributes -----------
----------------------------------
knownAttributes :: [ST.AttributeDecl]
knownAttributes =
  coerce
    [ idAttr
    , deprecatedAttr
    , requiredAttr
    , forceAlignAttr
    , bitFlagsAttr
    ]
  <> otherKnownAttributes

idAttr, deprecatedAttr, requiredAttr, forceAlignAttr, bitFlagsAttr :: Text
idAttr          = "id"
deprecatedAttr  = "deprecated"
requiredAttr    = "required"
forceAlignAttr  = "force_align"
bitFlagsAttr    = "bit_flags"

otherKnownAttributes :: [ST.AttributeDecl]
otherKnownAttributes =
  -- https://google.github.io/flatbuffers/flatbuffers_guide_writing_schema.html
  [ "nested_flatbuffer"
  , "flexbuffer"
  , "key"
  , "hash"
  , "original_order"
  -- https://google.github.io/flatbuffers/flatbuffers_guide_use_cpp.html#flatbuffers_cpp_object_based_api
  , "native_inline"
  , "native_default"
  , "native_custom_alloc"
  , "native_type"
  , "cpp_type"
  , "cpp_ptr_type"
  , "cpp_str_type"
  , "cpp_str_flex_ctor"
  , "shared"
  ]

----------------------------------
--------- Symbol search ----------
----------------------------------
data Match enum struct table union
  = MatchE !(Namespace, enum)
  | MatchS !(Namespace, struct)
  | MatchT !(Namespace, table)
  | MatchU !(Namespace, union)

-- | Looks for a type reference in a set of type declarations.
-- If none is found, the list of namespaces in which the type reference was searched for is returned.
findDecl ::
     ValidationCtx m
  => (HasIdent e, HasIdent s, HasIdent t, HasIdent u)
  => Namespace
  -> FileTree (SymbolTable e s t u)
  -> TypeRef
  -> m (Match e s t u)
findDecl currentNamespace symbolTables typeRef@(TypeRef refNamespace refIdent) =
  let parentNamespaces' = parentNamespaces currentNamespace
      results = do
        parentNamespace <- parentNamespaces'
        let candidateNamespace = parentNamespace <> refNamespace
        let searchSymbolTable symbolTable =
              asum
                [ MatchE <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allEnums symbolTable)
                , MatchS <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allStructs symbolTable)
                , MatchT <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allTables symbolTable)
                , MatchU <$> find (\(ns, e) -> ns == candidateNamespace && getIdent e == refIdent) (allUnions symbolTable)
                ]
        pure $ asum $ fmap searchSymbolTable symbolTables
  in
    case asum results of
      Just match -> pure match
      Nothing    ->
        throwErrorMsg $
          "type '"
          <> display typeRef
          <> "' does not exist (checked in these namespaces: "
          <> display parentNamespaces'
          <> ")"

-- | Returns a list of all the namespaces "between" the current namespace
-- and the root namespace, in that order.
-- See: https://github.com/google/flatbuffers/issues/5234#issuecomment-471680403
--
-- > parentNamespaces "A.B.C" == ["A.B.C", "A.B", "A", ""]
parentNamespaces :: ST.Namespace -> NonEmpty ST.Namespace
parentNamespaces (ST.Namespace ns) =
  coerce $ NE.reverse $ NE.inits ns

----------------------------------
------------- Enums --------------
----------------------------------
validateEnums :: forall m. ValidationCtx m => FileTree Stage1 -> m (FileTree Stage2)
validateEnums symbolTables =
  for symbolTables $ \symbolTable -> do
    let enums = allEnums symbolTable
    let validate (namespace, enum) = do
          validEnum <- validateEnum (namespace, enum)
          pure (namespace, validEnum)
    validEnums <- traverse validate enums
    pure symbolTable { allEnums = validEnums }

-- TODO: add support for `bit_flags` attribute
validateEnum :: forall m. ValidationCtx m => (Namespace, ST.EnumDecl) -> m EnumDecl
validateEnum (currentNamespace, enum) =
  modifyContext (\_ -> qualify currentNamespace enum) $ do
    checkBitFlags
    checkDuplicateFields
    checkUndeclaredAttributes enum
    validEnum
  where
    validEnum = do
      enumType <- validateEnumType (ST.enumType enum)
      let enumVals = flip evalState Nothing . traverse mapEnumVal $ ST.enumVals enum
      validateOrder enumVals
      traverse_ (validateBounds enumType) enumVals
      pure EnumDecl
        { enumIdent = getIdent enum
        , enumType = enumType
        , enumVals = enumVals
        }

    mapEnumVal :: ST.EnumVal -> State (Maybe Integer) EnumVal
    mapEnumVal enumVal = do
      thisInt <-
        case ST.enumValLiteral enumVal of
          Just (ST.IntLiteral thisInt) ->
            pure thisInt
          Nothing ->
            get <&> \case
              Just lastInt -> lastInt + 1
              Nothing      -> 0
      put (Just thisInt)
      pure (EnumVal (getIdent enumVal) thisInt)

    validateOrder :: NonEmpty EnumVal -> m ()
    validateOrder xs =
      if all (\(x, y) -> enumValInt x < enumValInt y) (NE.toList xs `zip` NE.tail xs)
        then pure ()
        else throwErrorMsg "enum values must be specified in ascending order"

    validateBounds :: EnumType -> EnumVal -> m ()
    validateBounds enumType enumVal =
      modifyContext (\context -> context <> "." <> getIdent enumVal) $
        case enumType of
          EInt8 -> validateBounds' @Int8 enumVal
          EInt16 -> validateBounds' @Int16 enumVal
          EInt32 -> validateBounds' @Int32 enumVal
          EInt64 -> validateBounds' @Int64 enumVal
          EWord8 -> validateBounds' @Word8 enumVal
          EWord16 -> validateBounds' @Word16 enumVal
          EWord32 -> validateBounds' @Word32 enumVal
          EWord64 -> validateBounds' @Word64 enumVal

    validateBounds' :: forall a. (Integral a, Bounded a, Show a) => EnumVal -> m ()
    validateBounds' e =
      if inRange (toInteger (minBound @a), toInteger (maxBound @a)) (enumValInt e)
        then pure ()
        else throwErrorMsg $
              "enum value does not fit ["
              <> T.pack (show (minBound @a))
              <> "; "
              <> T.pack (show (maxBound @a))
              <> "]"

    validateEnumType :: ST.Type -> m EnumType
    validateEnumType t =
      case t of
        ST.TInt8 -> pure EInt8
        ST.TInt16 -> pure EInt16
        ST.TInt32 -> pure EInt32
        ST.TInt64 -> pure EInt64
        ST.TWord8 -> pure EWord8
        ST.TWord16 -> pure EWord16
        ST.TWord32 -> pure EWord32
        ST.TWord64 -> pure EWord64
        _          -> throwErrorMsg "underlying enum type must be integral"

    checkDuplicateFields :: m ()
    checkDuplicateFields =
      checkDuplicateIdentifiers
        (ST.enumVals enum)

    checkBitFlags :: m ()
    checkBitFlags =
      when (hasAttribute bitFlagsAttr (ST.enumMetadata enum)) $
        throwErrorMsg "`bit_flags` are not supported yet"


----------------------------------
------------ Tables --------------
----------------------------------
data TableFieldWithoutId = TableFieldWithoutId !Ident !TableFieldType !Bool

validateTables :: ValidationCtx m => FileTree Stage3 -> m (FileTree Stage4)
validateTables symbolTables =
  for symbolTables $ \symbolTable -> do
    let tables = allTables symbolTable
    let validate (namespace, table) = do
          validTable <- validateTable symbolTables (namespace, table)
          pure (namespace, validTable)
    validTables <- traverse validate tables
    pure symbolTable { allTables = validTables }

validateTable :: forall m. ValidationCtx m => FileTree Stage3 -> (Namespace, ST.TableDecl) -> m TableDecl
validateTable symbolTables (currentNamespace, table) =
  modifyContext (\_ -> qualify currentNamespace table) $ do

    let fields = ST.tableFields table
    let fieldsMetadata = ST.tableFieldMetadata <$> fields

    checkDuplicateFields fields
    checkUndeclaredAttributes table

    validFieldsWithoutIds <- traverse validateTableField fields
    validFields <- assignFieldIds fieldsMetadata validFieldsWithoutIds

    pure TableDecl
      { tableIdent = getIdent table
      , tableIsRoot = NotRoot
      , tableFields = validFields
      }

  where
    checkDuplicateFields :: [ST.TableField] -> m ()
    checkDuplicateFields = checkDuplicateIdentifiers

    assignFieldIds :: [ST.Metadata] -> [TableFieldWithoutId] -> m [TableField]
    assignFieldIds metadata fieldsWithoutIds = do
      ids <- catMaybes <$> traverse (findIntAttr idAttr) metadata
      if null ids
        then pure $ evalState (traverse assignFieldId fieldsWithoutIds) (-1)
        else if length ids == length fieldsWithoutIds
          then do
            let fields = zipWith (\(TableFieldWithoutId ident typ depr) id -> TableField id ident typ depr) fieldsWithoutIds ids
            let sorted = List.sortOn tableFieldId fields
            evalStateT (traverse_ checkFieldId sorted) (-1)
            pure sorted
          else
            throwErrorMsg "either all fields or no fields must have an 'id' attribute"

    assignFieldId :: TableFieldWithoutId -> State Integer TableField
    assignFieldId (TableFieldWithoutId ident typ depr) = do
      lastId <- get
      let fieldId =
            case typ of
              TUnion _ _           -> lastId + 2
              TVector _ (VUnion _) -> lastId + 2
              _                    -> lastId + 1
      put fieldId
      pure (TableField fieldId ident typ depr)

    checkFieldId :: TableField -> StateT Integer m ()
    checkFieldId field = do
      lastId <- get
      modifyContext (\context -> context <> "." <> getIdent field) $ do
        case tableFieldType field of
          TUnion _ _ ->
            when (tableFieldId field /= lastId + 2) $
              throwErrorMsg "the id of a union field must be the last field's id + 2"
          TVector _ (VUnion _) ->
            when (tableFieldId field /= lastId + 2) $
              throwErrorMsg "the id of a vector of unions field must be the last field's id + 2"
          _ ->
            when (tableFieldId field /= lastId + 1) $
              throwErrorMsg $ "field ids must be consecutive from 0; id " <> display (lastId + 1) <> " is missing"
        put (tableFieldId field)

    validateTableField :: ST.TableField -> m TableFieldWithoutId
    validateTableField tf =
      modifyContext (\context -> context <> "." <> getIdent tf) $ do
        checkUndeclaredAttributes tf
        validFieldType <- validateTableFieldType (ST.tableFieldMetadata tf) (ST.tableFieldDefault tf) (ST.tableFieldType tf)

        pure $ TableFieldWithoutId
          (getIdent tf)
          validFieldType
          (hasAttribute deprecatedAttr (ST.tableFieldMetadata tf))

    validateTableFieldType :: ST.Metadata -> Maybe ST.DefaultVal -> ST.Type -> m TableFieldType
    validateTableFieldType md dflt tableFieldType =
      case tableFieldType of
        ST.TInt8 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt8
        ST.TInt16 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt16
        ST.TInt32 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt32
        ST.TInt64 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TInt64
        ST.TWord8 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord8
        ST.TWord16 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord16
        ST.TWord32 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord32
        ST.TWord64 -> checkNoRequired md >> validateDefaultValAsInt dflt <&> TWord64
        ST.TFloat -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TFloat
        ST.TDouble -> checkNoRequired md >> validateDefaultValAsScientific dflt <&> TDouble
        ST.TBool -> checkNoRequired md >> validateDefaultValAsBool dflt <&> TBool
        ST.TString -> checkNoDefault dflt $> TString (isRequired md)
        ST.TRef typeRef ->
          findDecl currentNamespace symbolTables typeRef >>= \case
            MatchE (ns, enum) -> do
              checkNoRequired md
              validDefault <- validateDefaultAsEnum dflt enum
              pure $ TEnum (TypeRef ns (getIdent enum)) (enumType enum) validDefault
            MatchS (ns, struct) -> checkNoDefault dflt $> TStruct (TypeRef ns (getIdent struct))  (isRequired md)
            MatchT (ns, table)  -> checkNoDefault dflt $> TTable  (TypeRef ns (getIdent table)) (isRequired md)
            MatchU (ns, union)  -> checkNoDefault dflt $> TUnion  (TypeRef ns (getIdent union)) (isRequired md)
        ST.TVector vecType ->
          checkNoDefault dflt >> TVector (isRequired md) <$>
            case vecType of
              ST.TInt8 -> pure VInt8
              ST.TInt16 -> pure VInt16
              ST.TInt32 -> pure VInt32
              ST.TInt64 -> pure VInt64
              ST.TWord8 -> pure VWord8
              ST.TWord16 -> pure VWord16
              ST.TWord32 -> pure VWord32
              ST.TWord64 -> pure VWord64
              ST.TFloat -> pure VFloat
              ST.TDouble -> pure VDouble
              ST.TBool -> pure VBool
              ST.TString -> pure VString
              ST.TVector _ -> throwErrorMsg "nested vector types not supported"
              ST.TRef typeRef ->
                findDecl currentNamespace symbolTables typeRef <&> \case
                  MatchE (ns, enum) ->
                    VEnum (TypeRef ns (getIdent enum))
                          (enumType enum)
                  MatchS (ns, struct) ->
                    VStruct (TypeRef ns (getIdent struct))
                  MatchT (ns, table) -> VTable (TypeRef ns (getIdent table))
                  MatchU (ns, union) -> VUnion (TypeRef ns (getIdent union))

checkNoRequired :: ValidationCtx m => ST.Metadata -> m ()
checkNoRequired md =
  when (hasAttribute requiredAttr md) $
    throwErrorMsg "only non-scalar fields (strings, vectors, unions, structs, tables) may be 'required'"

checkNoDefault :: ValidationCtx m => Maybe ST.DefaultVal -> m ()
checkNoDefault dflt =
  when (isJust dflt) $
    throwErrorMsg
      "default values currently only supported for scalar fields (integers, floating point, bool, enums)"

isRequired :: ST.Metadata -> Required
isRequired md = if hasAttribute requiredAttr md then Req else Opt

validateDefaultValAsInt :: forall m a. (ValidationCtx m, Integral a, Bounded a, Show a) => Maybe ST.DefaultVal -> m (DefaultVal a)
validateDefaultValAsInt dflt =
  case dflt of
    Nothing -> pure (DefaultVal 0)
    Just (ST.DefaultNum n) ->
      if not (Scientific.isInteger n)
        then throwErrorMsg "default value must be integral"
        else case Scientific.toBoundedInteger @a n of
          Nothing ->
            throwErrorMsg $
              "default value does not fit ["
              <> T.pack (show (minBound @a))
              <> "; "
              <> T.pack (show (maxBound @a))
              <> "]"
          Just i -> pure (DefaultVal i)
    Just _ -> throwErrorMsg "default value must be integral"

validateDefaultValAsScientific :: ValidationCtx m => Maybe ST.DefaultVal -> m (DefaultVal Scientific)
validateDefaultValAsScientific dflt =
  case dflt of
    Nothing                 -> pure (DefaultVal 0)
    Just (ST.DefaultNum n)  -> pure (DefaultVal n)
    Just _                  -> throwErrorMsg "default value must be a number"

validateDefaultValAsBool :: ValidationCtx m => Maybe ST.DefaultVal -> m (DefaultVal Bool)
validateDefaultValAsBool dflt =
  case dflt of
    Nothing                 -> pure (DefaultVal False)
    Just (ST.DefaultBool b) -> pure (DefaultVal b)
    Just _                  -> throwErrorMsg "default value must be a boolean"

validateDefaultAsEnum :: ValidationCtx m => Maybe ST.DefaultVal -> EnumDecl -> m (DefaultVal Integer)
validateDefaultAsEnum dflt enum =
  DefaultVal <$>
    case dflt of
      Nothing ->
        case find (\val -> enumValInt val == 0) (enumVals enum) of
          Just zeroVal -> pure (enumValInt zeroVal)
          Nothing -> throwErrorMsg "enum does not have a 0 value; please manually specify a default for this field"
      Just (ST.DefaultNum n) ->
        case Scientific.floatingOrInteger @Float n of
          Left _float -> throwErrorMsg $ "default value must be integral or one of: " <> display (getIdent <$> enumVals enum)
          Right i ->
            case find (\val -> enumValInt val == i) (enumVals enum) of
              Just matchingVal -> pure (enumValInt matchingVal)
              Nothing -> throwErrorMsg $ "default value of " <> display i <> " is not part of enum " <> display (getIdent enum)
      Just (ST.DefaultRef ref) ->
        case find (\val -> getIdent val == ref) (enumVals enum) of
          Just matchingVal ->  pure (enumValInt matchingVal)
          Nothing          -> throwErrorMsg $ "default value of " <> display ref <> " is not part of enum " <> display (getIdent enum)

      Just (ST.DefaultBool _) -> throwErrorMsg $ "default value must be integral or one of: " <> display (getIdent <$> enumVals enum)


----------------------------------
------------ Unions --------------
----------------------------------
validateUnions :: ValidationCtx m => FileTree Stage4 -> m (FileTree ValidDecls)
validateUnions symbolTables =
  for symbolTables $ \symbolTable -> do
    let unions = allUnions symbolTable
    let validate (namespace, union) = do
          validUnion <- validateUnion symbolTables (namespace, union)
          pure (namespace, validUnion)
    validUnions <- traverse validate unions
    pure symbolTable { allUnions = validUnions }

validateUnion :: forall m. ValidationCtx m => FileTree Stage4 -> (Namespace, ST.UnionDecl) -> m UnionDecl
validateUnion symbolTables (currentNamespace, union) =
  modifyContext (\_ -> qualify currentNamespace union) $ do
    validUnionVals <- traverse validateUnionVal (ST.unionVals union)
    checkDuplicateVals validUnionVals
    checkUndeclaredAttributes union
    pure $ UnionDecl
      { unionIdent = getIdent union
      , unionVals = validUnionVals
      }
  where
    validateUnionVal :: ST.UnionVal -> m UnionVal
    validateUnionVal uv = do
      let tref = ST.unionValTypeRef uv
      let partiallyQualifiedTypeRef = qualify (typeRefNamespace tref) (typeRefIdent tref)
      let ident = fromMaybe partiallyQualifiedTypeRef (ST.unionValIdent uv)
      let identFormatted = coerce $ T.replace "." "_" $ coerce ident
      modifyContext (\context -> context <> "." <> identFormatted) $ do
        tableRef <- validateUnionValType tref
        pure $ UnionVal
          { unionValIdent = identFormatted
          , unionValTableRef = tableRef
          }

    validateUnionValType :: TypeRef -> m TypeRef
    validateUnionValType typeRef =
      findDecl currentNamespace symbolTables typeRef >>= \case
        MatchT (ns, table)        -> pure $ TypeRef ns (getIdent table)
        _                         -> throwErrorMsg "union members may only be tables"

    checkDuplicateVals :: NonEmpty UnionVal -> m ()
    checkDuplicateVals vals = checkDuplicateIdentifiers (NE.cons "NONE" (fmap getIdent vals))


----------------------------------
------------ Structs -------------
----------------------------------
validateStructs :: ValidationCtx m => FileTree Stage2 -> m (FileTree Stage3)
validateStructs symbolTables =
  flip evalStateT [] $ traverse validateFile symbolTables
  where
  validateFile :: (MonadState [(Namespace, StructDecl)] m, ValidationCtx m) => Stage2 -> m Stage3
  validateFile symbolTable = do
    let structs = allStructs symbolTable

    traverse_ (checkStructCycles symbolTables) structs
    validStructs <- traverse (validateStruct symbolTables) structs

    pure symbolTable { allStructs = validStructs }

checkStructCycles :: forall m. ValidationCtx m => FileTree Stage2 -> (Namespace, ST.StructDecl) -> m ()
checkStructCycles symbolTables = go []
  where
    go :: [Ident] -> (Namespace, ST.StructDecl) -> m ()
    go visited (currentNamespace, struct) =
      let qualifiedName = qualify currentNamespace struct
      in  modifyContext (const qualifiedName) $
            if qualifiedName `elem` visited
              then
                throwErrorMsg $
                  "cyclic dependency detected ["
                  <> display (T.intercalate " -> " . coerce $ List.dropWhile (/= qualifiedName) $ List.reverse (qualifiedName : visited))
                  <>"] - structs cannot contain themselves, directly or indirectly"
              else
                forM_ (ST.structFields struct) $ \field ->
                  modifyContext (\context -> context <> "." <> getIdent field) $
                    case ST.structFieldType field of
                      ST.TRef typeRef ->
                        findDecl currentNamespace symbolTables typeRef >>= \case
                          MatchS struct -> go (qualifiedName : visited) struct
                          _             -> pure () -- The TypeRef points to an enum (or is invalid), so no further validation is needed at this point
                      _ -> pure () -- Field is not a TypeRef, no validation needed

data UnpaddedStructField = UnpaddedStructField
  { unpaddedStructFieldIdent :: !Ident
  , unpaddedStructFieldType  :: !StructFieldType
  } deriving (Show, Eq)

validateStruct ::
     forall m. (MonadState [(Namespace, StructDecl)] m, ValidationCtx m)
  => FileTree Stage2
  -> (Namespace, ST.StructDecl)
  -> m (Namespace, StructDecl)
validateStruct symbolTables (currentNamespace, struct) =
  modifyContext (\_ -> qualify currentNamespace struct) $ do
    validStructs <- get
    -- Check if this struct has already been validated in a previous iteration
    case find (\(ns, s) -> ns == currentNamespace && getIdent s == getIdent struct) validStructs of
      Just match -> pure match
      Nothing -> do
        checkDuplicateFields
        checkUndeclaredAttributes struct

        fields <- traverse validateStructField (ST.structFields struct)
        let naturalAlignment = maximum (structFieldAlignment <$> fields)
        forceAlignAttrVal <- getForceAlignAttr
        forceAlign <- traverse (validateForceAlign naturalAlignment) forceAlignAttrVal
        let alignment = fromMaybe naturalAlignment forceAlign

        -- In order to calculate the padding between fields, we must first know the fields' and the struct's
        -- alignment. Which means we must first validate all the struct's fields, and then do a second
        -- pass to calculate the padding.
        let (size, paddedFields) = addFieldPadding alignment fields

        let validStruct = StructDecl
              { structIdent      = getIdent struct
              , structAlignment  = alignment
              , structSize       = size
              , structFields     = paddedFields
              }
        modify ((currentNamespace, validStruct) :)
        pure (currentNamespace, validStruct)

  where
    invalidStructFieldType = "struct fields may only be integers, floating point, bool, enums, or other structs"

    -- | Calculates how much padding each field needs, and returns the struct's total size
    -- and a list of fields with padding information.
    addFieldPadding :: Alignment -> NonEmpty UnpaddedStructField -> (InlineSize, NonEmpty StructField)
    addFieldPadding structAlignment unpaddedFields =
      (size, NE.fromList (reverse paddedFields))
      where

        (size, paddedFields) = go 0 [] (NE.toList unpaddedFields)

        go :: InlineSize -> [StructField] -> [UnpaddedStructField] -> (InlineSize, [StructField])
        go size paddedFields [] = (size, paddedFields)
        go size paddedFields (x : y : tail) =
          let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
              nextFieldsAlignment = fromIntegral @Alignment @InlineSize (structFieldAlignment y)
              paddingNeeded = (size' `roundUpToNearestMultipleOf` nextFieldsAlignment) - size'
              size'' = size' + paddingNeeded
              paddedField = StructField
                { structFieldIdent = unpaddedStructFieldIdent x
                -- NOTE: it is safe to narrow `paddingNeeded` to a word8 here because it's always smaller than `nextFieldsAlignment`
                , structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
                , structFieldOffset = coerce size
                , structFieldType = unpaddedStructFieldType x
                }
          in  go size'' (paddedField : paddedFields) (y : tail)
        go size paddedFields [x] =
          let size' = size + structFieldTypeSize (unpaddedStructFieldType x)
              structAlignment' = fromIntegral @Alignment @InlineSize structAlignment
              paddingNeeded = (size' `roundUpToNearestMultipleOf` structAlignment') - size'
              size'' = size' + paddingNeeded
              paddedField = StructField
                { structFieldIdent = unpaddedStructFieldIdent x
                -- NOTE: it is safe to narrow `paddingNeeded` to a word8 here because it's always smaller than `nextFieldsAlignment`
                , structFieldPadding = fromIntegral @InlineSize @Word8 paddingNeeded
                , structFieldOffset = coerce size
                , structFieldType = unpaddedStructFieldType x
                }
          in  (size'', paddedField : paddedFields)

    validateStructField :: ST.StructField -> m UnpaddedStructField
    validateStructField sf =
      modifyContext (\context -> context <> "." <> getIdent sf) $ do
        checkUnsupportedAttributes sf
        checkUndeclaredAttributes sf
        structFieldType <- validateStructFieldType (ST.structFieldType sf)
        pure $ UnpaddedStructField
          { unpaddedStructFieldIdent = getIdent sf
          , unpaddedStructFieldType = structFieldType
          }

    validateStructFieldType :: ST.Type -> m StructFieldType
    validateStructFieldType structFieldType =
      case structFieldType of
        ST.TInt8 -> pure SInt8
        ST.TInt16 -> pure SInt16
        ST.TInt32 -> pure SInt32
        ST.TInt64 -> pure SInt64
        ST.TWord8 -> pure SWord8
        ST.TWord16 -> pure SWord16
        ST.TWord32 -> pure SWord32
        ST.TWord64 -> pure SWord64
        ST.TFloat -> pure SFloat
        ST.TDouble -> pure SDouble
        ST.TBool -> pure SBool
        ST.TString -> throwErrorMsg invalidStructFieldType
        ST.TVector _ -> throwErrorMsg invalidStructFieldType
        ST.TRef typeRef ->
          findDecl currentNamespace symbolTables typeRef >>= \case
            MatchE (enumNamespace, enum) ->
              pure (SEnum (TypeRef enumNamespace (getIdent enum)) (enumType enum))
            MatchS (nestedNamespace, nestedStruct) ->
              -- if this is a reference to a struct, we need to validate it first
              SStruct <$> validateStruct symbolTables (nestedNamespace, nestedStruct)
            _ -> throwErrorMsg invalidStructFieldType

    checkUnsupportedAttributes :: ST.StructField -> m ()
    checkUnsupportedAttributes structField = do
      when (hasAttribute deprecatedAttr (ST.structFieldMetadata structField)) $
        throwErrorMsg "can't deprecate fields in a struct"
      when (hasAttribute requiredAttr (ST.structFieldMetadata structField)) $
        throwErrorMsg "struct fields are already required, the 'required' attribute is redundant"
      when (hasAttribute idAttr (ST.structFieldMetadata structField)) $
        throwErrorMsg "struct fields cannot be reordered using the 'id' attribute"

    getForceAlignAttr :: m (Maybe Integer)
    getForceAlignAttr = findIntAttr forceAlignAttr (ST.structMetadata struct)

    validateForceAlign :: Alignment -> Integer -> m Alignment
    validateForceAlign naturalAlignment forceAlign =
      if isPowerOfTwo forceAlign
        && inRange (fromIntegral @Alignment @Integer naturalAlignment, 16) forceAlign
        then pure (fromIntegral @Integer @Alignment forceAlign)
        else throwErrorMsg $
              "force_align must be a power of two integer ranging from the struct's natural alignment (in this case, "
              <> T.pack (show naturalAlignment)
              <> ") to 16"

    checkDuplicateFields :: m ()
    checkDuplicateFields =
      checkDuplicateIdentifiers
        (ST.structFields struct)

----------------------------------
------------ Helpers -------------
----------------------------------
structFieldAlignment :: UnpaddedStructField -> Alignment
structFieldAlignment usf =
  case unpaddedStructFieldType usf of
    SInt8 -> int8Size
    SInt16 -> int16Size
    SInt32 -> int32Size
    SInt64 -> int64Size
    SWord8 -> word8Size
    SWord16 -> word16Size
    SWord32 -> word32Size
    SWord64 -> word64Size
    SFloat -> floatSize
    SDouble -> doubleSize
    SBool -> boolSize
    SEnum _ enumType -> enumAlignment enumType
    SStruct (_, nestedStruct) -> structAlignment nestedStruct

enumAlignment :: EnumType -> Alignment
enumAlignment = Alignment . enumSize

-- | The size of an enum is either 1, 2, 4 or 8 bytes, so its size fits in a Word8
enumSize :: EnumType -> Word8
enumSize e =
  case e of
    EInt8 -> int8Size
    EInt16 -> int16Size
    EInt32 -> int32Size
    EInt64 -> int64Size
    EWord8 -> word8Size
    EWord16 -> word16Size
    EWord32 -> word32Size
    EWord64 -> word64Size

structFieldTypeSize :: StructFieldType -> InlineSize
structFieldTypeSize sft =
  case sft of
    SInt8 -> int8Size
    SInt16 -> int16Size
    SInt32 -> int32Size
    SInt64 -> int64Size
    SWord8 -> word8Size
    SWord16 -> word16Size
    SWord32 -> word32Size
    SWord64 -> word64Size
    SFloat -> floatSize
    SDouble -> doubleSize
    SBool -> boolSize
    SEnum _ enumType -> fromIntegral @Word8 @InlineSize (enumSize enumType)
    SStruct (_, nestedStruct) -> structSize nestedStruct

checkDuplicateIdentifiers :: (ValidationCtx m, Foldable f, Functor f, HasIdent a) => f a -> m ()
checkDuplicateIdentifiers xs =
  case findDups (getIdent <$> xs) of
    [] -> pure ()
    dups ->
      throwErrorMsg $
        display dups <> " declared more than once"
  where
    findDups :: (Foldable f, Functor f, Ord a) => f a -> [a]
    findDups xs = Map.keys $ Map.filter (>1) $ occurrences xs

    occurrences :: (Foldable f, Functor f, Ord a) => f a -> Map a (Sum Int)
    occurrences xs =
      Map.unionsWith (<>) $ Foldable.toList $ fmap (\x -> Map.singleton x (Sum 1)) xs

checkUndeclaredAttributes :: (ValidationCtx m, HasMetadata a) => a -> m ()
checkUndeclaredAttributes a = do
  allAttributes <- asks validationStateAllAttributes
  forM_ (Map.keys . ST.unMetadata . getMetadata $ a) $ \attr ->
    when (coerce attr `Set.notMember` allAttributes) $
      throwErrorMsg $ "user defined attributes must be declared before use: " <> attr

hasAttribute :: Text -> ST.Metadata -> Bool
hasAttribute name (ST.Metadata attrs) = Map.member name attrs

findIntAttr :: ValidationCtx m => Text -> ST.Metadata -> m (Maybe Integer)
findIntAttr name (ST.Metadata attrs) =
  case Map.lookup name attrs of
    Nothing                  -> pure Nothing
    Just Nothing             -> err
    Just (Just (ST.AttrI i)) -> pure (Just i)
    Just (Just (ST.AttrS t)) ->
      case readMaybe @Integer (T.unpack t) of
        Just i  -> pure (Just i)
        Nothing -> err
  where
    err =
      throwErrorMsg $
        "expected attribute '"
        <> name
        <> "' to have an integer value, e.g. '"
        <> name
        <> ": 123'"

findStringAttr :: ValidationCtx m => Text -> ST.Metadata -> m (Maybe Text)
findStringAttr name (ST.Metadata attrs) =
  case Map.lookup name attrs of
    Nothing                  -> pure Nothing
    Just (Just (ST.AttrS s)) -> pure (Just s)
    Just _ ->
      throwErrorMsg $
        "expected attribute '"
        <> name
        <> "' to have a string value, e.g. '"
        <> name
        <> ": \"abc\"'"

throwErrorMsg :: ValidationCtx m => Text -> m a
throwErrorMsg msg = do
  context <- asks validationStateCurrentContext
  if context == ""
    then throwError msg
    else throwError $ "[" <> display context <> "]: " <> msg