-- XXX: TypeError in Compatible generates unused constraint argument {-# OPTIONS_GHC -Wno-redundant-constraints #-} {-# LANGUAGE OverloadedLists #-} module Engine.Vulkan.Pipeline ( Config(..) , baseConfig , Configure , Pipeline(..) , allocate , create , destroy , bind , pushPlaceholder , vertexInput , attrBindings , formatSize ) where import RIO import Data.Bits ((.|.)) import Data.Kind (Type) import Data.List qualified as List import Data.Tagged (Tagged(..)) import Data.Vector qualified as Vector import GHC.Stack (callStack, getCallStack, srcLocModule, withFrozenCallStack) import UnliftIO.Resource (MonadResource, ReleaseKey) import UnliftIO.Resource qualified as Resource import Vulkan.Core10 qualified as Vk import Vulkan.Core12.Promoted_From_VK_EXT_descriptor_indexing qualified as Vk12 import Vulkan.CStruct.Extends (SomeStruct(..), pattern (:&), pattern (::&)) import Vulkan.NamedType ((:::)) import Vulkan.Utils.Debug qualified as Debug import Vulkan.Zero (Zero(..)) import Engine.Vulkan.DescSets (Bound(..), Compatible) import Engine.Vulkan.Types (HasVulkan(..), HasRenderPass(..), MonadVulkan, DsBindings, DsLayouts, getPipelineCache) import Engine.Vulkan.Shader qualified as Shader data Pipeline (dsl :: [Type]) vertices instances = Pipeline { pipeline :: Vk.Pipeline , pLayout :: Tagged dsl Vk.PipelineLayout , pDescLayouts :: Tagged dsl DsLayouts } -- * Pipeline type family Configure pipeline spec where Configure (Pipeline dsl vertices instances) spec = Config dsl vertices instances spec data Config (dsl :: [Type]) vertices instances spec = Config { cVertexCode :: Maybe ByteString , cFragmentCode :: Maybe ByteString , cVertexInput :: SomeStruct Vk.PipelineVertexInputStateCreateInfo , cDescLayouts :: Tagged dsl [DsBindings] , cPushConstantRanges :: Vector Vk.PushConstantRange , cBlend :: Bool , cDepthWrite :: Bool , cDepthTest :: Bool , cDepthCompare :: Vk.CompareOp , cTopology :: Vk.PrimitiveTopology , cCull :: Vk.CullModeFlagBits , cDepthBias :: Maybe ("constant" ::: Float, "slope" ::: Float) , cSpecialization :: spec } -- | Settings for generic triangle-rendering pipeline. baseConfig :: Config '[] vertices instances () baseConfig = Config { cVertexCode = Nothing , cFragmentCode = Nothing , cVertexInput = zero , cDescLayouts = Tagged [] , cPushConstantRanges = mempty , cBlend = False , cDepthWrite = True , cDepthTest = True , cDepthCompare = Vk.COMPARE_OP_LESS , cTopology = Vk.PRIMITIVE_TOPOLOGY_TRIANGLE_LIST , cCull = Vk.CULL_MODE_BACK_BIT , cDepthBias = Nothing , cSpecialization = () } -- XXX: consider using instance attrs or uniforms pushPlaceholder :: Vk.PushConstantRange pushPlaceholder = Vk.PushConstantRange { Vk.stageFlags = Vk.SHADER_STAGE_VERTEX_BIT .|. Vk.SHADER_STAGE_FRAGMENT_BIT , Vk.offset = 0 , Vk.size = 4 * dwords } where -- XXX: each 4 word32s eat up one register (on AMD) dwords = 4 allocate :: ( MonadVulkan env m , MonadResource m , HasRenderPass renderpass , Shader.Specialization spec , HasCallStack ) => Maybe Vk.Extent2D -> Vk.SampleCountFlagBits -> Config dsl vertices instances spec -> renderpass -> m (ReleaseKey, Pipeline dsl vertices instances) allocate extent msaa config renderpass = withFrozenCallStack do ctx <- ask Resource.allocate (create ctx extent msaa renderpass config) (destroy ctx) create :: ( MonadUnliftIO io , HasVulkan ctx , HasRenderPass renderpass , Shader.Specialization spec , HasCallStack ) => ctx -> Maybe Vk.Extent2D -> Vk.SampleCountFlagBits -> renderpass -> Config dsl vertices instances spec -> io (Pipeline dsl vertices instances) create context mextent msaa renderpass Config{..} = do let originModule = fromString . List.intercalate "|" $ map (srcLocModule . snd) (getCallStack callStack) dsLayouts <- Vector.forM (Vector.fromList $ unTagged cDescLayouts) \bindsFlags -> do let (binds, flags) = List.unzip bindsFlags setCI = zero { Vk.bindings = Vector.fromList binds } ::& zero { Vk12.bindingFlags = Vector.fromList flags } :& () Vk.createDescriptorSetLayout device setCI Nothing -- TODO: get from outside layout <- Vk.createPipelineLayout device (layoutCI dsLayouts) Nothing Debug.nameObject device layout originModule shader <- Shader.withSpecialization cSpecialization $ Shader.create context $ case (cVertexCode, cFragmentCode) of (Just vertCode, Just fragCode) -> [ (Vk.SHADER_STAGE_VERTEX_BIT, vertCode) , (Vk.SHADER_STAGE_FRAGMENT_BIT, fragCode) ] (Just vertCode, Nothing) -> [ (Vk.SHADER_STAGE_VERTEX_BIT, vertCode) ] (Nothing, Just fragCode) -> -- XXX: good luck [ (Vk.SHADER_STAGE_FRAGMENT_BIT, fragCode) ] (Nothing, Nothing) -> [] let cis = Vector.singleton . SomeStruct $ pipelineCI (Shader.sPipelineStages shader) layout Vk.createGraphicsPipelines device cache cis Nothing >>= \case (Vk.SUCCESS, pipelines) -> case Vector.toList pipelines of [one] -> do Shader.destroy context shader Debug.nameObject device one originModule pure Pipeline { pipeline = one , pLayout = Tagged layout , pDescLayouts = Tagged dsLayouts } _ -> error "assert: exactly one pipeline requested" (err, _) -> error $ "createGraphicsPipelines: " <> show err where device = getDevice context cache = getPipelineCache context layoutCI dsLayouts = Vk.PipelineLayoutCreateInfo { flags = zero , setLayouts = dsLayouts , pushConstantRanges = cPushConstantRanges } pipelineCI stages layout = zero { Vk.stages = stages , Vk.vertexInputState = Just cVertexInput , Vk.inputAssemblyState = Just inputAsembly , Vk.viewportState = Just $ SomeStruct viewportState , Vk.rasterizationState = SomeStruct rasterizationState , Vk.multisampleState = Just $ SomeStruct multisampleState , Vk.depthStencilState = Just depthStencilState , Vk.colorBlendState = Just $ SomeStruct colorBlendState , Vk.dynamicState = dynamicState , Vk.layout = layout , Vk.renderPass = getRenderPass renderpass , Vk.subpass = 0 , Vk.basePipelineHandle = zero } where inputAsembly = zero { Vk.topology = cTopology , Vk.primitiveRestartEnable = restartable } restartable = elem @Set cTopology [ Vk.PRIMITIVE_TOPOLOGY_LINE_STRIP , Vk.PRIMITIVE_TOPOLOGY_TRIANGLE_STRIP , Vk.PRIMITIVE_TOPOLOGY_TRIANGLE_FAN ] (viewportState, dynamicState) = case mextent of Nothing -> ( zero { Vk.viewportCount = 1 , Vk.scissorCount = 1 } , Just zero { Vk.dynamicStates = Vector.fromList [ Vk.DYNAMIC_STATE_VIEWPORT , Vk.DYNAMIC_STATE_SCISSOR ] } ) Just extent@Vk.Extent2D{width, height} -> ( zero { Vk.viewports = Vector.fromList [ Vk.Viewport { Vk.x = 0 , Vk.y = 0 , Vk.width = realToFrac width , Vk.height = realToFrac height , Vk.minDepth = 0 , Vk.maxDepth = 1 } ] , Vk.scissors = Vector.singleton Vk.Rect2D { Vk.offset = Vk.Offset2D 0 0 , extent = extent } } , Nothing ) rasterizationState = case cDepthBias of Nothing -> rasterizationBase Just (constantFactor, slopeFactor) -> rasterizationBase { Vk.depthBiasEnable = True , Vk.depthBiasConstantFactor = constantFactor , Vk.depthBiasSlopeFactor = slopeFactor } rasterizationBase = zero { Vk.depthClampEnable = False , Vk.rasterizerDiscardEnable = False , Vk.lineWidth = 1 , Vk.polygonMode = Vk.POLYGON_MODE_FILL , Vk.cullMode = cCull , Vk.frontFace = Vk.FRONT_FACE_CLOCKWISE , Vk.depthBiasEnable = False } multisampleState = zero { Vk.rasterizationSamples = msaa , Vk.sampleShadingEnable = enable , Vk.minSampleShading = if enable then 0.2 else 1.0 , Vk.sampleMask = Vector.singleton maxBound } where enable = True -- TODO: check and enable sample rate shading feature depthStencilState = zero { Vk.depthTestEnable = cDepthTest , Vk.depthWriteEnable = cDepthWrite , Vk.depthCompareOp = cDepthCompare , Vk.depthBoundsTestEnable = False , Vk.minDepthBounds = 0.0 -- Optional , Vk.maxDepthBounds = 1.0 -- Optional , Vk.stencilTestEnable = False , Vk.front = zero -- Optional , Vk.back = zero -- Optional } colorBlendState = zero { Vk.logicOpEnable = False , Vk.attachments = Vector.singleton zero { Vk.blendEnable = cBlend , Vk.srcColorBlendFactor = Vk.BLEND_FACTOR_ONE , Vk.dstColorBlendFactor = Vk.BLEND_FACTOR_ONE_MINUS_SRC_ALPHA , Vk.colorBlendOp = Vk.BLEND_OP_ADD , Vk.srcAlphaBlendFactor = Vk.BLEND_FACTOR_ONE , Vk.dstAlphaBlendFactor = Vk.BLEND_FACTOR_ONE_MINUS_SRC_ALPHA , Vk.alphaBlendOp = Vk.BLEND_OP_ADD , Vk.colorWriteMask = colorRgba } } colorRgba = Vk.COLOR_COMPONENT_R_BIT .|. Vk.COLOR_COMPONENT_G_BIT .|. Vk.COLOR_COMPONENT_B_BIT .|. Vk.COLOR_COMPONENT_A_BIT destroy :: (MonadIO io, HasVulkan ctx) => ctx -> Pipeline dsl vertices instances -> io () destroy context Pipeline{..} = do Vector.forM_ (unTagged pDescLayouts) \dsLayout -> Vk.destroyDescriptorSetLayout device dsLayout Nothing Vk.destroyPipeline device pipeline Nothing Vk.destroyPipelineLayout device (unTagged pLayout) Nothing where device = getDevice context bind :: ( Compatible pipeLayout boundLayout , MonadIO m ) => Vk.CommandBuffer -> Pipeline pipeLayout vertices instances -> Bound boundLayout vertices instances m () -> Bound boundLayout oldVertices oldInstances m () bind cb Pipeline{pipeline} (Bound attrAction) = do Bound $ Vk.cmdBindPipeline cb Vk.PIPELINE_BIND_POINT_GRAPHICS pipeline Bound attrAction vertexInput :: [(Vk.VertexInputRate, [Vk.Format])] -> SomeStruct Vk.PipelineVertexInputStateCreateInfo vertexInput bindings = SomeStruct zero { Vk.vertexBindingDescriptions = binds , Vk.vertexAttributeDescriptions = attrs } where binds = Vector.fromList do (ix, (rate, formats)) <- zip [0..] bindings pure Vk.VertexInputBindingDescription { binding = ix , stride = sum $ map formatSize formats , inputRate = rate } attrs = attrBindings $ map snd bindings -- * Utils attrBindings :: [[Vk.Format]] -> Vector Vk.VertexInputAttributeDescription attrBindings bindings = mconcat $ List.unfoldr shiftLocations (0, 0, bindings) where shiftLocations = \case (_binding, _lastLoc, []) -> Nothing (binding, lastLoc, formats : rest) -> Just (bound, next) where bound = Vector.fromList do (ix, format) <- zip [0..] formats let offset = sum . map formatSize $ take ix formats pure zero { Vk.binding = binding , Vk.location = fromIntegral $ lastLoc + ix , Vk.format = format , Vk.offset = offset } next = ( binding + 1 , lastLoc + Vector.length bound , rest ) formatSize :: Integral a => Vk.Format -> a formatSize = \case Vk.FORMAT_R32G32B32A32_SFLOAT -> 16 Vk.FORMAT_R32G32B32_SFLOAT -> 12 Vk.FORMAT_R32G32_SFLOAT -> 8 Vk.FORMAT_R32_SFLOAT -> 4 Vk.FORMAT_R32G32B32A32_UINT -> 16 Vk.FORMAT_R32G32B32_UINT -> 12 Vk.FORMAT_R32G32_UINT -> 8 Vk.FORMAT_R32_UINT -> 4 Vk.FORMAT_R32G32B32A32_SINT -> 16 Vk.FORMAT_R32G32B32_SINT -> 12 Vk.FORMAT_R32G32_SINT -> 8 Vk.FORMAT_R32_SINT -> 4 format -> error $ "Format size unknown: " <> show format