{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TupleSections #-}

-- | This module defines a translation from imperative code with
-- kernels to imperative code with OpenCL or CUDA calls.
module Futhark.CodeGen.ImpGen.GPU.ToOpenCL
  ( kernelsToOpenCL,
    kernelsToCUDA,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Data.Text as T
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode.GPU hiding (Program)
import qualified Futhark.CodeGen.ImpCode.GPU as ImpGPU
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import qualified Futhark.CodeGen.ImpCode.OpenCL as ImpOpenCL
import Futhark.CodeGen.RTS.C (atomicsH, halfH)
import Futhark.Error (compilerLimitationS)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)
import Futhark.Util.Pretty (prettyOneLine, prettyText)
import qualified Language.C.Quote.OpenCL as C
import qualified Language.C.Syntax as C
import NeatInterpolation (untrimming)

-- | Generate CUDA host and device code.
kernelsToCUDA :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetCUDA

-- | Generate OpenCL host and device code.
kernelsToOpenCL :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetOpenCL

-- | Translate a kernels-program to an OpenCL-program.
translateGPU ::
  KernelTarget ->
  ImpGPU.Program ->
  ImpOpenCL.Program
translateGPU :: KernelTarget -> Program -> Program
translateGPU KernelTarget
target Program
prog =
  let ( Definitions OpenCL
prog',
        ToOpenCL Map Name (KernelSafety, Func)
kernels Map Name (Definition, Func)
device_funs Set PrimType
used_types Map Name SizeClass
sizes [FailureMsg]
failures
        ) =
          (forall s a. State s a -> s -> (a, s)
`runState` ToOpenCL
initialOpenCL) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` forall a. Definitions a -> Functions a
defFuns Program
prog) forall a b. (a -> b) -> a -> b
$ do
            let ImpGPU.Definitions
                  OpaqueTypes
types
                  (ImpGPU.Constants [Param]
ps Code HostOp
consts)
                  (ImpGPU.Functions [(Name, Function HostOp)]
funs) = Program
prog
            Code OpenCL
consts' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Code HostOp
consts
            [(Name, FunctionT OpenCL)]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Name, Function HostOp)]
funs forall a b. (a -> b) -> a -> b
$ \(Name
fname, Function HostOp
fun) ->
              (Name
fname,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Function HostOp
fun

            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
              forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
                OpaqueTypes
types
                (forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
                (forall a. [(Name, Function a)] -> Functions a
ImpOpenCL.Functions [(Name, FunctionT OpenCL)]
funs')

      ([Definition]
device_prototypes, [Func]
device_defs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (Definition, Func)
device_funs
      kernels' :: Map Name KernelSafety
kernels' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a b. (a, b) -> a
fst Map Name (KernelSafety, Func)
kernels
      opencl_code :: Text
opencl_code = [Func] -> Text
openClCode forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (KernelSafety, Func)
kernels

      opencl_prelude :: Text
opencl_prelude =
        [Text] -> Text
T.unlines
          [ KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
target Set PrimType
used_types,
            [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> Text
prettyText [Definition]
device_prototypes,
            [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a. Pretty a => a -> Text
prettyText [Func]
device_defs
          ]
   in Text
-> Text
-> Map Name KernelSafety
-> [PrimType]
-> Map Name SizeClass
-> [FailureMsg]
-> Definitions OpenCL
-> Program
ImpOpenCL.Program
        Text
opencl_code
        Text
opencl_prelude
        Map Name KernelSafety
kernels'
        (forall a. Set a -> [a]
S.toList Set PrimType
used_types)
        (Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
sizes)
        [FailureMsg]
failures
        Definitions OpenCL
prog'
  where
    genPrelude :: KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> Text
genOpenClPrelude
    genPrelude KernelTarget
TargetCUDA = forall a b. a -> b -> a
const Text
genCUDAPrelude

-- | Due to simplifications after kernel extraction, some threshold
-- parameters may contain KernelPaths that reference threshold
-- parameters that no longer exist.  We remove these here.
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
m = forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map Name SizeClass
m
  where
    known :: [Name]
known = forall k a. Map k a -> [k]
M.keys Map Name SizeClass
m
    clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path Maybe Int64
def) =
      KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold (forall a. (a -> Bool) -> [a] -> [a]
filter ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
known) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) KernelPath
path) Maybe Int64
def
    clean SizeClass
s = SizeClass
s

pointerQuals :: Monad m => String -> m [C.TypeQual]
pointerQuals :: forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
"global" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__global|]
pointerQuals EncodedString
"local" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__local|]
pointerQuals EncodedString
"private" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__private|]
pointerQuals EncodedString
"constant" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__constant|]
pointerQuals EncodedString
"write_only" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__write_only|]
pointerQuals EncodedString
"read_only" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__read_only|]
pointerQuals EncodedString
"kernel" = forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.ctyquals|__kernel|]
pointerQuals EncodedString
s = forall a. HasCallStack => EncodedString -> a
error forall a b. (a -> b) -> a -> b
$ EncodedString
"'" forall a. [a] -> [a] -> [a]
++ EncodedString
s forall a. [a] -> [a] -> [a]
++ EncodedString
"' is not an OpenCL kernel address space."

-- In-kernel name and per-workgroup size in bytes.
type LocalMemoryUse = (VName, Count Bytes Exp)

data KernelState = KernelState
  { KernelState -> [LocalMemoryUse]
kernelLocalMemory :: [LocalMemoryUse],
    KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg],
    KernelState -> Int
kernelNextSync :: Int,
    -- | Has a potential failure occurred sine the last
    -- ErrorSync?
    KernelState -> Bool
kernelSyncPending :: Bool,
    KernelState -> Bool
kernelHasBarriers :: Bool
  }

newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [LocalMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False

errorLabel :: KernelState -> String
errorLabel :: KernelState -> EncodedString
errorLabel = (EncodedString
"error_" forall a. [a] -> [a] -> [a]
++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> EncodedString
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync

data ToOpenCL = ToOpenCL
  { ToOpenCL -> Map Name (KernelSafety, Func)
clGPU :: M.Map KernelName (KernelSafety, C.Func),
    ToOpenCL -> Map Name (Definition, Func)
clDevFuns :: M.Map Name (C.Definition, C.Func),
    ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType,
    ToOpenCL -> Map Name SizeClass
clSizes :: M.Map Name SizeClass,
    ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg]
  }

initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map Name (KernelSafety, Func)
-> Map Name (Definition, Func)
-> Set PrimType
-> Map Name SizeClass
-> [FailureMsg]
-> ToOpenCL
ToOpenCL forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

type AllFunctions = ImpGPU.Functions ImpGPU.HostOp

lookupFunction :: Name -> AllFunctions -> Maybe (ImpGPU.Function HostOp)
lookupFunction :: Name -> AllFunctions -> Maybe (Function HostOp)
lookupFunction Name
fname (ImpGPU.Functions [(Name, Function HostOp)]
fs) = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname [(Name, Function HostOp)]
fs

type OnKernelM = ReaderT AllFunctions (State ToOpenCL)

addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
sclass =
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s {clSizes :: Map Name SizeClass
clSizes = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
key SizeClass
sclass forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name SizeClass
clSizes ToOpenCL
s}

onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpGPU.GetSize VName
v Name
key SizeClass
size_class) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> OpenCL
ImpOpenCL.GetSize VName
v Name
key
onHostOp KernelTarget
_ (ImpGPU.CmpSizeLe VName
v Name
key SizeClass
size_class Exp
x) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v Name
key Exp
x
onHostOp KernelTarget
_ (ImpGPU.GetSizeMax VName
v SizeClass
size_class) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class

genGPUCode ::
  OpsMode ->
  KernelCode ->
  [FailureMsg] ->
  GC.CompilerM KernelOp KernelState a ->
  (a, GC.CompilerState KernelState)
genGPUCode :: forall a.
OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode OpsMode
mode KernelCode
body [FailureMsg]
failures =
  forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GC.runCompilerM
    (OpsMode -> KernelCode -> Operations KernelOp KernelState
inKernelOperations OpsMode
mode KernelCode
body)
    VNameSource
blankNameSource
    ([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures)

-- Compilation of a device function that is not not invoked from the
-- host, but is invoked by (perhaps multiple) kernels.
generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
generateDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
device_func = do
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Param -> Bool
memParam forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> [Param]
functionInput Function KernelOp
device_func) forall {a}. a
bad

  [FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures

  let params :: [Param]
params =
        [ [C.cparam|__global int *global_failure|],
          [C.cparam|__global typename int64_t *global_failure_args|]
        ]
      ((Definition, Func)
func, CompilerState KernelState
cstate) =
        forall a.
OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode OpsMode
FunMode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
          forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
GC.compileFun forall a. Monoid a => a
mempty [Param]
params (Name
fname, Function KernelOp
device_func)
      kstate :: KernelState
kstate = forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate

  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clUsedTypes :: Set PrimType
clUsedTypes = KernelCode -> Set PrimType
typesInCode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clDevFuns :: Map Name (Definition, Func)
clDevFuns = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
fname (Definition, Func)
func forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (Definition, Func)
clDevFuns ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- Important to do this after the 'modify' call, so we propagate the
  -- right clFailures.
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ KernelCode -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func
  where
    memParam :: Param -> Bool
memParam MemParam {} = Bool
True
    memParam ScalarParam {} = Bool
False

    bad :: a
bad = forall a. EncodedString -> a
compilerLimitationS EncodedString
"Cannot generate GPU functions that use arrays."

-- Ensure that this device function is available, but don't regenerate
-- it if it already exists.
ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
ensureDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
host_func = do
  Bool
exists <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Bool
M.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. ToOpenCL -> Map Name (Definition, Func)
clDevFuns
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
host_func

ensureDeviceFuns :: ImpGPU.KernelCode -> OnKernelM [Name]
ensureDeviceFuns :: KernelCode -> OnKernelM [Name]
ensureDeviceFuns KernelCode
code = do
  let called :: Set Name
called = forall a. Code a -> Set Name
calledFuncs KernelCode
code
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Set a -> [a]
S.toList Set Name
called) forall a b. (a -> b) -> a -> b
$ \Name
fname -> do
      Maybe (Function HostOp)
def <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ Name -> AllFunctions -> Maybe (Function HostOp)
lookupFunction Name
fname
      case Maybe (Function HostOp)
def of
        Just Function HostOp
host_func -> do
          -- Functions are a priori always considered host-level, so we have
          -- to convert them to device code.  This is where most of our
          -- limitations on device-side functions (no arrays, no parallelism)
          -- comes from.
          let device_func :: Function KernelOp
device_func = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> KernelOp
toDevice Function HostOp
host_func
          Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
device_func
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Name
fname
        Maybe (Function HostOp)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  where
    bad :: a
bad = forall a. EncodedString -> a
compilerLimitationS EncodedString
"Cannot generate GPU functions that contain parallelism."
    toDevice :: HostOp -> KernelOp
    toDevice :: HostOp -> KernelOp
toDevice HostOp
_ = forall {a}. a
bad

onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
kernel = do
  [Name]
called <- KernelCode -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel

  -- Crucial that this is done after 'ensureDeviceFuns', as the device
  -- functions may themselves define failure points.
  [FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures

  let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
        forall a.
OpsMode
-> KernelCode
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode OpsMode
KernelMode (Kernel -> KernelCode
kernelBody Kernel
kernel) [FailureMsg]
failures forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ do
          [BlockItem]
body <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ forall op s. Code op -> CompilerM op s ()
GC.compileCode forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel
          -- No need to free, as we cannot allocate memory in kernels.
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op s. CompilerM op s [BlockItem]
GC.declAllocatedMem
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item [BlockItem]
body
      kstate :: KernelState
kstate = forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate

      ([Maybe KernelArg]
local_memory_args, [Maybe Param]
local_memory_params, [BlockItem]
local_memory_init) =
        forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> a
evalState (VNameSource
blankNameSource :: VNameSource) forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *}.
MonadFreshNames m =>
KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
target) forall a b. (a -> b) -> a -> b
$
            KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
kstate

      -- CUDA has very strict restrictions on the number of blocks
      -- permitted along the 'y' and 'z' dimensions of the grid
      -- (1<<16).  To work around this, we are going to dynamically
      -- permute the block dimensions to move the largest one to the
      -- 'x' dimension, which has a higher limit (1<<31).  This means
      -- we need to extend the kernel with extra parameters that
      -- contain information about this permutation, but we only do
      -- this for multidimensional kernels (at the time of this
      -- writing, only transposes).  The corresponding arguments are
      -- added automatically in CCUDA.hs.
      ([Param]
perm_params, [BlockItem]
block_dim_init) =
        case (KernelTarget
target, [Exp]
num_groups) of
          (KernelTarget
TargetCUDA, [Exp
_, Exp
_, Exp
_]) ->
            ( [ [C.cparam|const int block_dim0|],
                [C.cparam|const int block_dim1|],
                [C.cparam|const int block_dim2|]
              ],
              forall a. Monoid a => a
mempty
            )
          (KernelTarget, [Exp])
_ ->
            ( forall a. Monoid a => a
mempty,
              [ [C.citem|const int block_dim0 = 0;|],
                [C.citem|const int block_dim1 = 1;|],
                [C.citem|const int block_dim2 = 2;|]
              ]
            )

      ([BlockItem]
const_defs, [BlockItem]
const_undefs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (BlockItem, BlockItem)
constDef forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  let ([Param]
use_params, [[BlockItem]]
unpack_params) =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (Param, [BlockItem])
useAsParam forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  let (KernelSafety
safety, [BlockItem]
error_init)
        -- We conservatively assume that any called function can fail.
        | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
called =
            (KernelSafety
SafetyFull, [])
        | forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
            if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
              then (KernelSafety
SafetyNone, [])
              else -- No possible failures in this kernel, so if we make
              -- it past an initial check, then we are good to go.

                ( KernelSafety
SafetyCheap,
                  [C.citems|if (*global_failure >= 0) { return; }|]
                )
        | Bool
otherwise =
            if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
              then
                ( KernelSafety
SafetyFull,
                  [C.citems|if (*global_failure >= 0) { return; }|]
                )
              else
                ( KernelSafety
SafetyFull,
                  [C.citems|
                     volatile __local bool local_failure;
                     if (failure_is_an_option) {
                       int failed = *global_failure >= 0;
                       if (failed) {
                         return;
                       }
                     }
                     // All threads write this value - it looks like CUDA has a compiler bug otherwise.
                     local_failure = false;
                     barrier(CLK_LOCAL_MEM_FENCE);
                  |]
                )

      failure_params :: [Param]
failure_params =
        [ [C.cparam|__global int *global_failure|],
          [C.cparam|int failure_is_an_option|],
          [C.cparam|__global typename int64_t *global_failure_args|]
        ]

      params :: [Param]
params =
        [Param]
perm_params
          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take (KernelSafety -> Int
numFailureParams KernelSafety
safety) [Param]
failure_params
          forall a. [a] -> [a] -> [a]
++ forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
local_memory_params
          forall a. [a] -> [a] -> [a]
++ [Param]
use_params

      kernel_fun :: Func
kernel_fun =
        [C.cfun|__kernel void $id:name ($params:params) {
                  $items:(mconcat unpack_params)
                  $items:const_defs
                  $items:block_dim_init
                  $items:local_memory_init
                  $items:error_init
                  $items:kernel_body

                  $id:(errorLabel kstate): return;

                  $items:const_undefs
                }|]
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clGPU :: Map Name (KernelSafety, Func)
clGPU = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name (KernelSafety
safety, Func
kernel_fun) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (KernelSafety, Func)
clGPU ToOpenCL
s,
        clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- The argument corresponding to the global_failure parameters is
  -- added automatically later.
  let args :: [KernelArg]
args =
        forall a. [Maybe a] -> [a]
catMaybes [Maybe KernelArg]
local_memory_args
          forall a. [a] -> [a] -> [a]
++ Kernel -> [KernelArg]
kernelArgs Kernel
kernel

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelSafety -> Name -> [KernelArg] -> [Exp] -> [Exp] -> OpenCL
LaunchKernel KernelSafety
safety Name
name [KernelArg]
args [Exp]
num_groups [Exp]
group_size
  where
    name :: Name
name = Kernel -> Name
kernelName Kernel
kernel
    num_groups :: [Exp]
num_groups = Kernel -> [Exp]
kernelNumGroups Kernel
kernel
    group_size :: [Exp]
group_size = Kernel -> [Exp]
kernelGroupSize Kernel
kernel

    prepareLocalMemory :: KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
TargetOpenCL (VName
mem, Count Bytes Exp
size) = do
      VName
mem_aligned <- forall (m :: * -> *). MonadFreshNames m => EncodedString -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> EncodedString
baseString VName
mem forall a. [a] -> [a] -> [a]
++ EncodedString
"_aligned"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
          forall a. a -> Maybe a
Just [C.cparam|__local volatile typename int64_t* $id:mem_aligned|],
          [C.citem|__local volatile unsigned char* restrict $id:mem = (__local volatile unsigned char*) $id:mem_aligned;|]
        )
    prepareLocalMemory KernelTarget
TargetCUDA (VName
mem, Count Bytes Exp
size) = do
      VName
param <- forall (m :: * -> *). MonadFreshNames m => EncodedString -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> EncodedString
baseString VName
mem forall a. [a] -> [a] -> [a]
++ EncodedString
"_offset"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
          forall a. a -> Maybe a
Just [C.cparam|uint $id:param|],
          [C.citem|volatile $ty:defaultMemBlockType $id:mem = &shared_mem[$id:param];|]
        )

useAsParam :: KernelUse -> Maybe (C.Param, [C.BlockItem])
useAsParam :: KernelUse -> Maybe (Param, [BlockItem])
useAsParam (ScalarUse VName
name PrimType
pt) = do
  let name_bits :: EncodedString
name_bits = EncodedString -> EncodedString
zEncodeString (forall a. Pretty a => a -> EncodedString
pretty VName
name) forall a. Semigroup a => a -> a -> a
<> EncodedString
"_bits"
      ctp :: Type
ctp = case PrimType
pt of
        -- OpenCL does not permit bool as a kernel parameter type.
        PrimType
Bool -> [C.cty|unsigned char|]
        PrimType
Unit -> [C.cty|unsigned char|]
        PrimType
_ -> PrimType -> Type
primStorageType PrimType
pt
  if Type
ctp forall a. Eq a => a -> a -> Bool
== PrimType -> Type
primTypeToCType PrimType
pt
    then forall a. a -> Maybe a
Just ([C.cparam|$ty:ctp $id:name|], [])
    else
      let name_bits_e :: Exp
name_bits_e = [C.cexp|$id:name_bits|]
       in forall a. a -> Maybe a
Just
            ( [C.cparam|$ty:ctp $id:name_bits|],
              [[C.citem|$ty:(primTypeToCType pt) $id:name = $exp:(fromStorage pt name_bits_e);|]]
            )
useAsParam (MemoryUse VName
name) =
  forall a. a -> Maybe a
Just ([C.cparam|__global $ty:defaultMemBlockType $id:name|], [])
useAsParam ConstUse {} =
  forall a. Maybe a
Nothing

-- Constants are #defined as macros.  Since a constant name in one
-- kernel might potentially (although unlikely) also be used for
-- something else in another kernel, we #undef them after the kernel.
constDef :: KernelUse -> Maybe (C.BlockItem, C.BlockItem)
constDef :: KernelUse -> Maybe (BlockItem, BlockItem)
constDef (ConstUse VName
v KernelConstExp
e) =
  forall a. a -> Maybe a
Just
    ( [C.citem|$escstm:def|],
      [C.citem|$escstm:undef|]
    )
  where
    e' :: Exp
e' = KernelConstExp -> Exp
compilePrimExp KernelConstExp
e
    def :: EncodedString
def = EncodedString
"#define " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
pretty (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty) forall a. [a] -> [a] -> [a]
++ EncodedString
" (" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
prettyOneLine Exp
e' forall a. [a] -> [a] -> [a]
++ EncodedString
")"
    undef :: EncodedString
undef = EncodedString
"#undef " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
pretty (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty)
constDef KernelUse
_ = forall a. Maybe a
Nothing

openClCode :: [C.Func] -> T.Text
openClCode :: [Func] -> Text
openClCode [Func]
kernels =
  forall a. Pretty a => a -> Text
prettyText [C.cunit|$edecls:funcs|]
  where
    funcs :: [Definition]
funcs =
      [ [C.cedecl|$func:kernel_func|]
        | Func
kernel_func <- [Func]
kernels
      ]

genOpenClPrelude :: S.Set PrimType -> T.Text
genOpenClPrelude :: Set PrimType -> Text
genOpenClPrelude Set PrimType
ts =
  [untrimming|
// Clang-based OpenCL implementations need this for 'static' to work.
#ifdef cl_clang_storage_class_specifiers
#pragma OPENCL EXTENSION cl_clang_storage_class_specifiers : enable
#endif
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
$enable_f64
// Some OpenCL programs dislike empty progams, or programs with no kernels.
// Declare a dummy kernel to ensure they remain our friends.
__kernel void dummy_kernel(__global unsigned char *dummy, int n)
{
    const int thread_gid = get_global_id(0);
    if (thread_gid >= n) return;
}

#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long int64_t;

typedef uchar uint8_t;
typedef ushort uint16_t;
typedef uint uint32_t;
typedef ulong uint64_t;

// NVIDIAs OpenCL does not create device-wide memory fences (see #734), so we
// use inline assembly if we detect we are on an NVIDIA GPU.
#ifdef cl_nv_pragma_unroll
static inline void mem_fence_global() {
  asm("membar.gl;");
}
#else
static inline void mem_fence_global() {
  mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
}
#endif
static inline void mem_fence_local() {
  mem_fence(CLK_LOCAL_MEM_FENCE);
}
|]
    forall a. Semigroup a => a -> a -> a
<> Text
halfH
    forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
    forall a. Semigroup a => a -> a -> a
<> Text
atomicsH
  where
    enable_f64 :: Text
enable_f64
      | FloatType -> PrimType
FloatType FloatType
Float64 forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts =
          [untrimming|
         #pragma OPENCL EXTENSION cl_khr_fp64 : enable
         #define FUTHARK_F64_ENABLED
         |]
      | Bool
otherwise = forall a. Monoid a => a
mempty

genCUDAPrelude :: T.Text
genCUDAPrelude :: Text
genCUDAPrelude =
  [untrimming|
#define FUTHARK_CUDA
#define FUTHARK_F64_ENABLED

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long long int64_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef uint8_t uchar;
typedef uint16_t ushort;
typedef uint32_t uint;
typedef uint64_t ulong;
#define __kernel extern "C" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK)
#define __global
#define __local
#define __private
#define __constant
#define __write_only
#define __read_only

static inline int get_group_id_fn(int block_dim0, int block_dim1, int block_dim2, int d) {
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch (d) {
    case 0: return blockIdx.x;
    case 1: return blockIdx.y;
    case 2: return blockIdx.z;
    default: return 0;
  }
}
#define get_group_id(d) get_group_id_fn(block_dim0, block_dim1, block_dim2, d)

static inline int get_num_groups_fn(int block_dim0, int block_dim1, int block_dim2, int d) {
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch(d) {
    case 0: return gridDim.x;
    case 1: return gridDim.y;
    case 2: return gridDim.z;
    default: return 0;
  }
}
#define get_num_groups(d) get_num_groups_fn(block_dim0, block_dim1, block_dim2, d)

static inline int get_local_id(int d) {
  switch (d) {
    case 0: return threadIdx.x;
    case 1: return threadIdx.y;
    case 2: return threadIdx.z;
    default: return 0;
  }
}

static inline int get_local_size(int d) {
  switch (d) {
    case 0: return blockDim.x;
    case 1: return blockDim.y;
    case 2: return blockDim.z;
    default: return 0;
  }
}

#define CLK_LOCAL_MEM_FENCE 1
#define CLK_GLOBAL_MEM_FENCE 2
static inline void barrier(int x) {
  __syncthreads();
}
static inline void mem_fence_local() {
  __threadfence_block();
}
static inline void mem_fence_global() {
  __threadfence();
}

#define NAN (0.0/0.0)
#define INFINITY (1.0/0.0)
extern volatile __shared__ unsigned char shared_mem[];
|]
    forall a. Semigroup a => a -> a -> a
<> Text
halfH
    forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
    forall a. Semigroup a => a -> a -> a
<> Text
atomicsH

compilePrimExp :: PrimExp KernelConst -> C.Exp
compilePrimExp :: KernelConstExp -> Exp
compilePrimExp KernelConstExp
e = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
GC.compilePrimExp forall {f :: * -> *}. Applicative f => KernelConst -> f Exp
compileKernelConst KernelConstExp
e
  where
    compileKernelConst :: KernelConst -> f Exp
compileKernelConst (SizeConst Name
key) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:(zEncodeString (pretty key))|]

kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
  where
    useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
    useToArg (ScalarUse VName
v PrimType
pt) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt) PrimType
pt
    useToArg ConstUse {} = forall a. Maybe a
Nothing

nextErrorLabel :: GC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState EncodedString
nextErrorLabel =
  KernelState -> EncodedString
errorLabel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState

incErrorLabel :: GC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
  forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s forall a. Num a => a -> a -> a
+ Int
1}

pendingError :: Bool -> GC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
  forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelSyncPending :: Bool
kernelSyncPending = Bool
b}

hasCommunication :: ImpGPU.KernelCode -> Bool
hasCommunication :: KernelCode -> Bool
hasCommunication = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
  where
    communicates :: KernelOp -> Bool
communicates ErrorSync {} = Bool
True
    communicates Barrier {} = Bool
True
    communicates KernelOp
_ = Bool
False

-- Whether we are generating code for a kernel or a device function.
-- This has minor effects, such as exactly how failures are
-- propagated.
data OpsMode = KernelMode | FunMode deriving (OpsMode -> OpsMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpsMode -> OpsMode -> Bool
$c/= :: OpsMode -> OpsMode -> Bool
== :: OpsMode -> OpsMode -> Bool
$c== :: OpsMode -> OpsMode -> Bool
Eq)

inKernelOperations ::
  OpsMode ->
  ImpGPU.KernelCode ->
  GC.Operations KernelOp KernelState
inKernelOperations :: OpsMode -> KernelCode -> Operations KernelOp KernelState
inKernelOperations OpsMode
mode KernelCode
body =
  GC.Operations
    { opsCompiler :: OpCompiler KernelOp KernelState
GC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps,
      opsMemoryType :: MemoryType KernelOp KernelState
GC.opsMemoryType = forall {m :: * -> *}. Monad m => EncodedString -> m Type
kernelMemoryType,
      opsWriteScalar :: WriteScalar KernelOp KernelState
GC.opsWriteScalar = forall {op} {s}. WriteScalar op s
kernelWriteScalar,
      opsReadScalar :: ReadScalar KernelOp KernelState
GC.opsReadScalar = forall {op} {s}. ReadScalar op s
kernelReadScalar,
      opsAllocate :: Allocate KernelOp KernelState
GC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate,
      opsDeallocate :: Deallocate KernelOp KernelState
GC.opsDeallocate = Deallocate KernelOp KernelState
cannotDeallocate,
      opsCopy :: Copy KernelOp KernelState
GC.opsCopy = Copy KernelOp KernelState
copyInKernel,
      opsStaticArray :: StaticArray KernelOp KernelState
GC.opsStaticArray = StaticArray KernelOp KernelState
noStaticArrays,
      opsFatMemory :: Bool
GC.opsFatMemory = Bool
False,
      opsError :: ErrorCompiler KernelOp KernelState
GC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel,
      opsCall :: CallCompiler KernelOp KernelState
GC.opsCall = forall {a}.
ToIdent a =>
[a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel,
      opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical = forall a. Monoid a => a
mempty
    }
  where
    has_communication :: Bool
has_communication = KernelCode -> Bool
hasCommunication KernelCode
body

    fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
    fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|]

    kernelOps :: GC.OpCompiler KernelOp KernelState
    kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetGroupId VName
v Int
i) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_group_id($int:i);|]
    kernelOps (GetLocalId VName
v Int
i) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_id($int:i);|]
    kernelOps (GetLocalSize VName
v Int
i) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_size($int:i);|]
    kernelOps (GetLockstepWidth VName
v) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
    kernelOps (Barrier Fence
f) = do
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
    kernelOps (MemFence Fence
FenceLocal) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_local();|]
    kernelOps (MemFence Fence
FenceGlobal) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_global();|]
    kernelOps (LocalAlloc VName
name Count Bytes (TExp Int64)
size) = do
      VName
name' <- forall (m :: * -> *). MonadFreshNames m => EncodedString -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> EncodedString
pretty VName
name forall a. [a] -> [a] -> [a]
++ EncodedString
"_backing"
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelLocalMemory :: [LocalMemoryUse]
kernelLocalMemory = (VName
name', forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall t v. TPrimExp t v -> PrimExp v
untyped Count Bytes (TExp Int64)
size) forall a. a -> [a] -> [a]
: KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
s}
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|]
    kernelOps (ErrorSync Fence
f) = do
      EncodedString
label <- CompilerM KernelOp KernelState EncodedString
nextErrorLabel
      Bool
pending <- KernelState -> Bool
kernelSyncPending forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending forall a b. (a -> b) -> a -> b
$ do
        Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
        forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
        forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (local_failure) { return; }|]
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
      CompilerM KernelOp KernelState ()
incErrorLabel
    kernelOps (Atomic Space
space AtomicOp
aop) = forall {op} {s}. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop

    atomicCast :: Space -> Type -> m Type
atomicCast Space
s Type
t = do
      let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
      [TypeQual]
quals <- case Space
s of
        Space EncodedString
sid -> forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
sid
        Space
_ -> forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
"global"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volatile++quals) $ty:t|]

    atomicSpace :: Space -> EncodedString
atomicSpace (Space EncodedString
sid) = EncodedString
sid
    atomicSpace Space
_ = EncodedString
"global"

    doAtomic :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val EncodedString
op Type
ty = do
      Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall t v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall u e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- forall {m :: * -> *}. Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
      where
        op' :: EncodedString
op' = EncodedString
op forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
pretty p
t forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ Space -> EncodedString
atomicSpace Space
s

    doAtomicCmpXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
cmp Exp
val Type
ty = do
      Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall t v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall u e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
cmp' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
cmp
      Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- forall {m :: * -> *}. Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
      where
        op :: EncodedString
op = EncodedString
"atomic_cmpxchg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
pretty p
t forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ Space -> EncodedString
atomicSpace Space
s
    doAtomicXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val Type
ty = do
      Type
cast <- forall {m :: * -> *}. Monad m => Space -> Type -> m Type
atomicCast Space
s Type
ty
      Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall t v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall u e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
      where
        op :: EncodedString
op = EncodedString
"atomic_chg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> EncodedString
pretty p
t forall a. [a] -> [a] -> [a]
++ EncodedString
"_" forall a. [a] -> [a] -> [a]
++ Space -> EncodedString
atomicSpace Space
s
    -- First the 64-bit operations.
    atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_add" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicFAdd FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_fadd" [C.cty|double|]
    atomicOps Space
s (AtomicSMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smax" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicSMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smin" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicUMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umax" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicUMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umin" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicAnd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_and" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicOr IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_or" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXor IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_xor" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicCmpXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|typename int64_t|]
    --
    atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_add" [C.cty|int|]
    atomicOps Space
s (AtomicFAdd FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_fadd" [C.cty|float|]
    atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smax" [C.cty|int|]
    atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_smin" [C.cty|int|]
    atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umax" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_umin" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_and" [C.cty|int|]
    atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_or" [C.cty|int|]
    atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> EncodedString
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val EncodedString
"atomic_xor" [C.cty|int|]
    atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|int|]
    atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {a} {a} {p} {u} {t} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|int|]

    cannotAllocate :: GC.Allocate KernelOp KernelState
    cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
      forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot allocate memory in kernel"

    cannotDeallocate :: GC.Deallocate KernelOp KernelState
    cannotDeallocate :: Deallocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
      forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot deallocate memory in kernel"

    copyInKernel :: GC.Copy KernelOp KernelState
    copyInKernel :: Copy KernelOp KernelState
copyInKernel CopyBarrier
_ Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot bulk copy in kernel."

    noStaticArrays :: GC.StaticArray KernelOp KernelState
    noStaticArrays :: StaticArray KernelOp KernelState
noStaticArrays VName
_ EncodedString
_ PrimType
_ ArrayContents
_ =
      forall a. HasCallStack => EncodedString -> a
error EncodedString
"Cannot create static array in kernel."

    kernelMemoryType :: EncodedString -> m Type
kernelMemoryType EncodedString
space = do
      [TypeQual]
quals <- forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals EncodedString
space
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:quals $ty:defaultMemBlockType|]

    kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
      forall op s. PointerQuals op s -> WriteScalar op s
GC.writeScalarPointerWithQuals forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals

    kernelReadScalar :: ReadScalar op s
kernelReadScalar =
      forall op s. PointerQuals op s -> ReadScalar op s
GC.readScalarPointerWithQuals forall (m :: * -> *). Monad m => EncodedString -> m [TypeQual]
pointerQuals

    whatNext :: CompilerM KernelOp KernelState [BlockItem]
whatNext = do
      EncodedString
label <- CompilerM KernelOp KernelState EncodedString
nextErrorLabel
      Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        if Bool
has_communication
          then [C.citems|local_failure = true; goto $id:label;|]
          else
            if OpsMode
mode forall a. Eq a => a -> a -> Bool
== OpsMode
FunMode
              then [C.citems|return 1;|]
              else [C.citems|return;|]

    callInKernel :: [a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel [a]
dests Name
fname [Exp]
args = do
      let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
          args' :: [Exp]
args' =
            [C.cexp|global_failure|]
              forall a. a -> [a] -> [a]
: [C.cexp|global_failure_args|]
              forall a. a -> [a] -> [a]
: [Exp]
out_args
              forall a. [a] -> [a] -> [a]
++ [Exp]
args

      [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext

      forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|]

    errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) EncodedString
backtrace = do
      Int
n <- forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> EncodedString -> FailureMsg
FailureMsg ErrorMsg Exp
msg EncodedString
backtrace]}
      let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
          setArgs a
i (ErrorString {} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
          -- FIXME: bogus for non-ints.
          setArgs a
i (ErrorVal PrimType
_ Exp
x : [ErrorMsgPart Exp]
parts') = do
            Exp
x' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
            [Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] forall a. a -> [a] -> [a]
: [Stm]
stms
      [Stm]
argstms <- forall {a} {op} {s}.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0 :: Int) [ErrorMsgPart Exp]
parts

      [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext

      forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
                                 { $stms:argstms; }
                                 $items:what_next
                               }|]

--- Checking requirements

typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = KernelCode -> Set PrimType
typesInCode forall a b. (a -> b) -> a -> b
$ Kernel -> KernelCode
kernelBody Kernel
kernel

typesInCode :: ImpGPU.KernelCode -> S.Set PrimType
typesInCode :: KernelCode -> Set PrimType
typesInCode KernelCode
Skip = forall a. Monoid a => a
mempty
typesInCode (KernelCode
c1 :>>: KernelCode
c2) = KernelCode -> Set PrimType
typesInCode KernelCode
c1 forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c2
typesInCode (For VName
_ Exp
e KernelCode
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode (While (TPrimExp Exp
e) KernelCode
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode DeclareMem {} = forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ Space
_ PrimType
t ArrayContents
_) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count (TPrimExp Exp
e)) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free {} = forall a. Monoid a => a
mempty
typesInCode (Copy PrimType
_ VName
_ (Count (TPrimExp Exp
e1)) Space
_ VName
_ (Count (TPrimExp Exp
e2)) Space
_ (Count (TPrimExp Exp
e3))) =
  Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e3
typesInCode (Write VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_ Exp
e2) =
  Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (Read VName
_ VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_) =
  Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem {} = forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ Name
_ [Arg]
es) = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
  where
    typesInArg :: Arg -> Set PrimType
typesInArg MemArg {} = forall a. Monoid a => a
mempty
    typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If (TPrimExp Exp
e) KernelCode
c1 KernelCode
c2) =
  Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c1 forall a. Semigroup a => a -> a -> a
<> KernelCode -> Set PrimType
typesInCode KernelCode
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment EncodedString
_ KernelCode
c) = KernelCode -> Set PrimType
typesInCode KernelCode
c
typesInCode (DebugPrint EncodedString
_ Maybe Exp
v) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode (TracePrint ErrorMsg Exp
msg) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> Set PrimType
typesInExp ErrorMsg Exp
msg
typesInCode Op {} = forall a. Monoid a => a
mempty

typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
  where
    (PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp EncodedString
_ [Exp]
args PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp LeafExp {} = forall a. Monoid a => a
mempty