{-# LANGUAGE QuasiQuotes #-}

-- | Various boilerplate definitions for the CUDA backend.
module Futhark.CodeGen.Backends.CCUDA.Boilerplate
  ( generateBoilerplate,
    profilingEnclosure,
    module Futhark.CodeGen.Backends.COpenCL.Boilerplate,
  )
where

import Control.Monad
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.COpenCL.Boilerplate
  ( copyDevToDev,
    copyDevToHost,
    copyHostToDev,
    copyScalarFromDev,
    copyScalarToDev,
    costCentreReport,
    failureMsgFunction,
    kernelRuns,
    kernelRuntime,
  )
import Futhark.CodeGen.Backends.GenericC qualified as GC
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.RTS.C (backendsCudaH)
import Futhark.Util (chunk)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: forall a. ErrorMsg a -> Int
errorMsgNumArgs = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

-- | Block items to put before and after a thing to be profiled.
profilingEnclosure :: Name -> ([C.BlockItem], [C.BlockItem])
profilingEnclosure :: Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
name =
  ( [C.citems|
      typename cudaEvent_t *pevents = NULL;
      if (ctx->profiling && !ctx->profiling_paused) {
        pevents = cuda_get_events(ctx,
                                  &ctx->program->$id:(kernelRuns name),
                                  &ctx->program->$id:(kernelRuntime name));
        CUDA_SUCCEED_FATAL(cudaEventRecord(pevents[0], 0));
      }
      |],
    [C.citems|
      if (pevents != NULL) {
        CUDA_SUCCEED_FATAL(cudaEventRecord(pevents[1], 0));
      }
      |]
  )

generateCUDADecls ::
  [Name] ->
  M.Map KernelName KernelSafety ->
  GC.CompilerM op s ()
generateCUDADecls :: forall op s. [Name] -> Map Name KernelSafety -> CompilerM op s ()
generateCUDADecls [Name]
cost_centres Map Name KernelSafety
kernels = do
  let forCostCentre :: Name -> CompilerM op s ()
forCostCentre Name
name = do
        forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField
          (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Name -> Name
kernelRuntime Name
name) forall a. Monoid a => a
mempty)
          [C.cty|typename int64_t|]
          (forall a. a -> Maybe a
Just [C.cexp|0|])
        forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField
          (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (Name -> Name
kernelRuns Name
name) forall a. Monoid a => a
mempty)
          [C.cty|int|]
          (forall a. a -> Maybe a
Just [C.cexp|0|])

  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels) forall a b. (a -> b) -> a -> b
$ \Name
name -> do
    forall op s. Id -> Type -> Stm -> Stm -> CompilerM op s ()
GC.contextFieldDyn
      (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent Name
name forall a. Monoid a => a
mempty)
      [C.cty|typename CUfunction|]
      [C.cstm|
             CUDA_SUCCEED_FATAL(cuModuleGetFunction(
                                     &ctx->program->$id:name,
                                     ctx->module,
                                     $string:(T.unpack (idText (C.toIdent name mempty)))));|]
      [C.cstm|{}|]
    forall {op} {s}. Name -> CompilerM op s ()
forCostCentre Name
name

  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {op} {s}. Name -> CompilerM op s ()
forCostCentre [Name]
cost_centres

-- | Called after most code has been generated to generate the bulk of
-- the boilerplate.
generateBoilerplate ::
  T.Text ->
  T.Text ->
  [Name] ->
  M.Map KernelName KernelSafety ->
  [FailureMsg] ->
  GC.CompilerM OpenCL () ()
generateBoilerplate :: Text
-> Text
-> [Name]
-> Map Name KernelSafety
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate Text
cuda_program Text
cuda_prelude [Name]
cost_centres Map Name KernelSafety
kernels [FailureMsg]
failures = do
  let cuda_program_fragments :: [Initializer]
cuda_program_fragments =
        -- Some C compilers limit the size of literal strings, so
        -- chunk the entire program into small bits here, and
        -- concatenate it again at runtime.
        [[C.cinit|$string:s|] | FilePath
s <- forall a. Int -> [a] -> [[a]]
chunk Int
2000 forall a b. (a -> b) -> a -> b
$ Text -> FilePath
T.unpack forall a b. (a -> b) -> a -> b
$ Text
cuda_prelude forall a. Semigroup a => a -> a -> a
<> Text
cuda_program]
      program_fragments :: [Initializer]
program_fragments = [Initializer]
cuda_program_fragments forall a. [a] -> [a] -> [a]
++ [[C.cinit|NULL|]]
  let max_failure_args :: Int
max_failure_args = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Ord a => a -> a -> a
max Int
0 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. ErrorMsg a -> Int
errorMsgNumArgs forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    forall op s. Definition -> CompilerM op s ()
GC.earlyDecl
    [C.cunit|static const int max_failure_args = $int:max_failure_args;
             static const char *cuda_program[] = {$inits:program_fragments, NULL};
             $esc:(T.unpack backendsCudaH)
            |]
  forall op s. Definition -> CompilerM op s ()
GC.earlyDecl forall a b. (a -> b) -> a -> b
$ [FailureMsg] -> Definition
failureMsgFunction [FailureMsg]
failures

  forall op s. [Name] -> Map Name KernelSafety -> CompilerM op s ()
generateCUDADecls [Name]
cost_centres Map Name KernelSafety
kernels

  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_add_nvrtc_option(struct futhark_context_config *cfg, const char* opt);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_dump_program_to(struct futhark_context_config *cfg, const char* s);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_load_program_from(struct futhark_context_config *cfg, const char* s);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_dump_ptx_to(struct futhark_context_config *cfg, const char* s);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_load_ptx_from(struct futhark_context_config *cfg, const char* s);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_group_size(struct futhark_context_config *cfg, int size);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_num_groups(struct futhark_context_config *cfg, int size);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_tile_size(struct futhark_context_config *cfg, int size);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_reg_tile_size(struct futhark_context_config *cfg, int size);|]
  forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_threshold(struct futhark_context_config *cfg, int size);|]

  forall op s. CompilerM op s ()
GC.generateProgramStruct

  forall op s. BlockItem -> CompilerM op s ()
GC.onClear
    [C.citem|if (ctx->error == NULL) {
               CUDA_SUCCEED_NONFATAL(cuda_free_all(ctx));
             }|]

  forall op s. BlockItem -> CompilerM op s ()
GC.profileReport [C.citem|CUDA_SUCCEED_FATAL(cuda_tally_profiling_records(ctx));|]
  forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.profileReport forall a b. (a -> b) -> a -> b
$ [Name] -> [BlockItem]
costCentreReport forall a b. (a -> b) -> a -> b
$ [Name]
cost_centres forall a. [a] -> [a] -> [a]
++ forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels
{-# NOINLINE generateBoilerplate #-}