{-# LANGUAGE UndecidableInstances #-} module Engine.SpirV.Reflect where import RIO import Data.IntMap qualified as IntMap import Data.List qualified as List import Data.SpirV.Reflect.BlockVariable qualified as BlockVariable import Data.SpirV.Reflect.DescriptorBinding qualified as DescriptorBinding import Data.SpirV.Reflect.DescriptorSet qualified as DescriptorSet import Data.SpirV.Reflect.Enums qualified as Enums import Data.SpirV.Reflect.FFI qualified as Reflect import Data.SpirV.Reflect.InterfaceVariable (InterfaceVariable) import Data.SpirV.Reflect.InterfaceVariable qualified as InterfaceVariable import Data.SpirV.Reflect.Module (Module) import Data.SpirV.Reflect.Module qualified as Module import Data.SpirV.Reflect.Traits qualified as Traits import Data.SpirV.Reflect.TypeDescription qualified as TypeDescription import Data.Tree (Tree(..)) import Engine.Vulkan.Pipeline.Stages (StageInfo(..), withLabels) import RIO.Text qualified as Text import RIO.ByteString (readFile) import Vulkan.Core10.Enums.Format qualified as VkFormat invoke :: MonadIO m => FilePath -> m Module invoke file = readFile file >>= Reflect.loadBytes data Reflect stages = Reflect { bindMap :: BindMap BlockBinding , interfaces :: StageInterface stages , inputStage :: Text , inputs :: InterfaceBinds } -- | @layout(set=X, binding=Y) ...@ type BindMap a = IntMap (IntMap a) type StageInterface stages = stages (Maybe (InterfaceBinds, InterfaceBinds)) -- | @layout(location=N) type InterfaceBinds = IntMap InterfaceBinding deriving instance (Eq (StageInterface stages)) => Eq (Reflect stages) deriving instance (Show (StageInterface stages)) => Show (Reflect stages) -- * Block variables -- | @uniform Foo { ... } foo;@ type BlockBinding = ( Text , Enums.DescriptorType , Maybe (Tree ([Maybe Text], BlockSignature)) ) data BlockSignature = BlockSignature { offset :: Word32 , size :: Word32 , flags :: Enums.TypeFlags , scalar :: Maybe Traits.Scalar } deriving (Eq, Ord, Show) stagesBindMap :: ( MonadIO m , MonadReader env m , HasLogFunc env , StageInfo stages ) => stages (Maybe Module) -> m (BindMap BlockBinding) stagesBindMap = fmap snd . foldM collect ([] :: [Text], mempty) . annotate where annotate modules = (,) <$> stageNames <*> modules collect acc@(visited, old) (source, stageModule) = case stageModule of Nothing -> pure acc Just new -> case unionDS old (moduleBindMap new) of Left (six, bix, inAcc, inNew) -> do logError $ mconcat [ "incompatible data at " , "layout(" , "set=", display six , ", " , "binding=", display bix , ")" ] logError $ "old: " <> displayDS inAcc logError $ " from " <> displayShow visited logError $ "new: " <> displayDS inNew logError $ " from " <> displayShow source throwString "catch this" Right matching -> pure (visited <> [source], matching) unionDS = bindMapUnionWith \(_, adt, asig) (_, bdt, bsig) -> adt == bdt && fmap (fmap snd) asig == fmap (fmap snd) bsig displayDS (name, dt, sigs) = mconcat [ display name , " :: " , maybe (displayShow dt) display $ Enums.descriptorTypeName @Text dt , maybe mempty ( \sigs' -> mappend " -- " $ displayShow . toList $ sigs' <&> \(path, BlockSignature{..}) -> ( Text.intercalate "|" (catMaybes path) , (size, offset, Enums.typeFlagsNames @Text flags) , scalar ) ) sigs ] moduleBindMap :: Module -> BindMap BlockBinding moduleBindMap refl = IntMap.fromList do ds <- toList $ Module.descriptor_sets refl pure ( fromIntegral $ DescriptorSet.set ds , IntMap.fromList do db <- toList $ DescriptorSet.bindings ds let DescriptorBinding.DescriptorBinding {binding, name, descriptor_type, block} = db pure ( fromIntegral binding , ( name , descriptor_type , fmap (blockTree []) block ) ) ) blockTree :: [Maybe Text] -> BlockVariable.BlockVariable -> Tree ([Maybe Text], BlockSignature) blockTree ancestors bv = Node (path, here) $ map (blockTree path) there where here = BlockSignature { offset = BlockVariable.offset bv , size = BlockVariable.size bv , .. } where (flags, scalar) = case BlockVariable.type_description bv of Nothing -> (Enums.TYPE_FLAG_UNDEFINED, Nothing) Just td -> ( TypeDescription.type_flags td , do TypeDescription.Traits{numeric} <- TypeDescription.traits td let st@Traits.Scalar{width} = Traits.scalar numeric guard $ width > 0 pure st ) path = ancestors ++ [BlockVariable.name bv] there = toList $ BlockVariable.members bv {-# INLINE bindMapUnionWith #-} bindMapUnionWith :: (a -> a -> Bool) -> BindMap a -> BindMap a -> Either (Int, Int, a, a) (BindMap a) bindMapUnionWith compatible as bs = traverse sequence validated where validated = IntMap.unionWithKey (IntMap.unionWithKey . check) (wrap as) (wrap bs) wrap = fmap (fmap pure) check six bix a' b' = do a <- a' b <- b' if compatible a b then Right a else Left (six, bix, a, b) -- * Interface variables type InterfaceBinding = ( Maybe Text , [Text] , InterfaceSignature ) data InterfaceSignature = InterfaceSignature { format :: VkFormat.Format , flags :: Enums.TypeFlags , matrix :: Maybe Traits.Matrix } deriving (Eq, Ord, Show) stagesInterfaceMap :: ( Traversable stages ) => stages (Maybe Module) -> StageInterface stages stagesInterfaceMap = fmap (fmap moduleInterfaceBinds) moduleInterfaceBinds :: Module -> (InterfaceBinds, InterfaceBinds) moduleInterfaceBinds refl = ( interfaceBinds Enums.StorageClassInput (Module.input_variables refl) , interfaceBinds Enums.StorageClassOutput (Module.output_variables refl) ) interfaceBinds :: Enums.StorageClass -> Vector InterfaceVariable -> InterfaceBinds interfaceBinds cls vars = IntMap.fromList do var@InterfaceVariable.InterfaceVariable{location} <- toList vars guard $ InterfaceVariable.storage_class var == cls -- XXX: Remove vars like @gl_FragCoord@/@SV_Position@ from potential signatures. guard $ InterfaceVariable.built_in var == maxBound let td = InterfaceVariable.type_description var Enums.Format format = InterfaceVariable.format var flags = maybe Enums.TYPE_FLAG_UNDEFINED TypeDescription.type_flags td stuff = do TypeDescription.TypeDescription{traits} <- td TypeDescription.Traits{numeric} <- traits let mt@Traits.Matrix{column_count, row_count} = Traits.matrix numeric guard $ column_count > 0 && row_count > 0 pure mt signature = InterfaceSignature { format = VkFormat.Format $ fromIntegral format , flags = flags , matrix = stuff } pure ( fromIntegral location , ( InterfaceVariable.name var , Enums.typeFlagsNames @Text flags , signature ) ) type IncompatibleInterfaces label = (label, label, Int, Maybe (InterfaceSignature, InterfaceSignature)) type CompatibleInterfaces label = (label, label, IntMap ([Text], Matching (Maybe Text))) type Matching a = Either (a, a) a interfaceCompatible :: ( StageInfo stages , IsString label ) => StageInterface stages -> Either (IncompatibleInterfaces label) [CompatibleInterfaces label] interfaceCompatible staged = for chained \((inputLabel, input), (outputLabel, output)) -> do checked <- for (IntMap.assocs input) \(location, requested) -> case IntMap.lookup location output of Just provided -> do let (rName, rFlags, rSignature) = requested (pName, _pFlags, pSignature) = provided if rSignature == pSignature then let names = if rName == pName then Right rName else Left (rName, pName) in Right (location, (rFlags, names)) else Left ( inputLabel , outputLabel , location , Just (rSignature, pSignature) ) Nothing -> Left ( inputLabel , outputLabel , location , Nothing ) Right ( outputLabel , inputLabel , IntMap.fromList checked ) where chained = zip (drop 1 ins) outs (ins, outs) = List.unzip do (label, Just binds) <- toList $ withLabels staged pure ( (label, fst binds) , (label, snd binds) ) inputStageInterface :: (StageInfo stages, IsString label) => StageInterface stages -> Maybe (label, InterfaceBinds) inputStageInterface staged = listToMaybe active where active = do (label, Just binds) <- toList $ withLabels staged pure (label, fst binds)