{-# LANGUAGE QuasiQuotes #-}

-- | Various boilerplate definitions for the PyOpenCL backend.
module Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
  ( openClInit,
  )
where

import Control.Monad.Identity
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericPython qualified as Py
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.ImpCode.OpenCL
  ( ErrorMsg (..),
    ErrorMsgPart (..),
    FailureMsg (..),
    PrimType (..),
    SizeClass (..),
    errorMsgArgTypes,
    sizeDefault,
    untyped,
  )
import Futhark.CodeGen.OpenCL.Heuristics
import Futhark.Util.Pretty (prettyString, prettyText)
import NeatInterpolation (text)

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: forall a. ErrorMsg a -> Int
errorMsgNumArgs = forall (t :: * -> *) a. Foldable t => t a -> Int
length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

-- | Python code (as a string) that calls the
-- @initiatialize_opencl_object@ procedure.  Should be put in the
-- class constructor.
openClInit :: [PrimType] -> String -> M.Map Name SizeClass -> [FailureMsg] -> T.Text
openClInit :: [PrimType] -> [Char] -> Map Name SizeClass -> [FailureMsg] -> Text
openClInit [PrimType]
types [Char]
assign Map Name SizeClass
sizes [FailureMsg]
failures =
  [text|
size_heuristics=$size_heuristics
self.global_failure_args_max = $max_num_args
self.failure_msgs=$failure_msgs
program = initialise_opencl_object(self,
                                   program_src=fut_opencl_src,
                                   build_options=build_options,
                                   command_queue=command_queue,
                                   interactive=interactive,
                                   platform_pref=platform_pref,
                                   device_pref=device_pref,
                                   default_group_size=default_group_size,
                                   default_num_groups=default_num_groups,
                                   default_tile_size=default_tile_size,
                                   default_reg_tile_size=default_reg_tile_size,
                                   default_threshold=default_threshold,
                                   size_heuristics=size_heuristics,
                                   required_types=$types',
                                   user_sizes=sizes,
                                   all_sizes=$sizes')
$assign'
|]
  where
    assign' :: Text
assign' = [Char] -> Text
T.pack [Char]
assign
    size_heuristics :: Text
size_heuristics = forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ [SizeHeuristic] -> PyExp
sizeHeuristicsToPython [SizeHeuristic]
sizeHeuristicsTable
    types' :: Text
types' = forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Show a => a -> [Char]
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Pretty a => a -> [Char]
prettyString) [PrimType]
types -- Looks enough like Python.
    sizes' :: Text
sizes' = forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ Map Name SizeClass -> PyExp
sizeClassesToPython Map Name SizeClass
sizes
    max_num_args :: Text
max_num_args = forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl forall a. Ord a => a -> a -> a
max Int
0 forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. ErrorMsg a -> Int
errorMsgNumArgs forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures
    failure_msgs :: Text
failure_msgs = forall a. Pretty a => a -> Text
prettyText forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map FailureMsg -> PyExp
formatFailure [FailureMsg]
failures

formatFailure :: FailureMsg -> PyExp
formatFailure :: FailureMsg -> PyExp
formatFailure (FailureMsg (ErrorMsg [ErrorMsgPart Exp]
parts) [Char]
backtrace) =
  [Char] -> PyExp
String forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap forall {a}. ErrorMsgPart a -> [Char]
onPart [ErrorMsgPart Exp]
parts forall a. [a] -> [a] -> [a]
++ [Char]
"\n" forall a. [a] -> [a] -> [a]
++ [Char] -> [Char]
formatEscape [Char]
backtrace
  where
    formatEscape :: [Char] -> [Char]
formatEscape =
      let escapeChar :: Char -> [Char]
escapeChar Char
'{' = [Char]
"{{"
          escapeChar Char
'}' = [Char]
"}}"
          escapeChar Char
c = [Char
c]
       in forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Char -> [Char]
escapeChar

    onPart :: ErrorMsgPart a -> [Char]
onPart (ErrorString Text
s) = [Char] -> [Char]
formatEscape forall a b. (a -> b) -> a -> b
$ Text -> [Char]
T.unpack Text
s
    onPart ErrorVal {} = [Char]
"{}"

sizeClassesToPython :: M.Map Name SizeClass -> PyExp
sizeClassesToPython :: Map Name SizeClass -> PyExp
sizeClassesToPython = [(PyExp, PyExp)] -> PyExp
Dict forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Pretty a => (a, SizeClass) -> (PyExp, PyExp)
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [(k, a)]
M.toList
  where
    f :: (a, SizeClass) -> (PyExp, PyExp)
f (a
size_name, SizeClass
size_class) =
      ( [Char] -> PyExp
String forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString a
size_name,
        [(PyExp, PyExp)] -> PyExp
Dict
          [ ([Char] -> PyExp
String [Char]
"class", [Char] -> PyExp
String forall a b. (a -> b) -> a -> b
$ forall a. Pretty a => a -> [Char]
prettyString SizeClass
size_class),
            ( [Char] -> PyExp
String [Char]
"value",
              forall b a. b -> (a -> b) -> Maybe a -> b
maybe PyExp
None (Integer -> PyExp
Integer forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral) forall a b. (a -> b) -> a -> b
$
                SizeClass -> Maybe Int64
sizeDefault SizeClass
size_class
            )
          ]
      )

sizeHeuristicsToPython :: [SizeHeuristic] -> PyExp
sizeHeuristicsToPython :: [SizeHeuristic] -> PyExp
sizeHeuristicsToPython = [PyExp] -> PyExp
List forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SizeHeuristic -> PyExp
f
  where
    f :: SizeHeuristic -> PyExp
f (SizeHeuristic [Char]
platform_name DeviceType
device_type WhichSize
which TPrimExp Int32 DeviceInfo
what) =
      [PyExp] -> PyExp
Tuple
        [ [Char] -> PyExp
String [Char]
platform_name,
          DeviceType -> PyExp
clDeviceType DeviceType
device_type,
          PyExp
which',
          PyExp
what'
        ]
      where
        clDeviceType :: DeviceType -> PyExp
clDeviceType DeviceType
DeviceGPU = [Char] -> PyExp
Var [Char]
"cl.device_type.GPU"
        clDeviceType DeviceType
DeviceCPU = [Char] -> PyExp
Var [Char]
"cl.device_type.CPU"

        which' :: PyExp
which' = case WhichSize
which of
          WhichSize
LockstepWidth -> [Char] -> PyExp
String [Char]
"lockstep_width"
          WhichSize
NumGroups -> [Char] -> PyExp
String [Char]
"num_groups"
          WhichSize
GroupSize -> [Char] -> PyExp
String [Char]
"group_size"
          WhichSize
TileSize -> [Char] -> PyExp
String [Char]
"tile_size"
          WhichSize
RegTileSize -> [Char] -> PyExp
String [Char]
"reg_tile_size"
          WhichSize
Threshold -> [Char] -> PyExp
String [Char]
"threshold"

        what' :: PyExp
what' =
          [Char] -> PyExp -> PyExp
Lambda [Char]
"device" forall a b. (a -> b) -> a -> b
$
            forall a. Identity a -> a
runIdentity forall a b. (a -> b) -> a -> b
$
              forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
Py.compilePrimExp forall {f :: * -> *}. Applicative f => DeviceInfo -> f PyExp
onLeaf forall a b. (a -> b) -> a -> b
$
                forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int32 DeviceInfo
what

        onLeaf :: DeviceInfo -> f PyExp
onLeaf (DeviceInfo [Char]
s) =
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
            [Char] -> [PyExp] -> PyExp
Py.simpleCall
              [Char]
"device.get_info"
              [[Char] -> [PyExp] -> PyExp
Py.simpleCall [Char]
"getattr" [[Char] -> PyExp
Var [Char]
"cl.device_info", [Char] -> PyExp
String [Char]
s]]