{-# OPTIONS_GHC -Wwarn=orphans #-}

{-# LANGUAGE OverloadedLists #-}

module Engine.Vulkan.Shader
  ( Shader(..)
  , create
  , destroy

  , withSpecialization
  , Specialization(..)

  , SpecializationConst(..)
  ) where

import RIO

import Data.Vector qualified as Vector
import Data.Vector.Storable qualified as Storable
import Foreign qualified
import Vulkan.Core10 qualified as Vk
import Vulkan.CStruct.Extends (SomeStruct(..))
import Vulkan.Zero (Zero(..))
import Unsafe.Coerce (unsafeCoerce)

import Engine.Vulkan.Types (HasVulkan(..))

data Shader = Shader
  { Shader -> Vector ShaderModule
sModules        :: Vector Vk.ShaderModule
  , Shader -> Vector (SomeStruct PipelineShaderStageCreateInfo)
sPipelineStages :: Vector (SomeStruct Vk.PipelineShaderStageCreateInfo)
  }

create
  :: (MonadIO io, HasVulkan ctx)
  => ctx
  -> Vector (Vk.ShaderStageFlagBits, ByteString)
  -> Maybe Vk.SpecializationInfo
  -> io Shader
create :: ctx
-> Vector (ShaderStageFlagBits, ByteString)
-> Maybe SpecializationInfo
-> io Shader
create ctx
context Vector (ShaderStageFlagBits, ByteString)
stages Maybe SpecializationInfo
spec = do
  Vector (ShaderModule, SomeStruct PipelineShaderStageCreateInfo)
staged <- Vector (ShaderStageFlagBits, ByteString)
-> ((ShaderStageFlagBits, ByteString)
    -> io (ShaderModule, SomeStruct PipelineShaderStageCreateInfo))
-> io
     (Vector (ShaderModule, SomeStruct PipelineShaderStageCreateInfo))
forall (m :: * -> *) a b.
Monad m =>
Vector a -> (a -> m b) -> m (Vector b)
Vector.forM Vector (ShaderStageFlagBits, ByteString)
stages \(ShaderStageFlagBits
stage, ByteString
code) -> do
    ShaderModule
module_ <- Device
-> ShaderModuleCreateInfo '[]
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io ShaderModule
forall (a :: [*]) (io :: * -> *).
(Extendss ShaderModuleCreateInfo a, PokeChain a, MonadIO io) =>
Device
-> ShaderModuleCreateInfo a
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io ShaderModule
Vk.createShaderModule
      (ctx -> Device
forall a. HasVulkan a => a -> Device
getDevice ctx
context)
      ShaderModuleCreateInfo '[]
forall a. Zero a => a
zero
        { $sel:code:ShaderModuleCreateInfo :: ByteString
Vk.code = ByteString
code
        }
      "allocator" ::: Maybe AllocationCallbacks
forall a. Maybe a
Nothing

    pure
      ( ShaderModule
module_
      , PipelineShaderStageCreateInfo '[]
-> SomeStruct PipelineShaderStageCreateInfo
forall (a :: [*] -> *) (es :: [*]).
(Extendss a es, PokeChain es, Show (Chain es)) =>
a es -> SomeStruct a
SomeStruct PipelineShaderStageCreateInfo '[]
forall a. Zero a => a
zero
          { $sel:stage:PipelineShaderStageCreateInfo :: ShaderStageFlagBits
Vk.stage              = ShaderStageFlagBits
stage
          , $sel:module':PipelineShaderStageCreateInfo :: ShaderModule
Vk.module'            = ShaderModule
module_
          , $sel:name:PipelineShaderStageCreateInfo :: ByteString
Vk.name               = ByteString
"main"
          , $sel:specializationInfo:PipelineShaderStageCreateInfo :: Maybe SpecializationInfo
Vk.specializationInfo = Maybe SpecializationInfo
spec
          }
      )
  let (Vector ShaderModule
modules, Vector (SomeStruct PipelineShaderStageCreateInfo)
pStages) = Vector (ShaderModule, SomeStruct PipelineShaderStageCreateInfo)
-> (Vector ShaderModule,
    Vector (SomeStruct PipelineShaderStageCreateInfo))
forall a b. Vector (a, b) -> (Vector a, Vector b)
Vector.unzip Vector (ShaderModule, SomeStruct PipelineShaderStageCreateInfo)
staged
  Shader -> io Shader
forall (f :: * -> *) a. Applicative f => a -> f a
pure Shader :: Vector ShaderModule
-> Vector (SomeStruct PipelineShaderStageCreateInfo) -> Shader
Shader
    { $sel:sModules:Shader :: Vector ShaderModule
sModules        = Vector ShaderModule
modules
    , $sel:sPipelineStages:Shader :: Vector (SomeStruct PipelineShaderStageCreateInfo)
sPipelineStages = Vector (SomeStruct PipelineShaderStageCreateInfo)
pStages
    }

destroy :: (MonadIO io, HasVulkan ctx) => ctx -> Shader -> io ()
destroy :: ctx -> Shader -> io ()
destroy ctx
context Shader{Vector ShaderModule
sModules :: Vector ShaderModule
$sel:sModules:Shader :: Shader -> Vector ShaderModule
sModules} =
  Vector ShaderModule -> (ShaderModule -> io ()) -> io ()
forall (m :: * -> *) a b. Monad m => Vector a -> (a -> m b) -> m ()
Vector.forM_ Vector ShaderModule
sModules \ShaderModule
module_ ->
    Device
-> ShaderModule
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io ()
forall (io :: * -> *).
MonadIO io =>
Device
-> ShaderModule
-> ("allocator" ::: Maybe AllocationCallbacks)
-> io ()
Vk.destroyShaderModule (ctx -> Device
forall a. HasVulkan a => a -> Device
getDevice ctx
context) ShaderModule
module_ "allocator" ::: Maybe AllocationCallbacks
forall a. Maybe a
Nothing

-- * Specialization constants

withSpecialization
  :: ( Specialization spec
     , MonadUnliftIO m
     )
  => spec
  -> (Maybe Vk.SpecializationInfo -> m a)
  -> m a
withSpecialization :: spec -> (Maybe SpecializationInfo -> m a) -> m a
withSpecialization spec
spec Maybe SpecializationInfo -> m a
action =
  case Vector SpecializationMapEntry
mapEntries of
    [] ->
      Maybe SpecializationInfo -> m a
action Maybe SpecializationInfo
forall a. Maybe a
Nothing
    Vector SpecializationMapEntry
_some ->
      ((forall a. m a -> IO a) -> IO a) -> m a
forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> IO a) -> IO b) -> m b
withRunInIO \forall a. m a -> IO a
run ->
        Vector Word32 -> (Ptr Word32 -> IO a) -> IO a
forall a b. Storable a => Vector a -> (Ptr a -> IO b) -> IO b
Storable.unsafeWith Vector Word32
specData \Ptr Word32
specPtr ->
          m a -> IO a
forall a. m a -> IO a
run (m a -> IO a)
-> (Maybe SpecializationInfo -> m a)
-> Maybe SpecializationInfo
-> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe SpecializationInfo -> m a
action (Maybe SpecializationInfo -> IO a)
-> Maybe SpecializationInfo -> IO a
forall a b. (a -> b) -> a -> b
$ SpecializationInfo -> Maybe SpecializationInfo
forall a. a -> Maybe a
Just SpecializationInfo :: Vector SpecializationMapEntry
-> Word64 -> Ptr () -> SpecializationInfo
Vk.SpecializationInfo
            { $sel:mapEntries:SpecializationInfo :: Vector SpecializationMapEntry
mapEntries = Vector SpecializationMapEntry
mapEntries
            , $sel:dataSize:SpecializationInfo :: Word64
dataSize   = Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word64) -> Int -> Word64
forall a b. (a -> b) -> a -> b
$ Vector Word32 -> Int
forall a. Storable a => Vector a -> Int
Storable.length Vector Word32
specData Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
4
            , $sel:data':SpecializationInfo :: Ptr ()
data'      = Ptr Word32 -> Ptr ()
forall a b. Ptr a -> Ptr b
Foreign.castPtr @_ @() Ptr Word32
specPtr
            }
  where
    specData :: Vector Word32
specData = [Word32] -> Vector Word32
forall a. Storable a => [a] -> Vector a
Storable.fromList ([Word32] -> Vector Word32) -> [Word32] -> Vector Word32
forall a b. (a -> b) -> a -> b
$ spec -> [Word32]
forall a. Specialization a => a -> [Word32]
specializationData spec
spec

    mapEntries :: Vector SpecializationMapEntry
mapEntries = (Int -> Word32 -> SpecializationMapEntry)
-> Vector Word32 -> Vector SpecializationMapEntry
forall a b. (Int -> a -> b) -> Vector a -> Vector b
Vector.imap
      (\Int
ix Word32
_data -> SpecializationMapEntry :: Word32 -> Word32 -> Word64 -> SpecializationMapEntry
Vk.SpecializationMapEntry
          { $sel:constantID:SpecializationMapEntry :: Word32
constantID = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
ix
          , $sel:offset:SpecializationMapEntry :: Word32
offset     = Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> Int -> Word32
forall a b. (a -> b) -> a -> b
$ Int
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
4
          , $sel:size:SpecializationMapEntry :: Word64
size       = Word64
4
          }
      )
      (Vector Word32 -> Vector Word32
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
Vector.convert Vector Word32
specData)

class Specialization a where
  -- XXX: abusing the fact that most scalars (sans double) are 4-wide
  specializationData :: a -> [Word32]

instance Specialization () where
  specializationData :: () -> [Word32]
specializationData = [Word32] -> () -> [Word32]
forall a b. a -> b -> a
const []

instance Specialization [Word32] where
  specializationData :: [Word32] -> [Word32]
specializationData = [Word32] -> [Word32]
forall a. a -> a
id

{- |
  The constant_id can only be applied to a scalar *int*, a scalar *float* or a scalar *bool*.
  (https://github.com/KhronosGroup/GLSL/blob/master/extensions/khr/GL_KHR_vulkan_glsl.txt)

  XXX: Apparently it is possible to pass uints and doubles too.
-}

instance Specialization Word32 where
  specializationData :: Word32 -> [Word32]
specializationData Word32
x = [Word32 -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData Word32
x]

instance Specialization Int32 where
  specializationData :: Int32 -> [Word32]
specializationData Int32
x = [Int32 -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData Int32
x]

instance Specialization Float where
  specializationData :: Float -> [Word32]
specializationData Float
x = [Float -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData Float
x]

instance Specialization Bool where
  specializationData :: Bool -> [Word32]
specializationData Bool
x = [Bool -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData Bool
x]

class SpecializationConst a where
  packConstData :: a -> Word32

instance SpecializationConst Word32 where
  packConstData :: Word32 -> Word32
packConstData = Word32 -> Word32
forall a. a -> a
id

instance SpecializationConst Int32 where
  packConstData :: Int32 -> Word32
packConstData = Int32 -> Word32
forall a b. a -> b
unsafeCoerce

instance SpecializationConst Float where
  packConstData :: Float -> Word32
packConstData = Float -> Word32
forall a b. a -> b
unsafeCoerce

instance SpecializationConst Bool where
  packConstData :: Bool -> Word32
packConstData = Word32 -> Word32 -> Bool -> Word32
forall a. a -> a -> Bool -> a
bool Word32
0 Word32
1

instance
  ( SpecializationConst a
  , SpecializationConst b
  ) => Specialization (a, b) where
  specializationData :: (a, b) -> [Word32]
specializationData (a
a, b
b) =
    [ a -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData a
a
    , b -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData b
b
    ]

instance
  ( SpecializationConst a
  , SpecializationConst b
  , SpecializationConst c
  ) => Specialization (a, b, c) where
  specializationData :: (a, b, c) -> [Word32]
specializationData (a
a, b
b, c
c) =
    [ a -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData a
a
    , b -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData b
b
    , c -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData c
c
    ]

instance
  ( SpecializationConst a
  , SpecializationConst b
  , SpecializationConst c
  , SpecializationConst d
  ) => Specialization (a, b, c, d) where
  specializationData :: (a, b, c, d) -> [Word32]
specializationData (a
a, b
b, c
c, d
d) =
    [ a -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData a
a
    , b -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData b
b
    , c -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData c
c
    , d -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData d
d
    ]

instance
  ( SpecializationConst a
  , SpecializationConst b
  , SpecializationConst c
  , SpecializationConst d
  , SpecializationConst e
  ) => Specialization (a, b, c, d, e) where
  specializationData :: (a, b, c, d, e) -> [Word32]
specializationData (a
a, b
b, c
c, d
d, e
e) =
    [ a -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData a
a
    , b -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData b
b
    , c -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData c
c
    , d -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData d
d
    , e -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData e
e
    ]

instance
  ( SpecializationConst a
  , SpecializationConst b
  , SpecializationConst c
  , SpecializationConst d
  , SpecializationConst e
  , SpecializationConst f
  ) => Specialization (a, b, c, d, e, f) where
  specializationData :: (a, b, c, d, e, f) -> [Word32]
specializationData (a
a, b
b, c
c, d
d, e
e, f
f) =
    [ a -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData a
a
    , b -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData b
b
    , c -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData c
c
    , d -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData d
d
    , e -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData e
e
    , f -> Word32
forall a. SpecializationConst a => a -> Word32
packConstData f
f
    ]