{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
module Futhark.CodeGen.Backends.CCUDA.Boilerplate
  (
    generateBoilerplate
  ) where

import qualified Language.C.Quote.OpenCL as C

import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.Backends.COpenCL.Boilerplate (failureSwitch)
import Futhark.Util (chunk, zEncodeString)

import qualified Data.Map as M
import Data.FileEmbed (embedStringFile)

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs = [PrimType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimType] -> Int)
-> (ErrorMsg a -> [PrimType]) -> ErrorMsg a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg a -> [PrimType]
forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

generateBoilerplate :: String -> String -> M.Map KernelName Safety
                    -> M.Map Name SizeClass
                    -> [FailureMsg]
                    -> GC.CompilerM OpenCL () ()
generateBoilerplate :: String
-> String
-> Map String Safety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate String
cuda_program String
cuda_prelude Map String Safety
kernels Map Name SizeClass
sizes [FailureMsg]
failures = do
  (Definition -> CompilerM OpenCL () ())
-> [Definition] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cunit|
      $esc:("#include <cuda.h>")
      $esc:("#include <nvrtc.h>")
      $esc:("typedef CUdeviceptr fl_mem_t;")
      $esc:free_list_h
      $esc:cuda_h
      const char *cuda_program[] = {$inits:fragments, NULL};
      |]

  Map Name SizeClass -> CompilerM OpenCL () ()
generateSizeFuns Map Name SizeClass
sizes
  String
cfg <- Map Name SizeClass -> CompilerM OpenCL () String
generateConfigFuns Map Name SizeClass
sizes
  String
-> Map String Safety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateContextFuns String
cfg Map String Safety
kernels Map Name SizeClass
sizes [FailureMsg]
failures
  where
    cuda_h :: String
cuda_h = $(embedStringFile "rts/c/cuda.h")
    free_list_h :: String
free_list_h = $(embedStringFile "rts/c/free_list.h")
    fragments :: [Initializer]
fragments = (String -> Initializer) -> [String] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\String
s -> [C.cinit|$string:s|])
                  ([String] -> [Initializer]) -> [String] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Int -> String -> [String]
forall a. Int -> [a] -> [[a]]
chunk Int
2000 (String
cuda_prelude String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
cuda_program)

generateSizeFuns :: M.Map Name SizeClass -> GC.CompilerM OpenCL () ()
generateSizeFuns :: Map Name SizeClass -> CompilerM OpenCL () ()
generateSizeFuns Map Name SizeClass
sizes = do
  let size_name_inits :: [Initializer]
size_name_inits = (Name -> Initializer) -> [Name] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
k -> [C.cinit|$string:(pretty k)|]) ([Name] -> [Initializer]) -> [Name] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      size_var_inits :: [Initializer]
size_var_inits = (Name -> Initializer) -> [Name] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
k -> [C.cinit|$string:(zEncodeString (pretty k))|]) ([Name] -> [Initializer]) -> [Name] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      size_class_inits :: [Initializer]
size_class_inits = (SizeClass -> Initializer) -> [SizeClass] -> [Initializer]
forall a b. (a -> b) -> [a] -> [b]
map (\SizeClass
c -> [C.cinit|$string:(pretty c)|]) ([SizeClass] -> [Initializer]) -> [SizeClass] -> [Initializer]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [SizeClass]
forall k a. Map k a -> [a]
M.elems Map Name SizeClass
sizes
      num_sizes :: Int
num_sizes = Map Name SizeClass -> Int
forall k a. Map k a -> Int
M.size Map Name SizeClass
sizes

  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static const char *size_names[] = { $inits:size_name_inits };|]
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static const char *size_vars[] = { $inits:size_var_inits };|]
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|static const char *size_classes[] = { $inits:size_class_inits };|]

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"get_num_sizes" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|int $id:s(void);|],
     [C.cedecl|int $id:s(void) {
                return $int:num_sizes;
              }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"get_size_name" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|const char* $id:s(int);|],
     [C.cedecl|const char* $id:s(int i) {
                return size_names[i];
              }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"get_size_class" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|const char* $id:s(int);|],
     [C.cedecl|const char* $id:s(int i) {
                return size_classes[i];
              }|])

generateConfigFuns :: M.Map Name SizeClass -> GC.CompilerM OpenCL () String
generateConfigFuns :: Map Name SizeClass -> CompilerM OpenCL () String
generateConfigFuns Map Name SizeClass
sizes = do
  let size_decls :: [FieldGroup]
size_decls = (Name -> FieldGroup) -> [Name] -> [FieldGroup]
forall a b. (a -> b) -> [a] -> [b]
map (\Name
k -> [C.csdecl|size_t $id:k;|]) ([Name] -> [FieldGroup]) -> [Name] -> [FieldGroup]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      num_sizes :: Int
num_sizes = Map Name SizeClass -> Int
forall k a. Map k a -> Int
M.size Map Name SizeClass
sizes
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl [C.cedecl|struct sizes { $sdecls:size_decls };|]
  String
cfg <- String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
GC.publicDef String
"context_config" HeaderSection
GC.InitDecl ((String -> (Definition, Definition))
 -> CompilerM OpenCL () String)
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|struct $id:s;|],
     [C.cedecl|struct $id:s { struct cuda_config cu_cfg;
                              size_t sizes[$int:num_sizes];
                              int num_nvrtc_opts;
                              const char **nvrtc_opts;
                            };|])

  let size_value_inits :: [Stm]
size_value_inits = (Int -> SizeClass -> Stm) -> [Int] -> [SizeClass] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> SizeClass -> Stm
forall a. (Show a, Integral a) => a -> SizeClass -> Stm
sizeInit [Int
0..Map Name SizeClass -> Int
forall k a. Map k a -> Int
M.size Map Name SizeClass
sizesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] (Map Name SizeClass -> [SizeClass]
forall k a. Map k a -> [a]
M.elems Map Name SizeClass
sizes)
      sizeInit :: a -> SizeClass -> Stm
sizeInit a
i SizeClass
size = [C.cstm|cfg->sizes[$int:i] = $int:val;|]
         where val :: Int32
val = case SizeClass
size of SizeBespoke _ x -> Int32
x
                                  SizeClass
_               -> Int32
0
  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_new" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|struct $id:cfg* $id:s(void);|],
     [C.cedecl|struct $id:cfg* $id:s(void) {
                         struct $id:cfg *cfg = (struct $id:cfg*) malloc(sizeof(struct $id:cfg));
                         if (cfg == NULL) {
                           return NULL;
                         }

                         cfg->num_nvrtc_opts = 0;
                         cfg->nvrtc_opts = (const char**) malloc(sizeof(const char*));
                         cfg->nvrtc_opts[0] = NULL;
                         $stms:size_value_inits
                         cuda_config_init(&cfg->cu_cfg, $int:num_sizes,
                                          size_names, size_vars,
                                          cfg->sizes, size_classes);
                         return cfg;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_free" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg) {
                         free(cfg->nvrtc_opts);
                         free(cfg);
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_add_nvrtc_option" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *opt);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *opt) {
                         cfg->nvrtc_opts[cfg->num_nvrtc_opts] = opt;
                         cfg->num_nvrtc_opts++;
                         cfg->nvrtc_opts = (const char**) realloc(cfg->nvrtc_opts, (cfg->num_nvrtc_opts+1) * sizeof(const char*));
                         cfg->nvrtc_opts[cfg->num_nvrtc_opts] = NULL;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_debugging" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
                         cfg->cu_cfg.logging = cfg->cu_cfg.debugging = flag;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_logging" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
                         cfg->cu_cfg.logging = flag;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_device" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *s);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *s) {
                         set_preferred_device(&cfg->cu_cfg, s);
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_dump_program_to" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                         cfg->cu_cfg.dump_program_to = path;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_load_program_from" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                         cfg->cu_cfg.load_program_from = path;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_dump_ptx_to" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                          cfg->cu_cfg.dump_ptx_to = path;
                      }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_load_ptx_from" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
                          cfg->cu_cfg.load_ptx_from = path;
                      }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_group_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, int size);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_block_size = size;
                         cfg->cu_cfg.default_block_size_changed = 1;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_num_groups" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, int num) {
                         cfg->cu_cfg.default_grid_size = num;
                         cfg->cu_cfg.default_grid_size_changed = 1;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_tile_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_tile_size = size;
                         cfg->cu_cfg.default_tile_size_changed = 1;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_default_threshold" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
     [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
                         cfg->cu_cfg.default_threshold = size;
                       }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_config_set_size" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|int $id:s(struct $id:cfg* cfg, const char *size_name, size_t size_value);|],
     [C.cedecl|int $id:s(struct $id:cfg* cfg, const char *size_name, size_t size_value) {

                         for (int i = 0; i < $int:num_sizes; i++) {
                           if (strcmp(size_name, size_names[i]) == 0) {
                             cfg->sizes[i] = size_value;
                             return 0;
                           }
                         }

                         if (strcmp(size_name, "default_group_size") == 0) {
                           cfg->cu_cfg.default_block_size = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_num_groups") == 0) {
                           cfg->cu_cfg.default_grid_size = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_threshold") == 0) {
                           cfg->cu_cfg.default_threshold = size_value;
                           return 0;
                         }

                         if (strcmp(size_name, "default_tile_size") == 0) {
                           cfg->cu_cfg.default_tile_size = size_value;
                           return 0;
                         }
                         return 1;
                       }|])
  String -> CompilerM OpenCL () String
forall (m :: * -> *) a. Monad m => a -> m a
return String
cfg

generateContextFuns :: String -> M.Map KernelName Safety
                    -> M.Map Name SizeClass
                    -> [FailureMsg]
                    -> GC.CompilerM OpenCL () ()
generateContextFuns :: String
-> Map String Safety
-> Map Name SizeClass
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateContextFuns String
cfg Map String Safety
kernels Map Name SizeClass
sizes [FailureMsg]
failures = do
  [Stm]
final_inits <- CompilerM OpenCL () [Stm]
forall op s. CompilerM op s [Stm]
GC.contextFinalInits
  ([FieldGroup]
fields, [Stm]
init_fields) <- CompilerM OpenCL () ([FieldGroup], [Stm])
forall op s. CompilerM op s ([FieldGroup], [Stm])
GC.contextContents
  let kernel_fields :: [FieldGroup]
kernel_fields = (String -> FieldGroup) -> [String] -> [FieldGroup]
forall a b. (a -> b) -> [a] -> [b]
map (\String
k -> [C.csdecl|typename CUfunction $id:k;|]) ([String] -> [FieldGroup]) -> [String] -> [FieldGroup]
forall a b. (a -> b) -> a -> b
$
                          Map String Safety -> [String]
forall k a. Map k a -> [k]
M.keys Map String Safety
kernels

  String
ctx <- String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s String
GC.publicDef String
"context" HeaderSection
GC.InitDecl ((String -> (Definition, Definition))
 -> CompilerM OpenCL () String)
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () String
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|struct $id:s;|],
     [C.cedecl|struct $id:s {
                         int detail_memory;
                         int debugging;
                         int profiling;
                         typename lock_t lock;
                         char *error;
                         $sdecls:fields
                         $sdecls:kernel_fields
                         typename CUdeviceptr global_failure;
                         typename CUdeviceptr global_failure_args;
                         struct cuda_context cuda;
                         struct sizes sizes;
                         // True if a potentially failing kernel has been enqueued.
                         typename int32_t failure_is_an_option;
                       };|])

  let set_sizes :: [Stm]
set_sizes = (Int -> Name -> Stm) -> [Int] -> [Name] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i Name
k -> [C.cstm|ctx->sizes.$id:k = cfg->sizes[$int:i];|])
                          [(Int
0::Int)..] ([Name] -> [Stm]) -> [Name] -> [Stm]
forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
sizes
      max_failure_args :: Int
max_failure_args =
        (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (FailureMsg -> Int) -> [FailureMsg] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ErrorMsg Exp -> Int
forall a. ErrorMsg a -> Int
errorMsgNumArgs (ErrorMsg Exp -> Int)
-> (FailureMsg -> ErrorMsg Exp) -> FailureMsg -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_new" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|struct $id:ctx* $id:s(struct $id:cfg* cfg);|],
     [C.cedecl|struct $id:ctx* $id:s(struct $id:cfg* cfg) {
                 struct $id:ctx* ctx = (struct $id:ctx*) malloc(sizeof(struct $id:ctx));
                 if (ctx == NULL) {
                   return NULL;
                 }
                 ctx->profiling = ctx->debugging = ctx->detail_memory = cfg->cu_cfg.debugging;
                 ctx->error = NULL;

                 ctx->cuda.cfg = cfg->cu_cfg;
                 create_lock(&ctx->lock);

                 ctx->failure_is_an_option = 0;
                 $stms:init_fields

                 cuda_setup(&ctx->cuda, cuda_program, cfg->nvrtc_opts);

                 typename int32_t no_error = -1;
                 CUDA_SUCCEED(cuMemAlloc(&ctx->global_failure, sizeof(no_error)));
                 CUDA_SUCCEED(cuMemcpyHtoD(ctx->global_failure, &no_error, sizeof(no_error)));
                 // The +1 is to avoid zero-byte allocations.
                 CUDA_SUCCEED(cuMemAlloc(&ctx->global_failure_args, sizeof(int32_t)*($int:max_failure_args+1)));

                 $stms:(map loadKernel (M.toList kernels))

                 $stms:final_inits
                 $stms:set_sizes

                 init_constants(ctx);
                 // Clear the free list of any deallocations that occurred while initialising constants.
                 CUDA_SUCCEED(cuda_free_all(&ctx->cuda));

                 futhark_context_sync(ctx);

                 return ctx;
               }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_free" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:ctx* ctx);|],
     [C.cedecl|void $id:s(struct $id:ctx* ctx) {
                                 free_constants(ctx);
                                 cuda_cleanup(&ctx->cuda);
                                 free_lock(&ctx->lock);
                                 free(ctx);
                               }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_sync" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|int $id:s(struct $id:ctx* ctx);|],
     [C.cedecl|int $id:s(struct $id:ctx* ctx) {
                 CUDA_SUCCEED(cuCtxSynchronize());
                 if (ctx->failure_is_an_option) {
                   // Check for any delayed error.
                   typename int32_t failure_idx;
                   CUDA_SUCCEED(
                     cuMemcpyDtoH(&failure_idx,
                                  ctx->global_failure,
                                  sizeof(int32_t)));
                   ctx->failure_is_an_option = 0;

                   if (failure_idx >= 0) {
                     typename int32_t args[$int:max_failure_args+1];
                     CUDA_SUCCEED(
                       cuMemcpyDtoH(&args,
                                    ctx->global_failure_args,
                                    sizeof(args)));

                     $stm:(failureSwitch failures)

                     return 1;
                   }
                 }
                 return 0;
               }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_get_error" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|char* $id:s(struct $id:ctx* ctx);|],
     [C.cedecl|char* $id:s(struct $id:ctx* ctx) {
                         return ctx->error;
                       }|])


  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_pause_profiling" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:ctx* ctx);|],
     [C.cedecl|void $id:s(struct $id:ctx* ctx) {
                 (void)ctx;
               }|])

  String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM OpenCL () ()
forall op s.
String
-> HeaderSection
-> (String -> (Definition, Definition))
-> CompilerM op s ()
GC.publicDef_ String
"context_unpause_profiling" HeaderSection
GC.InitDecl ((String -> (Definition, Definition)) -> CompilerM OpenCL () ())
-> (String -> (Definition, Definition)) -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ \String
s ->
    ([C.cedecl|void $id:s(struct $id:ctx* ctx);|],
     [C.cedecl|void $id:s(struct $id:ctx* ctx) {
                 (void)ctx;
               }|])

  where
    loadKernel :: (String, b) -> Stm
loadKernel (String
name, b
_) =
      [C.cstm|CUDA_SUCCEED(cuModuleGetFunction(&ctx->$id:name,
                ctx->cuda.module, $string:name));|]