-- | Imperative code with an OpenCL component.
--
-- Apart from ordinary imperative code, this also carries around an
-- OpenCL program as a string, as well as a list of kernels defined by
-- the OpenCL program.
--
-- The imperative code has been augmented with a 'LaunchKernel'
-- operation that allows one to execute an OpenCL kernel.
module Futhark.CodeGen.ImpCode.OpenCL
  ( Program (..),
    KernelName,
    KernelArg (..),
    CLCode,
    OpenCL (..),
    KernelSafety (..),
    numFailureParams,
    KernelTarget (..),
    FailureMsg (..),
    GroupDim,
    KernelConst (..),
    module Futhark.CodeGen.ImpCode,
    module Futhark.IR.GPU.Sizes,
  )
where

import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.CodeGen.ImpCode
import Futhark.CodeGen.ImpCode.GPU (GroupDim, KernelConst (..))
import Futhark.IR.GPU.Sizes
import Futhark.Util.Pretty

-- | An program calling OpenCL kernels.
data Program = Program
  { Program -> Text
openClProgram :: T.Text,
    -- | Must be prepended to the program.
    Program -> Text
openClPrelude :: T.Text,
    Program -> Map KernelName KernelSafety
openClKernelNames :: M.Map KernelName KernelSafety,
    -- | So we can detect whether the device is capable.
    Program -> [PrimType]
openClUsedTypes :: [PrimType],
    -- | Runtime-configurable constants.
    Program -> ParamMap
openClParams :: ParamMap,
    -- | Assertion failure error messages.
    Program -> [FailureMsg]
openClFailures :: [FailureMsg],
    Program -> Definitions OpenCL
hostDefinitions :: Definitions OpenCL
  }

-- | Something that can go wrong in a kernel.  Part of the machinery
-- for reporting error messages from within kernels.
data FailureMsg = FailureMsg
  { FailureMsg -> ErrorMsg Exp
failureError :: ErrorMsg Exp,
    FailureMsg -> String
failureBacktrace :: String
  }

-- | A piece of code calling OpenCL.
type CLCode = Code OpenCL

-- | The name of a kernel.
type KernelName = Name

-- | An argument to be passed to a kernel.
data KernelArg
  = -- | Pass the value of this scalar expression as argument.
    ValueKArg Exp PrimType
  | -- | Pass this pointer as argument.
    MemKArg VName
  deriving (Int -> KernelArg -> ShowS
[KernelArg] -> ShowS
KernelArg -> String
(Int -> KernelArg -> ShowS)
-> (KernelArg -> String)
-> ([KernelArg] -> ShowS)
-> Show KernelArg
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelArg -> ShowS
showsPrec :: Int -> KernelArg -> ShowS
$cshow :: KernelArg -> String
show :: KernelArg -> String
$cshowList :: [KernelArg] -> ShowS
showList :: [KernelArg] -> ShowS
Show)

-- | Whether a kernel can potentially fail (because it contains bounds
-- checks and such).
data MayFail = MayFail | CannotFail
  deriving (Int -> MayFail -> ShowS
[MayFail] -> ShowS
MayFail -> String
(Int -> MayFail -> ShowS)
-> (MayFail -> String) -> ([MayFail] -> ShowS) -> Show MayFail
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> MayFail -> ShowS
showsPrec :: Int -> MayFail -> ShowS
$cshow :: MayFail -> String
show :: MayFail -> String
$cshowList :: [MayFail] -> ShowS
showList :: [MayFail] -> ShowS
Show)

-- | Information about bounds checks and how sensitive it is to
-- errors.  Ordered by least demanding to most.
data KernelSafety
  = -- | Does not need to know if we are in a failing state, and also
    -- cannot fail.
    SafetyNone
  | -- | Needs to be told if there's a global failure, and that's it,
    -- and cannot fail.
    SafetyCheap
  | -- | Needs all parameters, may fail itself.
    SafetyFull
  deriving (KernelSafety -> KernelSafety -> Bool
(KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool) -> Eq KernelSafety
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelSafety -> KernelSafety -> Bool
== :: KernelSafety -> KernelSafety -> Bool
$c/= :: KernelSafety -> KernelSafety -> Bool
/= :: KernelSafety -> KernelSafety -> Bool
Eq, Eq KernelSafety
Eq KernelSafety
-> (KernelSafety -> KernelSafety -> Ordering)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> Bool)
-> (KernelSafety -> KernelSafety -> KernelSafety)
-> (KernelSafety -> KernelSafety -> KernelSafety)
-> Ord KernelSafety
KernelSafety -> KernelSafety -> Bool
KernelSafety -> KernelSafety -> Ordering
KernelSafety -> KernelSafety -> KernelSafety
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: KernelSafety -> KernelSafety -> Ordering
compare :: KernelSafety -> KernelSafety -> Ordering
$c< :: KernelSafety -> KernelSafety -> Bool
< :: KernelSafety -> KernelSafety -> Bool
$c<= :: KernelSafety -> KernelSafety -> Bool
<= :: KernelSafety -> KernelSafety -> Bool
$c> :: KernelSafety -> KernelSafety -> Bool
> :: KernelSafety -> KernelSafety -> Bool
$c>= :: KernelSafety -> KernelSafety -> Bool
>= :: KernelSafety -> KernelSafety -> Bool
$cmax :: KernelSafety -> KernelSafety -> KernelSafety
max :: KernelSafety -> KernelSafety -> KernelSafety
$cmin :: KernelSafety -> KernelSafety -> KernelSafety
min :: KernelSafety -> KernelSafety -> KernelSafety
Ord, Int -> KernelSafety -> ShowS
[KernelSafety] -> ShowS
KernelSafety -> String
(Int -> KernelSafety -> ShowS)
-> (KernelSafety -> String)
-> ([KernelSafety] -> ShowS)
-> Show KernelSafety
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> KernelSafety -> ShowS
showsPrec :: Int -> KernelSafety -> ShowS
$cshow :: KernelSafety -> String
show :: KernelSafety -> String
$cshowList :: [KernelSafety] -> ShowS
showList :: [KernelSafety] -> ShowS
Show)

-- | How many leading failure arguments we must pass when launching a
-- kernel with these safety characteristics.
numFailureParams :: KernelSafety -> Int
numFailureParams :: KernelSafety -> Int
numFailureParams KernelSafety
SafetyNone = Int
0
numFailureParams KernelSafety
SafetyCheap = Int
1
numFailureParams KernelSafety
SafetyFull = Int
3

-- | Host-level OpenCL operation.
data OpenCL
  = LaunchKernel KernelSafety KernelName (Count Bytes (TExp Int64)) [KernelArg] [Exp] [GroupDim]
  | GetSize VName Name
  | CmpSizeLe VName Name Exp
  | GetSizeMax VName SizeClass
  deriving (Int -> OpenCL -> ShowS
[OpenCL] -> ShowS
OpenCL -> String
(Int -> OpenCL -> ShowS)
-> (OpenCL -> String) -> ([OpenCL] -> ShowS) -> Show OpenCL
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> OpenCL -> ShowS
showsPrec :: Int -> OpenCL -> ShowS
$cshow :: OpenCL -> String
show :: OpenCL -> String
$cshowList :: [OpenCL] -> ShowS
showList :: [OpenCL] -> ShowS
Show)

-- | The target platform when compiling imperative code to a 'Program'
data KernelTarget
  = TargetOpenCL
  | TargetCUDA
  | TargetHIP
  deriving (KernelTarget -> KernelTarget -> Bool
(KernelTarget -> KernelTarget -> Bool)
-> (KernelTarget -> KernelTarget -> Bool) -> Eq KernelTarget
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: KernelTarget -> KernelTarget -> Bool
== :: KernelTarget -> KernelTarget -> Bool
$c/= :: KernelTarget -> KernelTarget -> Bool
/= :: KernelTarget -> KernelTarget -> Bool
Eq)

instance Pretty OpenCL where
  pretty :: forall ann. OpenCL -> Doc ann
pretty = String -> Doc ann
forall ann. String -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty (String -> Doc ann) -> (OpenCL -> String) -> OpenCL -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
. OpenCL -> String
forall a. Show a => a -> String
show