module Vulkan.Utils.ShaderQQ.Backend.Glslang.Internal
  ( compileShaderQ
  , compileShader
  ) where

import           Control.Monad.IO.Class
import           Data.ByteString                ( ByteString )
import qualified Data.ByteString               as BS
import           Data.FileEmbed
import           Language.Haskell.TH
import           System.Exit
import           System.IO.Temp
import           System.Process.Typed
import           Vulkan.Utils.ShaderQQ.ShaderType
import qualified Vulkan.Utils.ShaderQQ.GLSL    as GLSL
import qualified Vulkan.Utils.ShaderQQ.HLSL    as HLSL
import           Vulkan.Utils.ShaderQQ.Backend.Glslang
import           Vulkan.Utils.ShaderQQ.Backend.Internal

-- * Utilities

-- | Compile a GLSL/HLSL shader to spir-v using glslangValidator.
--
-- Messages are converted to GHC warnings or errors depending on compilation success.
compileShaderQ
  :: Maybe String
  -- ^ Argument to pass to `--target-env`
  -> ShaderType
  -- ^ Argument to specify between glsl/hlsl shader
  -> String
  -- ^ stage
  -> Maybe String
  -- ^ Argument to specify entry-point function name
  -> String
  -- ^ glsl or hlsl shader code
  -> Q Exp
  -- ^ Spir-V bytecode
compileShaderQ :: Maybe String
-> ShaderType -> String -> Maybe String -> String -> Q Exp
compileShaderQ Maybe String
targetEnv ShaderType
shaderType String
stage Maybe String
entryPoint String
code = do
  Loc
loc                <- Q Loc
location
  ([String]
warnings, Either [String] ByteString
result) <- Maybe Loc
-> Maybe String
-> ShaderType
-> String
-> Maybe String
-> String
-> Q ([String], Either [String] ByteString)
forall (m :: * -> *).
MonadIO m =>
Maybe Loc
-> Maybe String
-> ShaderType
-> String
-> Maybe String
-> String
-> m ([String], Either [String] ByteString)
compileShader (Loc -> Maybe Loc
forall a. a -> Maybe a
Just Loc
loc) Maybe String
targetEnv ShaderType
shaderType String
stage Maybe String
entryPoint String
code
  ByteString
bs <- String
-> (String -> Q ())
-> (String -> Q ByteString)
-> ([String], Either [String] ByteString)
-> Q ByteString
forall (m :: * -> *).
(Applicative m, Monad m) =>
String
-> (String -> m ())
-> (String -> m ByteString)
-> ([String], Either [String] ByteString)
-> m ByteString
messageProcess String
"glslangValidator" String -> Q ()
reportWarning String -> Q ByteString
forall (m :: * -> *) a. MonadFail m => String -> m a
fail ([String]
warnings, Either [String] ByteString
result)
  ByteString -> Q Exp
bsToExp ByteString
bs

-- | Compile a GLSL/HLSL shader to spir-v using glslangValidator
compileShader
  :: MonadIO m
  => Maybe Loc
  -- ^ Source location
  -> Maybe String
  -- ^ Argument to pass to `--target-env`
  -> ShaderType
  -- ^ Argument to specify between glsl/hlsl shader
  -> String
  -- ^ stage
  -> Maybe String
  -- ^ Argument to specify entry-point function name
  -> String
  -- ^ glsl or hlsl shader code
  -> m ([GlslangWarning], Either [GlslangError] ByteString)
  -- ^ Spir-V bytecode with warnings or errors
compileShader :: forall (m :: * -> *).
MonadIO m =>
Maybe Loc
-> Maybe String
-> ShaderType
-> String
-> Maybe String
-> String
-> m ([String], Either [String] ByteString)
compileShader Maybe Loc
loc Maybe String
targetEnv ShaderType
shaderType String
stage Maybe String
entryPoint String
code =
  IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO ([String], Either [String] ByteString)
 -> m ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
-> m ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ String
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
String -> (String -> m a) -> m a
withSystemTempDirectory String
"th-shader" ((String -> IO ([String], Either [String] ByteString))
 -> IO ([String], Either [String] ByteString))
-> (String -> IO ([String], Either [String] ByteString))
-> IO ([String], Either [String] ByteString)
forall a b. (a -> b) -> a -> b
$ \String
dir -> do
    let codeWithLineDirective :: String
codeWithLineDirective = String -> (Loc -> String) -> Maybe Loc -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
code (case ShaderType
shaderType of
                                              ShaderType
GLSL -> String -> Loc -> String
GLSL.insertLineDirective String
code
                                              ShaderType
HLSL -> String -> Loc -> String
HLSL.insertLineDirective String
code
                                           ) Maybe Loc
loc
    let shader :: String
shader = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/shader." String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
stage
        spirv :: String
spirv  = String
dir String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"/shader.spv"
    String -> String -> IO ()
writeFile String
shader String
codeWithLineDirective

    let targetArgs :: [String]
targetArgs = case Maybe String
targetEnv of
          Maybe String
Nothing -> []
          Just String
t  -> [String
"--target-env", String
t]
        shaderTypeArgs :: [String]
shaderTypeArgs = case ShaderType
shaderType of
          ShaderType
GLSL -> []
          ShaderType
HLSL -> [String
"-D"]
        -- https://github.com/KhronosGroup/glslang/issues/1045#issuecomment-328707953
        entryPointArgs :: [String]
entryPointArgs = case Maybe String
entryPoint of
          Maybe String
Nothing -> []
          Just String
name -> case ShaderType
shaderType of
            ShaderType
GLSL -> [String
"-e", String
name, String
"--source-entry-point", String
"main"]
            ShaderType
HLSL -> [String
"-e", String
name]
        args :: [String]
args = [String]
targetArgs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String]
shaderTypeArgs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String]
entryPointArgs [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String
"-S", String
stage, String
"-V", String
shader, String
"-o", String
spirv]
    (ExitCode
rc, ByteString
out, ByteString
err) <- ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall (m :: * -> *) stdin stdoutIgnored stderrIgnored.
MonadIO m =>
ProcessConfig stdin stdoutIgnored stderrIgnored
-> m (ExitCode, ByteString, ByteString)
readProcess (ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString))
-> ProcessConfig () () () -> IO (ExitCode, ByteString, ByteString)
forall a b. (a -> b) -> a -> b
$ String -> [String] -> ProcessConfig () () ()
proc String
"glslangValidator" [String]
args
    let ([String]
warnings, [String]
errors) = ByteString -> ([String], [String])
processGlslangMessages (ByteString
out ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
err)
    case ExitCode
rc of
      ExitCode
ExitSuccess -> do
        ByteString
bs <- String -> IO ByteString
BS.readFile String
spirv
        ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, ByteString -> Either [String] ByteString
forall a b. b -> Either a b
Right ByteString
bs)
      ExitFailure Int
_rc -> ([String], Either [String] ByteString)
-> IO ([String], Either [String] ByteString)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String]
warnings, [String] -> Either [String] ByteString
forall a b. a -> Either a b
Left [String]
errors)