{-# OPTIONS_GHC -Wwarn=orphans #-} {-# LANGUAGE OverloadedLists #-} module Engine.Vulkan.Shader ( Shader(..) , create , destroy , withSpecialization , Specialization(..) , SpecializationConst(..) ) where import RIO import Data.Vector qualified as Vector import Data.Vector.Storable qualified as Storable import Foreign qualified import Vulkan.Core10 qualified as Vk import Vulkan.CStruct.Extends (SomeStruct(..)) import Vulkan.Zero (Zero(..)) import Unsafe.Coerce (unsafeCoerce) import Engine.Vulkan.Pipeline.Stages (StageInfo(..)) import Engine.Vulkan.Types (HasVulkan(..)) data Shader = Shader { sModules :: Vector Vk.ShaderModule , sPipelineStages :: Vector (SomeStruct Vk.PipelineShaderStageCreateInfo) } create :: ( MonadIO io , HasVulkan ctx , StageInfo t ) => ctx -> t (Maybe ByteString) -> Maybe Vk.SpecializationInfo -> io Shader create context stages spec = do staged <- Vector.forM collected \(stage, code) -> do module_ <- Vk.createShaderModule (getDevice context) zero { Vk.code = code } Nothing pure ( module_ , SomeStruct zero { Vk.stage = stage , Vk.module' = module_ , Vk.name = "main" , Vk.specializationInfo = spec } ) let (modules, pStages) = Vector.unzip staged pure Shader { sModules = modules , sPipelineStages = pStages } where collected = Vector.fromList do (stage, Just code) <- toList $ (,) <$> stageFlagBits <*> stages pure (stage, code) destroy :: (MonadIO io, HasVulkan ctx) => ctx -> Shader -> io () destroy context Shader{sModules} = Vector.forM_ sModules \module_ -> Vk.destroyShaderModule (getDevice context) module_ Nothing -- * Specialization constants withSpecialization :: ( Specialization spec , MonadUnliftIO m ) => spec -> (Maybe Vk.SpecializationInfo -> m a) -> m a withSpecialization spec action = case mapEntries of [] -> action Nothing _some -> withRunInIO \run -> Storable.unsafeWith specData \specPtr -> run . action $ Just Vk.SpecializationInfo { mapEntries = mapEntries , dataSize = fromIntegral $ Storable.length specData * 4 , data' = Foreign.castPtr @_ @() specPtr } where specData = Storable.fromList $ specializationData spec mapEntries = Vector.imap (\ix _data -> Vk.SpecializationMapEntry { constantID = fromIntegral ix , offset = fromIntegral $ ix * 4 , size = 4 } ) (Vector.convert specData) class Specialization a where -- XXX: abusing the fact that most scalars (sans double) are 4-wide specializationData :: a -> [Word32] instance Specialization () where specializationData = const [] instance Specialization [Word32] where specializationData = id {- | The constant_id can only be applied to a scalar *int*, a scalar *float* or a scalar *bool*. (https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_vulkan_glsl.txt) XXX: Apparently it is possible to pass uints and doubles too. -} instance Specialization Word32 where specializationData x = [packConstData x] instance Specialization Int32 where specializationData x = [packConstData x] instance Specialization Float where specializationData x = [packConstData x] instance Specialization Bool where specializationData x = [packConstData x] class SpecializationConst a where packConstData :: a -> Word32 instance SpecializationConst Word32 where packConstData = id instance SpecializationConst Int32 where packConstData = unsafeCoerce instance SpecializationConst Float where packConstData = unsafeCoerce instance SpecializationConst Bool where packConstData = bool 0 1 instance ( SpecializationConst a , SpecializationConst b ) => Specialization (a, b) where specializationData (a, b) = [ packConstData a , packConstData b ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c ) => Specialization (a, b, c) where specializationData (a, b, c) = [ packConstData a , packConstData b , packConstData c ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c , SpecializationConst d ) => Specialization (a, b, c, d) where specializationData (a, b, c, d) = [ packConstData a , packConstData b , packConstData c , packConstData d ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c , SpecializationConst d , SpecializationConst e ) => Specialization (a, b, c, d, e) where specializationData (a, b, c, d, e) = [ packConstData a , packConstData b , packConstData c , packConstData d , packConstData e ] instance ( SpecializationConst a , SpecializationConst b , SpecializationConst c , SpecializationConst d , SpecializationConst e , SpecializationConst f ) => Specialization (a, b, c, d, e, f) where specializationData (a, b, c, d, e, f) = [ packConstData a , packConstData b , packConstData c , packConstData d , packConstData e , packConstData f ]