{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}

-- | Code generation for Python with OpenCL.
module Futhark.CodeGen.Backends.PyOpenCL
  ( compileProg,
  )
where

import Control.Monad
import qualified Data.Map as M
import qualified Futhark.CodeGen.Backends.GenericPython as Py
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.Backends.GenericPython.Options
import Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
import qualified Futhark.CodeGen.ImpCode.OpenCL as Imp
import qualified Futhark.CodeGen.ImpGen.OpenCL as ImpGen
import Futhark.IR.KernelsMem (KernelsMem, Prog)
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)

-- | Compile the program to Python with calls to OpenCL.
compileProg ::
  MonadFreshNames m =>
  Py.CompilerMode ->
  String ->
  Prog KernelsMem ->
  m (ImpGen.Warnings, String)
compileProg :: CompilerMode -> String -> Prog KernelsMem -> m (Warnings, String)
compileProg CompilerMode
mode String
class_name Prog KernelsMem
prog = do
  ( Warnings
ws,
    Imp.Program
      String
opencl_code
      String
opencl_prelude
      Map KernelName KernelSafety
kernels
      [PrimType]
types
      Map KernelName SizeClass
sizes
      [FailureMsg]
failures
      Definitions OpenCL
prog'
    ) <-
    Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m (Warnings, Program)
ImpGen.compileProg Prog KernelsMem
prog
  --prepare the strings for assigning the kernels and set them as global
  let assign :: String
assign =
        [String] -> String
unlines ([String] -> String) -> [String] -> String
forall a b. (a -> b) -> a -> b
$
          (KernelName -> String) -> [KernelName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map
            ( \KernelName
x ->
                PyStmt -> String
forall a. Pretty a => a -> String
pretty (PyStmt -> String) -> PyStmt -> String
forall a b. (a -> b) -> a -> b
$
                  PyExp -> PyExp -> PyStmt
Assign
                    (String -> PyExp
Var (String
"self." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
zEncodeString (KernelName -> String
nameToString KernelName
x) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_var"))
                    (String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"program." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
zEncodeString (KernelName -> String
nameToString KernelName
x))
            )
            ([KernelName] -> [String]) -> [KernelName] -> [String]
forall a b. (a -> b) -> a -> b
$ Map KernelName KernelSafety -> [KernelName]
forall k a. Map k a -> [k]
M.keys Map KernelName KernelSafety
kernels

  let defines :: [PyStmt]
defines =
        [ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"synchronous") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
False,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_platform") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_device") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_threshold") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_group_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_num_groups") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_reg_tile_size") PyExp
None,
          PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"fut_opencl_src") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
RawStringLiteral (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
opencl_prelude String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
opencl_code
        ]

  let imports :: [PyStmt]
imports =
        [ String -> Maybe String -> PyStmt
Import String
"sys" Maybe String
forall a. Maybe a
Nothing,
          String -> Maybe String -> PyStmt
Import String
"numpy" (Maybe String -> PyStmt) -> Maybe String -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
"np",
          String -> Maybe String -> PyStmt
Import String
"ctypes" (Maybe String -> PyStmt) -> Maybe String -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> Maybe String
forall a. a -> Maybe a
Just String
"ct",
          String -> PyStmt
Escape String
openClPrelude,
          String -> Maybe String -> PyStmt
Import String
"pyopencl.array" Maybe String
forall a. Maybe a
Nothing,
          String -> Maybe String -> PyStmt
Import String
"time" Maybe String
forall a. Maybe a
Nothing
        ]

  let constructor :: Constructor
constructor =
        [String] -> [PyStmt] -> Constructor
Py.Constructor
          [ String
"self",
            String
"command_queue=None",
            String
"interactive=False",
            String
"platform_pref=preferred_platform",
            String
"device_pref=preferred_device",
            String
"default_group_size=default_group_size",
            String
"default_num_groups=default_num_groups",
            String
"default_tile_size=default_tile_size",
            String
"default_reg_tile_size=default_reg_tile_size",
            String
"default_threshold=default_threshold",
            String
"sizes=sizes"
          ]
          [String -> PyStmt
Escape (String -> PyStmt) -> String -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PrimType]
-> String -> Map KernelName SizeClass -> [FailureMsg] -> String
openClInit [PrimType]
types String
assign Map KernelName SizeClass
sizes [FailureMsg]
failures]
      options :: [Option]
options =
        [ Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"platform",
              optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'p',
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_platform") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"device",
              optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'd',
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"str",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"preferred_device") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-threshold",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_threshold") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-group-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_group_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-num-groups",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_num_groups") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-tile-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_tile_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"default-reg-tile-size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"int",
              optionAction :: [PyStmt]
optionAction =
                [PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"default_reg_tile_size") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"optarg"]
            },
          Option :: String -> Maybe Char -> OptionArgument -> [PyStmt] -> Option
Option
            { optionLongName :: String
optionLongName = String
"size",
              optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
              optionArgument :: OptionArgument
optionArgument = String -> OptionArgument
RequiredArgument String
"size_assignment",
              optionAction :: [PyStmt]
optionAction =
                [ PyExp -> PyExp -> PyStmt
Assign
                    ( PyExp -> PyIdx -> PyExp
Index
                        (String -> PyExp
Var String
"sizes")
                        ( PyExp -> PyIdx
IdxExp
                            ( PyExp -> PyIdx -> PyExp
Index
                                (String -> PyExp
Var String
"optarg")
                                (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
0))
                            )
                        )
                    )
                    (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"optarg") (PyExp -> PyIdx
IdxExp (Integer -> PyExp
Integer Integer
1)))
                ]
            }
        ]

  (Warnings
ws,)
    (String -> (Warnings, String)) -> m String -> m (Warnings, String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CompilerMode
-> String
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations OpenCL ()
-> ()
-> [PyStmt]
-> [Option]
-> Definitions OpenCL
-> m String
forall (m :: * -> *) op s.
MonadFreshNames m =>
CompilerMode
-> String
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations op s
-> s
-> [PyStmt]
-> [Option]
-> Definitions op
-> m String
Py.compileProg
      CompilerMode
mode
      String
class_name
      Constructor
constructor
      [PyStmt]
imports
      [PyStmt]
defines
      Operations OpenCL ()
operations
      ()
      [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]]
      [Option]
options
      Definitions OpenCL
prog'
  where
    operations :: Py.Operations Imp.OpenCL ()
    operations :: Operations OpenCL ()
operations =
      Operations :: forall op s.
WriteScalar op s
-> ReadScalar op s
-> Allocate op s
-> Copy op s
-> StaticArray op s
-> OpCompiler op s
-> EntryOutput op s
-> EntryInput op s
-> Operations op s
Py.Operations
        { opsCompiler :: OpCompiler OpenCL ()
Py.opsCompiler = OpCompiler OpenCL ()
callKernel,
          opsWriteScalar :: WriteScalar OpenCL ()
Py.opsWriteScalar = WriteScalar OpenCL ()
writeOpenCLScalar,
          opsReadScalar :: ReadScalar OpenCL ()
Py.opsReadScalar = ReadScalar OpenCL ()
readOpenCLScalar,
          opsAllocate :: Allocate OpenCL ()
Py.opsAllocate = Allocate OpenCL ()
allocateOpenCLBuffer,
          opsCopy :: Copy OpenCL ()
Py.opsCopy = Copy OpenCL ()
copyOpenCLMemory,
          opsStaticArray :: StaticArray OpenCL ()
Py.opsStaticArray = StaticArray OpenCL ()
staticOpenCLArray,
          opsEntryOutput :: EntryOutput OpenCL ()
Py.opsEntryOutput = EntryOutput OpenCL ()
packArrayOutput,
          opsEntryInput :: EntryInput OpenCL ()
Py.opsEntryInput = EntryInput OpenCL ()
unpackArrayInput
        }

-- We have many casts to 'long', because PyOpenCL may get confused at
-- the 32-bit numbers that ImpCode uses for offsets and the like.
asLong :: PyExp -> PyExp
asLong :: PyExp -> PyExp
asLong PyExp
x = String -> [PyExp] -> PyExp
Py.simpleCall String
"np.int64" [PyExp
x]

callKernel :: Py.OpCompiler Imp.OpenCL ()
callKernel :: OpCompiler OpenCL ()
callKernel (Imp.GetSize VName
v KernelName
key) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ KernelName -> String
forall a. Pretty a => a -> String
pretty KernelName
key)
callKernel (Imp.CmpSizeLe VName
v KernelName
key Exp
x) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyExp
x' <- Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
x
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> PyExp -> PyExp -> PyExp
BinOp String
"<=" (PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ KernelName -> String
forall a. Pretty a => a -> String
pretty KernelName
key)) PyExp
x'
callKernel (Imp.GetSizeMax VName
v SizeClass
size_class) = do
  PyExp
v' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"self.max_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ SizeClass -> String
forall a. Pretty a => a -> String
pretty SizeClass
size_class
callKernel (Imp.LaunchKernel KernelSafety
safety KernelName
name [KernelArg]
args [Exp]
num_workgroups [Exp]
workgroup_size) = do
  [PyExp]
num_workgroups' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
num_workgroups
  [PyExp]
workgroup_size' <- (Exp -> CompilerM OpenCL () PyExp)
-> [Exp] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((PyExp -> PyExp)
-> CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PyExp -> PyExp
asLong (CompilerM OpenCL () PyExp -> CompilerM OpenCL () PyExp)
-> (Exp -> CompilerM OpenCL () PyExp)
-> Exp
-> CompilerM OpenCL () PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> CompilerM OpenCL () PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp) [Exp]
workgroup_size
  let kernel_size :: [PyExp]
kernel_size = (PyExp -> PyExp -> PyExp) -> [PyExp] -> [PyExp] -> [PyExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PyExp -> PyExp -> PyExp
mult_exp [PyExp]
num_workgroups' [PyExp]
workgroup_size'
      total_elements :: PyExp
total_elements = (PyExp -> PyExp -> PyExp) -> PyExp -> [PyExp] -> PyExp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PyExp -> PyExp -> PyExp
mult_exp (Integer -> PyExp
Integer Integer
1) [PyExp]
kernel_size
      cond :: PyExp
cond = String -> PyExp -> PyExp -> PyExp
BinOp String
"!=" PyExp
total_elements (Integer -> PyExp
Integer Integer
0)

  [PyStmt]
body <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ KernelName
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM OpenCL () ()
forall op s.
KernelName
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM op s ()
launchKernel KernelName
name KernelSafety
safety [PyExp]
kernel_size [PyExp]
workgroup_size' [KernelArg]
args
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If PyExp
cond [PyStmt]
body []

  Bool -> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (KernelSafety
safety KernelSafety -> KernelSafety -> Bool
forall a. Ord a => a -> a -> Bool
>= KernelSafety
Imp.SafetyFull) (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
"self.failure_is_an_option") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PrimValue -> PyExp
Py.compilePrimValue (IntValue -> PrimValue
Imp.IntValue (Int32 -> IntValue
Imp.Int32Value Int32
1))
  where
    mult_exp :: PyExp -> PyExp -> PyExp
mult_exp = String -> PyExp -> PyExp -> PyExp
BinOp String
"*"

launchKernel ::
  Imp.KernelName ->
  Imp.KernelSafety ->
  [PyExp] ->
  [PyExp] ->
  [Imp.KernelArg] ->
  Py.CompilerM op s ()
launchKernel :: KernelName
-> KernelSafety
-> [PyExp]
-> [PyExp]
-> [KernelArg]
-> CompilerM op s ()
launchKernel KernelName
kernel_name KernelSafety
safety [PyExp]
kernel_dims [PyExp]
workgroup_dims [KernelArg]
args = do
  let kernel_dims' :: PyExp
kernel_dims' = [PyExp] -> PyExp
Tuple [PyExp]
kernel_dims
      workgroup_dims' :: PyExp
workgroup_dims' = [PyExp] -> PyExp
Tuple [PyExp]
workgroup_dims
      kernel_name' :: String
kernel_name' = String
"self." String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> String
zEncodeString (KernelName -> String
nameToString KernelName
kernel_name) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_var"
  [PyExp]
args' <- (KernelArg -> CompilerM op s PyExp)
-> [KernelArg] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelArg -> CompilerM op s PyExp
forall op s. KernelArg -> CompilerM op s PyExp
processKernelArg [KernelArg]
args
  let failure_args :: [PyExp]
failure_args =
        Int -> [PyExp] -> [PyExp]
forall a. Int -> [a] -> [a]
take
          (KernelSafety -> Int
Imp.numFailureParams KernelSafety
safety)
          [ String -> PyExp
Var String
"self.global_failure",
            String -> PyExp
Var String
"self.failure_is_an_option",
            String -> PyExp
Var String
"self.global_failure_args"
          ]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> [PyExp] -> PyExp
Py.simpleCall (String
kernel_name' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
".set_args") ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$
        [PyExp]
failure_args [PyExp] -> [PyExp] -> [PyExp]
forall a. [a] -> [a] -> [a]
++ [PyExp]
args'
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> [PyExp] -> PyExp
Py.simpleCall
        String
"cl.enqueue_nd_range_kernel"
        [String -> PyExp
Var String
"self.queue", String -> PyExp
Var String
kernel_name', PyExp
kernel_dims', PyExp
workgroup_dims']
  CompilerM op s ()
forall op s. CompilerM op s ()
finishIfSynchronous
  where
    processKernelArg :: Imp.KernelArg -> Py.CompilerM op s PyExp
    processKernelArg :: KernelArg -> CompilerM op s PyExp
processKernelArg (Imp.ValueKArg Exp
e PrimType
bt) = do
      PyExp
e' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
e
      PyExp -> CompilerM op s PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall (PrimType -> String
Py.compilePrimToNp PrimType
bt) [PyExp
e']
    processKernelArg (Imp.MemKArg VName
v) = VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
v
    processKernelArg (Imp.SharedMemoryKArg (Imp.Count Exp
num_bytes)) = do
      PyExp
num_bytes' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
Py.compileExp Exp
num_bytes
      PyExp -> CompilerM op s PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"cl.LocalMemory" [PyExp -> PyExp
asLong PyExp
num_bytes']

writeOpenCLScalar :: Py.WriteScalar Imp.OpenCL ()
writeOpenCLScalar :: WriteScalar OpenCL ()
writeOpenCLScalar PyExp
mem PyExp
i PrimType
bt String
"device" PyExp
val = do
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"np.array")
          [PyExp -> PyArg
Arg PyExp
val, String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimType PrimType
bt]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (String -> PyExp
Var String
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
          PyExp -> PyArg
Arg PyExp
mem,
          PyExp -> PyArg
Arg PyExp
nparr,
          String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp -> PyExp -> PyExp
BinOp String
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
        ]
writeOpenCLScalar PyExp
_ PyExp
_ PrimType
_ String
space PyExp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot write to '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

readOpenCLScalar :: Py.ReadScalar Imp.OpenCL ()
readOpenCLScalar :: ReadScalar OpenCL ()
readOpenCLScalar PyExp
mem PyExp
i PrimType
bt String
"device" = do
  VName
val <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"read_res"
  let val' :: PyExp
val' = String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
val
  let nparr :: PyExp
nparr =
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"np.empty")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
1,
            String -> PyExp -> PyArg
ArgKeyword String
"dtype" (String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimType PrimType
bt)
          ]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
val' PyExp
nparr
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> [PyArg] -> PyExp
Call
        (String -> PyExp
Var String
"cl.enqueue_copy")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
          PyExp -> PyArg
Arg PyExp
val',
          PyExp -> PyArg
Arg PyExp
mem,
          String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp -> PyExp -> PyExp
BinOp String
"*" (PyExp -> PyExp
asLong PyExp
i) (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt),
          String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
        ]
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]
  PyExp -> CompilerM OpenCL () PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx -> PyExp
Index PyExp
val' (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
0
readOpenCLScalar PyExp
_ PyExp
_ PrimType
_ String
space =
  String -> CompilerM OpenCL () PyExp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () PyExp)
-> String -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot read from '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' memory space."

allocateOpenCLBuffer :: Py.Allocate Imp.OpenCL ()
allocateOpenCLBuffer :: Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
size String
"device" =
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      String -> [PyExp] -> PyExp
Py.simpleCall String
"opencl_alloc" [String -> PyExp
Var String
"self", PyExp
size, String -> PyExp
String (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> String
forall a. Pretty a => a -> String
pretty PyExp
mem]
allocateOpenCLBuffer PyExp
_ PyExp
_ String
space =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate in '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"' space"

copyOpenCLMemory :: Py.Copy Imp.OpenCL ()
copyOpenCLMemory :: Copy OpenCL ()
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx (Imp.Space String
"device") PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = String -> PyExp -> PyExp -> PyExp
BinOp String
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = String -> PyExp -> PyExp -> PyExp
BinOp String
"+" PyExp
destidx PyExp
divide
      dest :: PyExp
dest = PyExp -> PyIdx -> PyExp
Index PyExp
destmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
destidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
            PyExp -> PyArg
Arg PyExp
dest,
            PyExp -> PyArg
Arg PyExp
srcmem,
            String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
            String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
          ]
copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space String
"device") PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
bt = do
  let divide :: PyExp
divide = String -> PyExp -> PyExp -> PyExp
BinOp String
"//" PyExp
nbytes (Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
bt)
      end :: PyExp
end = String -> PyExp -> PyExp -> PyExp
BinOp String
"+" PyExp
srcidx PyExp
divide
      src :: PyExp
src = PyExp -> PyIdx -> PyExp
Index PyExp
srcmem (PyExp -> PyExp -> PyIdx
IdxRange PyExp
srcidx PyExp
end)
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
            PyExp -> PyArg
Arg PyExp
destmem,
            PyExp -> PyArg
Arg PyExp
src,
            String -> PyExp -> PyArg
ArgKeyword String
"device_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
            String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
          ]
copyOpenCLMemory PyExp
destmem PyExp
destidx (Imp.Space String
"device") PyExp
srcmem PyExp
srcidx (Imp.Space String
"device") PyExp
nbytes PrimType
_ = do
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
nbytes (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> [PyArg] -> PyExp
Call
          (String -> PyExp
Var String
"cl.enqueue_copy")
          [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
            PyExp -> PyArg
Arg PyExp
destmem,
            PyExp -> PyArg
Arg PyExp
srcmem,
            String -> PyExp -> PyArg
ArgKeyword String
"dest_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
destidx,
            String -> PyExp -> PyArg
ArgKeyword String
"src_offset" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
srcidx,
            String -> PyExp -> PyArg
ArgKeyword String
"byte_count" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp
asLong PyExp
nbytes
          ]
  CompilerM OpenCL () ()
forall op s. CompilerM op s ()
finishIfSynchronous
copyOpenCLMemory PyExp
destmem PyExp
destidx Space
Imp.DefaultSpace PyExp
srcmem PyExp
srcidx Space
Imp.DefaultSpace PyExp
nbytes PrimType
_ =
  PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM OpenCL () ()
forall op s.
PyExp -> PyExp -> PyExp -> PyExp -> PyExp -> CompilerM op s ()
Py.copyMemoryDefaultSpace PyExp
destmem PyExp
destidx PyExp
srcmem PyExp
srcidx PyExp
nbytes
copyOpenCLMemory PyExp
_ PyExp
_ Space
destspace PyExp
_ PyExp
_ Space
srcspace PyExp
_ PrimType
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot copy to " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
destspace String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Space -> String
forall a. Show a => a -> String
show Space
srcspace

staticOpenCLArray :: Py.StaticArray Imp.OpenCL ()
staticOpenCLArray :: StaticArray OpenCL ()
staticOpenCLArray VName
name String
"device" PrimType
t ArrayContents
vs = do
  (PyStmt -> CompilerM OpenCL () ())
-> [PyStmt] -> CompilerM OpenCL () ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.atInit ([PyStmt] -> CompilerM OpenCL () ())
-> (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () ()
-> CompilerM OpenCL () ()
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () ())
-> CompilerM OpenCL () () -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ do
    -- Create host-side Numpy array with intended values.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
name') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ case ArrayContents
vs of
        Imp.ArrayValues [PrimValue]
vs' ->
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"np.array")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (PrimValue -> PyExp) -> [PrimValue] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PyExp
Py.compilePrimValue [PrimValue]
vs',
              String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimToNp PrimType
t
            ]
        Imp.ArrayZeros Int
n ->
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"np.zeros")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n,
              String -> PyExp -> PyArg
ArgKeyword String
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> String
Py.compilePrimToNp PrimType
t
            ]

    let num_elems :: Int
num_elems = case ArrayContents
vs of
          Imp.ArrayValues [PrimValue]
vs' -> [PrimValue] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimValue]
vs'
          Imp.ArrayZeros Int
n -> Int
n

    -- Create memory block on the device.
    VName
static_mem <- String -> CompilerM OpenCL () VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"static_mem"
    let size :: PyExp
size = Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
num_elems Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* PrimType -> Integer
forall a. Num a => PrimType -> a
Imp.primByteSize PrimType
t
    Allocate OpenCL ()
allocateOpenCLBuffer (String -> PyExp
Var (VName -> String
Py.compileName VName
static_mem)) PyExp
size String
"device"

    -- Copy Numpy array to the device memory block.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
size (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"cl.enqueue_copy")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
Py.compileName VName
static_mem,
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"normaliseArray") [PyExp -> PyArg
Arg (String -> PyExp
Var String
name')],
              String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
            ]

    -- Store the memory block for later reference.
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyExp -> PyStmt
Assign (PyExp -> String -> PyExp
Field (String -> PyExp
Var String
"self") String
name') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> String
Py.compileName VName
static_mem

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign (String -> PyExp
Var String
name') (PyExp -> String -> PyExp
Field (String -> PyExp
Var String
"self") String
name')
  where
    name' :: String
name' = VName -> String
Py.compileName VName
name
staticOpenCLArray VName
_ String
space PrimType
_ ArrayContents
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"PyOpenCL backend cannot create static array in memory space '" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
space String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"'"

packArrayOutput :: Py.EntryOutput Imp.OpenCL ()
packArrayOutput :: EntryOutput OpenCL ()
packArrayOutput VName
mem String
"device" PrimType
bt Signedness
ept [DimSize]
dims = do
  PyExp
mem' <- VName -> CompilerM OpenCL () PyExp
forall op s. VName -> CompilerM op s PyExp
Py.compileVar VName
mem
  [PyExp]
dims' <- (DimSize -> CompilerM OpenCL () PyExp)
-> [DimSize] -> CompilerM OpenCL () [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimSize -> CompilerM OpenCL () PyExp
forall op s. DimSize -> CompilerM op s PyExp
Py.compileDim [DimSize]
dims
  PyExp -> CompilerM OpenCL () PyExp
forall (m :: * -> *) a. Monad m => a -> m a
return (PyExp -> CompilerM OpenCL () PyExp)
-> PyExp -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyArg] -> PyExp
Call
      (String -> PyExp
Var String
"cl.array.Array")
      [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
        PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
Tuple [PyExp]
dims',
        PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> String
Py.compilePrimTypeExt PrimType
bt Signedness
ept,
        String -> PyExp -> PyArg
ArgKeyword String
"data" PyExp
mem'
      ]
packArrayOutput VName
_ String
sid PrimType
_ Signedness
_ [DimSize]
_ =
  String -> CompilerM OpenCL () PyExp
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () PyExp)
-> String -> CompilerM OpenCL () PyExp
forall a b. (a -> b) -> a -> b
$ String
"Cannot return array from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" space."

unpackArrayInput :: Py.EntryInput Imp.OpenCL ()
unpackArrayInput :: EntryInput OpenCL ()
unpackArrayInput PyExp
mem String
"device" PrimType
t Signedness
s [DimSize]
dims PyExp
e = do
  let type_is_ok :: PyExp
type_is_ok =
        String -> PyExp -> PyExp -> PyExp
BinOp
          String
"and"
          (String -> PyExp -> PyExp -> PyExp
BinOp String
"in" (String -> [PyExp] -> PyExp
Py.simpleCall String
"type" [PyExp
e]) ([PyExp] -> PyExp
List [String -> PyExp
Var String
"np.ndarray", String -> PyExp
Var String
"cl.array.Array"]))
          (String -> PyExp -> PyExp -> PyExp
BinOp String
"==" (PyExp -> String -> PyExp
Field PyExp
e String
"dtype") (String -> PyExp
Var (PrimType -> Signedness -> String
Py.compilePrimToExtNp PrimType
t Signedness
s)))
  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assert PyExp
type_is_ok (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> PyExp
String String
"Parameter has unexpected type"

  (DimSize -> Int32 -> CompilerM OpenCL () ())
-> [DimSize] -> [Int32] -> CompilerM OpenCL () ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (PyExp -> DimSize -> Int32 -> CompilerM OpenCL () ()
forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
Py.unpackDim PyExp
e) [DimSize]
dims [Int32
0 ..]

  let memsize' :: PyExp
memsize' = String -> [PyExp] -> PyExp
Py.simpleCall String
"np.int64" [PyExp -> String -> PyExp
Field PyExp
e String
"nbytes"]
      pyOpenCLArrayCase :: [PyStmt]
pyOpenCLArrayCase =
        [PyExp -> PyExp -> PyStmt
Assign PyExp
mem (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ PyExp -> String -> PyExp
Field PyExp
e String
"data"]
  [PyStmt]
numpyArrayCase <- CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
Py.collect (CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt])
-> CompilerM OpenCL () () -> CompilerM OpenCL () [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
    Allocate OpenCL ()
allocateOpenCLBuffer PyExp
mem PyExp
memsize' String
"device"
    PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
      PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
memsize' (PyStmt -> PyStmt) -> PyStmt -> PyStmt
forall a b. (a -> b) -> a -> b
$
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          PyExp -> [PyArg] -> PyExp
Call
            (String -> PyExp
Var String
"cl.enqueue_copy")
            [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"self.queue",
              PyExp -> PyArg
Arg PyExp
mem,
              PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyArg] -> PyExp
Call (String -> PyExp
Var String
"normaliseArray") [PyExp -> PyArg
Arg PyExp
e],
              String -> PyExp -> PyArg
ArgKeyword String
"is_blocking" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ String -> PyExp
Var String
"synchronous"
            ]

  PyStmt -> CompilerM OpenCL () ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM OpenCL () ())
-> PyStmt -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      (String -> PyExp -> PyExp -> PyExp
BinOp String
"==" (String -> [PyExp] -> PyExp
Py.simpleCall String
"type" [PyExp
e]) (String -> PyExp
Var String
"cl.array.Array"))
      [PyStmt]
pyOpenCLArrayCase
      [PyStmt]
numpyArrayCase
unpackArrayInput PyExp
_ String
sid PrimType
_ Signedness
_ [DimSize]
_ PyExp
_ =
  String -> CompilerM OpenCL () ()
forall a. HasCallStack => String -> a
error (String -> CompilerM OpenCL () ())
-> String -> CompilerM OpenCL () ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot accept array from " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
sid String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" space."

ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize :: PyExp -> PyStmt -> PyStmt
ifNotZeroSize PyExp
e PyStmt
s =
  PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp -> PyExp -> PyExp
BinOp String
"!=" PyExp
e (Integer -> PyExp
Integer Integer
0)) [PyStmt
s] []

finishIfSynchronous :: Py.CompilerM op s ()
finishIfSynchronous :: CompilerM op s ()
finishIfSynchronous =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
Py.stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If (String -> PyExp
Var String
"synchronous") [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ String -> [PyExp] -> PyExp
Py.simpleCall String
"sync" [String -> PyExp
Var String
"self"]] []