{-# LANGUAGE QuasiQuotes #-}

module Futhark.CodeGen.Backends.COpenCL.Boilerplate
  ( generateBoilerplate,
    profilingEvent,
    copyDevToDev,
    copyDevToHost,
    copyHostToDev,
    copyScalarToDev,
    copyScalarFromDev,
    commonOptions,
    failureMsgFunction,
    costCentreReport,
    kernelRuntime,
    kernelRuns,
    sizeLoggingCode,
  )
where

import Control.Monad
import Control.Monad.State
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC qualified as GC
import Futhark.CodeGen.Backends.GenericC.Options
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.ImpCode.OpenCL
import Futhark.CodeGen.OpenCL.Heuristics
import Futhark.CodeGen.RTS.C (backendsOpenclH)
import Futhark.Util (chunk)
import Futhark.Util.Pretty (prettyTextOneLine)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: forall a. ErrorMsg a -> Int
errorMsgNumArgs = [PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimType] -> Int)
-> (ErrorMsg a -> [PrimType]) -> ErrorMsg a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg a -> [PrimType]
forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

failureMsgFunction :: [FailureMsg] -> C.Definition
failureMsgFunction :: [FailureMsg] -> Definition
failureMsgFunction [FailureMsg]
failures =
  let printfEscape :: String -> String
printfEscape =
        let escapeChar :: Char -> String
escapeChar Char
'%' = String
"%%"
            escapeChar Char
c = [Char
c]
         in (Char -> String) -> String -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Char -> String
escapeChar
      onPart :: ErrorMsgPart a -> String
onPart (ErrorString Text
s) = String -> String
printfEscape (String -> String) -> String -> String
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
s
      -- FIXME: bogus for non-ints.
      onPart ErrorVal {} = String
"%lld"
      onFailure :: a -> FailureMsg -> Stm
onFailure a
i (FailureMsg emsg :: ErrorMsg Exp
emsg@(ErrorMsg [ErrorMsgPart Exp]
parts) String
backtrace) =
        let msg :: String
msg = (ErrorMsgPart Exp -> String) -> [ErrorMsgPart Exp] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ErrorMsgPart Exp -> String
forall {a}. ErrorMsgPart a -> String
onPart [ErrorMsgPart Exp]
parts String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
printfEscape String
backtrace
            msgargs :: [Exp]
msgargs = [[C.cexp|args[$int:j]|] | Int
j <- [Int
0 .. ErrorMsg Exp -> Int
forall a. ErrorMsg a -> Int
errorMsgNumArgs ErrorMsg Exp
emsg Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]]
         in [C.cstm|case $int:i: {return msgprintf($string:msg, $args:msgargs); break;}|]
      failure_cases :: [Stm]
failure_cases =
        (Int -> FailureMsg -> Stm) -> [Int] -> [FailureMsg] -> [Stm]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> FailureMsg -> Stm
forall {a}. (Show a, Integral a) => a -> FailureMsg -> Stm
onFailure [(Int
0 :: Int) ..] [FailureMsg]
failures
   in [C.cedecl|static char* get_failure_msg(int failure_idx, typename int64_t args[]) {
                  switch (failure_idx) { $stms:failure_cases }
                  return strdup("Unknown error.  This is a compiler bug.");
                }|]

copyDevToDev, copyDevToHost, copyHostToDev, copyScalarToDev, copyScalarFromDev :: Name
copyDevToDev :: KernelName
copyDevToDev = KernelName
"copy_dev_to_dev"
copyDevToHost :: KernelName
copyDevToHost = KernelName
"copy_dev_to_host"
copyHostToDev :: KernelName
copyHostToDev = KernelName
"copy_host_to_dev"
copyScalarToDev :: KernelName
copyScalarToDev = KernelName
"copy_scalar_to_dev"
copyScalarFromDev :: KernelName
copyScalarFromDev = KernelName
"copy_scalar_from_dev"

profilingEvent :: Name -> C.Exp
profilingEvent :: KernelName -> Exp
profilingEvent KernelName
name =
  [C.cexp|(ctx->profiling_paused || !ctx->profiling) ? NULL
          : opencl_get_event(ctx,
                             &ctx->program->$id:(kernelRuns name),
                             &ctx->program->$id:(kernelRuntime name))|]

releaseKernel :: (KernelName, KernelSafety) -> C.Stm
releaseKernel :: (KernelName, KernelSafety) -> Stm
releaseKernel (KernelName
name, KernelSafety
_) = [C.cstm|OPENCL_SUCCEED_FATAL(clReleaseKernel(ctx->program->$id:name));|]

loadKernel :: (KernelName, KernelSafety) -> C.Stm
loadKernel :: (KernelName, KernelSafety) -> Stm
loadKernel (KernelName
name, KernelSafety
safety) =
  [C.cstm|{
  ctx->program->$id:name = clCreateKernel(ctx->clprogram, $string:(T.unpack (idText (C.toIdent name mempty))), &error);
  OPENCL_SUCCEED_FATAL(error);
  $items:set_args
  if (ctx->debugging) {
    fprintf(ctx->log, "Created kernel %s.\n", $string:(prettyString name));
  }
  }|]
  where
    set_global_failure :: BlockItem
set_global_failure =
      [C.citem|OPENCL_SUCCEED_FATAL(
                     clSetKernelArg(ctx->program->$id:name, 0, sizeof(typename cl_mem),
                                    &ctx->global_failure));|]
    set_global_failure_args :: BlockItem
set_global_failure_args =
      [C.citem|OPENCL_SUCCEED_FATAL(
                     clSetKernelArg(ctx->program->$id:name, 2, sizeof(typename cl_mem),
                                    &ctx->global_failure_args));|]
    set_args :: [BlockItem]
set_args = case KernelSafety
safety of
      KernelSafety
SafetyNone -> []
      KernelSafety
SafetyCheap -> [BlockItem
set_global_failure]
      KernelSafety
SafetyFull -> [BlockItem
set_global_failure, BlockItem
set_global_failure_args]

generateOpenCLDecls ::
  [Name] ->
  M.Map KernelName KernelSafety ->
  GC.CompilerM op s ()
generateOpenCLDecls :: forall op s.
[KernelName] -> Map KernelName KernelSafety -> CompilerM op s ()
generateOpenCLDecls [KernelName]
cost_centres Map KernelName KernelSafety
kernels = do
  [(KernelName, KernelSafety)]
-> ((KernelName, KernelSafety) -> CompilerM op s ())
-> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (Map KernelName KernelSafety -> [(KernelName, KernelSafety)]
forall k a. Map k a -> [(k, a)]
M.toList Map KernelName KernelSafety
kernels) (((KernelName, KernelSafety) -> CompilerM op s ())
 -> CompilerM op s ())
-> ((KernelName, KernelSafety) -> CompilerM op s ())
-> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \(KernelName
name, KernelSafety
safety) ->
    Id -> Type -> Stm -> Stm -> CompilerM op s ()
forall op s. Id -> Type -> Stm -> Stm -> CompilerM op s ()
GC.contextFieldDyn
      (KernelName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent KernelName
name SrcLoc
forall a. Monoid a => a
mempty)
      [C.cty|typename cl_kernel|]
      ((KernelName, KernelSafety) -> Stm
loadKernel (KernelName
name, KernelSafety
safety))
      ((KernelName, KernelSafety) -> Stm
releaseKernel (KernelName
name, KernelSafety
safety))
  [KernelName]
-> (KernelName -> CompilerM op s ()) -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([KernelName]
cost_centres [KernelName] -> [KernelName] -> [KernelName]
forall a. Semigroup a => a -> a -> a
<> Map KernelName KernelSafety -> [KernelName]
forall k a. Map k a -> [k]
M.keys Map KernelName KernelSafety
kernels) ((KernelName -> CompilerM op s ()) -> CompilerM op s ())
-> (KernelName -> CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \KernelName
name -> do
    Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField
      (KernelName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (KernelName -> KernelName
kernelRuntime KernelName
name) SrcLoc
forall a. Monoid a => a
mempty)
      [C.cty|typename int64_t|]
      (Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|])
    Id -> Type -> Maybe Exp -> CompilerM op s ()
forall op s. Id -> Type -> Maybe Exp -> CompilerM op s ()
GC.contextField
      (KernelName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent (KernelName -> KernelName
kernelRuns KernelName
name) SrcLoc
forall a. Monoid a => a
mempty)
      [C.cty|int|]
      (Exp -> Maybe Exp
forall a. a -> Maybe a
Just [C.cexp|0|])
  Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl
    [C.cedecl|
void post_opencl_setup(struct futhark_context *ctx, struct opencl_device_option *option) {
  $stms:(map sizeHeuristicsCode sizeHeuristicsTable)
}|]

-- | Called after most code has been generated to generate the bulk of
-- the boilerplate.
generateBoilerplate ::
  T.Text ->
  T.Text ->
  [Name] ->
  M.Map KernelName KernelSafety ->
  [PrimType] ->
  [FailureMsg] ->
  GC.CompilerM OpenCL () ()
generateBoilerplate :: Text
-> Text
-> [KernelName]
-> Map KernelName KernelSafety
-> [PrimType]
-> [FailureMsg]
-> CompilerM OpenCL () ()
generateBoilerplate Text
opencl_program Text
opencl_prelude [KernelName]
cost_centres Map KernelName KernelSafety
kernels [PrimType]
types [FailureMsg]
failures = do
  let opencl_program_fragments :: [Initializer]
opencl_program_fragments =
        -- Some C compilers limit the size of literal strings, so
        -- chunk the entire program into small bits here, and
        -- concatenate it again at runtime.
        [[C.cinit|$string:s|] | String
s <- Int -> String -> [String]
forall a. Int -> [a] -> [[a]]
chunk Int
2000 (String -> [String]) -> String -> [String]
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text
opencl_prelude Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
opencl_program]
      program_fragments :: [Initializer]
program_fragments = [Initializer]
opencl_program_fragments [Initializer] -> [Initializer] -> [Initializer]
forall a. [a] -> [a] -> [a]
++ [[C.cinit|NULL|]]
      f64_required :: Exp
f64_required
        | FloatType -> PrimType
FloatType FloatType
Float64 PrimType -> [PrimType] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PrimType]
types = [C.cexp|1|]
        | Bool
otherwise = [C.cexp|0|]
      max_failure_args :: Int
max_failure_args = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (FailureMsg -> Int) -> [FailureMsg] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ErrorMsg Exp -> Int
forall a. ErrorMsg a -> Int
errorMsgNumArgs (ErrorMsg Exp -> Int)
-> (FailureMsg -> ErrorMsg Exp) -> FailureMsg -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures
  (Definition -> CompilerM OpenCL () ())
-> [Definition] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_
    Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl
    [C.cunit|static const int max_failure_args = $int:max_failure_args;
             static const int f64_required = $exp:f64_required;
             static const char *opencl_program[] = {$inits:program_fragments};
             $esc:(T.unpack backendsOpenclH)
            |]
  Definition -> CompilerM OpenCL () ()
forall op s. Definition -> CompilerM op s ()
GC.earlyDecl (Definition -> CompilerM OpenCL () ())
-> Definition -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [FailureMsg] -> Definition
failureMsgFunction [FailureMsg]
failures

  [KernelName]
-> Map KernelName KernelSafety -> CompilerM OpenCL () ()
forall op s.
[KernelName] -> Map KernelName KernelSafety -> CompilerM op s ()
generateOpenCLDecls [KernelName]
cost_centres Map KernelName KernelSafety
kernels

  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_add_build_option(struct futhark_context_config *cfg, const char* opt);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_device(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_platform(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_select_device_interactively(struct futhark_context_config *cfg);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_list_devices(struct futhark_context_config *cfg);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_dump_program_to(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_load_program_from(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_dump_binary_to(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_load_binary_from(struct futhark_context_config *cfg, const char* s);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_group_size(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_num_groups(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_tile_size(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_reg_tile_size(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_default_threshold(struct futhark_context_config *cfg, int size);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.InitDecl [C.cedecl|void futhark_context_config_set_command_queue(struct futhark_context_config *cfg, typename cl_command_queue);|]
  HeaderSection -> Definition -> CompilerM OpenCL () ()
forall op s. HeaderSection -> Definition -> CompilerM op s ()
GC.headerDecl HeaderSection
GC.MiscDecl [C.cedecl|typename cl_command_queue futhark_context_get_command_queue(struct futhark_context* ctx);|]

  CompilerM OpenCL () ()
forall op s. CompilerM op s ()
GC.generateProgramStruct

  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.onClear
    [C.citem|if (ctx->error == NULL) { ctx->error = OPENCL_SUCCEED_NONFATAL(opencl_free_all(ctx)); }|]

  BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.profileReport [C.citem|OPENCL_SUCCEED_FATAL(opencl_tally_profiling_records(ctx));|]
  (BlockItem -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ BlockItem -> CompilerM OpenCL () ()
forall op s. BlockItem -> CompilerM op s ()
GC.profileReport ([BlockItem] -> CompilerM OpenCL () ())
-> [BlockItem] -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ [KernelName] -> [BlockItem]
costCentreReport ([KernelName] -> [BlockItem]) -> [KernelName] -> [BlockItem]
forall a b. (a -> b) -> a -> b
$ [KernelName]
cost_centres [KernelName] -> [KernelName] -> [KernelName]
forall a. [a] -> [a] -> [a]
++ Map KernelName KernelSafety -> [KernelName]
forall k a. Map k a -> [k]
M.keys Map KernelName KernelSafety
kernels

kernelRuntime :: KernelName -> Name
kernelRuntime :: KernelName -> KernelName
kernelRuntime = (KernelName -> KernelName -> KernelName
forall a. Semigroup a => a -> a -> a
<> KernelName
"_total_runtime")

kernelRuns :: KernelName -> Name
kernelRuns :: KernelName -> KernelName
kernelRuns = (KernelName -> KernelName -> KernelName
forall a. Semigroup a => a -> a -> a
<> KernelName
"_runs")

costCentreReport :: [Name] -> [C.BlockItem]
costCentreReport :: [KernelName] -> [BlockItem]
costCentreReport [KernelName]
names = [BlockItem]
report_kernels [BlockItem] -> [BlockItem] -> [BlockItem]
forall a. [a] -> [a] -> [a]
++ [BlockItem
report_total]
  where
    longest_name :: Int
longest_name = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (KernelName -> Int) -> [KernelName] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (String -> Int) -> (KernelName -> String) -> KernelName -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelName -> String
forall a. Pretty a => a -> String
prettyString) [KernelName]
names
    report_kernels :: [BlockItem]
report_kernels = (KernelName -> [BlockItem]) -> [KernelName] -> [BlockItem]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap KernelName -> [BlockItem]
reportKernel [KernelName]
names
    format_string :: String -> String
format_string String
name =
      let padding :: String
padding = Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
longest_name Int -> Int -> Int
forall a. Num a => a -> a -> a
- String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
name) Char
' '
       in [String] -> String
unwords
            [ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
padding,
              String
"ran %5d times; avg: %8ldus; total: %8ldus\n"
            ]
    reportKernel :: KernelName -> [BlockItem]
reportKernel KernelName
name =
      let runs :: KernelName
runs = KernelName -> KernelName
kernelRuns KernelName
name
          total_runtime :: KernelName
total_runtime = KernelName -> KernelName
kernelRuntime KernelName
name
       in [ [C.citem|
               str_builder(&builder,
                           $string:(format_string (prettyString name)),
                           ctx->program->$id:runs,
                           (long int) ctx->program->$id:total_runtime / (ctx->program->$id:runs != 0 ? ctx->program->$id:runs : 1),
                           (long int) ctx->program->$id:total_runtime);
              |],
            [C.citem|ctx->total_runtime += ctx->program->$id:total_runtime;|],
            [C.citem|ctx->total_runs += ctx->program->$id:runs;|]
          ]

    report_total :: BlockItem
report_total =
      [C.citem|str_builder(&builder, "%d operations with cumulative runtime: %6ldus\n",
                           ctx->total_runs, ctx->total_runtime);|]

sizeHeuristicsCode :: SizeHeuristic -> C.Stm
sizeHeuristicsCode :: SizeHeuristic -> Stm
sizeHeuristicsCode (SizeHeuristic String
platform_name DeviceType
device_type WhichSize
which (TPrimExp PrimExp DeviceInfo
what)) =
  [C.cstm|
   if ($exp:which' == 0 &&
       strstr(option->platform_name, $string:platform_name) != NULL &&
       (option->device_type & $exp:(clDeviceType device_type)) == $exp:(clDeviceType device_type)) {
     $items:get_size
   }|]
  where
    clDeviceType :: DeviceType -> Exp
clDeviceType DeviceType
DeviceGPU = [C.cexp|CL_DEVICE_TYPE_GPU|]
    clDeviceType DeviceType
DeviceCPU = [C.cexp|CL_DEVICE_TYPE_CPU|]

    which' :: Exp
which' = case WhichSize
which of
      WhichSize
LockstepWidth -> [C.cexp|ctx->lockstep_width|]
      WhichSize
NumGroups -> [C.cexp|ctx->cfg->default_num_groups|]
      WhichSize
GroupSize -> [C.cexp|ctx->cfg->default_group_size|]
      WhichSize
TileSize -> [C.cexp|ctx->cfg->default_tile_size|]
      WhichSize
RegTileSize -> [C.cexp|ctx->cfg->default_reg_tile_size|]
      WhichSize
Threshold -> [C.cexp|ctx->cfg->default_threshold|]

    get_size :: [BlockItem]
get_size =
      let (Exp
e, Map String [BlockItem]
m) = State (Map String [BlockItem]) Exp
-> Map String [BlockItem] -> (Exp, Map String [BlockItem])
forall s a. State s a -> s -> (a, s)
runState ((DeviceInfo -> State (Map String [BlockItem]) Exp)
-> PrimExp DeviceInfo -> State (Map String [BlockItem]) Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
GC.compilePrimExp DeviceInfo -> State (Map String [BlockItem]) Exp
forall {m :: * -> *}.
MonadState (Map String [BlockItem]) m =>
DeviceInfo -> m Exp
onLeaf PrimExp DeviceInfo
what) Map String [BlockItem]
forall a. Monoid a => a
mempty
       in [[BlockItem]] -> [BlockItem]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Map String [BlockItem] -> [[BlockItem]]
forall k a. Map k a -> [a]
M.elems Map String [BlockItem]
m) [BlockItem] -> [BlockItem] -> [BlockItem]
forall a. [a] -> [a] -> [a]
++ [[C.citem|$exp:which' = $exp:e;|]]

    onLeaf :: DeviceInfo -> m Exp
onLeaf (DeviceInfo String
s) = do
      let s' :: String
s' = String
"CL_DEVICE_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s
          v :: String
v = String
s String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_val"
      Map String [BlockItem]
m <- m (Map String [BlockItem])
forall s (m :: * -> *). MonadState s m => m s
get
      case String -> Map String [BlockItem] -> Maybe [BlockItem]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup String
s Map String [BlockItem]
m of
        Maybe [BlockItem]
Nothing ->
          -- XXX: Cheating with the type here; works for the infos we
          -- currently use because we zero-initialise and assume a
          -- little-endian platform, but should be made more
          -- size-aware in the future.
          (Map String [BlockItem] -> Map String [BlockItem]) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Map String [BlockItem] -> Map String [BlockItem]) -> m ())
-> (Map String [BlockItem] -> Map String [BlockItem]) -> m ()
forall a b. (a -> b) -> a -> b
$
            String
-> [BlockItem] -> Map String [BlockItem] -> Map String [BlockItem]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
              String
s'
              [C.citems|size_t $id:v = 0;
                        clGetDeviceInfo(ctx->device, $id:s',
                                        sizeof($id:v), &$id:v,
                                        NULL);|]
        Just [BlockItem]
_ -> () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

      Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:v|]

-- Output size information if logging is enabled.
--
-- The autotuner depends on the format of this output, so use caution if
-- changing it.
sizeLoggingCode :: VName -> Name -> C.Exp -> GC.CompilerM op () ()
sizeLoggingCode :: forall op. VName -> KernelName -> Exp -> CompilerM op () ()
sizeLoggingCode VName
v KernelName
key Exp
x' = do
  Stm -> CompilerM op () ()
forall op s. Stm -> CompilerM op s ()
GC.stm
    [C.cstm|if (ctx->logging) {
    fprintf(ctx->log, "Compared %s <= %ld: %s.\n", $string:(T.unpack (prettyTextOneLine key)), (long)$exp:x', $id:v ? "true" : "false");
    }|]

-- Options that are common to multiple GPU-like backends.
commonOptions :: [Option]
commonOptions :: [Option]
commonOptions =
  [ Option
      { optionLongName :: String
optionLongName = String
"device",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'd',
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"NAME",
        optionDescription :: String
optionDescription = String
"Use the first OpenCL device whose name contains the given string.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_device(cfg, optarg);|]
      },
    Option
      { optionLongName :: String
optionLongName = String
"default-group-size",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"INT",
        optionDescription :: String
optionDescription = String
"The default size of OpenCL workgroups that are launched.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_default_group_size(cfg, atoi(optarg));|]
      },
    Option
      { optionLongName :: String
optionLongName = String
"default-num-groups",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"INT",
        optionDescription :: String
optionDescription = String
"The default number of OpenCL workgroups that are launched.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_default_num_groups(cfg, atoi(optarg));|]
      },
    Option
      { optionLongName :: String
optionLongName = String
"default-tile-size",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"INT",
        optionDescription :: String
optionDescription = String
"The default tile size used when performing two-dimensional tiling.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_default_tile_size(cfg, atoi(optarg));|]
      },
    Option
      { optionLongName :: String
optionLongName = String
"default-reg-tile-size",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"INT",
        optionDescription :: String
optionDescription = String
"The default register tile size used when performing two-dimensional tiling.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_default_reg_tile_size(cfg, atoi(optarg));|]
      },
    Option
      { optionLongName :: String
optionLongName = String
"default-threshold",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"INT",
        optionDescription :: String
optionDescription = String
"The default parallelism threshold.",
        optionAction :: Stm
optionAction = [C.cstm|futhark_context_config_set_default_threshold(cfg, atoi(optarg));|]
      }
  ]

{-# NOINLINE generateBoilerplate #-}