{-# LANGUAGE OverloadedLists #-}

-- XXX: TypeError in Compatible generates unused constraint argument
{-# OPTIONS_GHC -Wno-redundant-constraints #-}

module Engine.Vulkan.Pipeline.Compute
  ( Config(..)
  , Configure

  , Stages(..)
  , stageNames
  , stageFlagBits
  , StageCode
  , StageSpirv
  , StageReflect

  , Pipeline(..)
  , allocate
  , create

  , bind
  , Compute
  ) where

import RIO

import Data.Kind (Type)
import Data.Tagged (Tagged(..))
import Data.Vector qualified as Vector
import GHC.Generics (Generic1)
import GHC.Stack (withFrozenCallStack)
import UnliftIO.Resource (MonadResource, ReleaseKey)
import Vulkan.Core10 qualified as Vk
import Vulkan.CStruct.Extends (SomeStruct(..))
import Vulkan.Zero (Zero(..))

import Engine.SpirV.Reflect (Reflect)
import Engine.Vulkan.DescSets (Bound(..), Compatible)
import Engine.Vulkan.Pipeline (Pipeline(..))
import Engine.Vulkan.Pipeline qualified as Pipeline
import Engine.Vulkan.Pipeline.Stages (StageInfo(..))
import Engine.Vulkan.Shader qualified as Shader
import Engine.Vulkan.Types (HasVulkan(..), MonadVulkan, DsLayoutBindings, getPipelineCache)
import Render.Code (Code)
import Resource.Collection (Generically1(..))
import Resource.Vulkan.DescriptorLayout qualified as Layout
import Resource.Vulkan.Named qualified as Named

data Config (dsl :: [Type]) spec = Config
  { forall (dsl :: [*]) spec. Config dsl spec -> ByteString
cComputeCode        :: ByteString
  , forall (dsl :: [*]) spec.
Config dsl spec -> Tagged dsl [DsLayoutBindings]
cDescLayouts        :: Tagged dsl [DsLayoutBindings]
  , forall (dsl :: [*]) spec.
Config dsl spec -> Vector PushConstantRange
cPushConstantRanges :: Vector Vk.PushConstantRange
  , forall (dsl :: [*]) spec. Config dsl spec -> spec
cSpecialization     :: spec
  }

data Compute

type family Configure pipeline spec where
  Configure (Pipeline dsl Compute Compute) spec = Config dsl spec

newtype Stages a = Stages
  { forall a. Stages a -> a
comp :: a -- ^ compute
  }
  deriving (Stages a -> Stages a -> Bool
forall a. Eq a => Stages a -> Stages a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Stages a -> Stages a -> Bool
$c/= :: forall a. Eq a => Stages a -> Stages a -> Bool
== :: Stages a -> Stages a -> Bool
$c== :: forall a. Eq a => Stages a -> Stages a -> Bool
Eq, Stages a -> Stages a -> Bool
Stages a -> Stages a -> Ordering
Stages a -> Stages a -> Stages a
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
forall {a}. Ord a => Eq (Stages a)
forall a. Ord a => Stages a -> Stages a -> Bool
forall a. Ord a => Stages a -> Stages a -> Ordering
forall a. Ord a => Stages a -> Stages a -> Stages a
min :: Stages a -> Stages a -> Stages a
$cmin :: forall a. Ord a => Stages a -> Stages a -> Stages a
max :: Stages a -> Stages a -> Stages a
$cmax :: forall a. Ord a => Stages a -> Stages a -> Stages a
>= :: Stages a -> Stages a -> Bool
$c>= :: forall a. Ord a => Stages a -> Stages a -> Bool
> :: Stages a -> Stages a -> Bool
$c> :: forall a. Ord a => Stages a -> Stages a -> Bool
<= :: Stages a -> Stages a -> Bool
$c<= :: forall a. Ord a => Stages a -> Stages a -> Bool
< :: Stages a -> Stages a -> Bool
$c< :: forall a. Ord a => Stages a -> Stages a -> Bool
compare :: Stages a -> Stages a -> Ordering
$ccompare :: forall a. Ord a => Stages a -> Stages a -> Ordering
Ord, Int -> Stages a -> ShowS
forall a. Show a => Int -> Stages a -> ShowS
forall a. Show a => [Stages a] -> ShowS
forall a. Show a => Stages a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Stages a] -> ShowS
$cshowList :: forall a. Show a => [Stages a] -> ShowS
show :: Stages a -> String
$cshow :: forall a. Show a => Stages a -> String
showsPrec :: Int -> Stages a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> Stages a -> ShowS
Show, forall a b. a -> Stages b -> Stages a
forall a b. (a -> b) -> Stages a -> Stages b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> Stages b -> Stages a
$c<$ :: forall a b. a -> Stages b -> Stages a
fmap :: forall a b. (a -> b) -> Stages a -> Stages b
$cfmap :: forall a b. (a -> b) -> Stages a -> Stages b
Functor, forall a. Eq a => a -> Stages a -> Bool
forall a. Num a => Stages a -> a
forall a. Ord a => Stages a -> a
forall m. Monoid m => Stages m -> m
forall a. Stages a -> Bool
forall a. Stages a -> Int
forall a. Stages a -> [a]
forall a. (a -> a -> a) -> Stages a -> a
forall m a. Monoid m => (a -> m) -> Stages a -> m
forall b a. (b -> a -> b) -> b -> Stages a -> b
forall a b. (a -> b -> b) -> b -> Stages a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => Stages a -> a
$cproduct :: forall a. Num a => Stages a -> a
sum :: forall a. Num a => Stages a -> a
$csum :: forall a. Num a => Stages a -> a
minimum :: forall a. Ord a => Stages a -> a
$cminimum :: forall a. Ord a => Stages a -> a
maximum :: forall a. Ord a => Stages a -> a
$cmaximum :: forall a. Ord a => Stages a -> a
elem :: forall a. Eq a => a -> Stages a -> Bool
$celem :: forall a. Eq a => a -> Stages a -> Bool
length :: forall a. Stages a -> Int
$clength :: forall a. Stages a -> Int
null :: forall a. Stages a -> Bool
$cnull :: forall a. Stages a -> Bool
toList :: forall a. Stages a -> [a]
$ctoList :: forall a. Stages a -> [a]
foldl1 :: forall a. (a -> a -> a) -> Stages a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> Stages a -> a
foldr1 :: forall a. (a -> a -> a) -> Stages a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> Stages a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> Stages a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> Stages a -> b
foldl :: forall b a. (b -> a -> b) -> b -> Stages a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> Stages a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> Stages a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> Stages a -> b
foldr :: forall a b. (a -> b -> b) -> b -> Stages a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> Stages a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> Stages a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> Stages a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> Stages a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> Stages a -> m
fold :: forall m. Monoid m => Stages m -> m
$cfold :: forall m. Monoid m => Stages m -> m
Foldable, Functor Stages
Foldable Stages
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => Stages (m a) -> m (Stages a)
forall (f :: * -> *) a.
Applicative f =>
Stages (f a) -> f (Stages a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stages a -> m (Stages b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Stages a -> f (Stages b)
sequence :: forall (m :: * -> *) a. Monad m => Stages (m a) -> m (Stages a)
$csequence :: forall (m :: * -> *) a. Monad m => Stages (m a) -> m (Stages a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stages a -> m (Stages b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Stages a -> m (Stages b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
Stages (f a) -> f (Stages a)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
Stages (f a) -> f (Stages a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Stages a -> f (Stages b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Stages a -> f (Stages b)
Traversable, forall a. Rep1 Stages a -> Stages a
forall a. Stages a -> Rep1 Stages a
forall k (f :: k -> *).
(forall (a :: k). f a -> Rep1 f a)
-> (forall (a :: k). Rep1 f a -> f a) -> Generic1 f
$cto1 :: forall a. Rep1 Stages a -> Stages a
$cfrom1 :: forall a. Stages a -> Rep1 Stages a
Generic1)
  deriving Functor Stages
forall a. a -> Stages a
forall a b. Stages a -> Stages b -> Stages a
forall a b. Stages a -> Stages b -> Stages b
forall a b. Stages (a -> b) -> Stages a -> Stages b
forall a b c. (a -> b -> c) -> Stages a -> Stages b -> Stages c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. Stages a -> Stages b -> Stages a
$c<* :: forall a b. Stages a -> Stages b -> Stages a
*> :: forall a b. Stages a -> Stages b -> Stages b
$c*> :: forall a b. Stages a -> Stages b -> Stages b
liftA2 :: forall a b c. (a -> b -> c) -> Stages a -> Stages b -> Stages c
$cliftA2 :: forall a b c. (a -> b -> c) -> Stages a -> Stages b -> Stages c
<*> :: forall a b. Stages (a -> b) -> Stages a -> Stages b
$c<*> :: forall a b. Stages (a -> b) -> Stages a -> Stages b
pure :: forall a. a -> Stages a
$cpure :: forall a. a -> Stages a
Applicative via (Generically1 Stages)

instance StageInfo Stages where
  stageNames :: forall label. IsString label => Stages label
stageNames = Stages
    { $sel:comp:Stages :: label
comp = label
"comp"
    }

  stageFlagBits :: Stages ShaderStageFlagBits
stageFlagBits = Stages
    { $sel:comp:Stages :: ShaderStageFlagBits
comp = ShaderStageFlagBits
Vk.SHADER_STAGE_COMPUTE_BIT
    }

type StageCode = Stages (Maybe Code)
type StageSpirv = Stages (Maybe ByteString)
type StageReflect = Reflect Stages

allocate
  :: ( MonadVulkan env m
     , MonadResource m
     , HasCallStack
     , Shader.Specialization spec
     )
  => Config dsl spec
  -> m (ReleaseKey, Pipeline dsl Compute Compute)
allocate :: forall env (m :: * -> *) spec (dsl :: [*]).
(MonadVulkan env m, MonadResource m, HasCallStack,
 Specialization spec) =>
Config dsl spec -> m (ReleaseKey, Pipeline dsl Compute Compute)
allocate Config dsl spec
config =
  forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack forall a b. (a -> b) -> a -> b
$
    forall env (m :: * -> *) (dsl :: [*]) vertices instances.
(MonadVulkan env m, MonadResource m) =>
m (Pipeline dsl vertices instances)
-> m (ReleaseKey, Pipeline dsl vertices instances)
Pipeline.allocateWith forall a b. (a -> b) -> a -> b
$ forall env (io :: * -> *) spec (dsl :: [*]).
(MonadVulkan env io, Specialization spec, HasCallStack) =>
Config dsl spec -> io (Pipeline dsl Compute Compute)
create Config dsl spec
config

create
  :: ( MonadVulkan env io
     , Shader.Specialization spec
     , HasCallStack
     )
  => Config dsl spec
  -> io (Pipeline dsl Compute Compute)
create :: forall env (io :: * -> *) spec (dsl :: [*]).
(MonadVulkan env io, Specialization spec, HasCallStack) =>
Config dsl spec -> io (Pipeline dsl Compute Compute)
create Config{spec
ByteString
Vector PushConstantRange
Tagged dsl [DsLayoutBindings]
cSpecialization :: spec
cPushConstantRanges :: Vector PushConstantRange
cDescLayouts :: Tagged dsl [DsLayoutBindings]
cComputeCode :: ByteString
$sel:cSpecialization:Config :: forall (dsl :: [*]) spec. Config dsl spec -> spec
$sel:cPushConstantRanges:Config :: forall (dsl :: [*]) spec.
Config dsl spec -> Vector PushConstantRange
$sel:cDescLayouts:Config :: forall (dsl :: [*]) spec.
Config dsl spec -> Tagged dsl [DsLayoutBindings]
$sel:cComputeCode:Config :: forall (dsl :: [*]) spec. Config dsl spec -> ByteString
..} = forall a. HasCallStack => (HasCallStack => a) -> a
withFrozenCallStack do
  -- TODO: get from outside ?
  Vector DescriptorSetLayout
dsLayouts <- forall env (m :: * -> *).
MonadVulkan env m =>
Vector DsLayoutBindings -> m (Vector DescriptorSetLayout)
Layout.create forall a b. (a -> b) -> a -> b
$ forall a. [a] -> Vector a
Vector.fromList (forall {k} (s :: k) b. Tagged s b -> b
unTagged Tagged dsl [DsLayoutBindings]
cDescLayouts)

  -- TODO: get from outside ??
  PipelineLayout
pipelineLayout <- forall env (m :: * -> *).
MonadVulkan env m =>
Vector DescriptorSetLayout
-> Vector PushConstantRange -> m PipelineLayout
Layout.forPipeline
    Vector DescriptorSetLayout
dsLayouts
    Vector PushConstantRange
cPushConstantRanges
  forall env (m :: * -> *) a.
(MonadVulkan env m, HasObjectType a, HasCallStack) =>
a -> m ()
Named.objectOrigin PipelineLayout
pipelineLayout

  Shader
shader <- forall spec (m :: * -> *) a.
(Specialization spec, MonadUnliftIO m) =>
spec -> (Maybe SpecializationInfo -> m a) -> m a
Shader.withSpecialization spec
cSpecialization forall a b. (a -> b) -> a -> b
$
    forall env (io :: * -> *) (t :: * -> *).
(MonadVulkan env io, StageInfo t, HasCallStack) =>
t (Maybe ByteString) -> Maybe SpecializationInfo -> io Shader
Shader.create Stages
      { $sel:comp:Stages :: Maybe ByteString
comp = forall a. a -> Maybe a
Just ByteString
cComputeCode
      }

  let
    cis :: Vector (SomeStruct ComputePipelineCreateInfo)
cis = forall a. a -> Vector a
Vector.singleton forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (a :: [*] -> *) (es :: [*]).
(Extendss a es, PokeChain es, Show (Chain es)) =>
a es -> SomeStruct a
SomeStruct forall a b. (a -> b) -> a -> b
$
      forall {l}.
(Item l ~ SomeStruct PipelineShaderStageCreateInfo, IsList l) =>
l -> PipelineLayout -> ComputePipelineCreateInfo '[]
pipelineCI (Shader -> Vector (SomeStruct PipelineShaderStageCreateInfo)
Shader.sPipelineStages Shader
shader) PipelineLayout
pipelineLayout

  Device
device <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a. HasVulkan a => a -> Device
getDevice
  forall (io :: * -> *).
MonadIO io =>
Device
-> PipelineCache
-> Vector (SomeStruct ComputePipelineCreateInfo)
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io (Result, "pipelines" ::: Vector Pipeline)
Vk.createComputePipelines Device
device PipelineCache
cache Vector (SomeStruct ComputePipelineCreateInfo)
cis forall a. Maybe a
Nothing forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    (Result
Vk.SUCCESS, "pipelines" ::: Vector Pipeline
pipelines) ->
      case "pipelines" ::: Vector Pipeline
pipelines of
        [Item ("pipelines" ::: Vector Pipeline)
pipeline] -> do
          forall env (io :: * -> *). MonadVulkan env io => Shader -> io ()
Shader.destroy Shader
shader
          forall env (m :: * -> *) a.
(MonadVulkan env m, HasObjectType a, HasCallStack) =>
a -> m ()
Named.objectOrigin Item ("pipelines" ::: Vector Pipeline)
pipeline
          pure Pipeline
            { $sel:pipeline:Pipeline :: Pipeline
pipeline     = Item ("pipelines" ::: Vector Pipeline)
pipeline
            , $sel:pLayout:Pipeline :: Tagged dsl PipelineLayout
pLayout      = forall {k} (s :: k) b. b -> Tagged s b
Tagged PipelineLayout
pipelineLayout
            , $sel:pDescLayouts:Pipeline :: Tagged dsl (Vector DescriptorSetLayout)
pDescLayouts = forall {k} (s :: k) b. b -> Tagged s b
Tagged Vector DescriptorSetLayout
dsLayouts
            }
        "pipelines" ::: Vector Pipeline
_ ->
          forall a. HasCallStack => String -> a
error String
"assert: exactly one pipeline requested"
    (Result
err, "pipelines" ::: Vector Pipeline
_) ->
      forall (m :: * -> *) a. (MonadIO m, HasCallStack) => String -> m a
throwString forall a b. (a -> b) -> a -> b
$ String
"createComputePipelines: " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show Result
err

  where
    cache :: PipelineCache
cache = forall ctx. ctx -> PipelineCache
getPipelineCache forall a. HasCallStack => a
undefined

    pipelineCI :: l -> PipelineLayout -> ComputePipelineCreateInfo '[]
pipelineCI l
stages PipelineLayout
layout = forall a. Zero a => a
zero
      { $sel:layout:ComputePipelineCreateInfo :: PipelineLayout
Vk.layout             = PipelineLayout
layout
      , $sel:stage:ComputePipelineCreateInfo :: SomeStruct PipelineShaderStageCreateInfo
Vk.stage              = Item l
stage
      , $sel:basePipelineHandle:ComputePipelineCreateInfo :: Pipeline
Vk.basePipelineHandle = forall a. Zero a => a
zero
      }
      where
        stage :: Item l
stage = case l
stages of
          [Item l
one]   -> Item l
one
          l
_assert -> forall a. HasCallStack => String -> a
error String
"compute code has one stage"

bind
  :: ( Compatible pipeLayout boundLayout
     , MonadIO m
     )
  => Vk.CommandBuffer
  -> Pipeline pipeLayout Compute Compute
  -> Bound boundLayout Compute Compute m ()
  -> Bound boundLayout noVertices noInstances m ()
bind :: forall (pipeLayout :: [*]) (boundLayout :: [*]) (m :: * -> *)
       noVertices noInstances.
(Compatible pipeLayout boundLayout, MonadIO m) =>
CommandBuffer
-> Pipeline pipeLayout Compute Compute
-> Bound boundLayout Compute Compute m ()
-> Bound boundLayout noVertices noInstances m ()
bind CommandBuffer
cb Pipeline{Pipeline
pipeline :: Pipeline
$sel:pipeline:Pipeline :: forall (dsl :: [*]) vertices instances.
Pipeline dsl vertices instances -> Pipeline
pipeline} (Bound m ()
attrAction) = do
  forall (dsl :: [*]) vertices instances (m :: * -> *) a.
m a -> Bound dsl vertices instances m a
Bound forall a b. (a -> b) -> a -> b
$ forall (io :: * -> *).
MonadIO io =>
CommandBuffer -> PipelineBindPoint -> Pipeline -> io ()
Vk.cmdBindPipeline CommandBuffer
cb PipelineBindPoint
Vk.PIPELINE_BIND_POINT_COMPUTE Pipeline
pipeline
  forall (dsl :: [*]) vertices instances (m :: * -> *) a.
m a -> Bound dsl vertices instances m a
Bound m ()
attrAction