{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TemplateHaskell #-}
-- | This module defines a translation from imperative code with
-- kernels to imperative code with OpenCL calls.
module Futhark.CodeGen.ImpGen.Kernels.ToOpenCL
  ( kernelsToOpenCL
  , kernelsToCUDA
  )
  where

import Control.Monad.State
import Control.Monad.Identity
import Control.Monad.Reader
import Data.FileEmbed
import Data.Maybe
import qualified Data.Set as S
import qualified Data.Map.Strict as M

import qualified Language.C.Syntax as C
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Quote.CUDA as CUDAC

import qualified Futhark.CodeGen.Backends.GenericC as GenericC
import Futhark.CodeGen.Backends.SimpleRepresentation
import Futhark.CodeGen.ImpCode.Kernels hiding (Program)
import qualified Futhark.CodeGen.ImpCode.Kernels as ImpKernels
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import qualified Futhark.CodeGen.ImpCode.OpenCL as ImpOpenCL
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)
import Futhark.Util.Pretty (prettyOneLine)

kernelsToCUDA, kernelsToOpenCL :: ImpKernels.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateKernels KernelTarget
TargetCUDA
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateKernels KernelTarget
TargetOpenCL

-- | Translate a kernels-program to an OpenCL-program.
translateKernels :: KernelTarget
                 -> ImpKernels.Program
                 -> ImpOpenCL.Program
translateKernels :: KernelTarget -> Program -> Program
translateKernels KernelTarget
target Program
prog =
  let (Definitions OpenCL
prog', ToOpenCL Map KernelName (Safety, Func)
kernels Set PrimType
used_types Map Name SizeClass
sizes [FailureMsg]
failures) =
        (State ToOpenCL (Definitions OpenCL)
 -> ToOpenCL -> (Definitions OpenCL, ToOpenCL))
-> ToOpenCL
-> State ToOpenCL (Definitions OpenCL)
-> (Definitions OpenCL, ToOpenCL)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State ToOpenCL (Definitions OpenCL)
-> ToOpenCL -> (Definitions OpenCL, ToOpenCL)
forall s a. State s a -> s -> (a, s)
runState ToOpenCL
initialOpenCL (State ToOpenCL (Definitions OpenCL)
 -> (Definitions OpenCL, ToOpenCL))
-> State ToOpenCL (Definitions OpenCL)
-> (Definitions OpenCL, ToOpenCL)
forall a b. (a -> b) -> a -> b
$ do
          let ImpKernels.Definitions
                (ImpKernels.Constants [Param]
ps Code HostOp
consts)
                (ImpKernels.Functions [(Name, Function HostOp)]
funs) = Program
prog
          Code OpenCL
consts' <- ReaderT Name (State ToOpenCL) (Code OpenCL)
-> Name -> State ToOpenCL (Code OpenCL)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((HostOp -> ReaderT Name (State ToOpenCL) OpenCL)
-> Code HostOp -> ReaderT Name (State ToOpenCL) (Code OpenCL)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> ReaderT Name (State ToOpenCL) OpenCL
onHostOp KernelTarget
target) Code HostOp
consts)
                     (KernelName -> Name
nameFromString KernelName
"val")
          [(Name, FunctionT OpenCL)]
funs' <- [(Name, Function HostOp)]
-> ((Name, Function HostOp)
    -> State ToOpenCL (Name, FunctionT OpenCL))
-> State ToOpenCL [(Name, FunctionT OpenCL)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Name, Function HostOp)]
funs (((Name, Function HostOp)
  -> State ToOpenCL (Name, FunctionT OpenCL))
 -> State ToOpenCL [(Name, FunctionT OpenCL)])
-> ((Name, Function HostOp)
    -> State ToOpenCL (Name, FunctionT OpenCL))
-> State ToOpenCL [(Name, FunctionT OpenCL)]
forall a b. (a -> b) -> a -> b
$ \(Name
fname, Function HostOp
fun) ->
            (Name
fname,) (FunctionT OpenCL -> (Name, FunctionT OpenCL))
-> State ToOpenCL (FunctionT OpenCL)
-> State ToOpenCL (Name, FunctionT OpenCL)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReaderT Name (State ToOpenCL) (FunctionT OpenCL)
-> Name -> State ToOpenCL (FunctionT OpenCL)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ((HostOp -> ReaderT Name (State ToOpenCL) OpenCL)
-> Function HostOp
-> ReaderT Name (State ToOpenCL) (FunctionT OpenCL)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> ReaderT Name (State ToOpenCL) OpenCL
onHostOp KernelTarget
target) Function HostOp
fun) Name
fname
          Definitions OpenCL -> State ToOpenCL (Definitions OpenCL)
forall (m :: * -> *) a. Monad m => a -> m a
return (Definitions OpenCL -> State ToOpenCL (Definitions OpenCL))
-> Definitions OpenCL -> State ToOpenCL (Definitions OpenCL)
forall a b. (a -> b) -> a -> b
$ Constants OpenCL -> Functions OpenCL -> Definitions OpenCL
forall a. Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
            ([Param] -> Code OpenCL -> Constants OpenCL
forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
            ([(Name, FunctionT OpenCL)] -> Functions OpenCL
forall a. [(Name, Function a)] -> Functions a
ImpOpenCL.Functions [(Name, FunctionT OpenCL)]
funs')

      kernels' :: Map KernelName Safety
kernels' = ((Safety, Func) -> Safety)
-> Map KernelName (Safety, Func) -> Map KernelName Safety
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (Safety, Func) -> Safety
forall a b. (a, b) -> a
fst Map KernelName (Safety, Func)
kernels
      opencl_code :: KernelName
opencl_code = [Func] -> KernelName
openClCode ([Func] -> KernelName) -> [Func] -> KernelName
forall a b. (a -> b) -> a -> b
$ ((Safety, Func) -> Func) -> [(Safety, Func)] -> [Func]
forall a b. (a -> b) -> [a] -> [b]
map (Safety, Func) -> Func
forall a b. (a, b) -> b
snd ([(Safety, Func)] -> [Func]) -> [(Safety, Func)] -> [Func]
forall a b. (a -> b) -> a -> b
$ Map KernelName (Safety, Func) -> [(Safety, Func)]
forall k a. Map k a -> [a]
M.elems Map KernelName (Safety, Func)
kernels
      opencl_prelude :: KernelName
opencl_prelude = [Definition] -> KernelName
forall a. Pretty a => a -> KernelName
pretty ([Definition] -> KernelName) -> [Definition] -> KernelName
forall a b. (a -> b) -> a -> b
$ KernelTarget -> Set PrimType -> [Definition]
genPrelude KernelTarget
target Set PrimType
used_types

  in KernelName
-> KernelName
-> Map KernelName Safety
-> [PrimType]
-> Map Name SizeClass
-> [FailureMsg]
-> Definitions OpenCL
-> Program
ImpOpenCL.Program KernelName
opencl_code KernelName
opencl_prelude Map KernelName Safety
kernels'
     (Set PrimType -> [PrimType]
forall a. Set a -> [a]
S.toList Set PrimType
used_types) (Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
sizes) [FailureMsg]
failures Definitions OpenCL
prog'

  where genPrelude :: KernelTarget -> Set PrimType -> [Definition]
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> [Definition]
genOpenClPrelude
        genPrelude KernelTarget
TargetCUDA = [Definition] -> Set PrimType -> [Definition]
forall a b. a -> b -> a
const [Definition]
genCUDAPrelude

-- | Due to simplifications after kernel extraction, some threshold
-- parameters may contain KernelPaths that reference threshold
-- parameters that no longer exist.  We remove these here.
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
m = (SizeClass -> SizeClass)
-> Map Name SizeClass -> Map Name SizeClass
forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map Name SizeClass
m
  where known :: [Name]
known = Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
m
        clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path) =
          KernelPath -> SizeClass
SizeThreshold (KernelPath -> SizeClass) -> KernelPath -> SizeClass
forall a b. (a -> b) -> a -> b
$ ((Name, Bool) -> Bool) -> KernelPath -> KernelPath
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> [Name] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
known) (Name -> Bool) -> ((Name, Bool) -> Name) -> (Name, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Bool) -> Name
forall a b. (a, b) -> a
fst) KernelPath
path
        clean SizeClass
s = SizeClass
s

pointerQuals ::  Monad m => String -> m [C.TypeQual]
pointerQuals :: KernelName -> m [TypeQual]
pointerQuals KernelName
"global"     = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__global|]
pointerQuals KernelName
"local"      = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__local|]
pointerQuals KernelName
"private"    = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__private|]
pointerQuals KernelName
"constant"   = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__constant|]
pointerQuals KernelName
"write_only" = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__write_only|]
pointerQuals KernelName
"read_only"  = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__read_only|]
pointerQuals KernelName
"kernel"     = [TypeQual] -> m [TypeQual]
forall (m :: * -> *) a. Monad m => a -> m a
return [C.ctyquals|__kernel|]
pointerQuals KernelName
s            = KernelName -> m [TypeQual]
forall a. HasCallStack => KernelName -> a
error (KernelName -> m [TypeQual]) -> KernelName -> m [TypeQual]
forall a b. (a -> b) -> a -> b
$ KernelName
"'" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
s KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"' is not an OpenCL kernel address space."

-- In-kernel name and per-workgroup size in bytes.
type LocalMemoryUse = (VName, Count Bytes Exp)

data KernelState =
  KernelState { KernelState -> [LocalMemoryUse]
kernelLocalMemory :: [LocalMemoryUse]
              , KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg]
              , KernelState -> Int
kernelNextSync :: Int
              , KernelState -> Bool
kernelSyncPending :: Bool
                -- ^ Has a potential failure occurred sine the last
                -- ErrorSync?
              , KernelState -> Bool
kernelHasBarriers :: Bool
              }

newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [LocalMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState [LocalMemoryUse]
forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False

errorLabel :: KernelState -> String
errorLabel :: KernelState -> KernelName
errorLabel = (KernelName
"error_"KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++) (KernelName -> KernelName)
-> (KernelState -> KernelName) -> KernelState -> KernelName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> KernelName
forall a. Show a => a -> KernelName
show (Int -> KernelName)
-> (KernelState -> Int) -> KernelState -> KernelName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync

data ToOpenCL = ToOpenCL { ToOpenCL -> Map KernelName (Safety, Func)
clKernels :: M.Map KernelName (Safety, C.Func)
                         , ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType
                         , ToOpenCL -> Map Name SizeClass
clSizes :: M.Map Name SizeClass
                         , ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg]
                         }

initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map KernelName (Safety, Func)
-> Set PrimType -> Map Name SizeClass -> [FailureMsg] -> ToOpenCL
ToOpenCL Map KernelName (Safety, Func)
forall a. Monoid a => a
mempty Set PrimType
forall a. Monoid a => a
mempty Map Name SizeClass
forall a. Monoid a => a
mempty [FailureMsg]
forall a. Monoid a => a
mempty

type OnKernelM = ReaderT Name (State ToOpenCL)

addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
sclass =
  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s { clSizes :: Map Name SizeClass
clSizes = Name -> SizeClass -> Map Name SizeClass -> Map Name SizeClass
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
key SizeClass
sclass (Map Name SizeClass -> Map Name SizeClass)
-> Map Name SizeClass -> Map Name SizeClass
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name SizeClass
clSizes ToOpenCL
s }

onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget -> HostOp -> ReaderT Name (State ToOpenCL) OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget -> Kernel -> ReaderT Name (State ToOpenCL) OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpKernels.GetSize VName
v Name
key SizeClass
size_class) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL -> ReaderT Name (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> Name -> OpenCL
ImpOpenCL.GetSize VName
v Name
key
onHostOp KernelTarget
_ (ImpKernels.CmpSizeLe VName
v Name
key SizeClass
size_class Exp
x) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL -> ReaderT Name (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> Name -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v Name
key Exp
x
onHostOp KernelTarget
_ (ImpKernels.GetSizeMax VName
v SizeClass
size_class) =
  OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL -> ReaderT Name (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class

onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL

onKernel :: KernelTarget -> Kernel -> ReaderT Name (State ToOpenCL) OpenCL
onKernel KernelTarget
target Kernel
kernel = do
  [FailureMsg]
failures <- (ToOpenCL -> [FailureMsg])
-> ReaderT Name (State ToOpenCL) [FailureMsg]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
  let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
        Operations KernelOp KernelState
-> VNameSource
-> KernelState
-> CompilerM KernelOp KernelState [BlockItem]
-> ([BlockItem], CompilerState KernelState)
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GenericC.runCompilerM (KernelCode -> Operations KernelOp KernelState
inKernelOperations (Kernel -> KernelCode
kernelBody Kernel
kernel))
        VNameSource
blankNameSource
        ([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures) (CompilerM KernelOp KernelState [BlockItem]
 -> ([BlockItem], CompilerState KernelState))
-> CompilerM KernelOp KernelState [BlockItem]
-> ([BlockItem], CompilerState KernelState)
forall a b. (a -> b) -> a -> b
$
        CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GenericC.blockScope (CompilerM KernelOp KernelState ()
 -> CompilerM KernelOp KernelState [BlockItem])
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall a b. (a -> b) -> a -> b
$ KernelCode -> CompilerM KernelOp KernelState ()
forall op s. Code op -> CompilerM op s ()
GenericC.compileCode (KernelCode -> CompilerM KernelOp KernelState ())
-> KernelCode -> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel
      kstate :: KernelState
kstate = CompilerState KernelState -> KernelState
forall s. CompilerState s -> s
GenericC.compUserState CompilerState KernelState
cstate

      use_params :: [Param]
use_params = (KernelUse -> Maybe Param) -> [KernelUse] -> [Param]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe Param
useAsParam ([KernelUse] -> [Param]) -> [KernelUse] -> [Param]
forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

      ([Maybe KernelArg]
local_memory_args, [Maybe Param]
local_memory_params, [BlockItem]
local_memory_init) =
        [(Maybe KernelArg, Maybe Param, BlockItem)]
-> ([Maybe KernelArg], [Maybe Param], [BlockItem])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(Maybe KernelArg, Maybe Param, BlockItem)]
 -> ([Maybe KernelArg], [Maybe Param], [BlockItem]))
-> [(Maybe KernelArg, Maybe Param, BlockItem)]
-> ([Maybe KernelArg], [Maybe Param], [BlockItem])
forall a b. (a -> b) -> a -> b
$
        (State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
 -> VNameSource -> [(Maybe KernelArg, Maybe Param, BlockItem)])
-> VNameSource
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
-> [(Maybe KernelArg, Maybe Param, BlockItem)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
-> VNameSource -> [(Maybe KernelArg, Maybe Param, BlockItem)]
forall s a. State s a -> s -> a
evalState (VNameSource
blankNameSource :: VNameSource) (State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
 -> [(Maybe KernelArg, Maybe Param, BlockItem)])
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
-> [(Maybe KernelArg, Maybe Param, BlockItem)]
forall a b. (a -> b) -> a -> b
$
        (LocalMemoryUse
 -> StateT
      VNameSource Identity (Maybe KernelArg, Maybe Param, BlockItem))
-> [LocalMemoryUse]
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (KernelTarget
-> LocalMemoryUse
-> StateT
     VNameSource Identity (Maybe KernelArg, Maybe Param, BlockItem)
forall (m :: * -> *).
MonadFreshNames m =>
KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
target) ([LocalMemoryUse]
 -> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)])
-> [LocalMemoryUse]
-> State VNameSource [(Maybe KernelArg, Maybe Param, BlockItem)]
forall a b. (a -> b) -> a -> b
$ KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
kstate

      -- CUDA has very strict restrictions on the number of blocks
      -- permitted along the 'y' and 'z' dimensions of the grid
      -- (1<<16).  To work around this, we are going to dynamically
      -- permute the block dimensions to move the largest one to the
      -- 'x' dimension, which has a higher limit (1<<31).  This means
      -- we need to extend the kernel with extra parameters that
      -- contain information about this permutation, but we only do
      -- this for multidimensional kernels (at the time of this
      -- writing, only transposes).  The corresponding arguments are
      -- added automatically in CCUDA.hs.
      ([Param]
perm_params, [BlockItem]
block_dim_init) =
        case (KernelTarget
target, [Exp]
num_groups) of
          (KernelTarget
TargetCUDA, [Exp
_, Exp
_, Exp
_]) -> ([[C.cparam|const int block_dim0|],
                                       [C.cparam|const int block_dim1|],
                                       [C.cparam|const int block_dim2|]],
                                      [BlockItem]
forall a. Monoid a => a
mempty)
          (KernelTarget, [Exp])
_ -> ([Param]
forall a. Monoid a => a
mempty,
                [[C.citem|const int block_dim0 = 0;|],
                 [C.citem|const int block_dim1 = 1;|],
                 [C.citem|const int block_dim2 = 2;|]])

      ([BlockItem]
const_defs, [BlockItem]
const_undefs) = [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem]))
-> [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. (a -> b) -> a -> b
$ (KernelUse -> Maybe (BlockItem, BlockItem))
-> [KernelUse] -> [(BlockItem, BlockItem)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (BlockItem, BlockItem)
constDef ([KernelUse] -> [(BlockItem, BlockItem)])
-> [KernelUse] -> [(BlockItem, BlockItem)]
forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  let (Safety
safety, [BlockItem]
error_init)
        | [FailureMsg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [FailureMsg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
            if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
            then (Safety
SafetyNone, [])
            else -- No possible failures in this kernel, so if we make
                 -- it past an initial check, then we are good to go.
                 (Safety
SafetyCheap,
                  [C.citems|if (*global_failure >= 0) { return; }|])

        | Bool
otherwise =
            if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
            then (Safety
SafetyFull,
                  [C.citems|if (*global_failure >= 0) { return; }|])
            else (Safety
SafetyFull,
                  [C.citems|
                     volatile __local bool local_failure;
                     if (failure_is_an_option) {
                       int failed = *global_failure >= 0;
                       if (failed) {
                         return;
                       }
                     }
                     // All threads write this value - it looks like CUDA has a compiler bug otherwise.
                     local_failure = false;
                     barrier(CLK_LOCAL_MEM_FENCE);
                  |])

      failure_params :: [Param]
failure_params =
        [[C.cparam|__global int *global_failure|],
         [C.cparam|int failure_is_an_option|],
         [C.cparam|__global int *global_failure_args|]]

      params :: [Param]
params = [Param]
perm_params [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++
               Int -> [Param] -> [Param]
forall a. Int -> [a] -> [a]
take (Safety -> Int
numFailureParams Safety
safety) [Param]
failure_params [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++
               [Maybe Param] -> [Param]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
local_memory_params [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++
               [Param]
use_params

      kernel_fun :: Func
kernel_fun =
        [C.cfun|__kernel void $id:name ($params:params) {
                  $items:const_defs
                  $items:block_dim_init
                  $items:local_memory_init
                  $items:error_init
                  $items:kernel_body

                  $id:(errorLabel kstate): return;

                  $items:const_undefs
                }|]
  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s
    { clKernels :: Map KernelName (Safety, Func)
clKernels = KernelName
-> (Safety, Func)
-> Map KernelName (Safety, Func)
-> Map KernelName (Safety, Func)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernelName
name (Safety
safety, Func
kernel_fun) (Map KernelName (Safety, Func) -> Map KernelName (Safety, Func))
-> Map KernelName (Safety, Func) -> Map KernelName (Safety, Func)
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map KernelName (Safety, Func)
clKernels ToOpenCL
s
    , clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s
    , clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
    }

  -- The argument corresponding to the global_failure parameters is
  -- added automatically later.
  let args :: [KernelArg]
args = [Maybe KernelArg] -> [KernelArg]
forall a. [Maybe a] -> [a]
catMaybes [Maybe KernelArg]
local_memory_args [KernelArg] -> [KernelArg] -> [KernelArg]
forall a. [a] -> [a] -> [a]
++
             Kernel -> [KernelArg]
kernelArgs Kernel
kernel

  OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall (m :: * -> *) a. Monad m => a -> m a
return (OpenCL -> ReaderT Name (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Name (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ Safety -> KernelName -> [KernelArg] -> [Exp] -> [Exp] -> OpenCL
LaunchKernel Safety
safety KernelName
name [KernelArg]
args [Exp]
num_groups [Exp]
group_size
  where name :: KernelName
name = Name -> KernelName
nameToString (Name -> KernelName) -> Name -> KernelName
forall a b. (a -> b) -> a -> b
$ Kernel -> Name
kernelName Kernel
kernel
        num_groups :: [Exp]
num_groups = Kernel -> [Exp]
kernelNumGroups Kernel
kernel
        group_size :: [Exp]
group_size = Kernel -> [Exp]
kernelGroupSize Kernel
kernel

        prepareLocalMemory :: KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
TargetOpenCL (VName
mem, Count Bytes Exp
size) = do
          VName
mem_aligned <- KernelName -> m VName
forall (m :: * -> *). MonadFreshNames m => KernelName -> m VName
newVName (KernelName -> m VName) -> KernelName -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> KernelName
baseString VName
mem KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_aligned"
          (Maybe KernelArg, Maybe Param, BlockItem)
-> m (Maybe KernelArg, Maybe Param, BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
                  Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|__local volatile typename int64_t* $id:mem_aligned|],
                  [C.citem|__local volatile char* restrict $id:mem = (__local volatile char*)$id:mem_aligned;|])
        prepareLocalMemory KernelTarget
TargetCUDA (VName
mem, Count Bytes Exp
size) = do
          VName
param <- KernelName -> m VName
forall (m :: * -> *). MonadFreshNames m => KernelName -> m VName
newVName (KernelName -> m VName) -> KernelName -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> KernelName
baseString VName
mem KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_offset"
          (Maybe KernelArg, Maybe Param, BlockItem)
-> m (Maybe KernelArg, Maybe Param, BlockItem)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
                  Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|uint $id:param|],
                  [C.citem|volatile char *$id:mem = &shared_mem[$id:param];|])

useAsParam :: KernelUse -> Maybe C.Param
useAsParam :: KernelUse -> Maybe Param
useAsParam (ScalarUse VName
name PrimType
bt) =
  let ctp :: Type
ctp = case PrimType
bt of
        -- OpenCL does not permit bool as a kernel parameter type.
        PrimType
Bool -> [C.cty|unsigned char|]
        PrimType
_    -> PrimType -> Type
GenericC.primTypeToCType PrimType
bt
  in Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|$ty:ctp $id:name|]
useAsParam (MemoryUse VName
name) =
  Param -> Maybe Param
forall a. a -> Maybe a
Just [C.cparam|__global unsigned char *$id:name|]
useAsParam ConstUse{} =
  Maybe Param
forall a. Maybe a
Nothing

-- Constants are #defined as macros.  Since a constant name in one
-- kernel might potentially (although unlikely) also be used for
-- something else in another kernel, we #undef them after the kernel.
constDef :: KernelUse -> Maybe (C.BlockItem, C.BlockItem)
constDef :: KernelUse -> Maybe (BlockItem, BlockItem)
constDef (ConstUse VName
v KernelConstExp
e) = (BlockItem, BlockItem) -> Maybe (BlockItem, BlockItem)
forall a. a -> Maybe a
Just ([C.citem|$escstm:def|],
                                [C.citem|$escstm:undef|])
  where e' :: Exp
e' = KernelConstExp -> Exp
compilePrimExp KernelConstExp
e
        def :: KernelName
def = KernelName
"#define " KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ Id -> KernelName
forall a. Pretty a => a -> KernelName
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v SrcLoc
forall a. Monoid a => a
mempty) KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
" (" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ Exp -> KernelName
forall a. Pretty a => a -> KernelName
prettyOneLine Exp
e' KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
")"
        undef :: KernelName
undef = KernelName
"#undef " KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ Id -> KernelName
forall a. Pretty a => a -> KernelName
pretty (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v SrcLoc
forall a. Monoid a => a
mempty)
constDef KernelUse
_ = Maybe (BlockItem, BlockItem)
forall a. Maybe a
Nothing

openClCode :: [C.Func] -> String
openClCode :: [Func] -> KernelName
openClCode [Func]
kernels =
  [Definition] -> KernelName
forall a. Pretty a => a -> KernelName
pretty [C.cunit|$edecls:funcs|]
  where funcs :: [Definition]
funcs =
          [[C.cedecl|$func:kernel_func|] |
           Func
kernel_func <- [Func]
kernels ]

atomicsDefs :: String
atomicsDefs :: KernelName
atomicsDefs = $(embedStringFile "rts/c/atomics.h")

genOpenClPrelude :: S.Set PrimType -> [C.Definition]
genOpenClPrelude :: Set PrimType -> [Definition]
genOpenClPrelude Set PrimType
ts =
  -- Clang-based OpenCL implementations need this for 'static' to work.
  [ [C.cedecl|$esc:("#ifdef cl_clang_storage_class_specifiers")|]
  , [C.cedecl|$esc:("#pragma OPENCL EXTENSION cl_clang_storage_class_specifiers : enable")|]
  , [C.cedecl|$esc:("#endif")|]
  , [C.cedecl|$esc:("#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable")|]]
  [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++
  [[C.cedecl|$esc:("#pragma OPENCL EXTENSION cl_khr_fp64 : enable")|] | Bool
uses_float64] [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++
  [C.cunit|
/* Some OpenCL programs dislike empty progams, or programs with no kernels.
 * Declare a dummy kernel to ensure they remain our friends. */
__kernel void dummy_kernel(__global unsigned char *dummy, int n)
{
    const int thread_gid = get_global_id(0);
    if (thread_gid >= n) return;
}

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long int64_t;

typedef uchar uint8_t;
typedef ushort uint16_t;
typedef uint uint32_t;
typedef ulong uint64_t;

// NVIDIAs OpenCL does not create device-wide memory fences (see #734), so we
// use inline assembly if we detect we are on an NVIDIA GPU.
$esc:("#ifdef cl_nv_pragma_unroll")
static inline void mem_fence_global() {
  asm("membar.gl;");
}
$esc:("#else")
static inline void mem_fence_global() {
  mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
}
$esc:("#endif")
static inline void mem_fence_local() {
  mem_fence(CLK_LOCAL_MEM_FENCE);
}
|] [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++
  [Definition]
cIntOps [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Funs [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++
  (if Bool
uses_float64 then [Definition]
cFloat64Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Funs [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloatConvOps else [])
  [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [[C.cedecl|$esc:atomicsDefs|]]
  where uses_float64 :: Bool
uses_float64 = FloatType -> PrimType
FloatType FloatType
Float64 PrimType -> Set PrimType -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts

genCUDAPrelude :: [C.Definition]
genCUDAPrelude :: [Definition]
genCUDAPrelude =
  [Definition]
cudafy [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
ops
  where ops :: [Definition]
ops = [Definition]
cIntOps [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Ops [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat32Funs [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Ops
              [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloat64Funs [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [Definition]
cFloatConvOps
              [Definition] -> [Definition] -> [Definition]
forall a. [a] -> [a] -> [a]
++ [[C.cedecl|$esc:atomicsDefs|]]
        cudafy :: [Definition]
cudafy = [CUDAC.cunit|
$esc:("#define FUTHARK_CUDA")

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long long int64_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef uint8_t uchar;
typedef uint16_t ushort;
typedef uint32_t uint;
typedef uint64_t ulong;
$esc:("#define __kernel extern \"C\" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK)")
$esc:("#define __global")
$esc:("#define __local")
$esc:("#define __private")
$esc:("#define __constant")
$esc:("#define __write_only")
$esc:("#define __read_only")

static inline int get_group_id_fn(int block_dim0, int block_dim1, int block_dim2, int d)
{
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch (d) {
    case 0: return blockIdx.x;
    case 1: return blockIdx.y;
    case 2: return blockIdx.z;
    default: return 0;
  }
}
$esc:("#define get_group_id(d) get_group_id_fn(block_dim0, block_dim1, block_dim2, d)")

static inline int get_num_groups_fn(int block_dim0, int block_dim1, int block_dim2, int d)
{
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch(d) {
    case 0: return gridDim.x;
    case 1: return gridDim.y;
    case 2: return gridDim.z;
    default: return 0;
  }
}
$esc:("#define get_num_groups(d) get_num_groups_fn(block_dim0, block_dim1, block_dim2, d)")

static inline int get_local_id(int d)
{
  switch (d) {
    case 0: return threadIdx.x;
    case 1: return threadIdx.y;
    case 2: return threadIdx.z;
    default: return 0;
  }
}

static inline int get_local_size(int d)
{
  switch (d) {
    case 0: return blockDim.x;
    case 1: return blockDim.y;
    case 2: return blockDim.z;
    default: return 0;
  }
}

static inline int get_global_id_fn(int block_dim0, int block_dim1, int block_dim2, int d)
{
  return get_group_id(d) * get_local_size(d) + get_local_id(d);
}
$esc:("#define get_global_id(d) get_global_id_fn(block_dim0, block_dim1, block_dim2, d)")

static inline int get_global_size(int block_dim0, int block_dim1, int block_dim2, int d)
{
  return get_num_groups(d) * get_local_size(d);
}

$esc:("#define CLK_LOCAL_MEM_FENCE 1")
$esc:("#define CLK_GLOBAL_MEM_FENCE 2")
static inline void barrier(int x)
{
  __syncthreads();
}
static inline void mem_fence_local() {
  __threadfence_block();
}
static inline void mem_fence_global() {
  __threadfence();
}

$esc:("#define NAN (0.0/0.0)")
$esc:("#define INFINITY (1.0/0.0)")
extern volatile __shared__ char shared_mem[];
|]

compilePrimExp :: PrimExp KernelConst -> C.Exp
compilePrimExp :: KernelConstExp -> Exp
compilePrimExp KernelConstExp
e = Identity Exp -> Exp
forall a. Identity a -> a
runIdentity (Identity Exp -> Exp) -> Identity Exp -> Exp
forall a b. (a -> b) -> a -> b
$ (KernelConst -> Identity Exp) -> KernelConstExp -> Identity Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
GenericC.compilePrimExp KernelConst -> Identity Exp
forall (m :: * -> *). Monad m => KernelConst -> m Exp
compileKernelConst KernelConstExp
e
  where compileKernelConst :: KernelConst -> m Exp
compileKernelConst (SizeConst Name
key) =
          Exp -> m Exp
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cexp|$id:(zEncodeString (pretty key))|]

kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = (KernelUse -> Maybe KernelArg) -> [KernelUse] -> [KernelArg]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg ([KernelUse] -> [KernelArg])
-> (Kernel -> [KernelUse]) -> Kernel -> [KernelArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
  where useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem)  = KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
        useToArg (ScalarUse VName
v PrimType
bt) = KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> ExpLeaf
ScalarVar VName
v) PrimType
bt) PrimType
bt
        useToArg ConstUse{}       = Maybe KernelArg
forall a. Maybe a
Nothing

nextErrorLabel :: GenericC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState KernelName
nextErrorLabel =
  KernelState -> KernelName
errorLabel (KernelState -> KernelName)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState KernelName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GenericC.getUserState

incErrorLabel :: GenericC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
  (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GenericC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s { kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 }

pendingError :: Bool -> GenericC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
  (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GenericC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s { kernelSyncPending :: Bool
kernelSyncPending = Bool
b }

hasCommunication :: ImpKernels.KernelCode -> Bool
hasCommunication :: KernelCode -> Bool
hasCommunication = (KernelOp -> Bool) -> KernelCode -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
  where communicates :: KernelOp -> Bool
communicates ErrorSync{} = Bool
True
        communicates Barrier{} = Bool
True
        communicates KernelOp
_ = Bool
False

inKernelOperations :: ImpKernels.KernelCode -> GenericC.Operations KernelOp KernelState
inKernelOperations :: KernelCode -> Operations KernelOp KernelState
inKernelOperations KernelCode
body =
  Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Deallocate op s
-> Copy op s
-> StaticArray op s
-> MemoryType op s
-> OpCompiler op s
-> ErrorCompiler op s
-> Bool
-> Operations op s
GenericC.Operations
  { opsCompiler :: OpCompiler KernelOp KernelState
GenericC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps
  , opsMemoryType :: MemoryType KernelOp KernelState
GenericC.opsMemoryType = MemoryType KernelOp KernelState
forall (m :: * -> *). Monad m => KernelName -> m Type
kernelMemoryType
  , opsWriteScalar :: WriteScalar KernelOp KernelState
GenericC.opsWriteScalar = WriteScalar KernelOp KernelState
forall op s. WriteScalar op s
kernelWriteScalar
  , opsReadScalar :: ReadScalar KernelOp KernelState
GenericC.opsReadScalar = ReadScalar KernelOp KernelState
forall op s. ReadScalar op s
kernelReadScalar
  , opsAllocate :: Allocate KernelOp KernelState
GenericC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate
  , opsDeallocate :: Deallocate KernelOp KernelState
GenericC.opsDeallocate = Deallocate KernelOp KernelState
cannotDeallocate
  , opsCopy :: Copy KernelOp KernelState
GenericC.opsCopy = Copy KernelOp KernelState
copyInKernel
  , opsStaticArray :: StaticArray KernelOp KernelState
GenericC.opsStaticArray = StaticArray KernelOp KernelState
noStaticArrays
  , opsFatMemory :: Bool
GenericC.opsFatMemory = Bool
False
  , opsError :: ErrorCompiler KernelOp KernelState
GenericC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel
  }
  where has_communication :: Bool
has_communication = KernelCode -> Bool
hasCommunication KernelCode
body

        fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
        fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE|]

        kernelOps :: GenericC.OpCompiler KernelOp KernelState
        kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetGroupId VName
v Int
i) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:v = get_group_id($int:i);|]
        kernelOps (GetLocalId VName
v Int
i) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:v = get_local_id($int:i);|]
        kernelOps (GetLocalSize VName
v Int
i) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:v = get_local_size($int:i);|]
        kernelOps (GetGlobalId VName
v Int
i) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:v = get_global_id($int:i);|]
        kernelOps (GetGlobalSize VName
v Int
i) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:v = get_global_size($int:i);|]
        kernelOps (GetLockstepWidth VName
v) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
        kernelOps (Barrier Fence
f) = do
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|barrier($exp:(fence f));|]
          (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GenericC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s { kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True }
        kernelOps (MemFence Fence
FenceLocal) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|mem_fence_local();|]
        kernelOps (MemFence Fence
FenceGlobal) =
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|mem_fence_global();|]
        kernelOps (LocalAlloc VName
name Count Bytes Exp
size) = do
          VName
name' <- KernelName -> CompilerM KernelOp KernelState VName
forall (m :: * -> *). MonadFreshNames m => KernelName -> m VName
newVName (KernelName -> CompilerM KernelOp KernelState VName)
-> KernelName -> CompilerM KernelOp KernelState VName
forall a b. (a -> b) -> a -> b
$ VName -> KernelName
forall a. Pretty a => a -> KernelName
pretty VName
name KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_backing"
          (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GenericC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
            KernelState
s { kernelLocalMemory :: [LocalMemoryUse]
kernelLocalMemory = (VName
name', Count Bytes Exp
size) LocalMemoryUse -> [LocalMemoryUse] -> [LocalMemoryUse]
forall a. a -> [a] -> [a]
: KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
s }
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:name = (__local char*) $id:name';|]
        kernelOps (ErrorSync Fence
f) = do
          KernelName
label <- CompilerM KernelOp KernelState KernelName
nextErrorLabel
          Bool
pending <- KernelState -> Bool
kernelSyncPending (KernelState -> Bool)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GenericC.getUserState
          Bool
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending (CompilerM KernelOp KernelState ()
 -> CompilerM KernelOp KernelState ())
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ do
            Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
            Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
            Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|if (local_failure) { return; }|]
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|barrier(CLK_LOCAL_MEM_FENCE);|] -- intentional
          (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GenericC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s { kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True }
          CompilerM KernelOp KernelState ()
incErrorLabel
        kernelOps (Atomic Space
space AtomicOp
aop) = Space -> AtomicOp -> CompilerM KernelOp KernelState ()
forall op s. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop

        atomicCast :: Space -> Type -> m Type
atomicCast Space
s Type
t = do
          let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
          [TypeQual]
quals <- case Space
s of Space KernelName
sid    -> KernelName -> m [TypeQual]
forall (m :: * -> *). Monad m => KernelName -> m [TypeQual]
pointerQuals KernelName
sid
                             Space
_            -> KernelName -> m [TypeQual]
forall (m :: * -> *). Monad m => KernelName -> m [TypeQual]
pointerQuals KernelName
"global"
          Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$tyquals:(volatile++quals) $ty:t|]

        atomicSpace :: Space -> KernelName
atomicSpace (Space KernelName
sid) = KernelName
sid
        atomicSpace Space
_           = KernelName
"global"

        doAtomic :: Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s a
t a
old a
arr Count u Exp
ind Exp
val KernelName
op Type
ty = do
          Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Count u Exp -> Exp
forall u e. Count u e -> e
unCount Count u Exp
ind
          Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp Exp
val
          Type
cast <- Space -> Type -> CompilerM op s Type
forall (m :: * -> *). Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
          where op' :: KernelName
op' = KernelName
op KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ a -> KernelName
forall a. Pretty a => a -> KernelName
pretty a
t KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ Space -> KernelName
atomicSpace Space
s

        atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_add" [C.cty|int|]

        atomicOps Space
s (AtomicFAdd FloatType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> FloatType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_fadd" [C.cty|float|]

        atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_smax" [C.cty|int|]

        atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_smin" [C.cty|int|]

        atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_umax" [C.cty|unsigned int|]

        atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_umin" [C.cty|unsigned int|]

        atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_and" [C.cty|int|]

        atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_or" [C.cty|int|]

        atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val) =
          Space
-> IntType
-> VName
-> VName
-> Count Elements Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
forall a a a u op s.
(ToIdent a, ToIdent a, Pretty a) =>
Space
-> a
-> a
-> a
-> Count u Exp
-> Exp
-> KernelName
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements Exp
ind Exp
val KernelName
"atomic_xor" [C.cty|int|]

        atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements Exp
ind Exp
cmp Exp
val) = do
          Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Count Elements Exp -> Exp
forall u e. Count u e -> e
unCount Count Elements Exp
ind
          Exp
cmp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp Exp
cmp
          Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp Exp
val
          Type
cast <- Space -> Type -> CompilerM op s Type
forall (m :: * -> *). Monad m => Space -> Type -> m Type
atomicCast Space
s [C.cty|int|]
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
          where op :: KernelName
op = KernelName
"atomic_cmpxchg_" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ PrimType -> KernelName
forall a. Pretty a => a -> KernelName
pretty PrimType
t KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ Space -> KernelName
atomicSpace Space
s

        atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements Exp
ind Exp
val) = do
          Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ Count Elements Exp -> Exp
forall u e. Count u e -> e
unCount Count Elements Exp
ind
          Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp Exp
val
          Type
cast <- Space -> Type -> CompilerM op s Type
forall (m :: * -> *). Monad m => Space -> Type -> m Type
atomicCast Space
s [C.cty|int|]
          Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
          where op :: KernelName
op = KernelName
"atomic_cmpxchg_" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ PrimType -> KernelName
forall a. Pretty a => a -> KernelName
pretty PrimType
t KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ KernelName
"_" KernelName -> KernelName -> KernelName
forall a. [a] -> [a] -> [a]
++ Space -> KernelName
atomicSpace Space
s

        cannotAllocate :: GenericC.Allocate KernelOp KernelState
        cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
          KernelName -> Deallocate KernelOp KernelState
forall a. HasCallStack => KernelName -> a
error KernelName
"Cannot allocate memory in kernel"

        cannotDeallocate :: GenericC.Deallocate KernelOp KernelState
        cannotDeallocate :: Deallocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
          KernelName -> KernelName -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => KernelName -> a
error KernelName
"Cannot deallocate memory in kernel"

        copyInKernel :: GenericC.Copy KernelOp KernelState
        copyInKernel :: Copy KernelOp KernelState
copyInKernel Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
          KernelName -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => KernelName -> a
error KernelName
"Cannot bulk copy in kernel."

        noStaticArrays :: GenericC.StaticArray KernelOp KernelState
        noStaticArrays :: StaticArray KernelOp KernelState
noStaticArrays VName
_ KernelName
_ PrimType
_ ArrayContents
_ =
          KernelName -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => KernelName -> a
error KernelName
"Cannot create static array in kernel."

        kernelMemoryType :: KernelName -> m Type
kernelMemoryType KernelName
space = do
          [TypeQual]
quals <- KernelName -> m [TypeQual]
forall (m :: * -> *). Monad m => KernelName -> m [TypeQual]
pointerQuals KernelName
space
          Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return [C.cty|$tyquals:quals $ty:defaultMemBlockType|]

        kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
          PointerQuals op s -> WriteScalar op s
forall op s. PointerQuals op s -> WriteScalar op s
GenericC.writeScalarPointerWithQuals PointerQuals op s
forall (m :: * -> *). Monad m => KernelName -> m [TypeQual]
pointerQuals

        kernelReadScalar :: ReadScalar op s
kernelReadScalar =
          PointerQuals op s -> ReadScalar op s
forall op s. PointerQuals op s -> ReadScalar op s
GenericC.readScalarPointerWithQuals PointerQuals op s
forall (m :: * -> *). Monad m => KernelName -> m [TypeQual]
pointerQuals

        errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) KernelName
backtrace = do
          Int
n <- [FailureMsg] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([FailureMsg] -> Int)
-> (KernelState -> [FailureMsg]) -> KernelState -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures (KernelState -> Int)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GenericC.getUserState
          (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GenericC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
            KernelState
s { kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s [FailureMsg] -> [FailureMsg] -> [FailureMsg]
forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> KernelName -> FailureMsg
FailureMsg ErrorMsg Exp
msg KernelName
backtrace] }
          let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = [Stm] -> CompilerM op s [Stm]
forall (m :: * -> *) a. Monad m => a -> m a
return []
              setArgs a
i (ErrorString{} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
              setArgs a
i (ErrorInt32 Exp
x : [ErrorMsgPart Exp]
parts') = do
                Exp
x' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GenericC.compileExp Exp
x
                [Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
ia -> a -> a
forall a. Num a => a -> a -> a
+a
1) [ErrorMsgPart Exp]
parts'
                [Stm] -> CompilerM op s [Stm]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stm] -> CompilerM op s [Stm]) -> [Stm] -> CompilerM op s [Stm]
forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = $exp:x';|] Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms
          [Stm]
argstms <- Int -> [ErrorMsgPart Exp] -> CompilerM KernelOp KernelState [Stm]
forall a op s.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0::Int) [ErrorMsgPart Exp]
parts
          KernelName
label <- CompilerM KernelOp KernelState KernelName
nextErrorLabel
          Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
          let what_next :: [BlockItem]
what_next
                | Bool
has_communication = [C.citems|local_failure = true;
                                                goto $id:label;|]
                | Bool
otherwise         = [C.citems|return;|]
          Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GenericC.stm [C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
                                 { $stms:argstms; }
                                 $items:what_next
                               }|]

--- Checking requirements

typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = KernelCode -> Set PrimType
typesInCode (KernelCode -> Set PrimType) -> KernelCode -> Set PrimType
forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel

typesInCode :: ImpKernels.KernelCode -> S.Set PrimType
typesInCode :: KernelCode -> Set PrimType
typesInCode KernelCode
Skip = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (KernelCode
c1 :>>: KernelCode
c2) = KernelCode -> Set PrimType
typesInCode KernelCode
c1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c2
typesInCode (For VName
_ IntType
it Exp
e KernelCode
c) = IntType -> PrimType
IntType IntType
it PrimType -> Set PrimType -> Set PrimType
forall a. Ord a => a -> Set a -> Set a
`S.insert` Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode (While Exp
e KernelCode
c) = Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode DeclareMem{} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ Space
_ PrimType
t ArrayContents
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count Exp
e) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free{} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (Copy VName
_ (Count Exp
e1) Space
_ VName
_ (Count Exp
e2) Space
_ (Count Exp
e3)) =
  Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e3
typesInCode (Write VName
_ (Count Exp
e1) PrimType
t Space
_ Volatility
_ Exp
e2) =
  Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem{} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ Name
_ [Arg]
es) = [Set PrimType] -> Set PrimType
forall a. Monoid a => [a] -> a
mconcat ([Set PrimType] -> Set PrimType) -> [Set PrimType] -> Set PrimType
forall a b. (a -> b) -> a -> b
$ (Arg -> Set PrimType) -> [Arg] -> [Set PrimType]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
  where typesInArg :: Arg -> Set PrimType
typesInArg MemArg{} = Set PrimType
forall a. Monoid a => a
mempty
        typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If Exp
e KernelCode
c1 KernelCode
c2) =
  Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment KernelName
_ KernelCode
c) = KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode (DebugPrint KernelName
_ Maybe Exp
v) = Set PrimType -> (Exp -> Set PrimType) -> Maybe Exp -> Set PrimType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Set PrimType
forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode Op{} = Set PrimType
forall a. Monoid a => a
mempty

typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton (PrimType -> Set PrimType) -> PrimType -> Set PrimType
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = [PrimType] -> Set PrimType
forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
  where (PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp KernelName
_ [Exp]
args PrimType
t) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> [Set PrimType] -> Set PrimType
forall a. Monoid a => [a] -> a
mconcat ((Exp -> Set PrimType) -> [Exp] -> [Set PrimType]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp (LeafExp (Index VName
_ (Count Exp
e) PrimType
t Space
_ Volatility
_) PrimType
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
typesInExp (LeafExp ScalarVar{} PrimType
_) = Set PrimType
forall a. Monoid a => a
mempty
typesInExp (LeafExp (SizeOf PrimType
t) PrimType
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t