{-# LANGUAGE QuasiQuotes #-}

-- | This module defines a translation from imperative code with
-- kernels to imperative code with OpenCL or CUDA calls.
module Futhark.CodeGen.ImpGen.GPU.ToOpenCL
  ( kernelsToOpenCL,
    kernelsToCUDA,
  )
where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Fun qualified as GC
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode.GPU hiding (Program)
import Futhark.CodeGen.ImpCode.GPU qualified as ImpGPU
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import Futhark.CodeGen.ImpCode.OpenCL qualified as ImpOpenCL
import Futhark.CodeGen.RTS.C (atomicsH, halfH)
import Futhark.Error (compilerLimitationS)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeText)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C
import NeatInterpolation (untrimming)

-- | Generate CUDA host and device code.
kernelsToCUDA :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetCUDA

-- | Generate OpenCL host and device code.
kernelsToOpenCL :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetOpenCL

-- | Translate a kernels-program to an OpenCL-program.
translateGPU ::
  KernelTarget ->
  ImpGPU.Program ->
  ImpOpenCL.Program
translateGPU :: KernelTarget -> Program -> Program
translateGPU KernelTarget
target Program
prog =
  let ( Definitions OpenCL
prog',
        ToOpenCL Map Name (KernelSafety, Text)
kernels Map Name (Definition, Text)
device_funs Set PrimType
used_types Map Name SizeClass
sizes [FailureMsg]
failures
        ) =
          (forall s a. State s a -> s -> (a, s)
`runState` ToOpenCL
initialOpenCL) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Program -> Env
envFromProg Program
prog) forall a b. (a -> b) -> a -> b
$ do
            let ImpGPU.Definitions
                  OpaqueTypes
types
                  (ImpGPU.Constants [Param]
ps Code HostOp
consts)
                  (ImpGPU.Functions [(Name, Function HostOp)]
funs) = Program
prog
            Code OpenCL
consts' <- forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Code HostOp
consts
            [(Name, FunctionT OpenCL)]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Name, Function HostOp)]
funs forall a b. (a -> b) -> a -> b
$ \(Name
fname, Function HostOp
fun) ->
              (Name
fname,) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target) Function HostOp
fun

            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
              forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
                OpaqueTypes
types
                (forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
                (forall a. [(Name, Function a)] -> Functions a
ImpOpenCL.Functions [(Name, FunctionT OpenCL)]
funs')

      ([Definition]
device_prototypes, [Text]
device_defs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (Definition, Text)
device_funs
      kernels' :: Map Name KernelSafety
kernels' = forall a b k. (a -> b) -> Map k a -> Map k b
M.map forall a b. (a, b) -> a
fst Map Name (KernelSafety, Text)
kernels
      opencl_code :: Text
opencl_code = [Text] -> Text
T.unlines forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [a]
M.elems Map Name (KernelSafety, Text)
kernels

      opencl_prelude :: Text
opencl_prelude =
        [Text] -> Text
T.unlines
          [ KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
target Set PrimType
used_types,
            [Definition] -> Text
definitionsText [Definition]
device_prototypes,
            [Text] -> Text
T.unlines [Text]
device_defs
          ]
   in Text
-> Text
-> Map Name KernelSafety
-> [PrimType]
-> Map Name SizeClass
-> [FailureMsg]
-> Definitions OpenCL
-> Program
ImpOpenCL.Program
        Text
opencl_code
        Text
opencl_prelude
        Map Name KernelSafety
kernels'
        (forall a. Set a -> [a]
S.toList Set PrimType
used_types)
        (Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
sizes)
        [FailureMsg]
failures
        Definitions OpenCL
prog'
  where
    genPrelude :: KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> Text
genOpenClPrelude
    genPrelude KernelTarget
TargetCUDA = forall a b. a -> b -> a
const Text
genCUDAPrelude

-- | Due to simplifications after kernel extraction, some threshold
-- parameters may contain KernelPaths that reference threshold
-- parameters that no longer exist.  We remove these here.
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
m = forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map Name SizeClass
m
  where
    known :: [Name]
known = forall k a. Map k a -> [k]
M.keys Map Name SizeClass
m
    clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path Maybe Int64
def) =
      KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold (forall a. (a -> Bool) -> [a] -> [a]
filter ((forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
known) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) KernelPath
path) Maybe Int64
def
    clean SizeClass
s = SizeClass
s

pointerQuals :: String -> [C.TypeQual]
pointerQuals :: [Char] -> [TypeQual]
pointerQuals [Char]
"global" = [C.ctyquals|__global|]
pointerQuals [Char]
"local" = [C.ctyquals|__local|]
pointerQuals [Char]
"private" = [C.ctyquals|__private|]
pointerQuals [Char]
"constant" = [C.ctyquals|__constant|]
pointerQuals [Char]
"write_only" = [C.ctyquals|__write_only|]
pointerQuals [Char]
"read_only" = [C.ctyquals|__read_only|]
pointerQuals [Char]
"kernel" = [C.ctyquals|__kernel|]
-- OpenCL does not actually have a "device" space, but we use it in
-- the compiler pipeline to defer to memory on the device, as opposed
-- to the host.  From a kernel's perspective, this is "global".
pointerQuals [Char]
"device" = [Char] -> [TypeQual]
pointerQuals [Char]
"global"
pointerQuals [Char]
s = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"'" forall a. [a] -> [a] -> [a]
++ [Char]
s forall a. [a] -> [a] -> [a]
++ [Char]
"' is not an OpenCL kernel address space."

-- In-kernel name and per-workgroup size in bytes.
type LocalMemoryUse = (VName, Count Bytes Exp)

data KernelState = KernelState
  { KernelState -> [LocalMemoryUse]
kernelLocalMemory :: [LocalMemoryUse],
    KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg],
    KernelState -> Int
kernelNextSync :: Int,
    -- | Has a potential failure occurred sine the last
    -- ErrorSync?
    KernelState -> Bool
kernelSyncPending :: Bool,
    KernelState -> Bool
kernelHasBarriers :: Bool
  }

newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [LocalMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False

errorLabel :: KernelState -> String
errorLabel :: KernelState -> [Char]
errorLabel = ([Char]
"error_" ++) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> [Char]
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync

data ToOpenCL = ToOpenCL
  { ToOpenCL -> Map Name (KernelSafety, Text)
clGPU :: M.Map KernelName (KernelSafety, T.Text),
    ToOpenCL -> Map Name (Definition, Text)
clDevFuns :: M.Map Name (C.Definition, T.Text),
    ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType,
    ToOpenCL -> Map Name SizeClass
clSizes :: M.Map Name SizeClass,
    ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg]
  }

initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map Name (KernelSafety, Text)
-> Map Name (Definition, Text)
-> Set PrimType
-> Map Name SizeClass
-> [FailureMsg]
-> ToOpenCL
ToOpenCL forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty forall a. Monoid a => a
mempty

data Env = Env
  { Env -> Functions HostOp
envFuns :: ImpGPU.Functions ImpGPU.HostOp,
    Env -> Set Name
envFunsMayFail :: S.Set Name
  }

codeMayFail :: (a -> Bool) -> ImpGPU.Code a -> Bool
codeMayFail :: forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
_ (Assert {}) = Bool
True
codeMayFail a -> Bool
f (Op a
x) = a -> Bool
f a
x
codeMayFail a -> Bool
f (Code a
x :>>: Code a
y) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (For VName
_ Exp
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (While TExp Bool
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (If TExp Bool
_ Code a
x Code a
y) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (Comment Text
_ Code a
x) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
_ Code a
_ = Bool
False

hostOpMayFail :: ImpGPU.HostOp -> Bool
hostOpMayFail :: HostOp -> Bool
hostOpMayFail (CallKernel Kernel
k) = forall a. (a -> Bool) -> Code a -> Bool
codeMayFail KernelOp -> Bool
kernelOpMayFail forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
hostOpMayFail HostOp
_ = Bool
False

kernelOpMayFail :: ImpGPU.KernelOp -> Bool
kernelOpMayFail :: KernelOp -> Bool
kernelOpMayFail = forall a b. a -> b -> a
const Bool
False

funsMayFail :: M.Map Name (S.Set Name) -> ImpGPU.Functions ImpGPU.HostOp -> S.Set Name
funsMayFail :: Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg (Functions [(Name, Function HostOp)]
funs) =
  forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter forall {b}. (Name, b) -> Bool
mayFail [(Name, Function HostOp)]
funs
  where
    base_mayfail :: [Name]
base_mayfail =
      forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (forall a. (a -> Bool) -> Code a -> Bool
codeMayFail HostOp -> Bool
hostOpMayFail forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FunctionT a -> Code a
ImpGPU.functionBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Name, Function HostOp)]
funs
    mayFail :: (Name, b) -> Bool
mayFail (Name
fname, b
_) =
      forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
base_mayfail) forall a b. (a -> b) -> a -> b
$ Name
fname forall a. a -> [a] -> [a]
: forall a. Set a -> [a]
S.toList (forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault forall a. Monoid a => a
mempty Name
fname Map Name (Set Name)
cg)

envFromProg :: ImpGPU.Program -> Env
envFromProg :: Program -> Env
envFromProg Program
prog = Functions HostOp -> Set Name -> Env
Env Functions HostOp
funs (Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg Functions HostOp
funs)
  where
    funs :: Functions HostOp
funs = forall a. Definitions a -> Functions a
defFuns Program
prog
    cg :: Map Name (Set Name)
cg = forall a. (a -> Set Name) -> Functions a -> Map Name (Set Name)
ImpGPU.callGraph HostOp -> Set Name
calledInHostOp Functions HostOp
funs

lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function HostOp)
lookupFunction :: Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname = forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Functions a -> [(Name, Function a)]
unFunctions forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Functions HostOp
envFuns

functionMayFail :: Name -> Env -> Bool
functionMayFail :: Name -> Env -> Bool
functionMayFail Name
fname = forall a. Ord a => a -> Set a -> Bool
S.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Set Name
envFunsMayFail

type OnKernelM = ReaderT Env (State ToOpenCL)

addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
sclass =
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s {clSizes :: Map Name SizeClass
clSizes = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
key SizeClass
sclass forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name SizeClass
clSizes ToOpenCL
s}

onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpGPU.GetSize VName
v Name
key SizeClass
size_class) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> OpenCL
ImpOpenCL.GetSize VName
v Name
key
onHostOp KernelTarget
_ (ImpGPU.CmpSizeLe VName
v Name
key SizeClass
size_class Exp
x) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Name -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v Name
key Exp
x
onHostOp KernelTarget
_ (ImpGPU.GetSizeMax VName
v SizeClass
size_class) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class

genGPUCode ::
  Env ->
  OpsMode ->
  KernelCode ->
  [FailureMsg] ->
  GC.CompilerM KernelOp KernelState a ->
  (a, GC.CompilerState KernelState)
genGPUCode :: forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
mode Code KernelOp
body [FailureMsg]
failures =
  forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GC.runCompilerM
    (Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body)
    VNameSource
blankNameSource
    ([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures)

-- Compilation of a device function that is not not invoked from the
-- host, but is invoked by (perhaps multiple) kernels.
generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
generateDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
device_func = do
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Param -> Bool
memParam forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> [Param]
functionInput Function KernelOp
device_func) forall {a}. a
bad

  Env
env <- forall r (m :: * -> *). MonadReader r m => m r
ask
  [FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures

  let ((Definition, Func)
func, KernelState
kstate) =
        if Name -> Env -> Bool
functionMayFail Name
fname Env
env
          then
            let params :: [Param]
params =
                  [ [C.cparam|__global int *global_failure|],
                    [C.cparam|__global typename int64_t *global_failure_args|]
                  ]
                ((Definition, Func)
f, CompilerState KernelState
cstate) =
                  forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
                    forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
GC.compileFun forall a. Monoid a => a
mempty [Param]
params (Name
fname, Function KernelOp
device_func)
             in ((Definition, Func)
f, forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)
          else
            let ((Definition, Func)
f, CompilerState KernelState
cstate) =
                  forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures forall a b. (a -> b) -> a -> b
$
                    forall op s.
[BlockItem]
-> (Name, Function op) -> CompilerM op s (Definition, Func)
GC.compileVoidFun forall a. Monoid a => a
mempty (Name
fname, Function KernelOp
device_func)
             in ((Definition, Func)
f, forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)

  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clUsedTypes :: Set PrimType
clUsedTypes = Code KernelOp -> Set PrimType
typesInCode (forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clDevFuns :: Map Name (Definition, Text)
clDevFuns = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
fname (forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Func -> Text
funcText (Definition, Func)
func) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (Definition, Text)
clDevFuns ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- Important to do this after the 'modify' call, so we propagate the
  -- right clFailures.
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func
  where
    memParam :: Param -> Bool
memParam MemParam {} = Bool
True
    memParam ScalarParam {} = Bool
False

    bad :: a
bad = forall a. [Char] -> a
compilerLimitationS [Char]
"Cannot generate GPU functions that use arrays."

-- Ensure that this device function is available, but don't regenerate
-- it if it already exists.
ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
ensureDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
host_func = do
  Bool
exists <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => k -> Map k a -> Bool
M.member Name
fname forall b c a. (b -> c) -> (a -> b) -> a -> c
. ToOpenCL -> Map Name (Definition, Text)
clDevFuns
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists forall a b. (a -> b) -> a -> b
$ Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
host_func

calledInHostOp :: HostOp -> S.Set Name
calledInHostOp :: HostOp -> Set Name
calledInHostOp (CallKernel Kernel
k) = forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
calledInHostOp HostOp
_ = forall a. Monoid a => a
mempty

calledInKernelOp :: KernelOp -> S.Set Name
calledInKernelOp :: KernelOp -> Set Name
calledInKernelOp = forall a b. a -> b -> a
const forall a. Monoid a => a
mempty

ensureDeviceFuns :: ImpGPU.KernelCode -> OnKernelM [Name]
ensureDeviceFuns :: Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns Code KernelOp
code = do
  let called :: Set Name
called = forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp Code KernelOp
code
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. [Maybe a] -> [a]
catMaybes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a. Set a -> [a]
S.toList Set Name
called) forall a b. (a -> b) -> a -> b
$ \Name
fname -> do
    Maybe (Function HostOp)
def <- forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks forall a b. (a -> b) -> a -> b
$ Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname
    case Maybe (Function HostOp)
def of
      Just Function HostOp
host_func -> do
        -- Functions are a priori always considered host-level, so we have
        -- to convert them to device code.  This is where most of our
        -- limitations on device-side functions (no arrays, no parallelism)
        -- comes from.
        let device_func :: Function KernelOp
device_func = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> KernelOp
toDevice Function HostOp
host_func
        Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
device_func
        forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just Name
fname
      Maybe (Function HostOp)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
  where
    bad :: a
bad = forall a. [Char] -> a
compilerLimitationS [Char]
"Cannot generate GPU functions that contain parallelism."
    toDevice :: HostOp -> KernelOp
    toDevice :: HostOp -> KernelOp
toDevice HostOp
_ = forall {a}. a
bad

isConst :: GroupDim -> Maybe T.Text
isConst :: GroupDim -> Maybe Text
isConst (Left (ValueExp (IntValue IntValue
x))) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ IntValue -> Int64
intToInt64 IntValue
x
isConst (Right (SizeConst Name
v)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
v
isConst (Right (SizeMaxConst SizeClass
size_class)) =
  forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Text
"max_" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText SizeClass
size_class
isConst GroupDim
_ = forall a. Maybe a
Nothing

onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel KernelTarget
target Kernel
kernel = do
  [Name]
called <- Code KernelOp -> OnKernelM [Name]
ensureDeviceFuns forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel

  -- Crucial that this is done after 'ensureDeviceFuns', as the device
  -- functions may themselves define failure points.
  [FailureMsg]
failures <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
  Env
env <- forall r (m :: * -> *). MonadReader r m => m r
ask

  let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
        forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
KernelMode (Kernel -> Code KernelOp
kernelBody Kernel
kernel) [FailureMsg]
failures forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ do
          [BlockItem]
body <- forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect forall a b. (a -> b) -> a -> b
$ forall op s. Code op -> CompilerM op s ()
GC.compileCode forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
          -- No need to free, as we cannot allocate memory in kernels.
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall op s. CompilerM op s [BlockItem]
GC.declAllocatedMem
          forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall op s. BlockItem -> CompilerM op s ()
GC.item [BlockItem]
body
      kstate :: KernelState
kstate = forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate

      ([Maybe KernelArg]
local_memory_args, [Maybe Param]
local_memory_params, [BlockItem]
local_memory_init) =
        forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall s a. State s a -> s -> a
evalState (VNameSource
blankNameSource :: VNameSource) forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall {m :: * -> *}.
MonadFreshNames m =>
KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
target) forall a b. (a -> b) -> a -> b
$
            KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
kstate

      -- CUDA has very strict restrictions on the number of blocks
      -- permitted along the 'y' and 'z' dimensions of the grid
      -- (1<<16).  To work around this, we are going to dynamically
      -- permute the block dimensions to move the largest one to the
      -- 'x' dimension, which has a higher limit (1<<31).  This means
      -- we need to extend the kernel with extra parameters that
      -- contain information about this permutation, but we only do
      -- this for multidimensional kernels (at the time of this
      -- writing, only transposes).  The corresponding arguments are
      -- added automatically in CCUDA.hs.
      ([Param]
perm_params, [BlockItem]
block_dim_init) =
        case (KernelTarget
target, [Exp]
num_groups) of
          (KernelTarget
TargetCUDA, [Exp
_, Exp
_, Exp
_]) ->
            ( [ [C.cparam|const int block_dim0|],
                [C.cparam|const int block_dim1|],
                [C.cparam|const int block_dim2|]
              ],
              forall a. Monoid a => a
mempty
            )
          (KernelTarget, [Exp])
_ ->
            ( forall a. Monoid a => a
mempty,
              [ [C.citem|const int block_dim0 = 0;|],
                [C.citem|const int block_dim1 = 1;|],
                [C.citem|const int block_dim2 = 2;|]
              ]
            )

      ([BlockItem]
const_defs, [BlockItem]
const_undefs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (BlockItem, BlockItem)
constDef forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  let ([Param]
use_params, [[BlockItem]]
unpack_params) =
        forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (Param, [BlockItem])
useAsParam forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  -- The local_failure variable is an int despite only really storing
  -- a single bit of information, as some OpenCL implementations
  -- (e.g. AMD) does not like byte-sized local memory (and the others
  -- likely pad to a whole word anyway).
  let (KernelSafety
safety, [BlockItem]
error_init)
        -- We conservatively assume that any called function can fail.
        | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
called =
            ( KernelSafety
SafetyFull,
              [C.citems|volatile __local int local_failure;
                        // Harmless for all threads to write this.
                        local_failure = 0;|]
            )
        | forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
            if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
              then (KernelSafety
SafetyNone, [])
              else -- No possible failures in this kernel, so if we make
              -- it past an initial check, then we are good to go.

                ( KernelSafety
SafetyCheap,
                  [C.citems|if (*global_failure >= 0) { return; }|]
                )
        | Bool
otherwise =
            if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
              then
                ( KernelSafety
SafetyFull,
                  [C.citems|if (*global_failure >= 0) { return; }|]
                )
              else
                ( KernelSafety
SafetyFull,
                  [C.citems|
                     volatile __local int local_failure;
                     if (failure_is_an_option) {
                       int failed = *global_failure >= 0;
                       if (failed) {
                         return;
                       }
                     }
                     // All threads write this value - it looks like CUDA has a compiler bug otherwise.
                     local_failure = 0;
                     barrier(CLK_LOCAL_MEM_FENCE);
                  |]
                )

      failure_params :: [Param]
failure_params =
        [ [C.cparam|__global int *global_failure|],
          [C.cparam|int failure_is_an_option|],
          [C.cparam|__global typename int64_t *global_failure_args|]
        ]

      params :: [Param]
params =
        [Param]
perm_params
          forall a. [a] -> [a] -> [a]
++ forall a. Int -> [a] -> [a]
take (KernelSafety -> Int
numFailureParams KernelSafety
safety) [Param]
failure_params
          forall a. [a] -> [a] -> [a]
++ forall a. [Maybe a] -> [a]
catMaybes [Maybe Param]
local_memory_params
          forall a. [a] -> [a] -> [a]
++ [Param]
use_params

      attribute :: Text
attribute =
        case (KernelTarget
target, forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM GroupDim -> Maybe Text
isConst forall a b. (a -> b) -> a -> b
$ Kernel -> [GroupDim]
kernelGroupSize Kernel
kernel) of
          (KernelTarget
TargetOpenCL, Just [Text
x, Text
y, Text
z]) ->
            Text
"__attribute__((reqd_work_group_size" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Text
y, Text
z) forall a. Semigroup a => a -> a -> a
<> Text
"))\n"
          (KernelTarget
TargetOpenCL, Just [Text
x, Text
y]) ->
            Text
"__attribute__((reqd_work_group_size" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Text
y, Int
1 :: Int) forall a. Semigroup a => a -> a -> a
<> Text
"))\n"
          (KernelTarget
TargetOpenCL, Just [Text
x]) ->
            Text
"__attribute__((reqd_work_group_size" forall a. Semigroup a => a -> a -> a
<> forall a. Pretty a => a -> Text
prettyText (Text
x, Int
1 :: Int, Int
1 :: Int) forall a. Semigroup a => a -> a -> a
<> Text
"))\n"
          (KernelTarget, Maybe [Text])
_ -> Text
""

      kernel_fun :: Text
kernel_fun =
        Text
attribute
          forall a. Semigroup a => a -> a -> a
<> Func -> Text
funcText
            [C.cfun|__kernel void $id:name ($params:params) {
                    $items:(mconcat unpack_params)
                    $items:const_defs
                    $items:block_dim_init
                    $items:local_memory_init
                    $items:error_init
                    $items:kernel_body

                    $id:(errorLabel kstate): return;

                    $items:const_undefs
                }|]
  forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clGPU :: Map Name (KernelSafety, Text)
clGPU = forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name (KernelSafety
safety, Text
kernel_fun) forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (KernelSafety, Text)
clGPU ToOpenCL
s,
        clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- The argument corresponding to the global_failure parameters is
  -- added automatically later.
  let args :: [KernelArg]
args = forall a. [Maybe a] -> [a]
catMaybes [Maybe KernelArg]
local_memory_args forall a. [a] -> [a] -> [a]
++ Kernel -> [KernelArg]
kernelArgs Kernel
kernel

  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ KernelSafety
-> Name -> [KernelArg] -> [Exp] -> [GroupDim] -> OpenCL
LaunchKernel KernelSafety
safety Name
name [KernelArg]
args [Exp]
num_groups [GroupDim]
group_size
  where
    name :: Name
name = Kernel -> Name
kernelName Kernel
kernel
    num_groups :: [Exp]
num_groups = Kernel -> [Exp]
kernelNumGroups Kernel
kernel
    group_size :: [GroupDim]
group_size = Kernel -> [GroupDim]
kernelGroupSize Kernel
kernel

    prepareLocalMemory :: KernelTarget
-> LocalMemoryUse -> m (Maybe KernelArg, Maybe Param, BlockItem)
prepareLocalMemory KernelTarget
TargetOpenCL (VName
mem, Count Bytes Exp
size) = do
      VName
mem_aligned <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
mem forall a. [a] -> [a] -> [a]
++ [Char]
"_aligned"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
          forall a. a -> Maybe a
Just [C.cparam|__local volatile typename int64_t* $id:mem_aligned|],
          [C.citem|__local volatile unsigned char* restrict $id:mem = (__local volatile unsigned char*) $id:mem_aligned;|]
        )
    prepareLocalMemory KernelTarget
TargetCUDA (VName
mem, Count Bytes Exp
size) = do
      VName
param <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString VName
mem forall a. [a] -> [a] -> [a]
++ [Char]
"_offset"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure
        ( forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Count Bytes Exp -> KernelArg
SharedMemoryKArg Count Bytes Exp
size,
          forall a. a -> Maybe a
Just [C.cparam|uint $id:param|],
          [C.citem|volatile $ty:defaultMemBlockType $id:mem = &shared_mem[$id:param];|]
        )

useAsParam :: KernelUse -> Maybe (C.Param, [C.BlockItem])
useAsParam :: KernelUse -> Maybe (Param, [BlockItem])
useAsParam (ScalarUse VName
name PrimType
pt) = do
  let name_bits :: Text
name_bits = Text -> Text
zEncodeText (forall a. Pretty a => a -> Text
prettyText VName
name) forall a. Semigroup a => a -> a -> a
<> Text
"_bits"
      ctp :: Type
ctp = case PrimType
pt of
        -- OpenCL does not permit bool as a kernel parameter type.
        PrimType
Bool -> [C.cty|unsigned char|]
        PrimType
Unit -> [C.cty|unsigned char|]
        PrimType
_ -> PrimType -> Type
primStorageType PrimType
pt
  if Type
ctp forall a. Eq a => a -> a -> Bool
== PrimType -> Type
primTypeToCType PrimType
pt
    then forall a. a -> Maybe a
Just ([C.cparam|$ty:ctp $id:name|], [])
    else
      let name_bits_e :: Exp
name_bits_e = [C.cexp|$id:name_bits|]
       in forall a. a -> Maybe a
Just
            ( [C.cparam|$ty:ctp $id:name_bits|],
              [[C.citem|$ty:(primTypeToCType pt) $id:name = $exp:(fromStorage pt name_bits_e);|]]
            )
useAsParam (MemoryUse VName
name) =
  forall a. a -> Maybe a
Just ([C.cparam|__global $ty:defaultMemBlockType $id:name|], [])
useAsParam ConstUse {} =
  forall a. Maybe a
Nothing

-- Constants are #defined as macros.  Since a constant name in one
-- kernel might potentially (although unlikely) also be used for
-- something else in another kernel, we #undef them after the kernel.
constDef :: KernelUse -> Maybe (C.BlockItem, C.BlockItem)
constDef :: KernelUse -> Maybe (BlockItem, BlockItem)
constDef (ConstUse VName
v KernelConstExp
e) =
  forall a. a -> Maybe a
Just
    ( [C.citem|$escstm:(T.unpack def)|],
      [C.citem|$escstm:(T.unpack undef)|]
    )
  where
    e' :: Exp
e' = KernelConstExp -> Exp
compilePrimExp KernelConstExp
e
    def :: Text
def = Text
"#define " forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty) forall a. Semigroup a => a -> a -> a
<> Text
" (" forall a. Semigroup a => a -> a -> a
<> Exp -> Text
expText Exp
e' forall a. Semigroup a => a -> a -> a
<> Text
")"
    undef :: Text
undef = Text
"#undef " forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v forall a. Monoid a => a
mempty)
constDef KernelUse
_ = forall a. Maybe a
Nothing

genOpenClPrelude :: S.Set PrimType -> T.Text
genOpenClPrelude :: Set PrimType -> Text
genOpenClPrelude Set PrimType
ts =
  [untrimming|
// Clang-based OpenCL implementations need this for 'static' to work.
#ifdef cl_clang_storage_class_specifiers
#pragma OPENCL EXTENSION cl_clang_storage_class_specifiers : enable
#endif
#pragma OPENCL EXTENSION cl_khr_byte_addressable_store : enable
$enable_f64
// Some OpenCL programs dislike empty progams, or programs with no kernels.
// Declare a dummy kernel to ensure they remain our friends.
__kernel void dummy_kernel(__global unsigned char *dummy, int n)
{
    const int thread_gid = get_global_id(0);
    if (thread_gid >= n) return;
}

#pragma OPENCL EXTENSION cl_khr_int64_base_atomics : enable
#pragma OPENCL EXTENSION cl_khr_int64_extended_atomics : enable

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long int64_t;

typedef uchar uint8_t;
typedef ushort uint16_t;
typedef uint uint32_t;
typedef ulong uint64_t;

// NVIDIAs OpenCL does not create device-wide memory fences (see #734), so we
// use inline assembly if we detect we are on an NVIDIA GPU.
#ifdef cl_nv_pragma_unroll
static inline void mem_fence_global() {
  asm("membar.gl;");
}
#else
static inline void mem_fence_global() {
  mem_fence(CLK_LOCAL_MEM_FENCE | CLK_GLOBAL_MEM_FENCE);
}
#endif
static inline void mem_fence_local() {
  mem_fence(CLK_LOCAL_MEM_FENCE);
}
|]
    forall a. Semigroup a => a -> a -> a
<> Text
halfH
    forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
    forall a. Semigroup a => a -> a -> a
<> Text
atomicsH
  where
    enable_f64 :: Text
enable_f64
      | FloatType -> PrimType
FloatType FloatType
Float64 forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts =
          [untrimming|
         #pragma OPENCL EXTENSION cl_khr_fp64 : enable
         #define FUTHARK_F64_ENABLED
         |]
      | Bool
otherwise = forall a. Monoid a => a
mempty

genCUDAPrelude :: T.Text
genCUDAPrelude :: Text
genCUDAPrelude =
  [untrimming|
#define FUTHARK_CUDA
#define FUTHARK_F64_ENABLED

typedef char int8_t;
typedef short int16_t;
typedef int int32_t;
typedef long long int64_t;
typedef unsigned char uint8_t;
typedef unsigned short uint16_t;
typedef unsigned int uint32_t;
typedef unsigned long long uint64_t;
typedef uint8_t uchar;
typedef uint16_t ushort;
typedef uint32_t uint;
typedef uint64_t ulong;
#define __kernel extern "C" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK)
#define __global
#define __local
#define __private
#define __constant
#define __write_only
#define __read_only

static inline int get_group_id_fn(int block_dim0, int block_dim1, int block_dim2, int d) {
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch (d) {
    case 0: return blockIdx.x;
    case 1: return blockIdx.y;
    case 2: return blockIdx.z;
    default: return 0;
  }
}
#define get_group_id(d) get_group_id_fn(block_dim0, block_dim1, block_dim2, d)

static inline int get_num_groups_fn(int block_dim0, int block_dim1, int block_dim2, int d) {
  switch (d) {
    case 0: d = block_dim0; break;
    case 1: d = block_dim1; break;
    case 2: d = block_dim2; break;
  }
  switch(d) {
    case 0: return gridDim.x;
    case 1: return gridDim.y;
    case 2: return gridDim.z;
    default: return 0;
  }
}
#define get_num_groups(d) get_num_groups_fn(block_dim0, block_dim1, block_dim2, d)

static inline int get_local_id(int d) {
  switch (d) {
    case 0: return threadIdx.x;
    case 1: return threadIdx.y;
    case 2: return threadIdx.z;
    default: return 0;
  }
}

static inline int get_local_size(int d) {
  switch (d) {
    case 0: return blockDim.x;
    case 1: return blockDim.y;
    case 2: return blockDim.z;
    default: return 0;
  }
}

#define CLK_LOCAL_MEM_FENCE 1
#define CLK_GLOBAL_MEM_FENCE 2
static inline void barrier(int x) {
  __syncthreads();
}
static inline void mem_fence_local() {
  __threadfence_block();
}
static inline void mem_fence_global() {
  __threadfence();
}

#define NAN (0.0/0.0)
#define INFINITY (1.0/0.0)
extern volatile __shared__ unsigned char shared_mem[];
|]
    forall a. Semigroup a => a -> a -> a
<> Text
halfH
    forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
    forall a. Semigroup a => a -> a -> a
<> Text
atomicsH

compilePrimExp :: PrimExp KernelConst -> C.Exp
compilePrimExp :: KernelConstExp -> Exp
compilePrimExp KernelConstExp
e = forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
GC.compilePrimExp forall {f :: * -> *}. Applicative f => KernelConst -> f Exp
compileKernelConst KernelConstExp
e
  where
    compileKernelConst :: KernelConst -> f Exp
compileKernelConst (SizeConst Name
key) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:(zEncodeText (prettyText key))|]
    compileKernelConst (SizeMaxConst SizeClass
size_class) =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:("max_" <> prettyString size_class)|]

kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
  where
    useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
    useToArg (ScalarUse VName
v PrimType
pt) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt) PrimType
pt
    useToArg ConstUse {} = forall a. Maybe a
Nothing

nextErrorLabel :: GC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState [Char]
nextErrorLabel =
  KernelState -> [Char]
errorLabel forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState

incErrorLabel :: GC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
  forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s forall a. Num a => a -> a -> a
+ Int
1}

pendingError :: Bool -> GC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
  forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelSyncPending :: Bool
kernelSyncPending = Bool
b}

hasCommunication :: ImpGPU.KernelCode -> Bool
hasCommunication :: Code KernelOp -> Bool
hasCommunication = forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
  where
    communicates :: KernelOp -> Bool
communicates ErrorSync {} = Bool
True
    communicates Barrier {} = Bool
True
    communicates KernelOp
_ = Bool
False

-- Whether we are generating code for a kernel or a device function.
-- This has minor effects, such as exactly how failures are
-- propagated.
data OpsMode = KernelMode | FunMode deriving (OpsMode -> OpsMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: OpsMode -> OpsMode -> Bool
$c/= :: OpsMode -> OpsMode -> Bool
== :: OpsMode -> OpsMode -> Bool
$c== :: OpsMode -> OpsMode -> Bool
Eq)

inKernelOperations ::
  Env ->
  OpsMode ->
  ImpGPU.KernelCode ->
  GC.Operations KernelOp KernelState
inKernelOperations :: Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body =
  GC.Operations
    { opsCompiler :: OpCompiler KernelOp KernelState
GC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps,
      opsMemoryType :: MemoryType KernelOp KernelState
GC.opsMemoryType = forall {f :: * -> *}. Applicative f => [Char] -> f Type
kernelMemoryType,
      opsWriteScalar :: WriteScalar KernelOp KernelState
GC.opsWriteScalar = forall {op} {s}. WriteScalar op s
kernelWriteScalar,
      opsReadScalar :: ReadScalar KernelOp KernelState
GC.opsReadScalar = forall {op} {s}. ReadScalar op s
kernelReadScalar,
      opsAllocate :: Allocate KernelOp KernelState
GC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate,
      opsDeallocate :: Allocate KernelOp KernelState
GC.opsDeallocate = Allocate KernelOp KernelState
cannotDeallocate,
      opsCopy :: Copy KernelOp KernelState
GC.opsCopy = Copy KernelOp KernelState
copyInKernel,
      opsFatMemory :: Bool
GC.opsFatMemory = Bool
False,
      opsError :: ErrorCompiler KernelOp KernelState
GC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel,
      opsCall :: CallCompiler KernelOp KernelState
GC.opsCall = forall {a}.
ToIdent a =>
[a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel,
      opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical = forall a. Monoid a => a
mempty
    }
  where
    has_communication :: Bool
has_communication = Code KernelOp -> Bool
hasCommunication Code KernelOp
body

    fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
    fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|]

    kernelOps :: GC.OpCompiler KernelOp KernelState
    kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetGroupId VName
v Int
i) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_group_id($int:i);|]
    kernelOps (GetLocalId VName
v Int
i) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_id($int:i);|]
    kernelOps (GetLocalSize VName
v Int
i) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_size($int:i);|]
    kernelOps (GetLockstepWidth VName
v) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
    kernelOps (Barrier Fence
f) = do
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
    kernelOps (MemFence Fence
FenceLocal) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_local();|]
    kernelOps (MemFence Fence
FenceGlobal) =
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_global();|]
    kernelOps (LocalAlloc VName
name Count Bytes (TExp Int64)
size) = do
      VName
name' <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString VName
name forall a. [a] -> [a] -> [a]
++ [Char]
"_backing"
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelLocalMemory :: [LocalMemoryUse]
kernelLocalMemory = (VName
name', forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Count Bytes (TExp Int64)
size) forall a. a -> [a] -> [a]
: KernelState -> [LocalMemoryUse]
kernelLocalMemory KernelState
s}
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|]
    kernelOps (ErrorSync Fence
f) = do
      [Char]
label <- CompilerM KernelOp KernelState [Char]
nextErrorLabel
      Bool
pending <- KernelState -> Bool
kernelSyncPending forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending forall a b. (a -> b) -> a -> b
$ do
        Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
        forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
        forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (local_failure) { return; }|]
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
      CompilerM KernelOp KernelState ()
incErrorLabel
    kernelOps (Atomic Space
space AtomicOp
aop) = forall {op} {s}. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop

    atomicCast :: Space -> Type -> f Type
atomicCast Space
s Type
t = do
      let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
      let quals :: [TypeQual]
quals = case Space
s of
            Space [Char]
sid -> [Char] -> [TypeQual]
pointerQuals [Char]
sid
            Space
_ -> [Char] -> [TypeQual]
pointerQuals [Char]
"global"
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volatile++quals) $ty:t|]

    atomicSpace :: Space -> [Char]
atomicSpace (Space [Char]
sid) = [Char]
sid
    atomicSpace Space
_ = [Char]
"global"

    doAtomic :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val [Char]
op Type
ty = do
      Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
      where
        op' :: [Char]
op' = [Char]
op forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString p
t forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s

    doAtomicCmpXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
cmp Exp
val Type
ty = do
      Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
cmp' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
cmp
      Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
      where
        op :: [Char]
op = [Char]
"atomic_cmpxchg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString p
t forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
    doAtomicXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val Type
ty = do
      Type
cast <- forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
      Exp
ind' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp forall a b. (a -> b) -> a -> b
$ forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped forall a b. (a -> b) -> a -> b
$ forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
val' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
      where
        op :: [Char]
op = [Char]
"atomic_chg_" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString p
t forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
    -- First the 64-bit operations.
    atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_add" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicFAdd FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_fadd" [C.cty|double|]
    atomicOps Space
s (AtomicSMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smax" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicSMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smin" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicUMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umax" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicUMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umin" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicAnd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_and" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicOr IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_or" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXor IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_xor" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicCmpXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|typename int64_t|]
    --
    atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_add" [C.cty|int|]
    atomicOps Space
s (AtomicFAdd FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_fadd" [C.cty|float|]
    atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smax" [C.cty|int|]
    atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smin" [C.cty|int|]
    atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umax" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umin" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_and" [C.cty|int|]
    atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_or" [C.cty|int|]
    atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_xor" [C.cty|int|]
    atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|int|]
    atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|int|]

    cannotAllocate :: GC.Allocate KernelOp KernelState
    cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
      forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot allocate memory in kernel"

    cannotDeallocate :: GC.Deallocate KernelOp KernelState
    cannotDeallocate :: Allocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
      forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot deallocate memory in kernel"

    copyInKernel :: GC.Copy KernelOp KernelState
    copyInKernel :: Copy KernelOp KernelState
copyInKernel CopyBarrier
_ Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot bulk copy in kernel."

    kernelMemoryType :: [Char] -> f Type
kernelMemoryType [Char]
space =
      forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(pointerQuals space) $ty:defaultMemBlockType|]

    kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
      forall op s. ([Char] -> [TypeQual]) -> WriteScalar op s
GC.writeScalarPointerWithQuals [Char] -> [TypeQual]
pointerQuals

    kernelReadScalar :: ReadScalar op s
kernelReadScalar =
      forall op s. ([Char] -> [TypeQual]) -> ReadScalar op s
GC.readScalarPointerWithQuals [Char] -> [TypeQual]
pointerQuals

    whatNext :: CompilerM KernelOp KernelState [BlockItem]
whatNext = do
      [Char]
label <- CompilerM KernelOp KernelState [Char]
nextErrorLabel
      Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
        if Bool
has_communication
          then [C.citems|local_failure = 1; goto $id:label;|]
          else
            if OpsMode
mode forall a. Eq a => a -> a -> Bool
== OpsMode
FunMode
              then [C.citems|return 1;|]
              else [C.citems|return;|]

    callInKernel :: [a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel [a]
dests Name
fname [Exp]
args
      | Name -> Env -> Bool
functionMayFail Name
fname Env
env = do
          let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
              args' :: [Exp]
args' =
                [C.cexp|global_failure|]
                  forall a. a -> [a] -> [a]
: [C.cexp|global_failure_args|]
                  forall a. a -> [a] -> [a]
: [Exp]
out_args
                  forall a. [a] -> [a] -> [a]
++ [Exp]
args

          [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext
          forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|]
      | Bool
otherwise = do
          let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
              args' :: [Exp]
args' = [Exp]
out_args forall a. [a] -> [a] -> [a]
++ [Exp]
args
          forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|$id:(funName fname)($args:args');|]

    errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) [Char]
backtrace = do
      Int
n <- forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall op s. CompilerM op s s
GC.getUserState
      forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> [Char] -> FailureMsg
FailureMsg ErrorMsg Exp
msg [Char]
backtrace]}
      let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
          setArgs a
i (ErrorString {} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
          -- FIXME: bogus for non-ints.
          setArgs a
i (ErrorVal PrimType
_ Exp
x : [ErrorMsgPart Exp]
parts') = do
            Exp
x' <- forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
            [Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
            forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] forall a. a -> [a] -> [a]
: [Stm]
stms
      [Stm]
argstms <- forall {a} {op} {s}.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0 :: Int) [ErrorMsgPart Exp]
parts

      [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext

      forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
                                 { $stms:argstms; }
                                 $items:what_next
                               }|]

--- Checking requirements

typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = Code KernelOp -> Set PrimType
typesInCode forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel

typesInCode :: ImpGPU.KernelCode -> S.Set PrimType
typesInCode :: Code KernelOp -> Set PrimType
typesInCode Code KernelOp
Skip = forall a. Monoid a => a
mempty
typesInCode (Code KernelOp
c1 :>>: Code KernelOp
c2) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (For VName
_ Exp
e Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (While (TPrimExp Exp
e) Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode DeclareMem {} = forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ PrimType
t ArrayContents
_) = forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count (TPrimExp Exp
e)) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free {} = forall a. Monoid a => a
mempty
typesInCode (Copy PrimType
_ VName
_ (Count (TPrimExp Exp
e1)) Space
_ VName
_ (Count (TPrimExp Exp
e2)) Space
_ (Count (TPrimExp Exp
e3))) =
  Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e3
typesInCode (Write VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_ Exp
e2) =
  Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (Read VName
_ VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_) =
  Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> forall a. a -> Set a
S.singleton PrimType
t
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem {} = forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ Name
_ [Arg]
es) = forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
  where
    typesInArg :: Arg -> Set PrimType
typesInArg MemArg {} = forall a. Monoid a => a
mempty
    typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If (TPrimExp Exp
e) Code KernelOp
c1 Code KernelOp
c2) =
  Exp -> Set PrimType
typesInExp Exp
e forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment Text
_ Code KernelOp
c) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (DebugPrint [Char]
_ Maybe Exp
v) = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode (TracePrint ErrorMsg Exp
msg) = forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> Set PrimType
typesInExp ErrorMsg Exp
msg
typesInCode Op {} = forall a. Monoid a => a
mempty

typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = forall a. a -> Set a
S.singleton forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
  where
    (PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp [Char]
_ [Exp]
args PrimType
t) = forall a. a -> Set a
S.singleton PrimType
t forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat (forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp LeafExp {} = forall a. Monoid a => a
mempty