-- XXX: TypeError in Compatible generates unused constraint argument {-# OPTIONS_GHC -Wno-redundant-constraints #-} module Engine.Vulkan.Pipeline ( Pipeline(..) , Config(..) , allocate , bind , pushPlaceholder , vertexInput , attrBindings , formatSize ) where import RIO import Data.Bits ((.|.)) import Data.Kind (Type) import Data.List qualified as List import Data.Tagged (Tagged(..)) import Data.Vector qualified as Vector import GHC.Stack (callStack, getCallStack, srcLocModule, withFrozenCallStack) import UnliftIO.Resource (MonadResource, ReleaseKey) import UnliftIO.Resource qualified as Resource import Vulkan.Core10 qualified as Vk import Vulkan.Core12.Promoted_From_VK_EXT_descriptor_indexing qualified as Vk12 import Vulkan.CStruct.Extends (SomeStruct(..), pattern (:&), pattern (::&)) import Vulkan.NamedType ((:::)) import Vulkan.Utils.Debug qualified as Debug import Vulkan.Zero (Zero(..)) import Engine.Vulkan.DescSets (Bound(..), Compatible) import Engine.Vulkan.Types (HasVulkan(..), HasRenderPass(..), MonadVulkan, DsBindings, DsLayouts, getPipelineCache) data Pipeline (dsl :: [Type]) vertices instances = Pipeline { pipeline :: Vk.Pipeline , pLayout :: Tagged dsl Vk.PipelineLayout , pDescLayouts :: Tagged dsl DsLayouts } -- * Pipeline data Config (dsl :: [Type]) vertices instances = Config { cVertexCode :: Maybe ByteString , cFragmentCode :: Maybe ByteString , cVertexInput :: SomeStruct Vk.PipelineVertexInputStateCreateInfo , cDescLayouts :: Tagged dsl [DsBindings] , cPushConstantRanges :: Vector Vk.PushConstantRange , cBlend :: Bool , cDepthWrite :: Bool , cDepthTest :: Bool , cTopology :: Vk.PrimitiveTopology , cCull :: Vk.CullModeFlagBits , cDepthBias :: Maybe ("constant" ::: Float, "slope" ::: Float) } instance Zero (Config dsl vertices instances) where zero = Config { cVertexCode = Nothing , cFragmentCode = Nothing , cVertexInput = zero , cDescLayouts = Tagged [] -- FIXME: unsafe wrt. "dsl" , cPushConstantRanges = mempty , cBlend = False , cDepthWrite = True , cDepthTest = True , cTopology = Vk.PRIMITIVE_TOPOLOGY_TRIANGLE_LIST , cCull = Vk.CULL_MODE_BACK_BIT , cDepthBias = Nothing } -- XXX: consider using instance attrs or uniforms pushPlaceholder :: Vk.PushConstantRange pushPlaceholder = Vk.PushConstantRange { Vk.stageFlags = Vk.SHADER_STAGE_VERTEX_BIT .|. Vk.SHADER_STAGE_FRAGMENT_BIT , Vk.offset = 0 , Vk.size = 4 * dwords } where -- XXX: each 4 word32s eat up one register (on AMD) dwords = 4 allocate :: ( MonadVulkan env m , MonadResource m , HasRenderPass renderpass , HasCallStack ) => Maybe Vk.Extent2D -> Vk.SampleCountFlagBits -> Config dsl vertices instances -> renderpass -> m (ReleaseKey, Pipeline dsl vertices instances) allocate extent msaa config renderpass = withFrozenCallStack do ctx <- ask Resource.allocate (create ctx extent msaa renderpass config) (destroy ctx) create :: (MonadIO io, HasVulkan ctx, HasRenderPass renderpass, HasCallStack) => ctx -> Maybe Vk.Extent2D -> Vk.SampleCountFlagBits -> renderpass -> Config dsl vertices instances -> io (Pipeline dsl vertices instances) create context mextent msaa renderpass Config{..} = do let originModule = fromString . List.intercalate "|" $ map (srcLocModule . snd) (getCallStack callStack) dsLayouts <- Vector.forM (Vector.fromList $ unTagged cDescLayouts) \bindsFlags -> do let (binds, flags) = List.unzip bindsFlags setCI = zero { Vk.bindings = Vector.fromList binds } ::& zero { Vk12.bindingFlags = Vector.fromList flags } :& () Vk.createDescriptorSetLayout device setCI Nothing -- TODO: get from outside layout <- Vk.createPipelineLayout device (layoutCI dsLayouts) Nothing Debug.nameObject device layout originModule let codeStages = Vector.fromList $ case (cVertexCode, cFragmentCode) of (Just vertCode, Just fragCode) -> [ (Vk.SHADER_STAGE_VERTEX_BIT, vertCode) , (Vk.SHADER_STAGE_FRAGMENT_BIT, fragCode) ] (Just vertCode, Nothing) -> [ (Vk.SHADER_STAGE_VERTEX_BIT, vertCode) ] (Nothing, Just fragCode) -> -- XXX: good luck [ (Vk.SHADER_STAGE_FRAGMENT_BIT, fragCode) ] (Nothing, Nothing) -> [] shader <- createShader context codeStages let cis = Vector.singleton . SomeStruct $ pipelineCI (sPipelineStages shader) layout Vk.createGraphicsPipelines device cache cis Nothing >>= \case (Vk.SUCCESS, pipelines) -> case Vector.toList pipelines of [one] -> do destroyShader context shader Debug.nameObject device one originModule pure Pipeline { pipeline = one , pLayout = Tagged layout , pDescLayouts = Tagged dsLayouts } _ -> error "assert: exactly one pipeline requested" (err, _) -> error $ "createGraphicsPipelines: " <> show err where device = getDevice context cache = getPipelineCache context layoutCI dsLayouts = Vk.PipelineLayoutCreateInfo { flags = zero , setLayouts = dsLayouts , pushConstantRanges = cPushConstantRanges } pipelineCI stages layout = zero { Vk.stages = stages , Vk.vertexInputState = Just cVertexInput , Vk.inputAssemblyState = Just inputAsembly , Vk.viewportState = Just $ SomeStruct viewportState , Vk.rasterizationState = SomeStruct rasterizationState , Vk.multisampleState = Just $ SomeStruct multisampleState , Vk.depthStencilState = Just depthStencilState , Vk.colorBlendState = Just $ SomeStruct colorBlendState , Vk.dynamicState = dynamicState , Vk.layout = layout , Vk.renderPass = getRenderPass renderpass , Vk.subpass = 0 , Vk.basePipelineHandle = zero } where inputAsembly = zero { Vk.topology = cTopology , Vk.primitiveRestartEnable = restartable } restartable = elem cTopology [ Vk.PRIMITIVE_TOPOLOGY_LINE_STRIP , Vk.PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP , Vk.PRIMITIVE_TOPOLOGY_TRIANGLE_FAN ] (viewportState, dynamicState) = case mextent of Nothing -> ( zero { Vk.viewportCount = 1 , Vk.scissorCount = 1 } , Just zero { Vk.dynamicStates = Vector.fromList [ Vk.DYNAMIC_STATE_VIEWPORT , Vk.DYNAMIC_STATE_SCISSOR ] } ) Just extent@Vk.Extent2D{width, height} -> ( zero { Vk.viewports = Vector.fromList [ Vk.Viewport { Vk.x = 0 , Vk.y = 0 , Vk.width = realToFrac width , Vk.height = realToFrac height , Vk.minDepth = 0 , Vk.maxDepth = 1 } ] , Vk.scissors = Vector.singleton Vk.Rect2D { Vk.offset = Vk.Offset2D 0 0 , extent = extent } } , Nothing ) rasterizationState = case cDepthBias of Nothing -> rasterizationBase Just (constantFactor, slopeFactor) -> rasterizationBase { Vk.depthBiasEnable = True , Vk.depthBiasConstantFactor = constantFactor , Vk.depthBiasSlopeFactor = slopeFactor } rasterizationBase = zero { Vk.depthClampEnable = False , Vk.rasterizerDiscardEnable = False , Vk.lineWidth = 1 , Vk.polygonMode = Vk.POLYGON_MODE_FILL , Vk.cullMode = cCull , Vk.frontFace = Vk.FRONT_FACE_CLOCKWISE , Vk.depthBiasEnable = False } multisampleState = zero { Vk.rasterizationSamples = msaa , Vk.sampleShadingEnable = enable , Vk.minSampleShading = if enable then 0.2 else 1.0 , Vk.sampleMask = Vector.singleton maxBound } where enable = True -- TODO: check and enable sample rate shading feature depthStencilState = zero { Vk.depthTestEnable = cDepthTest , Vk.depthWriteEnable = cDepthWrite , Vk.depthCompareOp = Vk.COMPARE_OP_LESS , Vk.depthBoundsTestEnable = False , Vk.minDepthBounds = 0.0 -- Optional , Vk.maxDepthBounds = 1.0 -- Optional , Vk.stencilTestEnable = False , Vk.front = zero -- Optional , Vk.back = zero -- Optional } colorBlendState = zero { Vk.logicOpEnable = False , Vk.attachments = Vector.singleton zero { Vk.blendEnable = cBlend , Vk.srcColorBlendFactor = Vk.BLEND_FACTOR_ONE , Vk.dstColorBlendFactor = Vk.BLEND_FACTOR_ONE_MINUS_SRC_ALPHA , Vk.colorBlendOp = Vk.BLEND_OP_ADD , Vk.srcAlphaBlendFactor = Vk.BLEND_FACTOR_ONE , Vk.dstAlphaBlendFactor = Vk.BLEND_FACTOR_ONE_MINUS_SRC_ALPHA , Vk.alphaBlendOp = Vk.BLEND_OP_ADD , Vk.colorWriteMask = colorRgba } } colorRgba = Vk.COLOR_COMPONENT_R_BIT .|. Vk.COLOR_COMPONENT_G_BIT .|. Vk.COLOR_COMPONENT_B_BIT .|. Vk.COLOR_COMPONENT_A_BIT destroy :: (MonadIO io, HasVulkan ctx) => ctx -> Pipeline dsl vertices instances -> io () destroy context Pipeline{..} = do Vector.forM_ (unTagged pDescLayouts) \dsLayout -> Vk.destroyDescriptorSetLayout device dsLayout Nothing Vk.destroyPipeline device pipeline Nothing Vk.destroyPipelineLayout device (unTagged pLayout) Nothing where device = getDevice context bind :: ( Compatible pipeLayout boundLayout , MonadIO m ) => Vk.CommandBuffer -> Pipeline pipeLayout vertices instances -> Bound boundLayout vertices instances m () -> Bound boundLayout oldVertices oldInstances m () bind cb Pipeline{pipeline} (Bound attrAction) = do Bound $ Vk.cmdBindPipeline cb Vk.PIPELINE_BIND_POINT_GRAPHICS pipeline Bound attrAction vertexInput :: [(Vk.VertexInputRate, [Vk.Format])] -> SomeStruct Vk.PipelineVertexInputStateCreateInfo vertexInput bindings = SomeStruct zero { Vk.vertexBindingDescriptions = binds , Vk.vertexAttributeDescriptions = attrs } where binds = Vector.fromList do (ix, (rate, formats)) <- zip [0..] bindings pure Vk.VertexInputBindingDescription { binding = ix , stride = sum $ map formatSize formats , inputRate = rate } attrs = attrBindings $ map snd bindings -- * Shader code data Shader = Shader { sModules :: Vector Vk.ShaderModule , sPipelineStages :: Vector (SomeStruct Vk.PipelineShaderStageCreateInfo) } createShader :: (MonadIO io, HasVulkan ctx) => ctx -> Vector (Vk.ShaderStageFlagBits, ByteString) -> io Shader createShader context stages = do staged <- Vector.forM stages \(stage, code) -> do module_ <- Vk.createShaderModule (getDevice context) zero { Vk.code = code } Nothing pure ( module_ , SomeStruct zero { Vk.stage = stage , Vk.module' = module_ , Vk.name = "main" -- , Vk.specializationInfo = Nothing } ) let (modules, pStages) = Vector.unzip staged pure Shader { sModules = modules , sPipelineStages = pStages } destroyShader :: (MonadIO io, HasVulkan ctx) => ctx -> Shader -> io () destroyShader context Shader{sModules} = Vector.forM_ sModules \module_ -> Vk.destroyShaderModule (getDevice context) module_ Nothing -- * Utils attrBindings :: [[Vk.Format]] -> Vector Vk.VertexInputAttributeDescription attrBindings bindings = mconcat $ List.unfoldr shiftLocations (0, 0, bindings) where shiftLocations = \case (_binding, _lastLoc, []) -> Nothing (binding, lastLoc, formats : rest) -> Just (bound, next) where bound = Vector.fromList do (ix, format) <- zip [0..] formats let offset = sum . map formatSize $ take ix formats pure zero { Vk.binding = binding , Vk.location = fromIntegral $ lastLoc + ix , Vk.format = format , Vk.offset = offset } next = ( binding + 1 , lastLoc + Vector.length bound , rest ) formatSize :: Integral a => Vk.Format -> a formatSize = \case Vk.FORMAT_R32G32B32A32_SFLOAT -> 16 Vk.FORMAT_R32G32B32_SFLOAT -> 12 Vk.FORMAT_R32G32_SFLOAT -> 8 Vk.FORMAT_R32_SFLOAT -> 4 Vk.FORMAT_R32G32B32A32_UINT -> 16 Vk.FORMAT_R32G32B32_UINT -> 12 Vk.FORMAT_R32G32_UINT -> 8 Vk.FORMAT_R32_UINT -> 4 Vk.FORMAT_R32G32B32A32_SINT -> 16 Vk.FORMAT_R32G32B32_SINT -> 12 Vk.FORMAT_R32G32_SINT -> 8 Vk.FORMAT_R32_SINT -> 4 format -> error $ "Format size unknown: " <> show format