{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}

-- | Various boilerplate definitions for the CUDA backend.
module Futhark.CodeGen.Backends.CCUDA.Boilerplate
  ( generateBoilerplate,
    profilingEnclosure,
    module Futhark.CodeGen.Backends.COpenCL.Boilerplate,
  )
where

import Data.FileEmbed (embedStringFile)
import qualified Data.Map as M
import Data.Maybe
import Futhark.CodeGen.Backends.COpenCL.Boilerplate
  ( copyDevToDev,
    copyDevToHost,
    copyHostToDev,
    copyScalarFromDev,
    copyScalarToDev,
    costCentreReport,
    failureSwitch,
    kernelRuns,
    kernelRuntime,
  )
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.Util (chunk, zEncodeString)
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: forall a. ErrorMsg a -> Int
errorMsgNumArgs = [PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimType] -> Int)
-> (ErrorMsg a -> [PrimType]) -> ErrorMsg a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg a -> [PrimType]
forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

-- | Block items to put before and after a thing to be profiled.
profilingEnclosure :: Name -> ([C.BlockItem], [C.BlockItem])
profilingEnclosure :: Name -> ([BlockItem], [BlockItem])
profilingEnclosure Name
name =
  ( [C.citems|
      typename cudaEvent_t *pevents = NULL;
      if (ctx->profiling && !ctx->profiling_paused) {
        pevents = cuda_get_events(&ctx->cuda,
                                  &ctx->$id:(kernelRuns name),
                                  &ctx->$id:(kernelRuntime name));
        CUDA_SUCCEED(cudaEventRecord(pevents[0], 0));
      }
      |],
    [C.citems|
      if (pevents != NULL) {
        CUDA_SUCCEED(cudaEventRecord(pevents[1], 0));
      }
      |]
  )

-- | Called after most code has been generated to generate the bulk of
-- the boilerplate.
generateBoilerplate ::
  String ->
  String ->
  [Name] ->
  M.Map KernelName KernelSafety ->
  M.Map Name SizeClass ->
  [FailureMsg] ->
  GC.CompilerM OpenCL () ()
generateBoilerplate :: String
-> String
-> [Name]
-> Map Name KernelSafety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate String
cuda_program String
cuda_prelude [Name]
cost_centres Map Name KernelSafety
kernels Map Name SizeClass
sizes [FailureMsg]
failures = do
  (Definition -> CompilerM OpenCL () ())
-> [Definition] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl
    [C.cunit|
      $esc:("#include <cuda.h>")
      $esc:("#include <nvrtc.h>")
      $esc:("typedef CUdeviceptr fl_mem_t;")
      $esc:free_list_h
      $esc:cuda_h
      const char *cuda_program[] = {$inits:fragments, NULL};
      |]

  Map Name SizeClass -> CompilerM OpenCL () ()
generateSizeFuns Map Name SizeClass
sizes
  String
cfg <- Map Name SizeClass -> CompilerM OpenCL () String
generateConfigFuns Map Name SizeClass
sizes
  String
-> [Name]
-> Map Name KernelSafety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateContextFuns String
cfg [Name]
cost_centres Map Name KernelSafety
kernels Map Name SizeClass
sizes [FailureMsg]
failures

  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.profileReport [C.citem|CUDA_SUCCEED(cuda_tally_profiling_records(&ctx->cuda));|]
  (BlockItem -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.profileReport ([BlockItem] -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [Name] -> [BlockItem]
costCentreReport ([Name] -> [BlockItem]) -> [Name] -> [BlockItem]
forall a b. (a -> b) -> a -> b
$ [Name]
cost_centres [Name] -> [Name] -> [Name]
forall a. [a] -> [a] -> [a]
++ Map Name KernelSafety -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels
  where
    cuda_h :: String
cuda_h = $(embedStringFile "rts/c/cuda.h")
    free_list_h :: String
free_list_h = $(embedStringFile "rts/c/free_list.h")
    fragments :: [Initializer]
fragments =
      (String -> Initializer) -> [String] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\String
s -> [C.cinit|$string:s|]) ([String] -> [Initializer]) -> [String] -> [Initializer]
forall a b. (a -> b) -> a -> b
$
        Int -> String -> [String]
forall a. Int -> [a] -> [[a]]
chunk Int
2000 (String
cuda_prelude String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
cuda_program)

generateSizeFuns :: M.Map Name SizeClass -> GC.CompilerM OpenCL () ()
generateSizeFuns :: Map Name SizeClass -> CompilerM OpenCL () ()
generateSizeFuns Map Name SizeClass
sizes = do
  let size_name_inits :: [Initializer]
size_name_inits = (Name -> Initializer) -> [Name] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
k -> [C.cinit|$string:(pretty k)|]) ([Name] -> [Initializer]) -> [Name] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      size_var_inits :: [Initializer]
size_var_inits = (Name -> Initializer) -> [Name] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
k -> [C.cinit|$string:(zEncodeString (pretty k))|]) ([Name] -> [Initializer]) -> [Name] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      size_class_inits :: [Initializer]
size_class_inits = (SizeClass -> Initializer) -> [SizeClass] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\SizeClass
c -> [C.cinit|$string:(pretty c)|]) ([SizeClass] -> [Initializer]) -> [SizeClass] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [SizeClass]
forall k a. Map k a -> [a]
M.elems Map Name SizeClass
sizes

  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static const char *size_names[] = { $inits:size_name_inits };|]
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static const char *size_vars[] = { $inits:size_var_inits };|]
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static const char *size_classes[] = { $inits:size_class_inits };|]

generateConfigFuns :: M.Map Name SizeClass -> GC.CompilerM OpenCL () String
generateConfigFuns :: Map Name SizeClass -> CompilerM OpenCL () String
generateConfigFuns Map Name SizeClass
sizes = do
  let size_decls :: [FieldGroup]
size_decls = (Name -> FieldGroup) -> [Name] -> [FieldGroup]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
k -> [C.csdecl|typename int64_t $id:k;|]) ([Name] -> [FieldGroup]) -> [Name] -> [FieldGroup]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      num_sizes :: Int
num_sizes = Map Name SizeClass -> Int
forall k a. Map k a -> Int
M.size Map Name SizeClass
sizes
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|struct sizes { $sdecls:size_decls };|]
  String
cfg <- String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
GC.publicDef String
"context_config" HeaderSection
GC.InitDecl ((String -> (Definition, Definition))
 -> CompilerM OpenCL () String)
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|struct $id:s;|],
      [C.cedecl|struct $id:s { struct cuda_config cu_cfg;
                              int profiling;
                              typename int64_t sizes[$int:num_sizes];
                              int num_nvrtc_opts;
                              const char **nvrtc_opts;
                            };|]
    )

  let size_value_inits :: [Stm]
size_value_inits = (Int -> SizeClass -> Stm) -> [Int] -> [SizeClass] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SizeClass -> Stm
forall {a}. (Show a, Integral a) => a -> SizeClass -> Stm
sizeInit [Int
0 .. Map Name SizeClass -> Int
forall k a. Map k a -> Int
M.size Map Name SizeClass
sizes Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] (Map Name SizeClass -> [SizeClass]
forall k a. Map k a -> [a]
M.elems Map Name SizeClass
sizes)
      sizeInit :: a -> SizeClass -> Stm
sizeInit a
i SizeClass
size = [C.cstm|cfg->sizes[$int:i] = $int:val;|]
        where
          val :: Int64
val = Int64 -> Maybe Int64 -> Int64
forall a. a -> Maybe a -> a
fromMaybe Int64
0 (Maybe Int64 -> Int64) -> Maybe Int64 -> Int64
forall a b. (a -> b) -> a -> b
$ SizeClass -> Maybe Int64
sizeDefault SizeClass
size
  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_new" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|struct $id:cfg* $id:s(void);|],
      [C.cedecl|struct $id:cfg* $id:s(void) {
                         struct $id:cfg *cfg = (struct $id:cfg*) malloc(sizeof(struct $id:cfg));
                         if (cfg == NULL) {
                           return NULL;
                         }

                         cfg->profiling = 0;
                         cfg->num_nvrtc_opts = 0;
                         cfg->nvrtc_opts = (const char**) malloc(sizeof(const char*));
                         cfg->nvrtc_opts[0] = NULL;
                         $stms:size_value_inits
                         cuda_config_init(&cfg->cu_cfg, $int:num_sizes,
                                          size_names, size_vars,
                                          cfg->sizes, size_classes);
                         return cfg;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_free" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg) {
                         free(cfg->nvrtc_opts);
                         free(cfg);
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_add_nvrtc_option" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *opt);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *opt) {
                         cfg->nvrtc_opts[cfg->num_nvrtc_opts] = opt;
                         cfg->num_nvrtc_opts++;
                         cfg->nvrtc_opts = (const char**) realloc(cfg->nvrtc_opts, (cfg->num_nvrtc_opts+1) * sizeof(const char*));
                         cfg->nvrtc_opts[cfg->num_nvrtc_opts] = NULL;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_debugging" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
                         cfg->cu_cfg.logging = cfg->cu_cfg.debugging = flag;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_profiling" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
                         cfg->profiling = flag;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_logging" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
                         cfg->cu_cfg.logging = flag;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_device" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *s);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *s) {
                         set_preferred_device(&cfg->cu_cfg, s);
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_dump_program_to" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                         cfg->cu_cfg.dump_program_to = path;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_load_program_from" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                         cfg->cu_cfg.load_program_from = path;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_dump_ptx_to" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                          cfg->cu_cfg.dump_ptx_to = path;
                      }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_load_ptx_from" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                          cfg->cu_cfg.load_ptx_from = path;
                      }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_group_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int size);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_block_size = size;
                         cfg->cu_cfg.default_block_size_changed = 1;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_num_groups" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int num) {
                         cfg->cu_cfg.default_grid_size = num;
                         cfg->cu_cfg.default_grid_size_changed = 1;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_tile_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_tile_size = size;
                         cfg->cu_cfg.default_tile_size_changed = 1;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_reg_tile_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_reg_tile_size = size;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_threshold" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
      [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_threshold = size;
                       }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|int $id:s(struct $id:cfg* cfg, const char *size_name, size_t size_value);|],
      [C.cedecl|int $id:s(struct $id:cfg* cfg, const char *size_name, size_t size_value) {

                         for (int i = 0; i < $int:num_sizes; i++) {
                           if (strcmp(size_name, size_names[i]) == 0) {
                             cfg->sizes[i] = size_value;
                             return 0;
                           }
                         }

                         if (strcmp(size_name, "default_group_size") == 0) {
                           cfg->cu_cfg.default_block_size = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_num_groups") == 0) {
                           cfg->cu_cfg.default_grid_size = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_threshold") == 0) {
                           cfg->cu_cfg.default_threshold = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_tile_size") == 0) {
                           cfg->cu_cfg.default_tile_size = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_reg_tile_size") == 0) {
                           cfg->cu_cfg.default_reg_tile_size = size_value;
                           return 0;
                         }

                         return 1;
                       }|]
    )
  String -> CompilerM OpenCL () String
forall (m :: * -> *) a. Monad m => a -> m a
return String
cfg

generateContextFuns ::
  String ->
  [Name] ->
  M.Map KernelName KernelSafety ->
  M.Map Name SizeClass ->
  [FailureMsg] ->
  GC.CompilerM OpenCL () ()
generateContextFuns :: String
-> [Name]
-> Map Name KernelSafety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateContextFuns String
cfg [Name]
cost_centres Map Name KernelSafety
kernels Map Name SizeClass
sizes [FailureMsg]
failures = do
  [Stm]
final_inits <- CompilerM OpenCL () [Stm]
forall op s. CompilerM op s [Stm]
GC.contextFinalInits
  ([FieldGroup]
fields, [Stm]
init_fields) <- CompilerM OpenCL () ([FieldGroup], [Stm])
forall op s. CompilerM op s ([FieldGroup], [Stm])
GC.contextContents
  let forCostCentre :: Name -> [(FieldGroup, Stm)]
forCostCentre Name
name =
        [ ( [C.csdecl|typename int64_t $id:(kernelRuntime name);|],
            [C.cstm|ctx->$id:(kernelRuntime name) = 0;|]
          ),
          ( [C.csdecl|int $id:(kernelRuns name);|],
            [C.cstm|ctx->$id:(kernelRuns name) = 0;|]
          )
        ]

      forKernel :: Name -> [(FieldGroup, Stm)]
forKernel Name
name =
        ( [C.csdecl|typename CUfunction $id:name;|],
          [C.cstm|CUDA_SUCCEED(cuModuleGetFunction(
                                &ctx->$id:name,
                                ctx->cuda.module,
                                $string:(pretty (C.toIdent name mempty))));|]
        ) (FieldGroup, Stm) -> [(FieldGroup, Stm)] -> [(FieldGroup, Stm)]
forall a. a -> [a] -> [a]
:
        Name -> [(FieldGroup, Stm)]
forCostCentre Name
name

      ([FieldGroup]
kernel_fields, [Stm]
init_kernel_fields) =
        [(FieldGroup, Stm)] -> ([FieldGroup], [Stm])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(FieldGroup, Stm)] -> ([FieldGroup], [Stm]))
-> [(FieldGroup, Stm)] -> ([FieldGroup], [Stm])
forall a b. (a -> b) -> a -> b
$
          (Name -> [(FieldGroup, Stm)]) -> [Name] -> [(FieldGroup, Stm)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Name -> [(FieldGroup, Stm)]
forKernel (Map Name KernelSafety -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels)
            [(FieldGroup, Stm)] -> [(FieldGroup, Stm)] -> [(FieldGroup, Stm)]
forall a. [a] -> [a] -> [a]
++ (Name -> [(FieldGroup, Stm)]) -> [Name] -> [(FieldGroup, Stm)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Name -> [(FieldGroup, Stm)]
forCostCentre [Name]
cost_centres

  String
ctx <- String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
GC.publicDef String
"context" HeaderSection
GC.InitDecl ((String -> (Definition, Definition))
 -> CompilerM OpenCL () String)
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|struct $id:s;|],
      [C.cedecl|struct $id:s {
                         int detail_memory;
                         int debugging;
                         int profiling;
                         int profiling_paused;
                         int logging;
                         typename lock_t lock;
                         char *error;
                         typename FILE *log;
                         $sdecls:fields
                         $sdecls:kernel_fields
                         typename CUdeviceptr global_failure;
                         typename CUdeviceptr global_failure_args;
                         struct cuda_context cuda;
                         struct sizes sizes;
                         // True if a potentially failing kernel has been enqueued.
                         typename int32_t failure_is_an_option;

                         int total_runs;
                         long int total_runtime;
                       };|]
    )

  let set_sizes :: [Stm]
set_sizes =
        (Int -> Name -> Stm) -> [Int] -> [Name] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (\Int
i Name
k -> [C.cstm|ctx->sizes.$id:k = cfg->sizes[$int:i];|])
          [(Int
0 :: Int) ..]
          ([Name] -> [Stm]) -> [Name] -> [Stm]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      max_failure_args :: Int
max_failure_args =
        (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (FailureMsg -> Int) -> [FailureMsg] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ErrorMsg Exp -> Int
forall a. ErrorMsg a -> Int
errorMsgNumArgs (ErrorMsg Exp -> Int)
-> (FailureMsg -> ErrorMsg Exp) -> FailureMsg -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_new" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|struct $id:ctx* $id:s(struct $id:cfg* cfg);|],
      [C.cedecl|struct $id:ctx* $id:s(struct $id:cfg* cfg) {
                 struct $id:ctx* ctx = (struct $id:ctx*) malloc(sizeof(struct $id:ctx));
                 if (ctx == NULL) {
                   return NULL;
                 }
                 ctx->debugging = ctx->detail_memory = cfg->cu_cfg.debugging;
                 ctx->profiling = cfg->profiling;
                 ctx->profiling_paused = 0;
                 ctx->logging = cfg->cu_cfg.logging;
                 ctx->error = NULL;
                 ctx->log = stderr;
                 ctx->cuda.profiling_records_capacity = 200;
                 ctx->cuda.profiling_records_used = 0;
                 ctx->cuda.profiling_records =
                   malloc(ctx->cuda.profiling_records_capacity *
                          sizeof(struct profiling_record));

                 ctx->cuda.cfg = cfg->cu_cfg;
                 create_lock(&ctx->lock);

                 ctx->failure_is_an_option = 0;
                 ctx->total_runs = 0;
                 ctx->total_runtime = 0;
                 $stms:init_fields

                 cuda_setup(&ctx->cuda, cuda_program, cfg->nvrtc_opts);

                 typename int32_t no_error = -1;
                 CUDA_SUCCEED(cuMemAlloc(&ctx->global_failure, sizeof(no_error)));
                 CUDA_SUCCEED(cuMemcpyHtoD(ctx->global_failure, &no_error, sizeof(no_error)));
                 // The +1 is to avoid zero-byte allocations.
                 CUDA_SUCCEED(cuMemAlloc(&ctx->global_failure_args, sizeof(int64_t)*($int:max_failure_args+1)));

                 $stms:init_kernel_fields

                 $stms:final_inits
                 $stms:set_sizes

                 init_constants(ctx);
                 // Clear the free list of any deallocations that occurred while initialising constants.
                 CUDA_SUCCEED(cuda_free_all(&ctx->cuda));

                 futhark_context_sync(ctx);

                 return ctx;
               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_free" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|void $id:s(struct $id:ctx* ctx);|],
      [C.cedecl|void $id:s(struct $id:ctx* ctx) {
                                 free_constants(ctx);
                                 cuda_cleanup(&ctx->cuda);
                                 free_lock(&ctx->lock);
                                 free(ctx);
                               }|]
    )

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_sync" HeaderSection
GC.MiscDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ( [C.cedecl|int $id:s(struct $id:ctx* ctx);|],
      [C.cedecl|int $id:s(struct $id:ctx* ctx) {
                 CUDA_SUCCEED(cuCtxPushCurrent(ctx->cuda.cu_ctx));
                 CUDA_SUCCEED(cuCtxSynchronize());
                 if (ctx->failure_is_an_option) {
                   // Check for any delayed error.
                   typename int32_t failure_idx;
                   CUDA_SUCCEED(
                     cuMemcpyDtoH(&failure_idx,
                                  ctx->global_failure,
                                  sizeof(int32_t)));
                   ctx->failure_is_an_option = 0;

                   if (failure_idx >= 0) {
                     // We have to clear global_failure so that the next entry point
                     // is not considered a failure from the start.
                     typename int32_t no_failure = -1;
                     CUDA_SUCCEED(
                       cuMemcpyHtoD(ctx->global_failure,
                                    &no_failure,
                                    sizeof(int32_t)));

                     typename int64_t args[$int:max_failure_args+1];
                     CUDA_SUCCEED(
                       cuMemcpyDtoH(&args,
                                    ctx->global_failure_args,
                                    sizeof(args)));

                     $stm:(failureSwitch failures)

                     return 1;
                   }
                 }
                 CUDA_SUCCEED(cuCtxPopCurrent(&ctx->cuda.cu_ctx));
                 return 0;
               }|]
    )

  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.onClear
    [C.citem|if (ctx->error == NULL) {
               CUDA_SUCCEED(cuda_free_all(&ctx->cuda));
             }|]