{-# LANGUAGE MagicHash #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} module Data.SpirV.Reflect.FFI.Internal where import Data.Coerce (Coercible, coerce) import Data.List (sortOn) import Data.Text (Text) import Data.Text qualified as Text import Data.Vector (Vector) import Data.Vector qualified as Vector import Data.Word (Word32) import Foreign.C.String (CString) import Foreign.C.Types (CULong) import Foreign.Marshal.Utils (maybePeek) import Foreign.Ptr (Ptr, castPtr, plusPtr, nullPtr) import Foreign.Storable (peek) import GHC.Ptr qualified as GHC import Data.SpirV.Reflect.BlockVariable (BlockVariable) import Data.SpirV.Reflect.BlockVariable qualified as BlockVariable import Data.SpirV.Reflect.DescriptorBinding (DescriptorBinding) import Data.SpirV.Reflect.DescriptorBinding qualified as DescriptorBinding import Data.SpirV.Reflect.DescriptorSet (DescriptorSet) import Data.SpirV.Reflect.DescriptorSet qualified as DescriptorSet import Data.SpirV.Reflect.Enums qualified as Enums import Data.SpirV.Reflect.InterfaceVariable (InterfaceVariable) import Data.SpirV.Reflect.InterfaceVariable qualified as InterfaceVariable import Data.SpirV.Reflect.Module (Module) import Data.SpirV.Reflect.Module qualified as Module import Data.SpirV.Reflect.Traits qualified as Traits import Data.SpirV.Reflect.TypeDescription (TypeDescription) import Data.SpirV.Reflect.TypeDescription qualified as TypeDescription #include "spirv_reflect.h" -- * Loader interface {#pointer *SpvReflectShaderModule as ShaderModulePtr #} {#enum SpvReflectResult as Result {underscoreToCase} deriving (Eq, Ord, Show) #} {#fun unsafe spvReflectCreateShaderModule as createShaderModule { id `CULong' , `Ptr ()' , `ShaderModulePtr' } -> `Result' #} {#enum SpvReflectModuleFlagBits as ModuleFlags {underscoreToCase} deriving (Eq, Ord, Show) #} {#fun unsafe spvReflectCreateShaderModule2 as createShaderModule2 { `ModuleFlags' , id `CULong' , `Ptr ()' , `ShaderModulePtr' } -> `Result' #} {#fun unsafe spvReflectDestroyShaderModule as destroyShaderModule { `ShaderModulePtr' } -> `()' #} -- TODO: inflateEntryPoints :: ShaderModulePtr -> IO [EntryPoint] -- * Module shaderModuleSize :: Int shaderModuleSize = {#sizeof SpvReflectShaderModule #} inflateModule :: ShaderModulePtr -> IO Module inflateModule smp = do let sm = castPtr smp generator <- inflateEnum $ {#get SpvReflectShaderModule->generator #} sm entry_point_name <- inflateText $ {#get SpvReflectShaderModule->entry_point_name #} sm entry_point_id <- inflateIntegral $ {#get SpvReflectShaderModule->entry_point_id #} sm -- TODO: enums source_language <- inflateIntegral $ {#get SpvReflectShaderModule->source_language #} sm source_language_version <- inflateIntegral $ {#get SpvReflectShaderModule->source_language_version #} sm -- XXX: Uses value(s) from first entry point -- TODO: flags spirv_execution_model <- inflateIntegral $ {#get SpvReflectShaderModule->spirv_execution_model #} sm -- TODO: flags shader_stage <- inflateIntegral $ {#get SpvReflectShaderModule->shader_stage #} sm descriptor_bindings <- inflateVector ({#get SpvReflectShaderModule->descriptor_binding_count #} sm) ({#get SpvReflectShaderModule->descriptor_bindings #} sm) {#sizeof SpvReflectDescriptorBinding #} inflateDescriptorBinding descriptor_sets <- inflateVector ({#get SpvReflectShaderModule->descriptor_set_count #} sm) ({#get SpvReflectShaderModule->descriptor_sets #} sm) {#sizeof SpvReflectDescriptorSet #} inflateDescriptorSet interface_variables <- inflateVector ({#get SpvReflectShaderModule->interface_variable_count #} sm) ({#get SpvReflectShaderModule->interface_variables #} sm) {#sizeof SpvReflectInterfaceVariable #} inflateInterfaceVariable let ivLocation InterfaceVariable.InterfaceVariable{location} = location pickIvs query = Vector.fromList . sortOn ivLocation . filter query $ Vector.toList interface_variables input_variables = pickIvs $ (== Enums.StorageClassInput) . InterfaceVariable.storage_class output_variables = pickIvs $ (== Enums.StorageClassOutput) . InterfaceVariable.storage_class push_constants <- inflateVector ({#get SpvReflectShaderModule->push_constant_block_count #} sm) ({#get SpvReflectShaderModule->push_constant_blocks #} sm) {#sizeof SpvReflectBlockVariable #} inflateBlockVariable pure Module.Module{..} inflateVector :: Integral i => IO i -> IO (Ptr p) -> Int -> (Ptr p -> IO a) -> IO (Vector a) inflateVector getCount getItems itemSize inflate = do count <- getCount itemsPtr <- getItems Vector.generateM (fromIntegral count) \pos -> do inflate $ itemsPtr `plusPtr` (itemSize * pos) {#pointer *SpvReflectDescriptorBinding as DescriptorBindingPtr #} inflateDescriptorBinding :: DescriptorBindingPtr -> IO DescriptorBinding inflateDescriptorBinding db = do spirv_id <- fmap Just . inflateIntegral $ {#get SpvReflectDescriptorBinding->spirv_id #} db name <- inflateText $ {#get SpvReflectDescriptorBinding->name #} db binding <- inflateIntegral $ {#get SpvReflectDescriptorBinding->binding #} db input_attachment_index <- inflateIntegral $ {#get SpvReflectDescriptorBinding->input_attachment_index #} db set <- inflateIntegral $ {#get SpvReflectDescriptorBinding->set #} db descriptor_type <- inflateEnum $ {#get SpvReflectDescriptorBinding->descriptor_type #} db resource_type <- inflateFlags32 $ {#get SpvReflectDescriptorBinding->resource_type #} db image <- inflateImageTraits db {#offsetof SpvReflectDescriptorBinding->image #} block <- {#get SpvReflectDescriptorBinding->block #} db >>= maybePeek inflateBlockVariable array <- inflateArrayTraits db {#offsetof SpvReflectDescriptorBinding->array #} count <- fmap Just . inflateIntegral $ {#get SpvReflectDescriptorBinding->count #} db accessed <- inflateIntegral $ {#get SpvReflectDescriptorBinding->accessed #} db uav_counter_id <- inflateIntegral $ {#get SpvReflectDescriptorBinding->uav_counter_id #} db uav_counter_binding <- {#get SpvReflectDescriptorBinding->uav_counter_binding #} db >>= maybePeek inflateDescriptorBinding type_description <- {#get SpvReflectDescriptorBinding->type_description #} db >>= maybePeek inflateTypeDescription let word_offset = DescriptorBinding.WordOffset{..} decoration_flags <- inflateFlags32 $ {#get SpvReflectDescriptorBinding->decoration_flags #} db pure DescriptorBinding.DescriptorBinding{..} {#pointer *SpvReflectBlockVariable as BlockVariablePtr #} inflateBlockVariable :: BlockVariablePtr -> IO BlockVariable inflateBlockVariable bv = do spirv_id <- fmap Just . inflateIntegral $ {#get SpvReflectBlockVariable->spirv_id #} bv name <- fmap Just . inflateText $ {#get SpvReflectBlockVariable->name #} bv offset <- inflateIntegral $ {#get SpvReflectBlockVariable->offset #} bv absolute_offset <- inflateIntegral $ {#get SpvReflectBlockVariable->absolute_offset #} bv size <- inflateIntegral $ {#get SpvReflectBlockVariable->size #} bv padded_size <- inflateIntegral $ {#get SpvReflectBlockVariable->padded_size #} bv decorations <- inflateFlags32 $ {#get SpvReflectBlockVariable->decoration_flags #} bv numeric <- inflateNumericTraits bv {#offsetof SpvReflectBlockVariable->numeric #} array <- inflateArrayTraits bv {#offsetof SpvReflectBlockVariable->array #} members <- inflateVector ({#get SpvReflectBlockVariable->member_count #} bv) ({#get SpvReflectBlockVariable->members #} bv) {#sizeof SpvReflectBlockVariable #} inflateBlockVariable type_description <- {#get SpvReflectBlockVariable->members #} bv >>= maybePeek inflateTypeDescription pure BlockVariable.BlockVariable{..} {#pointer *SpvReflectTypeDescription as TypeDescriptionPtr #} inflateTypeDescription :: TypeDescriptionPtr -> IO TypeDescription inflateTypeDescription td = do id_ <- fmap Just . inflateIntegral $ {#get SpvReflectTypeDescription->id #} td op <- fmap Just . inflateEnum $ {#get SpvReflectTypeDescription->op #} td type_name <- fmap Just . inflateText $ {#get SpvReflectTypeDescription->type_name #} td struct_member_name <- fmap Just . inflateText $ {#get SpvReflectTypeDescription->struct_member_name #} td storage_class <- inflateEnum $ {#get SpvReflectTypeDescription->storage_class #} td type_flags <- inflateFlags32 $ {#get SpvReflectTypeDescription->type_flags #} td numeric <- inflateNumericTraits td {#offsetof SpvReflectTypeDescription->traits.numeric #} image <- inflateImageTraits td {#offsetof SpvReflectTypeDescription->traits.image #} array <- inflateArrayTraits td {#offsetof SpvReflectTypeDescription->traits.array #} let traits = Just TypeDescription.Traits{..} members <- inflateVector ({#get SpvReflectTypeDescription->member_count #} td) ({#get SpvReflectTypeDescription->members #} td) {#sizeof SpvReflectTypeDescription #} inflateTypeDescription pure TypeDescription.TypeDescription{id=id_, ..} {#pointer *SpvReflectDescriptorSet as DescriptorSetPtr #} inflateDescriptorSet :: DescriptorSetPtr -> IO DescriptorSet inflateDescriptorSet ds = do set <- inflateIntegral $ {#get SpvReflectDescriptorSet->set #} ds bindingsPtr <- {#get SpvReflectDescriptorSet->bindings #} ds bindings <- inflateVector ({#get SpvReflectDescriptorSet->binding_count #} ds) (peek bindingsPtr) {#sizeof SpvReflectDescriptorBinding #} inflateDescriptorBinding pure DescriptorSet.DescriptorSet{..} {#pointer *SpvReflectInterfaceVariable as InterfaceVariablePtr #} inflateInterfaceVariable :: InterfaceVariablePtr -> IO InterfaceVariable inflateInterfaceVariable iv = do spirv_id <- fmap Just . inflateIntegral $ {#get SpvReflectInterfaceVariable->spirv_id #} iv name <- fmap Just . inflateText $ {#get SpvReflectInterfaceVariable->name #} iv location <- inflateIntegral $ {#get SpvReflectInterfaceVariable->location #} iv storage_class <- inflateEnum $ {#get SpvReflectInterfaceVariable->storage_class #} iv semantic <- fmap Just . inflateText $ {#get SpvReflectInterfaceVariable->semantic #} iv decoration_flags <- inflateFlags32 $ {#get SpvReflectInterfaceVariable->decoration_flags #} iv built_in <- inflateEnum $ {#get SpvReflectInterfaceVariable->built_in #} iv numeric <- inflateNumericTraits iv {#offsetof SpvReflectInterfaceVariable->numeric #} array <- inflateArrayTraits iv {#offsetof SpvReflectInterfaceVariable->array #} members <- inflateVector ({#get SpvReflectInterfaceVariable->member_count #} iv) ({#get SpvReflectInterfaceVariable->members #} iv) {#sizeof SpvReflectInterfaceVariable #} inflateInterfaceVariable format <- inflateEnum $ {#get SpvReflectInterfaceVariable->format #} iv type_description <- {#get SpvReflectInterfaceVariable->type_description #} iv >>= maybePeek inflateTypeDescription wo_location <- inflateIntegral $ {#get SpvReflectInterfaceVariable->word_offset.location #} iv let word_offset = InterfaceVariable.WordOffset{location=wo_location} pure InterfaceVariable.InterfaceVariable{..} -- * Traits {#pointer *SpvReflectImageTraits as ImageTraitsPtr #} inflateImageTraits :: Ptr struct -> Int -> IO Traits.Image inflateImageTraits src offset = do let it = castPtr src `plusPtr` offset dim <- inflateEnum $ {#get SpvReflectImageTraits->dim #} it depth <- inflateIntegral $ {#get SpvReflectImageTraits->depth #} it arrayed <- inflateIntegral $ {#get SpvReflectImageTraits->arrayed #} it ms <- inflateIntegral $ {#get SpvReflectImageTraits->ms #} it sampled <- inflateIntegral $ {#get SpvReflectImageTraits->sampled #} it image_format <- inflateEnum $ {#get SpvReflectImageTraits->image_format #} it pure Traits.Image{..} {#pointer *SpvReflectNumericTraits as NumericTraitsPtr #} inflateNumericTraits :: Ptr struct -> Int -> IO Traits.Numeric inflateNumericTraits src offset = do let nt = castPtr src `plusPtr` offset width <- inflateIntegral $ {#get SpvReflectNumericTraits.scalar.width #} nt signedness <- inflateIntegral $ {#get SpvReflectNumericTraits.scalar.signedness #} nt let scalar = Traits.Scalar{..} component_count <- inflateIntegral $ {#get SpvReflectNumericTraits->vector.component_count #} nt let vector = Traits.Vector{..} column_count <- inflateIntegral $ {#get SpvReflectNumericTraits->matrix.column_count #} nt row_count <- inflateIntegral $ {#get SpvReflectNumericTraits->matrix.row_count #} nt stride <- inflateIntegral $ {#get SpvReflectNumericTraits->matrix.stride #} nt let matrix = Traits.Matrix{..} pure Traits.Numeric{..} {#pointer *SpvReflectArrayTraits as ArrayTraitsPtr #} inflateArrayTraits :: Ptr struct -> Int -> IO Traits.Array inflateArrayTraits src offset = do let at = castPtr src `plusPtr` offset dims_count <- inflateIntegral $ {#get SpvReflectArrayTraits->dims_count #} at dims <- fmap Vector.convert $ inflateVector ({#get SpvReflectArrayTraits->dims_count #} at) ({#get SpvReflectArrayTraits->dims #} at) {#sizeof uint32_t #} (fmap fromIntegral . peek) stride <- fmap Just . inflateIntegral $ {#get SpvReflectArrayTraits->stride #} at pure Traits.Array{..} -- * Atomic types inflateIntegral :: (Integral a, Num b) => IO a -> IO b inflateIntegral getIntegral = getIntegral >>= pure . fromIntegral inflateEnum :: (Integral a, Enum b) => IO a -> IO b inflateEnum getEnum = getEnum >>= pure . toEnum . fromIntegral inflateFlags32 :: forall a b . (Integral a, Coercible Word32 b) => IO a -> IO b inflateFlags32 gitBits = gitBits >>= pure . coerce @Word32 @b . fromIntegral inflateText :: IO CString -> IO Text inflateText getPtr = getPtr >>= \ptr -> if nullPtr == ptr then pure mempty else case ptr of GHC.Ptr addr -> pure $! Text.unpackCString# addr