{-# OPTIONS_GHC -Wwarn=orphans #-} {-# LANGUAGE OverloadedLists #-} module Engine.Vulkan.Shader ( Shader(..) , create , destroy , withSpecialization , Specialization(..) , SpecializationConst(..) ) where import RIO import Data.Vector qualified as Vector import Data.Vector.Storable qualified as Storable import Foreign qualified import GHC.Stack (withFrozenCallStack) import Vulkan.Core10 qualified as Vk import Vulkan.CStruct.Extends (SomeStruct(..)) import Vulkan.Zero (Zero(..)) import Unsafe.Coerce (unsafeCoerce) import Engine.Vulkan.Pipeline.Stages (StageInfo(..)) import Engine.Vulkan.Types (MonadVulkan, getDevice) import Resource.Vulkan.Named qualified as Named data Shader = Shader { sModules :: Vector Vk.ShaderModule , sPipelineStages :: Vector (SomeStruct Vk.PipelineShaderStageCreateInfo) } create :: ( MonadVulkan env io , StageInfo t , HasCallStack ) => t (Maybe ByteString) -> Maybe Vk.SpecializationInfo -> io Shader create stages spec = withFrozenCallStack do device <- asks getDevice staged <- Vector.forM collected \(stage, code) -> do module_ <- Vk.createShaderModule device zero { Vk.code = code } Nothing Named.objectOrigin module_ pure ( module_ , SomeStruct zero { Vk.stage = stage , Vk.module' = module_ , Vk.name = "main" , Vk.specializationInfo = spec } ) let (modules, pStages) = Vector.unzip staged pure Shader { sModules = modules , sPipelineStages = pStages } where collected = Vector.fromList do (stage, Just code) <- toList $ (,) <$> stageFlagBits <*> stages pure (stage, code) destroy :: MonadVulkan env io => Shader -> io () destroy Shader{sModules} = do device <- asks getDevice Vector.forM_ sModules \module_ -> Vk.destroyShaderModule device module_ Nothing -- * Specialization constants withSpecialization :: ( Specialization spec , MonadUnliftIO m ) => spec -> (Maybe Vk.SpecializationInfo -> m a) -> m a withSpecialization spec action = case mapEntries of [] -> action Nothing _some -> withRunInIO \run -> Storable.unsafeWith specData \specPtr -> run . action $ Just Vk.SpecializationInfo { mapEntries = mapEntries , dataSize = fromIntegral $ Storable.length specData * 4 , data' = Foreign.castPtr @_ @() specPtr } where specData = Storable.fromList $ specializationData spec mapEntries = Vector.imap (\ix _data -> Vk.SpecializationMapEntry { constantID = fromIntegral ix , offset = fromIntegral $ ix * 4 , size = 4 } ) (Vector.convert specData) class Specialization a where -- XXX: abusing the fact that most scalars (sans double) are 4-wide specializationData :: a -> [Word32] instance Specialization () where specializationData = const [] instance Specialization [Word32] where specializationData = id {- | The constant_id can only be applied to a scalar *int*, a scalar *float* or a scalar *bool*. (https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_vulkan_glsl.txt) XXX: Apparently it is possible to pass uints and doubles too. -} instance Specialization Word32 where specializationData x = [packConstData x] instance Specialization Int32 where specializationData x = [packConstData x] instance Specialization Float where specializationData x = [packConstData x] instance Specialization Bool where specializationData x = [packConstData x] class SpecializationConst a where packConstData :: a -> Word32 instance SpecializationConst Word32 where packConstData = id instance SpecializationConst Int32 where packConstData = unsafeCoerce instance SpecializationConst Float where packConstData = unsafeCoerce instance SpecializationConst Bool where packConstData = bool 0 1 instance ( SpecializationConst a , SpecializationConst b ) => Specialization (a, b) where specializationData (a, b) = [ packConstData a , packConstData b ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c ) => Specialization (a, b, c) where specializationData (a, b, c) = [ packConstData a , packConstData b , packConstData c ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c , SpecializationConst d ) => Specialization (a, b, c, d) where specializationData (a, b, c, d) = [ packConstData a , packConstData b , packConstData c , packConstData d ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c , SpecializationConst d , SpecializationConst e ) => Specialization (a, b, c, d, e) where specializationData (a, b, c, d, e) = [ packConstData a , packConstData b , packConstData c , packConstData d , packConstData e ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c , SpecializationConst d , SpecializationConst e , SpecializationConst f ) => Specialization (a, b, c, d, e, f) where specializationData (a, b, c, d, e, f) = [ packConstData a , packConstData b , packConstData c , packConstData d , packConstData e , packConstData f ]