{-# LANGUAGE QuasiQuotes #-}

-- | Various boilerplate definitions for the PyOpenCL backend.
module Futhark.CodeGen.Backends.PyOpenCL.Boilerplate
  ( openClInit,
  )
where

import Control.Monad.Identity
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericPython qualified as Py
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.ImpCode.OpenCL
  ( ErrorMsg (..),
    ErrorMsgPart (..),
    FailureMsg (..),
    KernelConst (..),
    KernelConstExp,
    ParamMap,
    PrimType (..),
    errorMsgArgTypes,
    sizeDefault,
    untyped,
  )
import Futhark.CodeGen.OpenCL.Heuristics
import Futhark.Util.Pretty (prettyString, prettyText)
import NeatInterpolation (text)

errorMsgNumArgs :: ErrorMsg a -> Int
errorMsgNumArgs :: forall a. ErrorMsg a -> Int
errorMsgNumArgs = [PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([PrimType] -> Int)
-> (ErrorMsg a -> [PrimType]) -> ErrorMsg a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorMsg a -> [PrimType]
forall a. ErrorMsg a -> [PrimType]
errorMsgArgTypes

getParamByKey :: Name -> PyExp
getParamByKey :: Name -> PyExp
getParamByKey Name
key = PyExp -> PyIdx -> PyExp
Index (String -> PyExp
Var String
"self.sizes") (PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Name -> Text
forall a. Pretty a => a -> Text
prettyText Name
key)

kernelConstToExp :: KernelConst -> PyExp
kernelConstToExp :: KernelConst -> PyExp
kernelConstToExp (SizeConst Name
key SizeClass
_) =
  Name -> PyExp
getParamByKey Name
key
kernelConstToExp (SizeMaxConst SizeClass
size_class) =
  String -> PyExp
Var (String -> PyExp) -> String -> PyExp
forall a b. (a -> b) -> a -> b
$ String
"self.max_" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> SizeClass -> String
forall a. Pretty a => a -> String
prettyString SizeClass
size_class

compileConstExp :: KernelConstExp -> PyExp
compileConstExp :: KernelConstExp -> PyExp
compileConstExp KernelConstExp
e = Identity PyExp -> PyExp
forall a. Identity a -> a
runIdentity (Identity PyExp -> PyExp) -> Identity PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ (KernelConst -> Identity PyExp) -> KernelConstExp -> Identity PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
Py.compilePrimExp (PyExp -> Identity PyExp
forall a. a -> Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> Identity PyExp)
-> (KernelConst -> PyExp) -> KernelConst -> Identity PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConst -> PyExp
kernelConstToExp) KernelConstExp
e

-- | Python code (as a string) that calls the
-- @initiatialize_opencl_object@ procedure.  Should be put in the
-- class constructor.
openClInit :: [(Name, KernelConstExp)] -> [PrimType] -> String -> ParamMap -> [FailureMsg] -> T.Text
openClInit :: [(Name, KernelConstExp)]
-> [PrimType] -> String -> ParamMap -> [FailureMsg] -> Text
openClInit [(Name, KernelConstExp)]
constants [PrimType]
types String
assign ParamMap
sizes [FailureMsg]
failures =
  [text|
size_heuristics=$size_heuristics
self.global_failure_args_max = $max_num_args
self.failure_msgs=$failure_msgs
constants = $constants'
program = initialise_opencl_object(self,
                                   program_src=fut_opencl_src,
                                   build_options=build_options,
                                   command_queue=command_queue,
                                   interactive=interactive,
                                   platform_pref=platform_pref,
                                   device_pref=device_pref,
                                   default_group_size=default_group_size,
                                   default_num_groups=default_num_groups,
                                   default_tile_size=default_tile_size,
                                   default_reg_tile_size=default_reg_tile_size,
                                   default_threshold=default_threshold,
                                   size_heuristics=size_heuristics,
                                   required_types=$types',
                                   user_sizes=sizes,
                                   all_sizes=$sizes',
                                   constants=constants)
$assign'
|]
  where
    assign' :: Text
assign' = String -> Text
T.pack String
assign
    size_heuristics :: Text
size_heuristics = PyExp -> Text
forall a. Pretty a => a -> Text
prettyText (PyExp -> Text) -> PyExp -> Text
forall a b. (a -> b) -> a -> b
$ [SizeHeuristic] -> PyExp
sizeHeuristicsToPython [SizeHeuristic]
sizeHeuristicsTable
    types' :: Text
types' = [String] -> Text
forall a. Pretty a => a -> Text
prettyText ([String] -> Text) -> [String] -> Text
forall a b. (a -> b) -> a -> b
$ (PrimType -> String) -> [PrimType] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (String -> String
forall a. Show a => a -> String
show (String -> String) -> (PrimType -> String) -> PrimType -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> String
forall a. Pretty a => a -> String
prettyString) [PrimType]
types -- Looks enough like Python.
    sizes' :: Text
sizes' = PyExp -> Text
forall a. Pretty a => a -> Text
prettyText (PyExp -> Text) -> PyExp -> Text
forall a b. (a -> b) -> a -> b
$ ParamMap -> PyExp
sizeClassesToPython ParamMap
sizes
    max_num_args :: Text
max_num_args = Int -> Text
forall a. Pretty a => a -> Text
prettyText (Int -> Text) -> Int -> Text
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
0 ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (FailureMsg -> Int) -> [FailureMsg] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (ErrorMsg Exp -> Int
forall a. ErrorMsg a -> Int
errorMsgNumArgs (ErrorMsg Exp -> Int)
-> (FailureMsg -> ErrorMsg Exp) -> FailureMsg -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FailureMsg -> ErrorMsg Exp
failureError) [FailureMsg]
failures
    failure_msgs :: Text
failure_msgs = PyExp -> Text
forall a. Pretty a => a -> Text
prettyText (PyExp -> Text) -> PyExp -> Text
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (FailureMsg -> PyExp) -> [FailureMsg] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map FailureMsg -> PyExp
formatFailure [FailureMsg]
failures
    onConstant :: (Name, KernelConstExp) -> PyExp
onConstant (Name
name, KernelConstExp
e) =
      [PyExp] -> PyExp
Tuple
        [ Text -> PyExp
String (Name -> Text
nameToText Name
name),
          String -> PyExp -> PyExp
Lambda String
"" (KernelConstExp -> PyExp
compileConstExp KernelConstExp
e)
        ]
    constants' :: Text
constants' = PyExp -> Text
forall a. Pretty a => a -> Text
prettyText (PyExp -> Text) -> PyExp -> Text
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ ((Name, KernelConstExp) -> PyExp)
-> [(Name, KernelConstExp)] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map (Name, KernelConstExp) -> PyExp
onConstant [(Name, KernelConstExp)]
constants

formatFailure :: FailureMsg -> PyExp
formatFailure :: FailureMsg -> PyExp
formatFailure (FailureMsg (ErrorMsg [ErrorMsgPart Exp]
parts) String
backtrace) =
  Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ((ErrorMsgPart Exp -> Text) -> [ErrorMsgPart Exp] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map ErrorMsgPart Exp -> Text
forall {a}. ErrorMsgPart a -> Text
onPart [ErrorMsgPart Exp]
parts) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> String -> Text
formatEscape String
backtrace
  where
    formatEscape :: String -> Text
formatEscape =
      let escapeChar :: Char -> Text
escapeChar Char
'{' = Text
"{{"
          escapeChar Char
'}' = Text
"}}"
          escapeChar Char
c = Char -> Text
T.singleton Char
c
       in [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat ([Text] -> Text) -> (String -> [Text]) -> String -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Text) -> String -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map Char -> Text
escapeChar

    onPart :: ErrorMsgPart a -> Text
onPart (ErrorString Text
s) = String -> Text
formatEscape (String -> Text) -> String -> Text
forall a b. (a -> b) -> a -> b
$ Text -> String
T.unpack Text
s
    onPart ErrorVal {} = Text
"{}"

sizeClassesToPython :: ParamMap -> PyExp
sizeClassesToPython :: ParamMap -> PyExp
sizeClassesToPython = [(PyExp, PyExp)] -> PyExp
Dict ([(PyExp, PyExp)] -> PyExp)
-> (ParamMap -> [(PyExp, PyExp)]) -> ParamMap -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Name, (SizeClass, Set Name)) -> (PyExp, PyExp))
-> [(Name, (SizeClass, Set Name))] -> [(PyExp, PyExp)]
forall a b. (a -> b) -> [a] -> [b]
map (Name, (SizeClass, Set Name)) -> (PyExp, PyExp)
forall {a} {b}. Pretty a => (a, (SizeClass, b)) -> (PyExp, PyExp)
f ([(Name, (SizeClass, Set Name))] -> [(PyExp, PyExp)])
-> (ParamMap -> [(Name, (SizeClass, Set Name))])
-> ParamMap
-> [(PyExp, PyExp)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ParamMap -> [(Name, (SizeClass, Set Name))]
forall k a. Map k a -> [(k, a)]
M.toList
  where
    f :: (a, (SizeClass, b)) -> (PyExp, PyExp)
f (a
size_name, (SizeClass
size_class, b
_)) =
      ( Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ a -> Text
forall a. Pretty a => a -> Text
prettyText a
size_name,
        [(PyExp, PyExp)] -> PyExp
Dict
          [ (Text -> PyExp
String Text
"class", Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ SizeClass -> Text
forall a. Pretty a => a -> Text
prettyText SizeClass
size_class),
            ( Text -> PyExp
String Text
"value",
              PyExp -> (Int64 -> PyExp) -> Maybe Int64 -> PyExp
forall b a. b -> (a -> b) -> Maybe a -> b
maybe PyExp
None (Integer -> PyExp
Integer (Integer -> PyExp) -> (Int64 -> Integer) -> Int64 -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral) (Maybe Int64 -> PyExp) -> Maybe Int64 -> PyExp
forall a b. (a -> b) -> a -> b
$
                SizeClass -> Maybe Int64
sizeDefault SizeClass
size_class
            )
          ]
      )

sizeHeuristicsToPython :: [SizeHeuristic] -> PyExp
sizeHeuristicsToPython :: [SizeHeuristic] -> PyExp
sizeHeuristicsToPython = [PyExp] -> PyExp
List ([PyExp] -> PyExp)
-> ([SizeHeuristic] -> [PyExp]) -> [SizeHeuristic] -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SizeHeuristic -> PyExp) -> [SizeHeuristic] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map SizeHeuristic -> PyExp
f
  where
    f :: SizeHeuristic -> PyExp
f (SizeHeuristic String
platform_name DeviceType
device_type WhichSize
which TPrimExp Int32 DeviceInfo
what) =
      [PyExp] -> PyExp
Tuple
        [ Text -> PyExp
String (String -> Text
T.pack String
platform_name),
          DeviceType -> PyExp
clDeviceType DeviceType
device_type,
          PyExp
which',
          PyExp
what'
        ]
      where
        clDeviceType :: DeviceType -> PyExp
clDeviceType DeviceType
DeviceGPU = String -> PyExp
Var String
"cl.device_type.GPU"
        clDeviceType DeviceType
DeviceCPU = String -> PyExp
Var String
"cl.device_type.CPU"

        which' :: PyExp
which' = case WhichSize
which of
          WhichSize
LockstepWidth -> Text -> PyExp
String Text
"lockstep_width"
          WhichSize
NumBlocks -> Text -> PyExp
String Text
"num_groups"
          WhichSize
BlockSize -> Text -> PyExp
String Text
"group_size"
          WhichSize
TileSize -> Text -> PyExp
String Text
"tile_size"
          WhichSize
RegTileSize -> Text -> PyExp
String Text
"reg_tile_size"
          WhichSize
Threshold -> Text -> PyExp
String Text
"threshold"

        what' :: PyExp
what' =
          String -> PyExp -> PyExp
Lambda String
"device" (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$
            Identity PyExp -> PyExp
forall a. Identity a -> a
runIdentity (Identity PyExp -> PyExp) -> Identity PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$
              (DeviceInfo -> Identity PyExp)
-> PrimExp DeviceInfo -> Identity PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
Py.compilePrimExp DeviceInfo -> Identity PyExp
forall {f :: * -> *}. Applicative f => DeviceInfo -> f PyExp
onLeaf (PrimExp DeviceInfo -> Identity PyExp)
-> PrimExp DeviceInfo -> Identity PyExp
forall a b. (a -> b) -> a -> b
$
                TPrimExp Int32 DeviceInfo -> PrimExp DeviceInfo
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int32 DeviceInfo
what

        onLeaf :: DeviceInfo -> f PyExp
onLeaf (DeviceInfo String
s) =
          PyExp -> f PyExp
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> f PyExp) -> PyExp -> f PyExp
forall a b. (a -> b) -> a -> b
$
            String -> [PyExp] -> PyExp
Py.simpleCall
              String
"device.get_info"
              [String -> [PyExp] -> PyExp
Py.simpleCall String
"getattr" [String -> PyExp
Var String
"cl.device_info", Text -> PyExp
String (String -> Text
T.pack String
s)]]