{-# LANGUAGE QuasiQuotes #-}

-- | This module defines a translation from imperative code with
-- kernels to imperative code with OpenCL or CUDA calls.
module Futhark.CodeGen.ImpGen.GPU.ToOpenCL
  ( kernelsToOpenCL,
    kernelsToCUDA,
    kernelsToHIP,
  )
where

import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.Bifunctor (second)
import Data.Foldable (toList)
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Fun qualified as GC
import Futhark.CodeGen.Backends.GenericC.Pretty
import Futhark.CodeGen.Backends.SimpleRep
import Futhark.CodeGen.ImpCode.GPU hiding (Program)
import Futhark.CodeGen.ImpCode.GPU qualified as ImpGPU
import Futhark.CodeGen.ImpCode.OpenCL hiding (Program)
import Futhark.CodeGen.ImpCode.OpenCL qualified as ImpOpenCL
import Futhark.CodeGen.RTS.C (atomicsH, halfH)
import Futhark.CodeGen.RTS.CUDA (preludeCU)
import Futhark.CodeGen.RTS.OpenCL (copyCL, preludeCL, transposeCL)
import Futhark.Error (compilerLimitationS)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeText)
import Futhark.Util.IntegralExp (rem)
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C
import NeatInterpolation (untrimming)
import Prelude hiding (rem)

-- | Generate HIP host and device code.
kernelsToHIP :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToHIP :: Program -> Program
kernelsToHIP = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetHIP

-- | Generate CUDA host and device code.
kernelsToCUDA :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToCUDA :: Program -> Program
kernelsToCUDA = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetCUDA

-- | Generate OpenCL host and device code.
kernelsToOpenCL :: ImpGPU.Program -> ImpOpenCL.Program
kernelsToOpenCL :: Program -> Program
kernelsToOpenCL = KernelTarget -> Program -> Program
translateGPU KernelTarget
TargetOpenCL

-- | Translate a kernels-program to an OpenCL-program.
translateGPU ::
  KernelTarget ->
  ImpGPU.Program ->
  ImpOpenCL.Program
translateGPU :: KernelTarget -> Program -> Program
translateGPU KernelTarget
target Program
prog =
  let env :: Env
env = Program -> Env
envFromProg Program
prog
      ( Definitions OpenCL
prog',
        ToOpenCL Map Name (KernelSafety, Text)
kernels Map Name (Definition, Text)
device_funs Set PrimType
used_types Map Name SizeClass
sizes [FailureMsg]
failures [(Name, KernelConstExp)]
constants
        ) =
          (State ToOpenCL (Definitions OpenCL)
-> ToOpenCL -> (Definitions OpenCL, ToOpenCL)
forall s a. State s a -> s -> (a, s)
`runState` ToOpenCL
initialOpenCL) (State ToOpenCL (Definitions OpenCL)
 -> (Definitions OpenCL, ToOpenCL))
-> (ReaderT Env (State ToOpenCL) (Definitions OpenCL)
    -> State ToOpenCL (Definitions OpenCL))
-> ReaderT Env (State ToOpenCL) (Definitions OpenCL)
-> (Definitions OpenCL, ToOpenCL)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ReaderT Env (State ToOpenCL) (Definitions OpenCL)
-> Env -> State ToOpenCL (Definitions OpenCL)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Env
env) (ReaderT Env (State ToOpenCL) (Definitions OpenCL)
 -> (Definitions OpenCL, ToOpenCL))
-> ReaderT Env (State ToOpenCL) (Definitions OpenCL)
-> (Definitions OpenCL, ToOpenCL)
forall a b. (a -> b) -> a -> b
$ do
            let ImpGPU.Definitions
                  OpaqueTypes
types
                  (ImpGPU.Constants [Param]
ps Code HostOp
consts)
                  (ImpGPU.Functions [(Name, Function HostOp)]
funs) = Program
prog
            Code OpenCL
consts' <- (HostOp -> ReaderT Env (State ToOpenCL) OpenCL)
-> Code HostOp -> ReaderT Env (State ToOpenCL) (Code OpenCL)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Code a -> f (Code b)
traverse (KernelTarget -> HostOp -> ReaderT Env (State ToOpenCL) OpenCL
onHostOp KernelTarget
target) Code HostOp
consts
            [(Name, FunctionT OpenCL)]
funs' <- [(Name, Function HostOp)]
-> ((Name, Function HostOp)
    -> ReaderT Env (State ToOpenCL) (Name, FunctionT OpenCL))
-> ReaderT Env (State ToOpenCL) [(Name, FunctionT OpenCL)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Name, Function HostOp)]
funs (((Name, Function HostOp)
  -> ReaderT Env (State ToOpenCL) (Name, FunctionT OpenCL))
 -> ReaderT Env (State ToOpenCL) [(Name, FunctionT OpenCL)])
-> ((Name, Function HostOp)
    -> ReaderT Env (State ToOpenCL) (Name, FunctionT OpenCL))
-> ReaderT Env (State ToOpenCL) [(Name, FunctionT OpenCL)]
forall a b. (a -> b) -> a -> b
$ \(Name
fname, Function HostOp
fun) ->
              (Name
fname,) (FunctionT OpenCL -> (Name, FunctionT OpenCL))
-> ReaderT Env (State ToOpenCL) (FunctionT OpenCL)
-> ReaderT Env (State ToOpenCL) (Name, FunctionT OpenCL)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HostOp -> ReaderT Env (State ToOpenCL) OpenCL)
-> Function HostOp
-> ReaderT Env (State ToOpenCL) (FunctionT OpenCL)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> FunctionT a -> f (FunctionT b)
traverse (KernelTarget -> HostOp -> ReaderT Env (State ToOpenCL) OpenCL
onHostOp KernelTarget
target) Function HostOp
fun

            Definitions OpenCL
-> ReaderT Env (State ToOpenCL) (Definitions OpenCL)
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Definitions OpenCL
 -> ReaderT Env (State ToOpenCL) (Definitions OpenCL))
-> Definitions OpenCL
-> ReaderT Env (State ToOpenCL) (Definitions OpenCL)
forall a b. (a -> b) -> a -> b
$
              OpaqueTypes
-> Constants OpenCL -> Functions OpenCL -> Definitions OpenCL
forall a.
OpaqueTypes -> Constants a -> Functions a -> Definitions a
ImpOpenCL.Definitions
                OpaqueTypes
types
                ([Param] -> Code OpenCL -> Constants OpenCL
forall a. [Param] -> Code a -> Constants a
ImpOpenCL.Constants [Param]
ps Code OpenCL
consts')
                ([(Name, FunctionT OpenCL)] -> Functions OpenCL
forall a. [(Name, Function a)] -> Functions a
ImpOpenCL.Functions [(Name, FunctionT OpenCL)]
funs')

      ([Definition]
device_prototypes, [Text]
device_defs) = [(Definition, Text)] -> ([Definition], [Text])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Definition, Text)] -> ([Definition], [Text]))
-> [(Definition, Text)] -> ([Definition], [Text])
forall a b. (a -> b) -> a -> b
$ Map Name (Definition, Text) -> [(Definition, Text)]
forall k a. Map k a -> [a]
M.elems Map Name (Definition, Text)
device_funs
      kernels' :: Map Name KernelSafety
kernels' = ((KernelSafety, Text) -> KernelSafety)
-> Map Name (KernelSafety, Text) -> Map Name KernelSafety
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (KernelSafety, Text) -> KernelSafety
forall a b. (a, b) -> a
fst Map Name (KernelSafety, Text)
kernels
      opencl_code :: Text
opencl_code = [Text] -> Text
T.unlines ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ ((KernelSafety, Text) -> Text) -> [(KernelSafety, Text)] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (KernelSafety, Text) -> Text
forall a b. (a, b) -> b
snd ([(KernelSafety, Text)] -> [Text])
-> [(KernelSafety, Text)] -> [Text]
forall a b. (a -> b) -> a -> b
$ Map Name (KernelSafety, Text) -> [(KernelSafety, Text)]
forall k a. Map k a -> [a]
M.elems Map Name (KernelSafety, Text)
kernels

      opencl_prelude :: Text
opencl_prelude =
        [Text] -> Text
T.unlines
          [ KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
target Set PrimType
used_types,
            [Definition] -> Text
definitionsText [Definition]
device_prototypes,
            [Text] -> Text
T.unlines [Text]
device_defs
          ]
   in ImpOpenCL.Program
        { openClProgram :: Text
openClProgram = Text
opencl_code,
          openClPrelude :: Text
openClPrelude = Text
opencl_prelude,
          openClMacroDefs :: [(Name, KernelConstExp)]
openClMacroDefs = [(Name, KernelConstExp)]
constants,
          openClKernelNames :: Map Name KernelSafety
openClKernelNames = Map Name KernelSafety
kernels',
          openClUsedTypes :: [PrimType]
openClUsedTypes = Set PrimType -> [PrimType]
forall a. Set a -> [a]
S.toList Set PrimType
used_types,
          openClParams :: ParamMap
openClParams = Env -> Definitions OpenCL -> Map Name SizeClass -> ParamMap
findParamUsers Env
env Definitions OpenCL
prog' (Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
sizes),
          openClFailures :: [FailureMsg]
openClFailures = [FailureMsg]
failures,
          hostDefinitions :: Definitions OpenCL
hostDefinitions = Definitions OpenCL
prog'
        }
  where
    genPrelude :: KernelTarget -> Set PrimType -> Text
genPrelude KernelTarget
TargetOpenCL = Set PrimType -> Text
genOpenClPrelude
    genPrelude KernelTarget
TargetCUDA = Text -> Set PrimType -> Text
forall a b. a -> b -> a
const Text
genCUDAPrelude
    genPrelude KernelTarget
TargetHIP = Text -> Set PrimType -> Text
forall a b. a -> b -> a
const Text
genHIPPrelude

-- | Due to simplifications after kernel extraction, some threshold
-- parameters may contain KernelPaths that reference threshold
-- parameters that no longer exist.  We remove these here.
cleanSizes :: M.Map Name SizeClass -> M.Map Name SizeClass
cleanSizes :: Map Name SizeClass -> Map Name SizeClass
cleanSizes Map Name SizeClass
m = (SizeClass -> SizeClass)
-> Map Name SizeClass -> Map Name SizeClass
forall a b k. (a -> b) -> Map k a -> Map k b
M.map SizeClass -> SizeClass
clean Map Name SizeClass
m
  where
    known :: [Name]
known = Map Name SizeClass -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name SizeClass
m
    clean :: SizeClass -> SizeClass
clean (SizeThreshold KernelPath
path Maybe Int64
def) =
      KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold (((Name, Bool) -> Bool) -> KernelPath -> KernelPath
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
known) (Name -> Bool) -> ((Name, Bool) -> Name) -> (Name, Bool) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Bool) -> Name
forall a b. (a, b) -> a
fst) KernelPath
path) Maybe Int64
def
    clean SizeClass
s = SizeClass
s

findParamUsers ::
  Env ->
  Definitions ImpOpenCL.OpenCL ->
  M.Map Name SizeClass ->
  ParamMap
findParamUsers :: Env -> Definitions OpenCL -> Map Name SizeClass -> ParamMap
findParamUsers Env
env Definitions OpenCL
defs = (Name -> SizeClass -> (SizeClass, Set Name))
-> Map Name SizeClass -> ParamMap
forall k a b. (k -> a -> b) -> Map k a -> Map k b
M.mapWithKey Name -> SizeClass -> (SizeClass, Set Name)
forall {a}. Name -> a -> (a, Set Name)
onParam
  where
    cg :: Map Name (Set Name)
cg = Env -> Map Name (Set Name)
envCallGraph Env
env

    getSize :: OpenCL -> Maybe Name
getSize (ImpOpenCL.GetSize VName
_ Name
v) = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
v
    getSize (ImpOpenCL.CmpSizeLe VName
_ Name
v Exp
_) = Name -> Maybe Name
forall a. a -> Maybe a
Just Name
v
    getSize (ImpOpenCL.GetSizeMax {}) = Maybe Name
forall a. Maybe a
Nothing
    getSize (ImpOpenCL.LaunchKernel {}) = Maybe Name
forall a. Maybe a
Nothing
    directUseInFun :: FunctionT OpenCL -> [Name]
directUseInFun FunctionT OpenCL
fun = (OpenCL -> Maybe Name) -> [OpenCL] -> [Name]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe OpenCL -> Maybe Name
getSize ([OpenCL] -> [Name]) -> [OpenCL] -> [Name]
forall a b. (a -> b) -> a -> b
$ Code OpenCL -> [OpenCL]
forall a. Code a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (Code OpenCL -> [OpenCL]) -> Code OpenCL -> [OpenCL]
forall a b. (a -> b) -> a -> b
$ FunctionT OpenCL -> Code OpenCL
forall a. FunctionT a -> Code a
functionBody FunctionT OpenCL
fun
    direct_uses :: [(Name, [Name])]
direct_uses = ((Name, FunctionT OpenCL) -> (Name, [Name]))
-> [(Name, FunctionT OpenCL)] -> [(Name, [Name])]
forall a b. (a -> b) -> [a] -> [b]
map ((FunctionT OpenCL -> [Name])
-> (Name, FunctionT OpenCL) -> (Name, [Name])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second FunctionT OpenCL -> [Name]
directUseInFun) ([(Name, FunctionT OpenCL)] -> [(Name, [Name])])
-> [(Name, FunctionT OpenCL)] -> [(Name, [Name])]
forall a b. (a -> b) -> a -> b
$ Functions OpenCL -> [(Name, FunctionT OpenCL)]
forall a. Functions a -> [(Name, Function a)]
unFunctions (Functions OpenCL -> [(Name, FunctionT OpenCL)])
-> Functions OpenCL -> [(Name, FunctionT OpenCL)]
forall a b. (a -> b) -> a -> b
$ Definitions OpenCL -> Functions OpenCL
forall a. Definitions a -> Functions a
defFuns Definitions OpenCL
defs

    calledBy :: Name -> Set Name
calledBy Name
fname = Set Name -> Name -> Map Name (Set Name) -> Set Name
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Set Name
forall a. Monoid a => a
mempty Name
fname Map Name (Set Name)
cg
    indirectUseInFun :: Name -> (Name, [Name])
indirectUseInFun Name
fname =
      ( Name
fname,
        ((Name, [Name]) -> [Name]) -> [(Name, [Name])] -> [Name]
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Name, [Name]) -> [Name]
forall a b. (a, b) -> b
snd ([(Name, [Name])] -> [Name]) -> [(Name, [Name])] -> [Name]
forall a b. (a -> b) -> a -> b
$ ((Name, [Name]) -> Bool) -> [(Name, [Name])] -> [(Name, [Name])]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Name -> Set Name
calledBy Name
fname) (Name -> Bool)
-> ((Name, [Name]) -> Name) -> (Name, [Name]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [Name]) -> Name
forall a b. (a, b) -> a
fst) [(Name, [Name])]
direct_uses
      )
    indirect_uses :: [(Name, [Name])]
indirect_uses = [(Name, [Name])]
direct_uses [(Name, [Name])] -> [(Name, [Name])] -> [(Name, [Name])]
forall a. Semigroup a => a -> a -> a
<> ((Name, [Name]) -> (Name, [Name]))
-> [(Name, [Name])] -> [(Name, [Name])]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> (Name, [Name])
indirectUseInFun (Name -> (Name, [Name]))
-> ((Name, [Name]) -> Name) -> (Name, [Name]) -> (Name, [Name])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [Name]) -> Name
forall a b. (a, b) -> a
fst) [(Name, [Name])]
direct_uses

    onParam :: Name -> a -> (a, Set Name)
onParam Name
k a
c = (a
c, [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ ((Name, [Name]) -> Name) -> [(Name, [Name])] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, [Name]) -> Name
forall a b. (a, b) -> a
fst ([(Name, [Name])] -> [Name]) -> [(Name, [Name])] -> [Name]
forall a b. (a -> b) -> a -> b
$ ((Name, [Name]) -> Bool) -> [(Name, [Name])] -> [(Name, [Name])]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Name
k `elem`) ([Name] -> Bool)
-> ((Name, [Name]) -> [Name]) -> (Name, [Name]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [Name]) -> [Name]
forall a b. (a, b) -> b
snd) [(Name, [Name])]
indirect_uses)

pointerQuals :: String -> [C.TypeQual]
pointerQuals :: [Char] -> [TypeQual]
pointerQuals [Char]
"global" = [C.ctyquals|__global|]
pointerQuals [Char]
"shared" = [C.ctyquals|__local|]
pointerQuals [Char]
"private" = [C.ctyquals|__private|]
pointerQuals [Char]
"constant" = [C.ctyquals|__constant|]
pointerQuals [Char]
"write_only" = [C.ctyquals|__write_only|]
pointerQuals [Char]
"read_only" = [C.ctyquals|__read_only|]
pointerQuals [Char]
"kernel" = [C.ctyquals|__kernel|]
-- OpenCL does not actually have a "device" space, but we use it in
-- the compiler pipeline to defer to memory on the device, as opposed
-- to the host.  From a kernel's perspective, this is "global".
pointerQuals [Char]
"device" = [Char] -> [TypeQual]
pointerQuals [Char]
"global"
pointerQuals [Char]
s = [Char] -> [TypeQual]
forall a. HasCallStack => [Char] -> a
error ([Char] -> [TypeQual]) -> [Char] -> [TypeQual]
forall a b. (a -> b) -> a -> b
$ [Char]
"'" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
s [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"' is not an OpenCL kernel address space."

-- In-kernel name and per-threadblock size in bytes.
type SharedMemoryUse = (VName, Count Bytes (TExp Int64))

data KernelState = KernelState
  { KernelState -> [SharedMemoryUse]
kernelSharedMemory :: [SharedMemoryUse],
    KernelState -> [FailureMsg]
kernelFailures :: [FailureMsg],
    KernelState -> Int
kernelNextSync :: Int,
    -- | Has a potential failure occurred sine the last
    -- ErrorSync?
    KernelState -> Bool
kernelSyncPending :: Bool,
    KernelState -> Bool
kernelHasBarriers :: Bool
  }

newKernelState :: [FailureMsg] -> KernelState
newKernelState :: [FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures = [SharedMemoryUse]
-> [FailureMsg] -> Int -> Bool -> Bool -> KernelState
KernelState [SharedMemoryUse]
forall a. Monoid a => a
mempty [FailureMsg]
failures Int
0 Bool
False Bool
False

errorLabel :: KernelState -> String
errorLabel :: KernelState -> [Char]
errorLabel = ([Char]
"error_" ++) ([Char] -> [Char])
-> (KernelState -> [Char]) -> KernelState -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Char]
forall a. Show a => a -> [Char]
show (Int -> [Char]) -> (KernelState -> Int) -> KernelState -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> Int
kernelNextSync

data ToOpenCL = ToOpenCL
  { ToOpenCL -> Map Name (KernelSafety, Text)
clGPU :: M.Map KernelName (KernelSafety, T.Text),
    ToOpenCL -> Map Name (Definition, Text)
clDevFuns :: M.Map Name (C.Definition, T.Text),
    ToOpenCL -> Set PrimType
clUsedTypes :: S.Set PrimType,
    ToOpenCL -> Map Name SizeClass
clSizes :: M.Map Name SizeClass,
    ToOpenCL -> [FailureMsg]
clFailures :: [FailureMsg],
    ToOpenCL -> [(Name, KernelConstExp)]
clConstants :: [(Name, KernelConstExp)]
  }

initialOpenCL :: ToOpenCL
initialOpenCL :: ToOpenCL
initialOpenCL = Map Name (KernelSafety, Text)
-> Map Name (Definition, Text)
-> Set PrimType
-> Map Name SizeClass
-> [FailureMsg]
-> [(Name, KernelConstExp)]
-> ToOpenCL
ToOpenCL Map Name (KernelSafety, Text)
forall a. Monoid a => a
mempty Map Name (Definition, Text)
forall a. Monoid a => a
mempty Set PrimType
forall a. Monoid a => a
mempty Map Name SizeClass
forall a. Monoid a => a
mempty [FailureMsg]
forall a. Monoid a => a
mempty [(Name, KernelConstExp)]
forall a. Monoid a => a
mempty

data Env = Env
  { Env -> Functions HostOp
envFuns :: ImpGPU.Functions ImpGPU.HostOp,
    Env -> Set Name
envFunsMayFail :: S.Set Name,
    Env -> Map Name (Set Name)
envCallGraph :: M.Map Name (S.Set Name)
  }

codeMayFail :: (a -> Bool) -> ImpGPU.Code a -> Bool
codeMayFail :: forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
_ (Assert {}) = Bool
True
codeMayFail a -> Bool
f (Op a
x) = a -> Bool
f a
x
codeMayFail a -> Bool
f (Code a
x :>>: Code a
y) = (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (For VName
_ Exp
_ Code a
x) = (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (While TExp Bool
_ Code a
x) = (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
f (If TExp Bool
_ Code a
x Code a
y) = (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x Bool -> Bool -> Bool
|| (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
y
codeMayFail a -> Bool
f (Comment Text
_ Code a
x) = (a -> Bool) -> Code a -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail a -> Bool
f Code a
x
codeMayFail a -> Bool
_ Code a
_ = Bool
False

hostOpMayFail :: ImpGPU.HostOp -> Bool
hostOpMayFail :: HostOp -> Bool
hostOpMayFail (CallKernel Kernel
k) = (KernelOp -> Bool) -> Code KernelOp -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail KernelOp -> Bool
kernelOpMayFail (Code KernelOp -> Bool) -> Code KernelOp -> Bool
forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
hostOpMayFail HostOp
_ = Bool
False

kernelOpMayFail :: ImpGPU.KernelOp -> Bool
kernelOpMayFail :: KernelOp -> Bool
kernelOpMayFail = Bool -> KernelOp -> Bool
forall a b. a -> b -> a
const Bool
False

funsMayFail :: M.Map Name (S.Set Name) -> ImpGPU.Functions ImpGPU.HostOp -> S.Set Name
funsMayFail :: Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg (Functions [(Name, Function HostOp)]
funs) =
  [Name] -> Set Name
forall a. Ord a => [a] -> Set a
S.fromList ([Name] -> Set Name) -> [Name] -> Set Name
forall a b. (a -> b) -> a -> b
$ ((Name, Function HostOp) -> Name)
-> [(Name, Function HostOp)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Function HostOp) -> Name
forall a b. (a, b) -> a
fst ([(Name, Function HostOp)] -> [Name])
-> [(Name, Function HostOp)] -> [Name]
forall a b. (a -> b) -> a -> b
$ ((Name, Function HostOp) -> Bool)
-> [(Name, Function HostOp)] -> [(Name, Function HostOp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Name, Function HostOp) -> Bool
forall {b}. (Name, b) -> Bool
mayFail [(Name, Function HostOp)]
funs
  where
    base_mayfail :: [Name]
base_mayfail =
      ((Name, Function HostOp) -> Name)
-> [(Name, Function HostOp)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, Function HostOp) -> Name
forall a b. (a, b) -> a
fst ([(Name, Function HostOp)] -> [Name])
-> [(Name, Function HostOp)] -> [Name]
forall a b. (a -> b) -> a -> b
$ ((Name, Function HostOp) -> Bool)
-> [(Name, Function HostOp)] -> [(Name, Function HostOp)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((HostOp -> Bool) -> Code HostOp -> Bool
forall a. (a -> Bool) -> Code a -> Bool
codeMayFail HostOp -> Bool
hostOpMayFail (Code HostOp -> Bool)
-> ((Name, Function HostOp) -> Code HostOp)
-> (Name, Function HostOp)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function HostOp -> Code HostOp
forall a. FunctionT a -> Code a
ImpGPU.functionBody (Function HostOp -> Code HostOp)
-> ((Name, Function HostOp) -> Function HostOp)
-> (Name, Function HostOp)
-> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Function HostOp) -> Function HostOp
forall a b. (a, b) -> b
snd) [(Name, Function HostOp)]
funs
    mayFail :: (Name, b) -> Bool
mayFail (Name
fname, b
_) =
      (Name -> Bool) -> [Name] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Name -> [Name] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Name]
base_mayfail) ([Name] -> Bool) -> [Name] -> Bool
forall a b. (a -> b) -> a -> b
$ Name
fname Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: Set Name -> [Name]
forall a. Set a -> [a]
S.toList (Set Name -> Name -> Map Name (Set Name) -> Set Name
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Set Name
forall a. Monoid a => a
mempty Name
fname Map Name (Set Name)
cg)

envFromProg :: ImpGPU.Program -> Env
envFromProg :: Program -> Env
envFromProg Program
prog = Functions HostOp -> Set Name -> Map Name (Set Name) -> Env
Env Functions HostOp
funs (Map Name (Set Name) -> Functions HostOp -> Set Name
funsMayFail Map Name (Set Name)
cg Functions HostOp
funs) Map Name (Set Name)
cg
  where
    funs :: Functions HostOp
funs = Program -> Functions HostOp
forall a. Definitions a -> Functions a
defFuns Program
prog
    cg :: Map Name (Set Name)
cg = (HostOp -> Set Name) -> Functions HostOp -> Map Name (Set Name)
forall a. (a -> Set Name) -> Functions a -> Map Name (Set Name)
ImpGPU.callGraph HostOp -> Set Name
calledInHostOp Functions HostOp
funs

lookupFunction :: Name -> Env -> Maybe (ImpGPU.Function HostOp)
lookupFunction :: Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname = Name -> [(Name, Function HostOp)] -> Maybe (Function HostOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Name
fname ([(Name, Function HostOp)] -> Maybe (Function HostOp))
-> (Env -> [(Name, Function HostOp)])
-> Env
-> Maybe (Function HostOp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Functions HostOp -> [(Name, Function HostOp)]
forall a. Functions a -> [(Name, Function a)]
unFunctions (Functions HostOp -> [(Name, Function HostOp)])
-> (Env -> Functions HostOp) -> Env -> [(Name, Function HostOp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Functions HostOp
envFuns

functionMayFail :: Name -> Env -> Bool
functionMayFail :: Name -> Env -> Bool
functionMayFail Name
fname = Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
fname (Set Name -> Bool) -> (Env -> Set Name) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Set Name
envFunsMayFail

type OnKernelM = ReaderT Env (State ToOpenCL)

addSize :: Name -> SizeClass -> OnKernelM ()
addSize :: Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
sclass =
  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s -> ToOpenCL
s {clSizes :: Map Name SizeClass
clSizes = Name -> SizeClass -> Map Name SizeClass -> Map Name SizeClass
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
key SizeClass
sclass (Map Name SizeClass -> Map Name SizeClass)
-> Map Name SizeClass -> Map Name SizeClass
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name SizeClass
clSizes ToOpenCL
s}

onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
onHostOp :: KernelTarget -> HostOp -> ReaderT Env (State ToOpenCL) OpenCL
onHostOp KernelTarget
target (CallKernel Kernel
k) = KernelTarget -> Kernel -> ReaderT Env (State ToOpenCL) OpenCL
onKernel KernelTarget
target Kernel
k
onHostOp KernelTarget
_ (ImpGPU.GetSize VName
v Name
key SizeClass
size_class) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenCL -> ReaderT Env (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> Name -> OpenCL
ImpOpenCL.GetSize VName
v Name
key
onHostOp KernelTarget
_ (ImpGPU.CmpSizeLe VName
v Name
key SizeClass
size_class Exp
x) = do
  Name -> SizeClass -> OnKernelM ()
addSize Name
key SizeClass
size_class
  OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenCL -> ReaderT Env (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> Name -> Exp -> OpenCL
ImpOpenCL.CmpSizeLe VName
v Name
key Exp
x
onHostOp KernelTarget
_ (ImpGPU.GetSizeMax VName
v SizeClass
size_class) =
  OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenCL -> ReaderT Env (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> OpenCL
ImpOpenCL.GetSizeMax VName
v SizeClass
size_class

genGPUCode ::
  Env ->
  OpsMode ->
  KernelCode ->
  [FailureMsg] ->
  GC.CompilerM KernelOp KernelState a ->
  (a, GC.CompilerState KernelState)
genGPUCode :: forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
mode Code KernelOp
body [FailureMsg]
failures =
  Operations KernelOp KernelState
-> VNameSource
-> KernelState
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
forall op s a.
Operations op s
-> VNameSource -> s -> CompilerM op s a -> (a, CompilerState s)
GC.runCompilerM
    (Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body)
    VNameSource
blankNameSource
    ([FailureMsg] -> KernelState
newKernelState [FailureMsg]
failures)

-- Compilation of a device function that is not not invoked from the
-- host, but is invoked by (perhaps multiple) kernels.
generateDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
generateDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
device_func = do
  Bool -> OnKernelM () -> OnKernelM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((Param -> Bool) -> [Param] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Param -> Bool
memParam ([Param] -> Bool) -> [Param] -> Bool
forall a b. (a -> b) -> a -> b
$ Function KernelOp -> [Param]
forall a. FunctionT a -> [Param]
functionInput Function KernelOp
device_func) OnKernelM ()
forall {a}. a
bad

  Env
env <- ReaderT Env (State ToOpenCL) Env
forall r (m :: * -> *). MonadReader r m => m r
ask
  [FailureMsg]
failures <- (ToOpenCL -> [FailureMsg])
-> ReaderT Env (State ToOpenCL) [FailureMsg]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures

  let ((Definition, Func)
func, KernelState
kstate) =
        if Name -> Env -> Bool
functionMayFail Name
fname Env
env
          then
            let params :: [Param]
params =
                  [ [C.cparam|__global int *global_failure|],
                    [C.cparam|__global typename int64_t *global_failure_args|]
                  ]
                ((Definition, Func)
f, CompilerState KernelState
cstate) =
                  Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState (Definition, Func)
-> ((Definition, Func), CompilerState KernelState)
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (Code KernelOp -> Code KernelOp
forall a. Code a -> Code a
declsFirst (Code KernelOp -> Code KernelOp) -> Code KernelOp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$ Function KernelOp -> Code KernelOp
forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures (CompilerM KernelOp KernelState (Definition, Func)
 -> ((Definition, Func), CompilerState KernelState))
-> CompilerM KernelOp KernelState (Definition, Func)
-> ((Definition, Func), CompilerState KernelState)
forall a b. (a -> b) -> a -> b
$
                    [BlockItem]
-> [Param]
-> (Name, Function KernelOp)
-> CompilerM KernelOp KernelState (Definition, Func)
forall op s.
[BlockItem]
-> [Param]
-> (Name, Function op)
-> CompilerM op s (Definition, Func)
GC.compileFun [BlockItem]
forall a. Monoid a => a
mempty [Param]
params (Name
fname, Function KernelOp
device_func)
             in ((Definition, Func)
f, CompilerState KernelState -> KernelState
forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)
          else
            let ((Definition, Func)
f, CompilerState KernelState
cstate) =
                  Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState (Definition, Func)
-> ((Definition, Func), CompilerState KernelState)
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
FunMode (Code KernelOp -> Code KernelOp
forall a. Code a -> Code a
declsFirst (Code KernelOp -> Code KernelOp) -> Code KernelOp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$ Function KernelOp -> Code KernelOp
forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) [FailureMsg]
failures (CompilerM KernelOp KernelState (Definition, Func)
 -> ((Definition, Func), CompilerState KernelState))
-> CompilerM KernelOp KernelState (Definition, Func)
-> ((Definition, Func), CompilerState KernelState)
forall a b. (a -> b) -> a -> b
$
                    [BlockItem]
-> (Name, Function KernelOp)
-> CompilerM KernelOp KernelState (Definition, Func)
forall op s.
[BlockItem]
-> (Name, Function op) -> CompilerM op s (Definition, Func)
GC.compileVoidFun [BlockItem]
forall a. Monoid a => a
mempty (Name
fname, Function KernelOp
device_func)
             in ((Definition, Func)
f, CompilerState KernelState -> KernelState
forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate)

  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clUsedTypes :: Set PrimType
clUsedTypes = Code KernelOp -> Set PrimType
typesInCode (Function KernelOp -> Code KernelOp
forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func) Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clDevFuns :: Map Name (Definition, Text)
clDevFuns = Name
-> (Definition, Text)
-> Map Name (Definition, Text)
-> Map Name (Definition, Text)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
fname ((Func -> Text) -> (Definition, Func) -> (Definition, Text)
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second Func -> Text
funcText (Definition, Func)
func) (Map Name (Definition, Text) -> Map Name (Definition, Text))
-> Map Name (Definition, Text) -> Map Name (Definition, Text)
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (Definition, Text)
clDevFuns ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate
      }

  -- Important to do this after the 'modify' call, so we propagate the
  -- right clFailures.
  ReaderT Env (State ToOpenCL) [Name] -> OnKernelM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ReaderT Env (State ToOpenCL) [Name] -> OnKernelM ())
-> ReaderT Env (State ToOpenCL) [Name] -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ Code KernelOp -> ReaderT Env (State ToOpenCL) [Name]
ensureDeviceFuns (Code KernelOp -> ReaderT Env (State ToOpenCL) [Name])
-> Code KernelOp -> ReaderT Env (State ToOpenCL) [Name]
forall a b. (a -> b) -> a -> b
$ Function KernelOp -> Code KernelOp
forall a. FunctionT a -> Code a
functionBody Function KernelOp
device_func
  where
    memParam :: Param -> Bool
memParam MemParam {} = Bool
True
    memParam ScalarParam {} = Bool
False

    bad :: a
bad = [Char] -> a
forall a. [Char] -> a
compilerLimitationS [Char]
"Cannot generate GPU functions that use arrays."

-- Ensure that this device function is available, but don't regenerate
-- it if it already exists.
ensureDeviceFun :: Name -> ImpGPU.Function ImpGPU.KernelOp -> OnKernelM ()
ensureDeviceFun :: Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
host_func = do
  Bool
exists <- (ToOpenCL -> Bool) -> ReaderT Env (State ToOpenCL) Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((ToOpenCL -> Bool) -> ReaderT Env (State ToOpenCL) Bool)
-> (ToOpenCL -> Bool) -> ReaderT Env (State ToOpenCL) Bool
forall a b. (a -> b) -> a -> b
$ Name -> Map Name (Definition, Text) -> Bool
forall k a. Ord k => k -> Map k a -> Bool
M.member Name
fname (Map Name (Definition, Text) -> Bool)
-> (ToOpenCL -> Map Name (Definition, Text)) -> ToOpenCL -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ToOpenCL -> Map Name (Definition, Text)
clDevFuns
  Bool -> OnKernelM () -> OnKernelM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (OnKernelM () -> OnKernelM ()) -> OnKernelM () -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ Name -> Function KernelOp -> OnKernelM ()
generateDeviceFun Name
fname Function KernelOp
host_func

calledInHostOp :: HostOp -> S.Set Name
calledInHostOp :: HostOp -> Set Name
calledInHostOp (CallKernel Kernel
k) = (KernelOp -> Set Name) -> Code KernelOp -> Set Name
forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp (Code KernelOp -> Set Name) -> Code KernelOp -> Set Name
forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
k
calledInHostOp HostOp
_ = Set Name
forall a. Monoid a => a
mempty

calledInKernelOp :: KernelOp -> S.Set Name
calledInKernelOp :: KernelOp -> Set Name
calledInKernelOp = Set Name -> KernelOp -> Set Name
forall a b. a -> b -> a
const Set Name
forall a. Monoid a => a
mempty

ensureDeviceFuns :: ImpGPU.KernelCode -> OnKernelM [Name]
ensureDeviceFuns :: Code KernelOp -> ReaderT Env (State ToOpenCL) [Name]
ensureDeviceFuns Code KernelOp
code = do
  let called :: Set Name
called = (KernelOp -> Set Name) -> Code KernelOp -> Set Name
forall a. (a -> Set Name) -> Code a -> Set Name
calledFuncs KernelOp -> Set Name
calledInKernelOp Code KernelOp
code
  ([Maybe Name] -> [Name])
-> ReaderT Env (State ToOpenCL) [Maybe Name]
-> ReaderT Env (State ToOpenCL) [Name]
forall a b.
(a -> b)
-> ReaderT Env (State ToOpenCL) a -> ReaderT Env (State ToOpenCL) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe Name] -> [Name]
forall a. [Maybe a] -> [a]
catMaybes (ReaderT Env (State ToOpenCL) [Maybe Name]
 -> ReaderT Env (State ToOpenCL) [Name])
-> ((Name -> ReaderT Env (State ToOpenCL) (Maybe Name))
    -> ReaderT Env (State ToOpenCL) [Maybe Name])
-> (Name -> ReaderT Env (State ToOpenCL) (Maybe Name))
-> ReaderT Env (State ToOpenCL) [Name]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Name]
-> (Name -> ReaderT Env (State ToOpenCL) (Maybe Name))
-> ReaderT Env (State ToOpenCL) [Maybe Name]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Set Name -> [Name]
forall a. Set a -> [a]
S.toList Set Name
called) ((Name -> ReaderT Env (State ToOpenCL) (Maybe Name))
 -> ReaderT Env (State ToOpenCL) [Name])
-> (Name -> ReaderT Env (State ToOpenCL) (Maybe Name))
-> ReaderT Env (State ToOpenCL) [Name]
forall a b. (a -> b) -> a -> b
$ \Name
fname -> do
    Maybe (Function HostOp)
def <- (Env -> Maybe (Function HostOp))
-> ReaderT Env (State ToOpenCL) (Maybe (Function HostOp))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Env -> Maybe (Function HostOp))
 -> ReaderT Env (State ToOpenCL) (Maybe (Function HostOp)))
-> (Env -> Maybe (Function HostOp))
-> ReaderT Env (State ToOpenCL) (Maybe (Function HostOp))
forall a b. (a -> b) -> a -> b
$ Name -> Env -> Maybe (Function HostOp)
lookupFunction Name
fname
    case Maybe (Function HostOp)
def of
      Just Function HostOp
host_func -> do
        -- Functions are a priori always considered host-level, so we have
        -- to convert them to device code.  This is where most of our
        -- limitations on device-side functions (no arrays, no parallelism)
        -- comes from.
        let device_func :: Function KernelOp
device_func = (HostOp -> KernelOp) -> Function HostOp -> Function KernelOp
forall a b. (a -> b) -> FunctionT a -> FunctionT b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap HostOp -> KernelOp
toDevice Function HostOp
host_func
        Name -> Function KernelOp -> OnKernelM ()
ensureDeviceFun Name
fname Function KernelOp
device_func
        Maybe Name -> ReaderT Env (State ToOpenCL) (Maybe Name)
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Name -> ReaderT Env (State ToOpenCL) (Maybe Name))
-> Maybe Name -> ReaderT Env (State ToOpenCL) (Maybe Name)
forall a b. (a -> b) -> a -> b
$ Name -> Maybe Name
forall a. a -> Maybe a
Just Name
fname
      Maybe (Function HostOp)
Nothing -> Maybe Name -> ReaderT Env (State ToOpenCL) (Maybe Name)
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Name
forall a. Maybe a
Nothing
  where
    bad :: a
bad = [Char] -> a
forall a. [Char] -> a
compilerLimitationS [Char]
"Cannot generate GPU functions that contain parallelism."
    toDevice :: HostOp -> KernelOp
    toDevice :: HostOp -> KernelOp
toDevice HostOp
_ = KernelOp
forall {a}. a
bad

isConst :: BlockDim -> Maybe KernelConstExp
isConst :: BlockDim -> Maybe KernelConstExp
isConst (Left (ValueExp (IntValue IntValue
x))) =
  KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> KernelConstExp
forall v. PrimValue -> PrimExp v
ValueExp (IntValue -> PrimValue
IntValue IntValue
x)
isConst (Right KernelConstExp
e) =
  KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just KernelConstExp
e
isConst BlockDim
_ = Maybe KernelConstExp
forall a. Maybe a
Nothing

onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
onKernel :: KernelTarget -> Kernel -> ReaderT Env (State ToOpenCL) OpenCL
onKernel KernelTarget
target Kernel
kernel = do
  [Name]
called <- Code KernelOp -> ReaderT Env (State ToOpenCL) [Name]
ensureDeviceFuns (Code KernelOp -> ReaderT Env (State ToOpenCL) [Name])
-> Code KernelOp -> ReaderT Env (State ToOpenCL) [Name]
forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel

  -- Crucial that this is done after 'ensureDeviceFuns', as the device
  -- functions may themselves define failure points.
  [FailureMsg]
failures <- (ToOpenCL -> [FailureMsg])
-> ReaderT Env (State ToOpenCL) [FailureMsg]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ToOpenCL -> [FailureMsg]
clFailures
  Env
env <- ReaderT Env (State ToOpenCL) Env
forall r (m :: * -> *). MonadReader r m => m r
ask

  let ([BlockItem]
kernel_body, CompilerState KernelState
cstate) =
        Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState [BlockItem]
-> ([BlockItem], CompilerState KernelState)
forall a.
Env
-> OpsMode
-> Code KernelOp
-> [FailureMsg]
-> CompilerM KernelOp KernelState a
-> (a, CompilerState KernelState)
genGPUCode Env
env OpsMode
KernelMode (Kernel -> Code KernelOp
kernelBody Kernel
kernel) [FailureMsg]
failures (CompilerM KernelOp KernelState [BlockItem]
 -> ([BlockItem], CompilerState KernelState))
-> (CompilerM KernelOp KernelState ()
    -> CompilerM KernelOp KernelState [BlockItem])
-> CompilerM KernelOp KernelState ()
-> ([BlockItem], CompilerState KernelState)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect (CompilerM KernelOp KernelState ()
 -> ([BlockItem], CompilerState KernelState))
-> CompilerM KernelOp KernelState ()
-> ([BlockItem], CompilerState KernelState)
forall a b. (a -> b) -> a -> b
$ do
          [BlockItem]
body <- CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
GC.collect (CompilerM KernelOp KernelState ()
 -> CompilerM KernelOp KernelState [BlockItem])
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code KernelOp -> CompilerM KernelOp KernelState ()
forall op s. Code op -> CompilerM op s ()
GC.compileCode (Code KernelOp -> CompilerM KernelOp KernelState ())
-> Code KernelOp -> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ Code KernelOp -> Code KernelOp
forall a. Code a -> Code a
declsFirst (Code KernelOp -> Code KernelOp) -> Code KernelOp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel
          -- No need to free, as we cannot allocate memory in kernels.
          (BlockItem -> CompilerM KernelOp KernelState ())
-> [BlockItem] -> CompilerM KernelOp KernelState ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ BlockItem -> CompilerM KernelOp KernelState ()
forall op s. BlockItem -> CompilerM op s ()
GC.item ([BlockItem] -> CompilerM KernelOp KernelState ())
-> CompilerM KernelOp KernelState [BlockItem]
-> CompilerM KernelOp KernelState ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CompilerM KernelOp KernelState [BlockItem]
forall op s. CompilerM op s [BlockItem]
GC.declAllocatedMem
          (BlockItem -> CompilerM KernelOp KernelState ())
-> [BlockItem] -> CompilerM KernelOp KernelState ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ BlockItem -> CompilerM KernelOp KernelState ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [BlockItem]
body
      kstate :: KernelState
kstate = CompilerState KernelState -> KernelState
forall s. CompilerState s -> s
GC.compUserState CompilerState KernelState
cstate

      ([(Name, KernelConstExp)]
kernel_consts, ([BlockItem]
const_defs, [BlockItem]
const_undefs)) =
        ([(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem]))
-> ([(Name, KernelConstExp)], [(BlockItem, BlockItem)])
-> ([(Name, KernelConstExp)], ([BlockItem], [BlockItem]))
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second [(BlockItem, BlockItem)] -> ([BlockItem], [BlockItem])
forall a b. [(a, b)] -> ([a], [b])
unzip (([(Name, KernelConstExp)], [(BlockItem, BlockItem)])
 -> ([(Name, KernelConstExp)], ([BlockItem], [BlockItem])))
-> ([(Name, KernelConstExp)], [(BlockItem, BlockItem)])
-> ([(Name, KernelConstExp)], ([BlockItem], [BlockItem]))
forall a b. (a -> b) -> a -> b
$ [((Name, KernelConstExp), (BlockItem, BlockItem))]
-> ([(Name, KernelConstExp)], [(BlockItem, BlockItem)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([((Name, KernelConstExp), (BlockItem, BlockItem))]
 -> ([(Name, KernelConstExp)], [(BlockItem, BlockItem)]))
-> [((Name, KernelConstExp), (BlockItem, BlockItem))]
-> ([(Name, KernelConstExp)], [(BlockItem, BlockItem)])
forall a b. (a -> b) -> a -> b
$ (KernelUse
 -> Maybe ((Name, KernelConstExp), (BlockItem, BlockItem)))
-> [KernelUse]
-> [((Name, KernelConstExp), (BlockItem, BlockItem))]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Name
-> KernelUse
-> Maybe ((Name, KernelConstExp), (BlockItem, BlockItem))
constDef (Kernel -> Name
kernelName Kernel
kernel)) ([KernelUse] -> [((Name, KernelConstExp), (BlockItem, BlockItem))])
-> [KernelUse]
-> [((Name, KernelConstExp), (BlockItem, BlockItem))]
forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  let (Exp
_, [[BlockItem]]
shared_memory_init) =
        (Exp -> SharedMemoryUse -> (Exp, [BlockItem]))
-> Exp -> [SharedMemoryUse] -> (Exp, [[BlockItem]])
forall (t :: * -> *) s a b.
Traversable t =>
(s -> a -> (s, b)) -> s -> t a -> (s, t b)
L.mapAccumL Exp -> SharedMemoryUse -> (Exp, [BlockItem])
forall {k} {a} {a} {a} {u :: k}.
(ToIdent a, ToExp a, ToExp a, IntegralExp a, Pretty a) =>
a -> (a, Count u a) -> (Exp, [BlockItem])
prepareSharedMemory [C.cexp|0|] (KernelState -> [SharedMemoryUse]
kernelSharedMemory KernelState
kstate)
      shared_memory_bytes :: Count Bytes (TExp Int64)
shared_memory_bytes = [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ (SharedMemoryUse -> Count Bytes (TExp Int64))
-> [SharedMemoryUse] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall {a}. IntegralExp a => a -> a
padTo8 (Count Bytes (TExp Int64) -> Count Bytes (TExp Int64))
-> (SharedMemoryUse -> Count Bytes (TExp Int64))
-> SharedMemoryUse
-> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SharedMemoryUse -> Count Bytes (TExp Int64)
forall a b. (a, b) -> b
snd) ([SharedMemoryUse] -> [Count Bytes (TExp Int64)])
-> [SharedMemoryUse] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ KernelState -> [SharedMemoryUse]
kernelSharedMemory KernelState
kstate

  let ([Param]
use_params, [[BlockItem]]
unpack_params) =
        [(Param, [BlockItem])] -> ([Param], [[BlockItem]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param, [BlockItem])] -> ([Param], [[BlockItem]]))
-> [(Param, [BlockItem])] -> ([Param], [[BlockItem]])
forall a b. (a -> b) -> a -> b
$ (KernelUse -> Maybe (Param, [BlockItem]))
-> [KernelUse] -> [(Param, [BlockItem])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe (Param, [BlockItem])
useAsParam ([KernelUse] -> [(Param, [BlockItem])])
-> [KernelUse] -> [(Param, [BlockItem])]
forall a b. (a -> b) -> a -> b
$ Kernel -> [KernelUse]
kernelUses Kernel
kernel

  -- The local_failure variable is an int despite only really storing
  -- a single bit of information, as some OpenCL implementations
  -- (e.g. AMD) does not like byte-sized shared memory (and the others
  -- likely pad to a whole word anyway).
  let (KernelSafety
safety, [BlockItem]
error_init)
        -- We conservatively assume that any called function can fail.
        | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Name] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
called =
            ( KernelSafety
SafetyFull,
              [C.citems|volatile __local int local_failure;
                        // Harmless for all threads to write this.
                        local_failure = 0;|]
            )
        | [FailureMsg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (KernelState -> [FailureMsg]
kernelFailures KernelState
kstate) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [FailureMsg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FailureMsg]
failures =
            if Kernel -> Bool
kernelFailureTolerant Kernel
kernel
              then (KernelSafety
SafetyNone, [])
              else -- No possible failures in this kernel, so if we make
              -- it past an initial check, then we are good to go.

                ( KernelSafety
SafetyCheap,
                  [C.citems|if (*global_failure >= 0) { return; }|]
                )
        | Bool
otherwise =
            if Bool -> Bool
not (KernelState -> Bool
kernelHasBarriers KernelState
kstate)
              then
                ( KernelSafety
SafetyFull,
                  [C.citems|if (*global_failure >= 0) { return; }|]
                )
              else
                ( KernelSafety
SafetyFull,
                  [C.citems|
                     volatile __local int local_failure;
                     if (failure_is_an_option) {
                       int failed = *global_failure >= 0;
                       if (failed) {
                         return;
                       }
                     }
                     // All threads write this value - it looks like CUDA has a compiler bug otherwise.
                     local_failure = 0;
                     barrier(CLK_LOCAL_MEM_FENCE);
                  |]
                )

      failure_params :: [Param]
failure_params =
        [ [C.cparam|__global int *global_failure|],
          [C.cparam|int failure_is_an_option|],
          [C.cparam|__global typename int64_t *global_failure_args|]
        ]

      ([Param]
shared_memory_param, [BlockItem]
prepare_shared_memory) =
        case KernelTarget
target of
          KernelTarget
TargetOpenCL ->
            ( [[C.cparam|__local typename uint64_t* shared_mem_aligned|]],
              [C.citems|__local unsigned char* shared_mem = (__local unsigned char*)shared_mem_aligned;|]
            )
          KernelTarget
TargetCUDA -> ([Param]
forall a. Monoid a => a
mempty, [BlockItem]
forall a. Monoid a => a
mempty)
          KernelTarget
TargetHIP -> ([Param]
forall a. Monoid a => a
mempty, [BlockItem]
forall a. Monoid a => a
mempty)

      params :: [Param]
params =
        [Param]
shared_memory_param
          [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ Int -> [Param] -> [Param]
forall a. Int -> [a] -> [a]
take (KernelSafety -> Int
numFailureParams KernelSafety
safety) [Param]
failure_params
          [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [Param]
use_params

      ([(Name, KernelConstExp)]
attribute_consts, Text
attribute) =
        case (BlockDim -> Maybe KernelConstExp)
-> [BlockDim] -> Maybe [KernelConstExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM BlockDim -> Maybe KernelConstExp
isConst ([BlockDim] -> Maybe [KernelConstExp])
-> [BlockDim] -> Maybe [KernelConstExp]
forall a b. (a -> b) -> a -> b
$ Kernel -> [BlockDim]
kernelBlockSize Kernel
kernel of
          Just [KernelConstExp
x, KernelConstExp
y, KernelConstExp
z] ->
            ( [(Name
xv, KernelConstExp
x), (Name
yv, KernelConstExp
y), (Name
zv, KernelConstExp
z)],
              Text
"FUTHARK_KERNEL_SIZED" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (Name, Name, Name) -> Text
forall a. Pretty a => a -> Text
prettyText (Name
xv, Name
yv, Name
zv) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n"
            )
            where
              xv :: Name
xv = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_dim1"
              yv :: Name
yv = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_dim2"
              zv :: Name
zv = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_dim3"
          Just [KernelConstExp
x, KernelConstExp
y] ->
            ( [(Name
xv, KernelConstExp
x), (Name
yv, KernelConstExp
y)],
              Text
"FUTHARK_KERNEL_SIZED" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (Name, Name, Int) -> Text
forall a. Pretty a => a -> Text
prettyText (Name
xv, Name
yv, Int
1 :: Int) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n"
            )
            where
              xv :: Name
xv = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_dim1"
              yv :: Name
yv = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_dim2"
          Just [KernelConstExp
x] ->
            ( [(Name
xv, KernelConstExp
x)],
              Text
"FUTHARK_KERNEL_SIZED" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> (Name, Int, Int) -> Text
forall a. Pretty a => a -> Text
prettyText (Name
xv, Int
1 :: Int, Int
1 :: Int) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n"
            )
            where
              xv :: Name
xv = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_dim1"
          Maybe [KernelConstExp]
_ -> ([(Name, KernelConstExp)]
forall a. Monoid a => a
mempty, Text
"FUTHARK_KERNEL\n")

      kernel_fun :: Text
kernel_fun =
        Text
attribute
          Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Func -> Text
funcText
            [C.cfun|void $id:name ($params:params) {
                    $items:(mconcat unpack_params)
                    $items:const_defs
                    $items:prepare_shared_memory
                    $items:(mconcat shared_memory_init)
                    $items:error_init
                    $items:kernel_body

                    $id:(errorLabel kstate): return;

                    $items:const_undefs
                }|]
  (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((ToOpenCL -> ToOpenCL) -> OnKernelM ())
-> (ToOpenCL -> ToOpenCL) -> OnKernelM ()
forall a b. (a -> b) -> a -> b
$ \ToOpenCL
s ->
    ToOpenCL
s
      { clGPU :: Map Name (KernelSafety, Text)
clGPU = Name
-> (KernelSafety, Text)
-> Map Name (KernelSafety, Text)
-> Map Name (KernelSafety, Text)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Name
name (KernelSafety
safety, Text
kernel_fun) (Map Name (KernelSafety, Text) -> Map Name (KernelSafety, Text))
-> Map Name (KernelSafety, Text) -> Map Name (KernelSafety, Text)
forall a b. (a -> b) -> a -> b
$ ToOpenCL -> Map Name (KernelSafety, Text)
clGPU ToOpenCL
s,
        clUsedTypes :: Set PrimType
clUsedTypes = Kernel -> Set PrimType
typesInKernel Kernel
kernel Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> Set PrimType
clUsedTypes ToOpenCL
s,
        clFailures :: [FailureMsg]
clFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
kstate,
        clConstants :: [(Name, KernelConstExp)]
clConstants = [(Name, KernelConstExp)]
attribute_consts [(Name, KernelConstExp)]
-> [(Name, KernelConstExp)] -> [(Name, KernelConstExp)]
forall a. Semigroup a => a -> a -> a
<> [(Name, KernelConstExp)]
kernel_consts [(Name, KernelConstExp)]
-> [(Name, KernelConstExp)] -> [(Name, KernelConstExp)]
forall a. Semigroup a => a -> a -> a
<> ToOpenCL -> [(Name, KernelConstExp)]
clConstants ToOpenCL
s
      }

  -- The error handling stuff is automatically added later.
  let args :: [KernelArg]
args = Kernel -> [KernelArg]
kernelArgs Kernel
kernel

  OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a. a -> ReaderT Env (State ToOpenCL) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (OpenCL -> ReaderT Env (State ToOpenCL) OpenCL)
-> OpenCL -> ReaderT Env (State ToOpenCL) OpenCL
forall a b. (a -> b) -> a -> b
$ KernelSafety
-> Name
-> Count Bytes (TExp Int64)
-> [KernelArg]
-> [Exp]
-> [BlockDim]
-> OpenCL
LaunchKernel KernelSafety
safety Name
name Count Bytes (TExp Int64)
shared_memory_bytes [KernelArg]
args [Exp]
num_tblocks [BlockDim]
tblock_size
  where
    name :: Name
name = Kernel -> Name
kernelName Kernel
kernel
    num_tblocks :: [Exp]
num_tblocks = Kernel -> [Exp]
kernelNumBlocks Kernel
kernel
    tblock_size :: [BlockDim]
tblock_size = Kernel -> [BlockDim]
kernelBlockSize Kernel
kernel
    padTo8 :: a -> a
padTo8 a
e = a
e a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
8 a -> a -> a
forall a. Num a => a -> a -> a
- (a
e a -> a -> a
forall e. IntegralExp e => e -> e -> e
`rem` a
8)) a -> a -> a
forall e. IntegralExp e => e -> e -> e
`rem` a
8)

    prepareSharedMemory :: a -> (a, Count u a) -> (Exp, [BlockItem])
prepareSharedMemory a
offset (a
mem, Count a
size) =
      let offset_v :: Name
offset_v = Text -> Name
nameFromText (Text -> Name) -> Text -> Name
forall a b. (a -> b) -> a -> b
$ a -> Text
forall a. Pretty a => a -> Text
prettyText a
mem Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_offset"
       in ( [C.cexp|$id:offset_v|],
            [C.citems|
             volatile __local $ty:defaultMemBlockType $id:mem = &shared_mem[$exp:offset];
             const typename int64_t $id:offset_v = $exp:offset + $exp:(padTo8 size);
             |]
          )

useAsParam :: KernelUse -> Maybe (C.Param, [C.BlockItem])
useAsParam :: KernelUse -> Maybe (Param, [BlockItem])
useAsParam (ScalarUse VName
name PrimType
pt) = do
  let name_bits :: Text
name_bits = Text -> Text
zEncodeText (VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
name) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_bits"
      ctp :: Type
ctp = case PrimType
pt of
        -- OpenCL does not permit bool as a kernel parameter type.
        PrimType
Bool -> [C.cty|unsigned char|]
        PrimType
Unit -> [C.cty|unsigned char|]
        PrimType
_ -> PrimType -> Type
primStorageType PrimType
pt
  if Type
ctp Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType -> Type
primTypeToCType PrimType
pt
    then (Param, [BlockItem]) -> Maybe (Param, [BlockItem])
forall a. a -> Maybe a
Just ([C.cparam|$ty:ctp $id:name|], [])
    else
      let name_bits_e :: Exp
name_bits_e = [C.cexp|$id:name_bits|]
       in (Param, [BlockItem]) -> Maybe (Param, [BlockItem])
forall a. a -> Maybe a
Just
            ( [C.cparam|$ty:ctp $id:name_bits|],
              [[C.citem|$ty:(primTypeToCType pt) $id:name = $exp:(fromStorage pt name_bits_e);|]]
            )
useAsParam (MemoryUse VName
name) =
  (Param, [BlockItem]) -> Maybe (Param, [BlockItem])
forall a. a -> Maybe a
Just ([C.cparam|__global $ty:defaultMemBlockType $id:name|], [])
useAsParam ConstUse {} =
  Maybe (Param, [BlockItem])
forall a. Maybe a
Nothing

-- Constants are #defined as macros.  Since a constant name in one
-- kernel might potentially (although unlikely) also be used for
-- something else in another kernel, we #undef them after the kernel.
constDef :: Name -> KernelUse -> Maybe ((Name, KernelConstExp), (C.BlockItem, C.BlockItem))
constDef :: Name
-> KernelUse
-> Maybe ((Name, KernelConstExp), (BlockItem, BlockItem))
constDef Name
kernel_name (ConstUse VName
v KernelConstExp
e) =
  ((Name, KernelConstExp), (BlockItem, BlockItem))
-> Maybe ((Name, KernelConstExp), (BlockItem, BlockItem))
forall a. a -> Maybe a
Just
    ( (Text -> Name
nameFromText Text
v', KernelConstExp
e),
      ( [C.citem|$escstm:(T.unpack def)|],
        [C.citem|$escstm:(T.unpack undef)|]
      )
    )
  where
    v' :: Text
v' = Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
kernel_name Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> VName -> Text
forall a. Pretty a => a -> Text
prettyText VName
v
    def :: Text
def = Text
"#define " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v SrcLoc
forall a. Monoid a => a
mempty) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" (" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
v' Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
")"
    undef :: Text
undef = Text
"#undef " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Id -> Text
idText (VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
v SrcLoc
forall a. Monoid a => a
mempty)
constDef Name
_ KernelUse
_ = Maybe ((Name, KernelConstExp), (BlockItem, BlockItem))
forall a. Maybe a
Nothing

commonPrelude :: T.Text
commonPrelude :: Text
commonPrelude =
  Text
halfH
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
cScalarDefs
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
atomicsH
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
transposeCL
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
copyCL

genOpenClPrelude :: S.Set PrimType -> T.Text
genOpenClPrelude :: Set PrimType -> Text
genOpenClPrelude Set PrimType
ts =
  Text
"#define FUTHARK_OPENCL\n"
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
enable_f64
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
preludeCL
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
commonPrelude
  where
    enable_f64 :: Text
enable_f64
      | FloatType -> PrimType
FloatType FloatType
Float64 PrimType -> Set PrimType -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.member` Set PrimType
ts =
          [untrimming|#define FUTHARK_F64_ENABLED|]
      | Bool
otherwise = Text
forall a. Monoid a => a
mempty

genCUDAPrelude :: T.Text
genCUDAPrelude :: Text
genCUDAPrelude =
  Text
"#define FUTHARK_CUDA\n"
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
preludeCU
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
commonPrelude

genHIPPrelude :: T.Text
genHIPPrelude :: Text
genHIPPrelude =
  Text
"#define FUTHARK_HIP\n"
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
preludeCU
    Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
commonPrelude

kernelArgs :: Kernel -> [KernelArg]
kernelArgs :: Kernel -> [KernelArg]
kernelArgs = (KernelUse -> Maybe KernelArg) -> [KernelUse] -> [KernelArg]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe KernelUse -> Maybe KernelArg
useToArg ([KernelUse] -> [KernelArg])
-> (Kernel -> [KernelUse]) -> Kernel -> [KernelArg]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> [KernelUse]
kernelUses
  where
    useToArg :: KernelUse -> Maybe KernelArg
useToArg (MemoryUse VName
mem) = KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ VName -> KernelArg
MemKArg VName
mem
    useToArg (ScalarUse VName
v PrimType
pt) = KernelArg -> Maybe KernelArg
forall a. a -> Maybe a
Just (KernelArg -> Maybe KernelArg) -> KernelArg -> Maybe KernelArg
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType -> KernelArg
ValueKArg (VName -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v PrimType
pt) PrimType
pt
    useToArg ConstUse {} = Maybe KernelArg
forall a. Maybe a
Nothing

nextErrorLabel :: GC.CompilerM KernelOp KernelState String
nextErrorLabel :: CompilerM KernelOp KernelState [Char]
nextErrorLabel =
  KernelState -> [Char]
errorLabel (KernelState -> [Char])
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GC.getUserState

incErrorLabel :: GC.CompilerM KernelOp KernelState ()
incErrorLabel :: CompilerM KernelOp KernelState ()
incErrorLabel =
  (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelNextSync :: Int
kernelNextSync = KernelState -> Int
kernelNextSync KernelState
s Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}

pendingError :: Bool -> GC.CompilerM KernelOp KernelState ()
pendingError :: Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
b =
  (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelSyncPending :: Bool
kernelSyncPending = Bool
b}

hasCommunication :: ImpGPU.KernelCode -> Bool
hasCommunication :: Code KernelOp -> Bool
hasCommunication = (KernelOp -> Bool) -> Code KernelOp -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any KernelOp -> Bool
communicates
  where
    communicates :: KernelOp -> Bool
communicates ErrorSync {} = Bool
True
    communicates Barrier {} = Bool
True
    communicates KernelOp
_ = Bool
False

-- Whether we are generating code for a kernel or a device function.
-- This has minor effects, such as exactly how failures are
-- propagated.
data OpsMode = KernelMode | FunMode deriving (OpsMode -> OpsMode -> Bool
(OpsMode -> OpsMode -> Bool)
-> (OpsMode -> OpsMode -> Bool) -> Eq OpsMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: OpsMode -> OpsMode -> Bool
== :: OpsMode -> OpsMode -> Bool
$c/= :: OpsMode -> OpsMode -> Bool
/= :: OpsMode -> OpsMode -> Bool
Eq)

inKernelOperations ::
  Env ->
  OpsMode ->
  ImpGPU.KernelCode ->
  GC.Operations KernelOp KernelState
inKernelOperations :: Env -> OpsMode -> Code KernelOp -> Operations KernelOp KernelState
inKernelOperations Env
env OpsMode
mode Code KernelOp
body =
  GC.Operations
    { opsCompiler :: OpCompiler KernelOp KernelState
GC.opsCompiler = OpCompiler KernelOp KernelState
kernelOps,
      opsMemoryType :: MemoryType KernelOp KernelState
GC.opsMemoryType = MemoryType KernelOp KernelState
forall {f :: * -> *}. Applicative f => [Char] -> f Type
kernelMemoryType,
      opsWriteScalar :: WriteScalar KernelOp KernelState
GC.opsWriteScalar = WriteScalar KernelOp KernelState
forall {op} {s}. WriteScalar op s
kernelWriteScalar,
      opsReadScalar :: ReadScalar KernelOp KernelState
GC.opsReadScalar = ReadScalar KernelOp KernelState
forall {op} {s}. ReadScalar op s
kernelReadScalar,
      opsAllocate :: Allocate KernelOp KernelState
GC.opsAllocate = Allocate KernelOp KernelState
cannotAllocate,
      opsDeallocate :: Allocate KernelOp KernelState
GC.opsDeallocate = Allocate KernelOp KernelState
cannotDeallocate,
      opsCopy :: Copy KernelOp KernelState
GC.opsCopy = Copy KernelOp KernelState
copyInKernel,
      opsCopies :: Map (Space, Space) (DoCopy KernelOp KernelState)
GC.opsCopies = Map (Space, Space) (DoCopy KernelOp KernelState)
forall a. Monoid a => a
mempty,
      opsFatMemory :: Bool
GC.opsFatMemory = Bool
False,
      opsError :: ErrorCompiler KernelOp KernelState
GC.opsError = ErrorCompiler KernelOp KernelState
errorInKernel,
      opsCall :: CallCompiler KernelOp KernelState
GC.opsCall = CallCompiler KernelOp KernelState
forall {a}.
ToIdent a =>
[a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel,
      opsCritical :: ([BlockItem], [BlockItem])
GC.opsCritical = ([BlockItem], [BlockItem])
forall a. Monoid a => a
mempty
    }
  where
    has_communication :: Bool
has_communication = Code KernelOp -> Bool
hasCommunication Code KernelOp
body

    fence :: Fence -> Exp
fence Fence
FenceLocal = [C.cexp|CLK_LOCAL_MEM_FENCE|]
    fence Fence
FenceGlobal = [C.cexp|CLK_GLOBAL_MEM_FENCE | CLK_LOCAL_MEM_FENCE|]

    kernelOps :: GC.OpCompiler KernelOp KernelState
    kernelOps :: OpCompiler KernelOp KernelState
kernelOps (GetBlockId VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_tblock_id($int:i);|]
    kernelOps (GetLocalId VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_id($int:i);|]
    kernelOps (GetLocalSize VName
v Int
i) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = get_local_size($int:i);|]
    kernelOps (GetLockstepWidth VName
v) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
    kernelOps (Barrier Fence
f) = do
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
    kernelOps (MemFence Fence
FenceLocal) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_local();|]
    kernelOps (MemFence Fence
FenceGlobal) =
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|mem_fence_global();|]
    kernelOps (SharedAlloc VName
name Count Bytes (TExp Int64)
size) = do
      VName
name' <- [Char] -> CompilerM KernelOp KernelState VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> CompilerM KernelOp KernelState VName)
-> [Char] -> CompilerM KernelOp KernelState VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString VName
name [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_backing"
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelSharedMemory :: [SharedMemoryUse]
kernelSharedMemory = (VName
name', Count Bytes (TExp Int64)
size) SharedMemoryUse -> [SharedMemoryUse] -> [SharedMemoryUse]
forall a. a -> [a] -> [a]
: KernelState -> [SharedMemoryUse]
kernelSharedMemory KernelState
s}
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:name = (__local unsigned char*) $id:name';|]
    kernelOps (ErrorSync Fence
f) = do
      [Char]
label <- CompilerM KernelOp KernelState [Char]
nextErrorLabel
      Bool
pending <- KernelState -> Bool
kernelSyncPending (KernelState -> Bool)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GC.getUserState
      Bool
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
pending (CompilerM KernelOp KernelState ()
 -> CompilerM KernelOp KernelState ())
-> CompilerM KernelOp KernelState ()
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ do
        Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
False
        Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:label: barrier($exp:(fence f));|]
        Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|if (local_failure) { return; }|]
      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|barrier($exp:(fence f));|]
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s -> KernelState
s {kernelHasBarriers :: Bool
kernelHasBarriers = Bool
True}
      CompilerM KernelOp KernelState ()
incErrorLabel
    kernelOps (Atomic Space
space AtomicOp
aop) = Space -> AtomicOp -> CompilerM KernelOp KernelState ()
forall {op} {s}. Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
space AtomicOp
aop

    atomicCast :: Space -> Type -> f Type
atomicCast Space
s Type
t = do
      let volatile :: [TypeQual]
volatile = [C.ctyquals|volatile|]
      let quals :: [TypeQual]
quals = case Space
s of
            Space [Char]
sid -> [Char] -> [TypeQual]
pointerQuals [Char]
sid
            Space
_ -> [Char] -> [TypeQual]
pointerQuals [Char]
"global"
      Type -> f Type
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(volatile++quals) $ty:t|]

    atomicSpace :: Space -> [Char]
atomicSpace (Space [Char]
sid) = [Char]
sid
    atomicSpace Space
_ = [Char]
"global"

    doAtomic :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val [Char]
op Type
ty = do
      Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp t VName -> Exp) -> TPrimExp t VName -> Exp
forall a b. (a -> b) -> a -> b
$ Count u (TPrimExp t VName) -> TPrimExp t VName
forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- Space -> Type -> CompilerM op s Type
forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op'(&(($ty:cast *)$id:arr)[$exp:ind'], ($ty:ty) $exp:val');|]
      where
        op' :: [Char]
op' = [Char]
op [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ p -> [Char]
forall a. Pretty a => a -> [Char]
prettyString p
t [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s

    doAtomicCmpXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
cmp Exp
val Type
ty = do
      Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp t VName -> Exp) -> TPrimExp t VName -> Exp
forall a b. (a -> b) -> a -> b
$ Count u (TPrimExp t VName) -> TPrimExp t VName
forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
cmp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
cmp
      Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Type
cast <- Space -> Type -> CompilerM op s Type
forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:cmp', $exp:val');|]
      where
        op :: [Char]
op = [Char]
"atomic_cmpxchg_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ p -> [Char]
forall a. Pretty a => a -> [Char]
prettyString p
t [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
    doAtomicXchg :: Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s p
t a
old a
arr Count u (TPrimExp t VName)
ind Exp
val Type
ty = do
      Type
cast <- Space -> Type -> CompilerM op s Type
forall {f :: * -> *}. Applicative f => Space -> Type -> f Type
atomicCast Space
s Type
ty
      Exp
ind' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp t VName -> Exp) -> TPrimExp t VName -> Exp
forall a b. (a -> b) -> a -> b
$ Count u (TPrimExp t VName) -> TPrimExp t VName
forall {k} (u :: k) e. Count u e -> e
unCount Count u (TPrimExp t VName)
ind
      Exp
val' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
val
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
GC.stm [C.cstm|$id:old = $id:op(&(($ty:cast *)$id:arr)[$exp:ind'], $exp:val');|]
      where
        op :: [Char]
op = [Char]
"atomic_chg_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ p -> [Char]
forall a. Pretty a => a -> [Char]
prettyString p
t [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Space -> [Char]
atomicSpace Space
s
    -- First the 64-bit operations.
    atomicOps :: Space -> AtomicOp -> CompilerM op s ()
atomicOps Space
s (AtomicAdd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_add" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicFAdd FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> FloatType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
Float64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_fadd" [C.cty|double|]
    atomicOps Space
s (AtomicSMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smax" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicSMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smin" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicUMax IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umax" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicUMin IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umin" [C.cty|unsigned int64_t|]
    atomicOps Space
s (AtomicAnd IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_and" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicOr IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_or" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXor IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
Int64 VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_xor" [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicCmpXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|typename int64_t|]
    atomicOps Space
s (AtomicXchg (IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s (IntType -> PrimType
IntType IntType
Int64) VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|typename int64_t|]
    --
    atomicOps Space
s (AtomicAdd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_add" [C.cty|int|]
    atomicOps Space
s (AtomicFAdd FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> FloatType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s FloatType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_fadd" [C.cty|float|]
    atomicOps Space
s (AtomicSMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smax" [C.cty|int|]
    atomicOps Space
s (AtomicSMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_smin" [C.cty|int|]
    atomicOps Space
s (AtomicUMax IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umax" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicUMin IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_umin" [C.cty|unsigned int|]
    atomicOps Space
s (AtomicAnd IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_and" [C.cty|int|]
    atomicOps Space
s (AtomicOr IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_or" [C.cty|int|]
    atomicOps Space
s (AtomicXor IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> IntType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> [Char]
-> Type
-> CompilerM op s ()
doAtomic Space
s IntType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [Char]
"atomic_xor" [C.cty|int|]
    atomicOps Space
s (AtomicCmpXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Exp
-> Type
-> CompilerM op s ()
doAtomicCmpXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
cmp Exp
val [C.cty|int|]
    atomicOps Space
s (AtomicXchg PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val) =
      Space
-> PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> Exp
-> Type
-> CompilerM op s ()
forall {k} {k} {a} {a} {p} {u :: k} {t :: k} {op} {s}.
(ToIdent a, ToIdent a, Pretty p) =>
Space
-> p
-> a
-> a
-> Count u (TPrimExp t VName)
-> Exp
-> Type
-> CompilerM op s ()
doAtomicXchg Space
s PrimType
t VName
old VName
arr Count Elements (TExp Int64)
ind Exp
val [C.cty|int|]

    cannotAllocate :: GC.Allocate KernelOp KernelState
    cannotAllocate :: Allocate KernelOp KernelState
cannotAllocate Exp
_ =
      [Char] -> Exp -> Exp -> [Char] -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot allocate memory in kernel"

    cannotDeallocate :: GC.Deallocate KernelOp KernelState
    cannotDeallocate :: Allocate KernelOp KernelState
cannotDeallocate Exp
_ Exp
_ =
      [Char] -> Exp -> [Char] -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot deallocate memory in kernel"

    copyInKernel :: GC.Copy KernelOp KernelState
    copyInKernel :: Copy KernelOp KernelState
copyInKernel CopyBarrier
_ Exp
_ Exp
_ Space
_ Exp
_ Exp
_ Space
_ Exp
_ =
      [Char] -> CompilerM KernelOp KernelState ()
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot bulk copy in kernel."

    kernelMemoryType :: [Char] -> f Type
kernelMemoryType [Char]
space =
      Type -> f Type
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cty|$tyquals:(pointerQuals space) $ty:defaultMemBlockType|]

    kernelWriteScalar :: WriteScalar op s
kernelWriteScalar =
      ([Char] -> [TypeQual]) -> WriteScalar op s
forall op s. ([Char] -> [TypeQual]) -> WriteScalar op s
GC.writeScalarPointerWithQuals [Char] -> [TypeQual]
pointerQuals

    kernelReadScalar :: ReadScalar op s
kernelReadScalar =
      ([Char] -> [TypeQual]) -> ReadScalar op s
forall op s. ([Char] -> [TypeQual]) -> ReadScalar op s
GC.readScalarPointerWithQuals [Char] -> [TypeQual]
pointerQuals

    whatNext :: CompilerM KernelOp KernelState [BlockItem]
whatNext = do
      [Char]
label <- CompilerM KernelOp KernelState [Char]
nextErrorLabel
      Bool -> CompilerM KernelOp KernelState ()
pendingError Bool
True
      [BlockItem] -> CompilerM KernelOp KernelState [BlockItem]
forall a. a -> CompilerM KernelOp KernelState a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BlockItem] -> CompilerM KernelOp KernelState [BlockItem])
-> [BlockItem] -> CompilerM KernelOp KernelState [BlockItem]
forall a b. (a -> b) -> a -> b
$
        if Bool
has_communication
          then [C.citems|local_failure = 1; goto $id:label;|]
          else
            if OpsMode
mode OpsMode -> OpsMode -> Bool
forall a. Eq a => a -> a -> Bool
== OpsMode
FunMode
              then [C.citems|return 1;|]
              else [C.citems|return;|]

    callInKernel :: [a] -> Name -> [Exp] -> CompilerM KernelOp KernelState ()
callInKernel [a]
dests Name
fname [Exp]
args
      | Name -> Env -> Bool
functionMayFail Name
fname Env
env = do
          let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
              args' :: [Exp]
args' =
                [C.cexp|global_failure|]
                  Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [C.cexp|global_failure_args|]
                  Exp -> [Exp] -> [Exp]
forall a. a -> [a] -> [a]
: [Exp]
out_args
                  [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
args

          [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext
          BlockItem -> CompilerM KernelOp KernelState ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|if ($id:(funName fname)($args:args') != 0) { $items:what_next; }|]
      | Bool
otherwise = do
          let out_args :: [Exp]
out_args = [[C.cexp|&$id:d|] | a
d <- [a]
dests]
              args' :: [Exp]
args' = [Exp]
out_args [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
args
          BlockItem -> CompilerM KernelOp KernelState ()
forall op s. BlockItem -> CompilerM op s ()
GC.item [C.citem|$id:(funName fname)($args:args');|]

    errorInKernel :: ErrorCompiler KernelOp KernelState
errorInKernel msg :: ErrorMsg Exp
msg@(ErrorMsg [ErrorMsgPart Exp]
parts) [Char]
backtrace = do
      Int
n <- [FailureMsg] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([FailureMsg] -> Int)
-> (KernelState -> [FailureMsg]) -> KernelState -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelState -> [FailureMsg]
kernelFailures (KernelState -> Int)
-> CompilerM KernelOp KernelState KernelState
-> CompilerM KernelOp KernelState Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerM KernelOp KernelState KernelState
forall op s. CompilerM op s s
GC.getUserState
      (KernelState -> KernelState) -> CompilerM KernelOp KernelState ()
forall s op. (s -> s) -> CompilerM op s ()
GC.modifyUserState ((KernelState -> KernelState) -> CompilerM KernelOp KernelState ())
-> (KernelState -> KernelState)
-> CompilerM KernelOp KernelState ()
forall a b. (a -> b) -> a -> b
$ \KernelState
s ->
        KernelState
s {kernelFailures :: [FailureMsg]
kernelFailures = KernelState -> [FailureMsg]
kernelFailures KernelState
s [FailureMsg] -> [FailureMsg] -> [FailureMsg]
forall a. [a] -> [a] -> [a]
++ [ErrorMsg Exp -> [Char] -> FailureMsg
FailureMsg ErrorMsg Exp
msg [Char]
backtrace]}
      let setArgs :: a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
_ [] = [Stm] -> CompilerM op s [Stm]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
          setArgs a
i (ErrorString {} : [ErrorMsgPart Exp]
parts') = a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs a
i [ErrorMsgPart Exp]
parts'
          -- FIXME: bogus for non-ints.
          setArgs a
i (ErrorVal PrimType
_ Exp
x : [ErrorMsgPart Exp]
parts') = do
            Exp
x' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
GC.compileExp Exp
x
            [Stm]
stms <- a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) [ErrorMsgPart Exp]
parts'
            [Stm] -> CompilerM op s [Stm]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm] -> CompilerM op s [Stm]) -> [Stm] -> CompilerM op s [Stm]
forall a b. (a -> b) -> a -> b
$ [C.cstm|global_failure_args[$int:i] = (typename int64_t)$exp:x';|] Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
stms
      [Stm]
argstms <- Int -> [ErrorMsgPart Exp] -> CompilerM KernelOp KernelState [Stm]
forall {a} {op} {s}.
(Show a, Integral a) =>
a -> [ErrorMsgPart Exp] -> CompilerM op s [Stm]
setArgs (Int
0 :: Int) [ErrorMsgPart Exp]
parts

      [BlockItem]
what_next <- CompilerM KernelOp KernelState [BlockItem]
whatNext

      Stm -> CompilerM KernelOp KernelState ()
forall op s. Stm -> CompilerM op s ()
GC.stm
        [C.cstm|{ if (atomic_cmpxchg_i32_global(global_failure, -1, $int:n) == -1)
                                 { $stms:argstms; }
                                 $items:what_next
                               }|]

--- Checking requirements

typesInKernel :: Kernel -> S.Set PrimType
typesInKernel :: Kernel -> Set PrimType
typesInKernel Kernel
kernel = Code KernelOp -> Set PrimType
typesInCode (Code KernelOp -> Set PrimType) -> Code KernelOp -> Set PrimType
forall a b. (a -> b) -> a -> b
$ Kernel -> Code KernelOp
kernelBody Kernel
kernel

typesInCode :: ImpGPU.KernelCode -> S.Set PrimType
typesInCode :: Code KernelOp -> Set PrimType
typesInCode Code KernelOp
Skip = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (Code KernelOp
c1 :>>: Code KernelOp
c2) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (For VName
_ Exp
e Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (While (TPrimExp Exp
e) Code KernelOp
c) = Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode DeclareMem {} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (DeclareScalar VName
_ Volatility
_ PrimType
t) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (DeclareArray VName
_ PrimType
t ArrayContents
_) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (Allocate VName
_ (Count (TPrimExp Exp
e)) Space
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode Free {} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (Copy PrimType
_ [Count Elements (TExp Int64)]
shape (VName, Space)
_ (Count (TPrimExp Exp
dstoffset), [Count Elements (TExp Int64)]
dststrides) (VName, Space)
_ (Count (TPrimExp Exp
srcoffset), [Count Elements (TExp Int64)]
srcstrides)) =
  (Count Elements (TExp Int64) -> Set PrimType)
-> [Count Elements (TExp Int64)] -> Set PrimType
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp -> Set PrimType
typesInExp (Exp -> Set PrimType)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> Set PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TExp Int64)]
shape
    Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
dstoffset
    Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> (Count Elements (TExp Int64) -> Set PrimType)
-> [Count Elements (TExp Int64)] -> Set PrimType
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp -> Set PrimType
typesInExp (Exp -> Set PrimType)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> Set PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TExp Int64)]
dststrides
    Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
srcoffset
    Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> (Count Elements (TExp Int64) -> Set PrimType)
-> [Count Elements (TExp Int64)] -> Set PrimType
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp -> Set PrimType
typesInExp (Exp -> Set PrimType)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> Set PrimType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TExp Int64)]
srcstrides
typesInCode (Write VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_ Exp
e2) =
  Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInCode (Read VName
_ VName
_ (Count (TPrimExp Exp
e1)) PrimType
t Space
_ Volatility
_) =
  Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t
typesInCode (SetScalar VName
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode SetMem {} = Set PrimType
forall a. Monoid a => a
mempty
typesInCode (Call [VName]
_ Name
_ [Arg]
es) = [Set PrimType] -> Set PrimType
forall a. Monoid a => [a] -> a
mconcat ([Set PrimType] -> Set PrimType) -> [Set PrimType] -> Set PrimType
forall a b. (a -> b) -> a -> b
$ (Arg -> Set PrimType) -> [Arg] -> [Set PrimType]
forall a b. (a -> b) -> [a] -> [b]
map Arg -> Set PrimType
typesInArg [Arg]
es
  where
    typesInArg :: Arg -> Set PrimType
typesInArg MemArg {} = Set PrimType
forall a. Monoid a => a
mempty
    typesInArg (ExpArg Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (If (TPrimExp Exp
e) Code KernelOp
c1 Code KernelOp
c2) =
  Exp -> Set PrimType
typesInExp Exp
e Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c2
typesInCode (Assert Exp
e ErrorMsg Exp
_ (SrcLoc, [SrcLoc])
_) = Exp -> Set PrimType
typesInExp Exp
e
typesInCode (Comment Text
_ Code KernelOp
c) = Code KernelOp -> Set PrimType
typesInCode Code KernelOp
c
typesInCode (DebugPrint [Char]
_ Maybe Exp
v) = Set PrimType -> (Exp -> Set PrimType) -> Maybe Exp -> Set PrimType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Set PrimType
forall a. Monoid a => a
mempty Exp -> Set PrimType
typesInExp Maybe Exp
v
typesInCode (TracePrint ErrorMsg Exp
msg) = (Exp -> Set PrimType) -> ErrorMsg Exp -> Set PrimType
forall m a. Monoid m => (a -> m) -> ErrorMsg a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Exp -> Set PrimType
typesInExp ErrorMsg Exp
msg
typesInCode Op {} = Set PrimType
forall a. Monoid a => a
mempty

typesInExp :: Exp -> S.Set PrimType
typesInExp :: Exp -> Set PrimType
typesInExp (ValueExp PrimValue
v) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton (PrimType -> Set PrimType) -> PrimType -> Set PrimType
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
typesInExp (BinOpExp BinOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (CmpOpExp CmpOp
_ Exp
e1 Exp
e2) = Exp -> Set PrimType
typesInExp Exp
e1 Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e2
typesInExp (ConvOpExp ConvOp
op Exp
e) = [PrimType] -> Set PrimType
forall a. Ord a => [a] -> Set a
S.fromList [PrimType
from, PrimType
to] Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> Exp -> Set PrimType
typesInExp Exp
e
  where
    (PrimType
from, PrimType
to) = ConvOp -> (PrimType, PrimType)
convOpType ConvOp
op
typesInExp (UnOpExp UnOp
_ Exp
e) = Exp -> Set PrimType
typesInExp Exp
e
typesInExp (FunExp [Char]
_ [Exp]
args PrimType
t) = PrimType -> Set PrimType
forall a. a -> Set a
S.singleton PrimType
t Set PrimType -> Set PrimType -> Set PrimType
forall a. Semigroup a => a -> a -> a
<> [Set PrimType] -> Set PrimType
forall a. Monoid a => [a] -> a
mconcat ((Exp -> Set PrimType) -> [Exp] -> [Set PrimType]
forall a b. (a -> b) -> [a] -> [b]
map Exp -> Set PrimType
typesInExp [Exp]
args)
typesInExp LeafExp {} = Set PrimType
forall a. Monoid a => a
mempty