{-# LANGUAGE QuasiQuotes #-}

-- | Various boilerplate definitions for the CUDA backend.
module Futhark.CodeGen.Backends.CCUDA.Boilerplate
  ( generateBoilerplate,
    profilingEnclosure,
    module Futhark.CodeGen.Backends.COpenCL.Boilerplate,
  )
where

import Control.Monad
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.COpenCL.Boilerplate
  ( copyDevToDev,
    copyDevToHost,
    copyHostToDev,
    copyScalarFromDev,
    copyScalarToDev,
    costCentreReport,
    failureMsgFunction,
    kernelRuns,
    kernelRuntime,
  )
import Futhark.CodeGen.Backends.GenericC qualified as GC
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.RTS.C (backendsCudaH)
import Futhark.Util (chunk)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: forall a. ErrorMsg a -> Int
errorMsgNumArgs = [PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimType] -> Int)
-> (ErrorMsg a -> [PrimType]) -> ErrorMsg a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg a -> [PrimType]
forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

-- | Block items to put before and after a thing to be profiled.
profilingEnclosure :: Name -> ([C.BlockItem], [C.BlockItem])
profilingEnclosure :: Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
name =
  ( [C.citems|
      typename CUevent *pevents = NULL;
      if (ctx->profiling && !ctx->profiling_paused) {
        pevents = cuda_get_events(ctx,
                                  &ctx->program->$id:(kernelRuns name),
                                  &ctx->program->$id:(kernelRuntime name));
        CUDA_SUCCEED_FATAL(cuEventRecord(pevents[0], ctx->stream));
      }
      |],
    [C.citems|
      if (pevents != NULL) {
        CUDA_SUCCEED_FATAL(cuEventRecord(pevents[1], ctx->stream));
      }
      |]
  )

generateCUDADecls ::
  [Name] ->
  M.Map KernelName KernelSafety ->
  GC.CompilerM op s ()
generateCUDADecls :: forall op s. [Name] -> Map Name KernelSafety -> CompilerM op s ()
generateCUDADecls [Name]
cost_centres Map Name KernelSafety
kernels = do
  let forCostCentre :: Name -> CompilerM op s ()
forCostCentre Name
name = do
        Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField
          (Name -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Name -> Name
kernelRuntime Name
name) SrcLoc
forall a. Monoid a => a
mempty)
          [C.cty|typename int64_t|]
          (Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|])
        Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField
          (Name -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Name -> Name
kernelRuns Name
name) SrcLoc
forall a. Monoid a => a
mempty)
          [C.cty|int|]
          (Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|])

  [Name] -> (Name -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map Name KernelSafety -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels) ((Name -> CompilerM op s ()) -> CompilerM op s ())
-> (Name -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \Name
name -> do
    Id -> Type -> Stm -> Stm -> CompilerM op s ()
forall op s. Id -> Type -> Stm -> Stm -> CompilerM op s ()
GC.contextFieldDyn
      (Name -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent Name
name SrcLoc
forall a. Monoid a => a
mempty)
      [C.cty|typename CUfunction|]
      [C.cstm|
             CUDA_SUCCEED_FATAL(cuModuleGetFunction(
                                     &ctx->program->$id:name,
                                     ctx->module,
                                     $string:(T.unpack (idText (C.toIdent name mempty)))));|]
      [C.cstm|{}|]
    Name -> CompilerM op s ()
forall {op} {s}. Name -> CompilerM op s ()
forCostCentre Name
name

  (Name -> CompilerM op s ()) -> [Name] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Name -> CompilerM op s ()
forall {op} {s}. Name -> CompilerM op s ()
forCostCentre [Name]
cost_centres

-- | Called after most code has been generated to generate the bulk of
-- the boilerplate.
generateBoilerplate ::
  T.Text ->
  T.Text ->
  [Name] ->
  M.Map KernelName KernelSafety ->
  [FailureMsg] ->
  GC.CompilerM OpenCL () ()
generateBoilerplate :: Text
-> Text
-> [Name]
-> Map Name KernelSafety
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate Text
cuda_program Text
cuda_prelude [Name]
cost_centres Map Name KernelSafety
kernels [FailureMsg]
failures = do
  let cuda_program_fragments :: [Initializer]
cuda_program_fragments =
        -- Some C compilers limit the size of literal strings, so
        -- chunk the entire program into small bits here, and
        -- concatenate it again at runtime.
        [[C.cinit|$string:s|] | FilePath
s <- Int -> FilePath -> [FilePath]
forall a. Int -> [a] -> [[a]]
chunk Int
2000 (FilePath -> [FilePath]) -> FilePath -> [FilePath]
forall a b. (a -> b) -> a -> b
$ Text -> FilePath
T.unpack (Text -> FilePath) -> Text -> FilePath
forall a b. (a -> b) -> a -> b
$ Text
cuda_prelude Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
cuda_program]
      program_fragments :: [Initializer]
program_fragments = [Initializer]
cuda_program_fragments [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [[C.cinit|NULL|]]
  let max_failure_args :: Int
max_failure_args = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (FailureMsg -> Int) -> [FailureMsg] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ErrorMsg Exp -> Int
forall a. ErrorMsg a -> Int
errorMsgNumArgs (ErrorMsg Exp -> Int)
-> (FailureMsg -> ErrorMsg Exp) -> FailureMsg -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures
  (Definition -> CompilerM OpenCL () ())
-> [Definition] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl
    [C.cunit|static const int max_failure_args = $int:max_failure_args;
             static const char *cuda_program[] = {$inits:program_fragments, NULL};
             $esc:(T.unpack backendsCudaH)
            |]
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl (Definition -> CompilerM OpenCL () ())
-> Definition -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [FailureMsg] -> Definition
failureMsgFunction [FailureMsg]
failures

  [Name] -> Map Name KernelSafety -> CompilerM OpenCL () ()
forall op s. [Name] -> Map Name KernelSafety -> CompilerM op s ()
generateCUDADecls [Name]
cost_centres Map Name KernelSafety
kernels

  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_add_nvrtc_option(struct futhark_context_config *cfg, const char* opt);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_dump_program_to(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_load_program_from(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_dump_ptx_to(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_load_ptx_from(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_group_size(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_num_groups(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_tile_size(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_reg_tile_size(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_threshold(struct futhark_context_config *cfg, int size);|]

  CompilerM OpenCL () ()
forall op s. CompilerM op s ()
GC.generateProgramStruct

  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.onClear
    [C.citem|if (ctx->error == NULL) {
               CUDA_SUCCEED_NONFATAL(cuda_free_all(ctx));
             }|]

  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.profileReport [C.citem|CUDA_SUCCEED_FATAL(cuda_tally_profiling_records(ctx));|]
  (BlockItem -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.profileReport ([BlockItem] -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [Name] -> [BlockItem]
costCentreReport ([Name] -> [BlockItem]) -> [Name] -> [BlockItem]
forall a b. (a -> b) -> a -> b
$ [Name]
cost_centres [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Map Name KernelSafety -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels
{-# NOINLINE generateBoilerplate #-}