-- | Code generation for Python with OpenCL.
module Futhark.CodeGen.Backends.PyOpenCL
  ( compileProg,
  )
where

import Control.Monad
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericPython qualified as Py
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.Backends.GenericPython.Options
import Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
import Futhark.CodeGen.ImpCode.OpenCL qualified as Imp
import Futhark.CodeGen.ImpGen.OpenCL qualified as ImpGen
import Futhark.CodeGen.RTS.Python (openclPy)
import Futhark.IR.GPUMem (GPUMem, Prog)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeText)
import Futhark.Util.Pretty (prettyString, prettyText)

-- | Compile the program to Python with calls to OpenCL.
compileProg ::
  MonadFreshNames m =>
  Py.CompilerMode ->
  String ->
  Prog GPUMem ->
  m (ImpGen.Warnings, T.Text)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
CompilerMode -> SpaceId -> Prog GPUMem -> m (Warnings, Text)
compileProg CompilerMode
mode SpaceId
class_name Prog GPUMem
prog = do
  ( Warnings
ws,
    Imp.Program
      Text
opencl_code
      Text
opencl_prelude
      Map Name KernelSafety
kernels
      [PrimType]
types
      ParamMap
sizes
      [FailureMsg]
failures
      Definitions OpenCL
prog'
    ) <-
    Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
ImpGen.compileProg Prog GPUMem
prog
  -- prepare the strings for assigning the kernels and set them as global
  let assign :: SpaceId
assign =
        [SpaceId] -> SpaceId
unlines
          ([SpaceId] -> SpaceId) -> [SpaceId] -> SpaceId
forall a b. (a -> b) -> a -> b
$ (Name -> SpaceId) -> [Name] -> [SpaceId]
forall a b. (a -> b) -> [a] -> [b]
map
            ( \Name
x ->
                PyStmt -> SpaceId
forall a. Pretty a => a -> SpaceId
prettyString (PyStmt -> SpaceId) -> PyStmt -> SpaceId
forall a b. (a -> b) -> a -> b
$
                  PyExp -> PyExp -> PyStmt
Assign
                    (SpaceId -> PyExp
Var (Text -> SpaceId
T.unpack (Text
"self." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText (Name -> Text
nameToText Name
x) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_var")))
                    (SpaceId -> PyExp
Var (SpaceId -> PyExp) -> SpaceId -> PyExp
forall a b. (a -> b) -> a -> b
$ Text -> SpaceId
T.unpack (Text -> SpaceId) -> Text -> SpaceId
forall a b. (a -> b) -> a -> b
$ Text
"program." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText (Name -> Text
nameToText Name
x))
            )
          ([Name] -> [SpaceId]) -> [Name] -> [SpaceId]
forall a b. (a -> b) -> a -> b
$ Map Name KernelSafety -> [Name]
forall k a. Map k a -> [k]
M.keys Map Name KernelSafety
kernels

  let defines :: [PyStmt]
defines =
        [ PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"synchronous") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
False,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_platform") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"build_options") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List [],
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_device") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_threshold") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_group_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_num_groups") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_reg_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"fut_opencl_src") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Text -> PyExp
RawStringLiteral (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Text
opencl_prelude Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
opencl_code
        ]

  let imports :: [PyStmt]
imports =
        [ SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"sys" Maybe SpaceId
forall a. Maybe a
Nothing,
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"numpy" (Maybe SpaceId -> PyStmt) -> Maybe SpaceId -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe SpaceId
forall a. a -> Maybe a
Just SpaceId
"np",
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"ctypes" (Maybe SpaceId -> PyStmt) -> Maybe SpaceId -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> Maybe SpaceId
forall a. a -> Maybe a
Just SpaceId
"ct",
          Text -> PyStmt
Escape Text
openclPy,
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"pyopencl.array" Maybe SpaceId
forall a. Maybe a
Nothing,
          SpaceId -> Maybe SpaceId -> PyStmt
Import SpaceId
"time" Maybe SpaceId
forall a. Maybe a
Nothing
        ]

  let constructor :: Constructor
constructor =
        [SpaceId] -> [PyStmt] -> Constructor
Py.Constructor
          [ SpaceId
"self",
            SpaceId
"build_options=build_options",
            SpaceId
"command_queue=None",
            SpaceId
"interactive=False",
            SpaceId
"platform_pref=preferred_platform",
            SpaceId
"device_pref=preferred_device",
            SpaceId
"default_group_size=default_group_size",
            SpaceId
"default_num_groups=default_num_groups",
            SpaceId
"default_tile_size=default_tile_size",
            SpaceId
"default_reg_tile_size=default_reg_tile_size",
            SpaceId
"default_threshold=default_threshold",
            SpaceId
"sizes=sizes"
          ]
          [Text -> PyStmt
Escape (Text -> PyStmt) -> Text -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PrimType] -> SpaceId -> ParamMap -> [FailureMsg] -> Text
openClInit [PrimType]
types SpaceId
assign ParamMap
sizes [FailureMsg]
failures]
      options :: [Option]
options =
        [ Option
            { optionLongName :: Text
optionLongName = Text
"platform",
              optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'p',
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_platform") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"device",
              optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'd',
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"preferred_device") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"build-option",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"str",
              optionAction :: [PyStmt]
optionAction =
                [ PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"build_options") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
                    SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"+" (SpaceId -> PyExp
Var SpaceId
"build_options") (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$
                      [PyExp] -> PyExp
List [SpaceId -> PyExp
Var SpaceId
"optarg"]
                ]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-threshold",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_threshold") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-group-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_group_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-num-groups",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_num_groups") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-tile-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_tile_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"default-reg-tile-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"default_reg_tile_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"optarg"]
            },
          Option
            { optionLongName :: Text
optionLongName = Text
"param",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = SpaceId -> OptionArgument
RequiredArgument SpaceId
"param_assignment",
              optionAction :: [PyStmt]
optionAction =
                [ PyExp -> PyExp -> PyStmt
Assign
                    ( PyExp -> PyIdx -> PyExp
Index
                        (SpaceId -> PyExp
Var SpaceId
"params")
                        ( PyExp -> PyIdx
IdxExp
                            ( PyExp -> PyIdx -> PyExp
Index
                                (SpaceId -> PyExp
Var SpaceId
"optarg")
                                (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
0))
                            )
                        )
                    )
                    (PyExp -> PyIdx -> PyExp
Index (SpaceId -> PyExp
Var SpaceId
"optarg") (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
1)))
                ]
            }
        ]

  (Warnings
ws,)
    (Text -> (Warnings, Text)) -> m Text -> m (Warnings, Text)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerMode
-> SpaceId
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations OpenCL ()
-> ()
-> [PyStmt]
-> [Option]
-> Definitions OpenCL
-> m Text
forall (m :: * -> *) op s.
MonadFreshNames m =>
CompilerMode
-> SpaceId
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations op s
-> s
-> [PyStmt]
-> [Option]
-> Definitions op
-> m Text
Py.compileProg
      CompilerMode
mode
      SpaceId
class_name
      Constructor
constructor
      [PyStmt]
imports
      [PyStmt]
defines
      Operations OpenCL ()
operations
      ()
      [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"sync" [SpaceId -> PyExp
Var SpaceId
"self"]]
      [Option]
options
      Definitions OpenCL
prog'
  where
    operations :: Py.Operations Imp.OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Py.Operations
        { opsCompiler :: OpCompiler OpenCL ()
Py.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsWriteScalar :: WriteScalar OpenCL ()
Py.opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar,
          opsReadScalar :: ReadScalar OpenCL ()
Py.opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar,
          opsAllocate :: Allocate OpenCL ()
Py.opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer,
          opsCopy :: Copy OpenCL ()
Py.opsCopy = Copy OpenCL ()
copyOpenCLMemory,
          opsEntryOutput :: EntryOutput OpenCL ()
Py.opsEntryOutput = EntryOutput OpenCL ()
packArrayOutput,
          opsEntryInput :: EntryInput OpenCL ()
Py.opsEntryInput = EntryInput OpenCL ()
unpackArrayInput
        }

-- We have many casts to 'long', because PyOpenCL may get confused at
-- the 32-bit numbers that ImpCode uses for offsets and the like.
asLong :: PyExp -> PyExp
asLong :: PyExp -> PyExp
asLong PyExp
x = SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"np.int64" [PyExp
x]

kernelConstToExp :: Imp.KernelConst -> PyExp
kernelConstToExp :: KernelConst -> PyExp
kernelConstToExp (Imp.SizeConst Name
key) =
  PyExp -> PyIdx -> PyExp
Index (SpaceId -> PyExp
Var SpaceId
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Name -> Text
forall a. Pretty a => a -> Text
prettyText Name
key)
kernelConstToExp (Imp.SizeMaxConst SizeClass
size_class) =
  SpaceId -> PyExp
Var (SpaceId -> PyExp) -> SpaceId -> PyExp
forall a b. (a -> b) -> a -> b
$ SpaceId
"self.max_" SpaceId -> SpaceId -> SpaceId
forall a. Semigroup a => a -> a -> a
<> SizeClass -> SpaceId
forall a. Pretty a => a -> SpaceId
prettyString SizeClass
size_class

compileGroupDim :: Imp.GroupDim -> Py.CompilerM op s PyExp
compileGroupDim :: forall op s. GroupDim -> CompilerM op s PyExp
compileGroupDim (Left Exp
e) = PyExp -> PyExp
asLong (PyExp -> PyExp) -> CompilerM op s PyExp -> CompilerM op s PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
e
compileGroupDim (Right KernelConst
kc) = PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PyExp
kernelConstToExp KernelConst
kc

callKernel :: Py.OpCompiler Imp.OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (Imp.GetSize VName
v Name
key) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ KernelConst -> PyExp
kernelConstToExp (KernelConst -> PyExp) -> KernelConst -> PyExp
forall a b. (a -> b) -> a -> b
$ Name -> KernelConst
Imp.SizeConst Name
key
callKernel (Imp.CmpSizeLe VName
v Name
key Exp
x) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyExp
x' <- Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
x
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"<=" (KernelConst -> PyExp
kernelConstToExp (Name -> KernelConst
Imp.SizeConst Name
key)) PyExp
x'
callKernel (Imp.GetSizeMax VName
v SizeClass
size_class) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ KernelConst -> PyExp
kernelConstToExp (KernelConst -> PyExp) -> KernelConst -> PyExp
forall a b. (a -> b) -> a -> b
$ SizeClass -> KernelConst
Imp.SizeMaxConst SizeClass
size_class
callKernel (Imp.LaunchKernel KernelSafety
safety Name
name [KernelArg]
args [Exp]
num_workgroups [GroupDim]
workgroup_size) = do
  [PyExp]
num_workgroups' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall a b.
(a -> b) -> CompilerM OpenCL () a -> CompilerM OpenCL () b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
num_workgroups
  [PyExp]
workgroup_size' <- (GroupDim -> CompilerM OpenCL () PyExp)
-> [GroupDim] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM GroupDim -> CompilerM OpenCL () PyExp
forall op s. GroupDim -> CompilerM op s PyExp
compileGroupDim [GroupDim]
workgroup_size
  let kernel_size :: [PyExp]
kernel_size = (PyExp -> PyExp -> PyExp) -> [PyExp] -> [PyExp] -> [PyExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PyExp -> PyExp -> PyExp
mult_exp [PyExp]
num_workgroups' [PyExp]
workgroup_size'
      total_elements :: PyExp
total_elements = (PyExp -> PyExp -> PyExp) -> PyExp -> [PyExp] -> PyExp
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PyExp -> PyExp -> PyExp
mult_exp (Integer -> PyExp
Integer Integer
1) [PyExp]
kernel_size
      cond :: PyExp
cond = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"!=" PyExp
total_elements (Integer -> PyExp
Integer Integer
0)

  [PyStmt]
body <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ Name
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM OpenCL () ()
forall op s.
Name
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM op s ()
launchKernel Name
name KernelSafety
safety [PyExp]
kernel_size [PyExp]
workgroup_size' [KernelArg]
args
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If PyExp
cond [PyStmt]
body []

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Ord a => a -> a -> Bool
>= KernelSafety
Imp.SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (SpaceId -> PyExp
Var SpaceId
"self.failure_is_an_option") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PrimValue -> PyExp
Py.compilePrimValue (IntValue -> PrimValue
Imp.IntValue (Int32 -> IntValue
Imp.Int32Value Int32
1))
  where
    mult_exp :: PyExp -> PyExp -> PyExp
mult_exp = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"*"

launchKernel ::
  Imp.KernelName ->
  Imp.KernelSafety ->
  [PyExp] ->
  [PyExp] ->
  [Imp.KernelArg] ->
  Py.CompilerM op s ()
launchKernel :: forall op s.
Name
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM op s ()
launchKernel Name
kernel_name KernelSafety
safety [PyExp]
kernel_dims [PyExp]
workgroup_dims [KernelArg]
args = do
  let kernel_dims' :: PyExp
kernel_dims' = [PyExp] -> PyExp
Tuple [PyExp]
kernel_dims
      workgroup_dims' :: PyExp
workgroup_dims' = [PyExp] -> PyExp
Tuple [PyExp]
workgroup_dims
      kernel_name' :: Text
kernel_name' = Text
"self." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText (Name -> Text
nameToText Name
kernel_name) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_var"
  [PyExp]
args' <- (KernelArg -> CompilerM op s PyExp)
-> [KernelArg] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM KernelArg -> CompilerM op s PyExp
forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg [KernelArg]
args
  let failure_args :: [PyExp]
failure_args =
        Int -> [PyExp] -> [PyExp]
forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
Imp.numFailureParams KernelSafety
safety)
          [ SpaceId -> PyExp
Var SpaceId
"self.global_failure",
            SpaceId -> PyExp
Var SpaceId
"self.failure_is_an_option",
            SpaceId -> PyExp
Var SpaceId
"self.global_failure_args"
          ]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      SpaceId -> [PyExp] -> PyExp
Py.simpleCall (Text -> SpaceId
T.unpack (Text -> SpaceId) -> Text -> SpaceId
forall a b. (a -> b) -> a -> b
$ Text
kernel_name' Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
".set_args") ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$
        [PyExp]
failure_args [PyExp] -> [PyExp] -> [PyExp]
forall a. [a] -> [a] -> [a]
++ [PyExp]
args'
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      SpaceId -> [PyExp] -> PyExp
Py.simpleCall
        SpaceId
"cl.enqueue_nd_range_kernel"
        [SpaceId -> PyExp
Var SpaceId
"self.queue", SpaceId -> PyExp
Var (Text -> SpaceId
T.unpack Text
kernel_name'), PyExp
kernel_dims', PyExp
workgroup_dims']
  CompilerM op s ()
forall op s. CompilerM op s ()
finishIfSynchronous
  where
    processKernelArg :: Imp.KernelArg -> Py.CompilerM op s PyExp
    processKernelArg :: forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg (Imp.ValueKArg Exp
e PrimType
bt) =
      PrimType -> PyExp -> PyExp
Py.toStorage PrimType
bt (PyExp -> PyExp) -> CompilerM op s PyExp -> CompilerM op s PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
e
    processKernelArg (Imp.MemKArg VName
v) = VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
    processKernelArg (Imp.SharedMemoryKArg (Imp.Count Exp
num_bytes)) = do
      PyExp
num_bytes' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
num_bytes
      PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"cl.LocalMemory" [PyExp -> PyExp
asLong PyExp
num_bytes']

writeOpenCLScalar :: Py.WriteScalar Imp.OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar PyExp
mem PyExp
i PrimType
bt SpaceId
"device" PyExp
val = do
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"np.array")
          [PyExp -> PyArg
Arg PyExp
val, SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var (SpaceId -> PyExp) -> SpaceId -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> SpaceId
Py.compilePrimType PrimType
bt]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
          PyExp -> PyArg
Arg PyExp
mem,
          PyExp -> PyArg
Arg PyExp
nparr,
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
        ]
writeOpenCLScalar PyExp
_ PyExp
_ PrimType
_ SpaceId
space PyExp
_ =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot write to '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

readOpenCLScalar :: Py.ReadScalar Imp.OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar PyExp
mem PyExp
i PrimType
bt SpaceId
"device" = do
  VName
val <- SpaceId -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => SpaceId -> m VName
newVName SpaceId
"read_res"
  let val' :: PyExp
val' = SpaceId -> PyExp
Var (SpaceId -> PyExp) -> SpaceId -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> SpaceId
forall a. Pretty a => a -> SpaceId
prettyString VName
val
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"np.empty")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
1,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"dtype" (SpaceId -> PyExp
Var (SpaceId -> PyExp) -> SpaceId -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> SpaceId
Py.compilePrimType PrimType
bt)
          ]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
val' PyExp
nparr
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
          PyExp -> PyArg
Arg PyExp
val',
          PyExp -> PyArg
Arg PyExp
mem,
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
        ]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"sync" [SpaceId -> PyExp
Var SpaceId
"self"]
  PyExp -> CompilerM OpenCL () PyExp
forall a. a -> CompilerM OpenCL () a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx -> PyExp
Index PyExp
val' (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
0
readOpenCLScalar PyExp
_ PyExp
_ PrimType
_ SpaceId
space =
  SpaceId -> CompilerM OpenCL () PyExp
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () PyExp)
-> SpaceId -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot read from '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' memory space."

allocateOpenCLBuffer :: Py.Allocate Imp.OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
size SpaceId
"device" =
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"opencl_alloc" [SpaceId -> PyExp
Var SpaceId
"self", PyExp
size, Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> Text
forall a. Pretty a => a -> Text
prettyText PyExp
mem]
allocateOpenCLBuffer PyExp
_ PyExp
_ SpaceId
space =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot allocate in '" SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
space SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
"' space"

copyOpenCLMemory :: Py.Copy Imp.OpenCL ()
copyOpenCLMemory :: Copy OpenCL ()
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx (Imp.Space SpaceId
"device") PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"+" PyExp
destidx PyExp
divide
      dest :: PyExp
dest = PyExp -> PyIdx -> PyExp
Index PyExp
destmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
destidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
            PyExp -> PyArg
Arg PyExp
dest,
            PyExp -> PyArg
Arg PyExp
srcmem,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
          ]
copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space SpaceId
"device") PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
_ = do
  let end :: PyExp
end = SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"+" PyExp
srcidx PyExp
nbytes
      src :: PyExp
src = PyExp -> PyIdx -> PyExp
Index (SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"createArray" [PyExp
srcmem, [PyExp] -> PyExp
List [PyExp
nbytes], SpaceId -> PyExp
Var SpaceId
"np.byte"]) (PyExp -> PyExp -> PyIdx
IdxRange PyExp
srcidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
            PyExp -> PyArg
Arg PyExp
destmem,
            PyExp -> PyArg
Arg PyExp
src,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
          ]
copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space SpaceId
"device") PyExp
srcmem PyExp
srcidx (Imp.Space SpaceId
"device") PyExp
nbytes PrimType
_ = do
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
            PyExp -> PyArg
Arg PyExp
destmem,
            PyExp -> PyArg
Arg PyExp
srcmem,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"dest_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"src_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
            SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"byte_count" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
nbytes
          ]
  CompilerM OpenCL () ()
forall op s. CompilerM op s ()
finishIfSynchronous
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
_ =
  PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM OpenCL () ()
forall op s.
PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM op s ()
Py.copyMemoryDefaultSpace PyExp
destmem PyExp
destidx PyExp
srcmem PyExp
srcidx PyExp
nbytes
copyOpenCLMemory PyExp
_ PyExp
_ Space
destspace PyExp
_ PyExp
_ Space
srcspace PyExp
_ PrimType
_ =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot copy to " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ Space -> SpaceId
forall a. Show a => a -> SpaceId
show Space
destspace SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" from " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ Space -> SpaceId
forall a. Show a => a -> SpaceId
show Space
srcspace

packArrayOutput :: Py.EntryOutput Imp.OpenCL ()
packArrayOutput :: EntryOutput OpenCL ()
packArrayOutput VName
mem SpaceId
"device" PrimType
bt Signedness
ept [DimSize]
dims = do
  PyExp
mem' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
mem
  [PyExp]
dims' <- (DimSize -> CompilerM OpenCL () PyExp)
-> [DimSize] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimSize -> CompilerM OpenCL () PyExp
forall op s. DimSize -> CompilerM op s PyExp
Py.compileDim [DimSize]
dims
  PyExp -> CompilerM OpenCL () PyExp
forall a. a -> CompilerM OpenCL () a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyArg] -> PyExp
Call
      (SpaceId -> PyExp
Var SpaceId
"cl.array.Array")
      [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
        PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
Tuple ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ [PyExp]
dims' [PyExp] -> [PyExp] -> [PyExp]
forall a. Semigroup a => a -> a -> a
<> [Integer -> PyExp
Integer Integer
0 | PrimType
bt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
Imp.Unit],
        PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var (SpaceId -> PyExp) -> SpaceId -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> SpaceId
Py.compilePrimToExtNp PrimType
bt Signedness
ept,
        SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"data" PyExp
mem'
      ]
packArrayOutput VName
_ SpaceId
sid PrimType
_ Signedness
_ [DimSize]
_ =
  SpaceId -> CompilerM OpenCL () PyExp
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () PyExp)
-> SpaceId -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot return array from " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
sid SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" space."

unpackArrayInput :: Py.EntryInput Imp.OpenCL ()
unpackArrayInput :: EntryInput OpenCL ()
unpackArrayInput PyExp
mem SpaceId
"device" PrimType
t Signedness
s [DimSize]
dims PyExp
e = do
  let type_is_ok :: PyExp
type_is_ok =
        SpaceId -> PyExp -> PyExp -> PyExp
BinOp
          SpaceId
"and"
          (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"in" (SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"type" [PyExp
e]) ([PyExp] -> PyExp
List [SpaceId -> PyExp
Var SpaceId
"np.ndarray", SpaceId -> PyExp
Var SpaceId
"cl.array.Array"]))
          (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"==" (PyExp -> SpaceId -> PyExp
Field PyExp
e SpaceId
"dtype") (SpaceId -> PyExp
Var (PrimType -> Signedness -> SpaceId
Py.compilePrimToExtNp PrimType
t Signedness
s)))
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assert PyExp
type_is_ok (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String Text
"Parameter has unexpected type"

  (DimSize -> Int32 -> CompilerM OpenCL () ())
-> [DimSize] -> [Int32] -> CompilerM OpenCL () ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (PyExp -> DimSize -> Int32 -> CompilerM OpenCL () ()
forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
Py.unpackDim PyExp
e) [DimSize]
dims [Int32
0 ..]

  let memsize' :: PyExp
memsize' = SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"np.int64" [PyExp -> SpaceId -> PyExp
Field PyExp
e SpaceId
"nbytes"]
      pyOpenCLArrayCase :: [PyStmt]
pyOpenCLArrayCase =
        [PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> SpaceId -> PyExp
Field PyExp
e SpaceId
"data"]
  [PyStmt]
numpyArrayCase <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
    Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
memsize' SpaceId
"device"
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
memsize' (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          PyExp -> [PyArg] -> PyExp
Call
            (SpaceId -> PyExp
Var SpaceId
"cl.enqueue_copy")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"self.queue",
              PyExp -> PyArg
Arg PyExp
mem,
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (SpaceId -> PyExp
Var SpaceId
"normaliseArray") [PyExp -> PyArg
Arg PyExp
e],
              SpaceId -> PyExp -> PyArg
ArgKeyword SpaceId
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ SpaceId -> PyExp
Var SpaceId
"synchronous"
            ]

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"==" (SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"type" [PyExp
e]) (SpaceId -> PyExp
Var SpaceId
"cl.array.Array"))
      [PyStmt]
pyOpenCLArrayCase
      [PyStmt]
numpyArrayCase
unpackArrayInput PyExp
_ SpaceId
sid PrimType
_ Signedness
_ [DimSize]
_ PyExp
_ =
  SpaceId -> CompilerM OpenCL () ()
forall a. HasCallStack => SpaceId -> a
error (SpaceId -> CompilerM OpenCL () ())
-> SpaceId -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ SpaceId
"Cannot accept array from " SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
sid SpaceId -> SpaceId -> SpaceId
forall a. [a] -> [a] -> [a]
++ SpaceId
" space."

ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
e PyStmt
s =
  PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (SpaceId -> PyExp -> PyExp -> PyExp
BinOp SpaceId
"!=" PyExp
e (Integer -> PyExp
Integer Integer
0)) [PyStmt
s] []

finishIfSynchronous :: Py.CompilerM op s ()
finishIfSynchronous :: forall op s. CompilerM op s ()
finishIfSynchronous =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (SpaceId -> PyExp
Var SpaceId
"synchronous") [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ SpaceId -> [PyExp] -> PyExp
Py.simpleCall SpaceId
"sync" [SpaceId -> PyExp
Var SpaceId
"self"]] []