module Vulkan.Utils.ShaderQQ.Backend.Shaderc.Internal
  ( compileShaderQ
  , compileShader
  ) where

import           Control.Monad.IO.Class
import           Data.ByteString                ( ByteString )
import qualified Data.ByteString               as BS
import           Data.FileEmbed
import           Language.Haskell.TH
import           System.Exit
import           System.IO.Temp
import           System.Process.Typed
import           Vulkan.Utils.ShaderQQ.ShaderType
import qualified Vulkan.Utils.ShaderQQ.GLSL    as GLSL
import qualified Vulkan.Utils.ShaderQQ.HLSL    as HLSL
import           Vulkan.Utils.ShaderQQ.Backend.Shaderc
import           Vulkan.Utils.ShaderQQ.Backend.Internal

-- * Utilities

-- | Compile a GLSL/HLSL shader to SPIR-V using glslc (from the shaderc project)
--
-- Messages are converted to GHC warnings or errors depending on compilation success.
compileShaderQ
  :: Maybe String
  -- ^ Argument to pass to `--target-spv`
  -> ShaderType
  -- ^ Argument to specify between glsl/hlsl shader
  -> String
  -- ^ stage
  -> Maybe String
  -- ^ Argument to specify entry-point function name for hlsl
  -> String
  -- ^ glsl or hlsl shader code
  -> Q Exp
  -- ^ Spir-V bytecode
compileShaderQ :: Maybe String
-> ShaderType -> String -> Maybe String -> String -> Q Exp
compileShaderQ targetSpv :: Maybe String
targetSpv shaderType :: ShaderType
shaderType stage :: String
stage entryPoint :: Maybe String
entryPoint code :: String
code = do
  Loc
loc                <- Q Loc
location
  (warnings :: [String]
warnings, result :: Either [String] ByteString
result) <- Maybe Loc
-> Maybe String
-> ShaderType
-> String
-> Maybe String
-> String
-> Q ([String], Either [String] ByteString)
forall (m :: * -> *).
MonadIO m =>
Maybe Loc
-> Maybe String
-> ShaderType
-> String
-> Maybe String
-> String
-> m ([String], Either [String] ByteString)
compileShader (Loc -> Maybe Loc
forall a. a -> Maybe a
Just Loc
loc) Maybe String
targetSpv ShaderType
shaderType String
stage Maybe String
entryPoint String
code
  ByteString
bs <- String
-> (String -> Q ())
-> (String -> Q ByteString)
-> ([String], Either [String] ByteString)
-> Q ByteString
forall (m :: * -> *).
(Applicative m, Monad m) =>
String
-> (String -> m ())
-> (String -> m ByteString)
-> ([String], Either [String] ByteString)
-> m ByteString
messageProcess "glslc" String -> Q ()
reportWarning String -> Q ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail ([String]
warnings, Either [String] ByteString
result)
  ByteString -> Q Exp
bsToExp ByteString
bs

-- | Compile a GLSL/HLSL shader to spir-v using glslc
compileShader
  :: MonadIO m
  => Maybe Loc
  -- ^ Source location
  -> Maybe String
  -- ^ Argument to pass to `--target-spv`
  -> ShaderType
  -- ^ Argument to specify between glsl/hlsl shader
  -> String
  -- ^ stage
  -> Maybe String
  -- ^ Argument to specify entry-point function name for hlsl
  -> String
  -- ^ glsl or hlsl shader code
  -> m ([ShadercWarning], Either [ShadercError] ByteString)
  -- ^ Spir-V bytecode with warnings or errors
compileShader :: Maybe Loc
-> Maybe String
-> ShaderType
-> String
-> Maybe String
-> String
-> m ([String], Either [String] ByteString)
compileShader loc :: Maybe Loc
loc targetSpv :: Maybe String
targetSpv shaderType :: ShaderType
shaderType stage :: String
stage entryPoint :: Maybe String
entryPoint code :: String
code =
  IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ([String], Either [String] ByteString)
 -> m ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ String
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory "th-shader" ((String -> IO ([String], Either [String] ByteString))
 -> IO ([String], Either [String] ByteString))
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ \dir :: String
dir -> do
    let codeWithLineDirective :: String
codeWithLineDirective = String -> (Loc -> String) -> Maybe Loc -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
code (case ShaderType
shaderType of
                                              GLSL -> String -> Loc -> String
GLSL.insertLineDirective String
code
                                              HLSL -> String -> Loc -> String
HLSL.insertLineDirective String
code
                                           ) Maybe Loc
loc
    let shader :: String
shader = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "/shader." String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
stage
        spirv :: String
spirv  = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> "/shader.spv"
    String -> String -> IO ()
writeFile String
shader String
codeWithLineDirective

    let targetArgs :: [String]
targetArgs = case Maybe String
targetSpv of
          Nothing -> []
          Just t :: String
t  -> ["--target-spv=" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
t]
        -- https://github.com/google/shaderc/blob/01dd72d6079ebdc0f96859365ba7abb1b62758bf/glslc/src/main.cc#L64
        entryPointArgs :: [String]
entryPointArgs = case Maybe String
entryPoint of
          Nothing -> []
          Just name :: String
name -> case ShaderType
shaderType of
            GLSL -> []
            HLSL -> ["-fentry-point=" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
name] 
        args :: [String]
args = [String]
targetArgs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String]
entryPointArgs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ ["-fshader-stage=" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
stage, "-x", ShaderType -> String
forall a. Show a => a -> String
show ShaderType
shaderType, String
shader, "-o", String
spirv]
    (rc :: ExitCode
rc, out :: ByteString
out, err :: ByteString
err) <- ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall (m :: * -> *) stdin stdoutIgnored stderrIgnored.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderrIgnored
-> m (ExitCode, ByteString, ByteString)
readProcess (ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString))
-> ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ String -> [String] -> ProcessConfig () () ()
proc "glslc" [String]
args
    let (warnings :: [String]
warnings, errors :: [String]
errors) = ByteString -> ([String], [String])
processShadercMessages (ByteString
out ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
err)
    case ExitCode
rc of
      ExitSuccess -> do
        ByteString
bs <- String -> IO ByteString
BS.readFile String
spirv
        ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, ByteString -> Either [String] ByteString
forall a b. b -> Either a b
Right ByteString
bs)
      ExitFailure _rc :: Int
_rc -> ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, [String] -> Either [String] ByteString
forall a b. a -> Either a b
Left [String]
errors)