-- | General implementation of GPU copying, using LMAD representation.
-- That means the dynamic performance of this kernel depends crucially
-- on the LMAD.  In most cases we should use a more specialised kernel.
-- Written in ImpCode so we can compile it to both CUDA and OpenCL.
module Futhark.CodeGen.ImpGen.GPU.Copy (copyKernel) where

import Control.Monad
import Control.Monad.State
import Data.Foldable (toList)
import Futhark.CodeGen.ImpCode.GPU
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop.Reshape
import Futhark.MonadFreshNames
import Futhark.Util (nubOrd)
import Futhark.Util.IntegralExp (divUp)
import Prelude hiding (quot, rem)

copyKernel ::
  PrimType ->
  (TExp Int64, GroupDim) ->
  (VName, LMAD.LMAD (TExp Int64)) ->
  (VName, LMAD.LMAD (TExp Int64)) ->
  Kernel
copyKernel :: PrimType
-> (TExp Int64, GroupDim)
-> (VName, LMAD (TExp Int64))
-> (VName, LMAD (TExp Int64))
-> Kernel
copyKernel PrimType
pt (TExp Int64
num_groups, GroupDim
group_dim) (VName
dest_mem, LMAD (TExp Int64)
dest_lmad) (VName
src_mem, LMAD (TExp Int64)
src_lmad) =
  Kernel
    { kernelBody :: Code KernelOp
kernelBody = Code KernelOp
body,
      kernelUses :: [KernelUse]
kernelUses =
        let frees :: [VName]
frees =
              [VName] -> [VName]
forall a. Ord a => [a] -> [a]
nubOrd
                ( (TExp Int64 -> [VName]) -> LMAD (TExp Int64) -> [VName]
forall m a. Monoid m => (a -> m) -> LMAD a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap TExp Int64 -> [VName]
forall a. TPrimExp Int64 a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LMAD (TExp Int64)
dest_lmad
                    [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> (TExp Int64 -> [VName]) -> LMAD (TExp Int64) -> [VName]
forall m a. Monoid m => (a -> m) -> LMAD a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap TExp Int64 -> [VName]
forall a. TPrimExp Int64 a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LMAD (TExp Int64)
src_lmad
                    [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> TExp Int64 -> [VName]
forall a. TPrimExp Int64 a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList TExp Int64
num_groups
                )
         in (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> KernelUse
`ScalarUse` IntType -> PrimType
IntType IntType
Int64) [VName]
frees
              [KernelUse] -> [KernelUse] -> [KernelUse]
forall a. [a] -> [a] -> [a]
++ (VName -> KernelUse) -> [VName] -> [KernelUse]
forall a b. (a -> b) -> [a] -> [b]
map VName -> KernelUse
MemoryUse [VName
dest_mem, VName
src_mem],
      kernelNumGroups :: [Exp]
kernelNumGroups = [TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_groups],
      kernelGroupSize :: [GroupDim]
kernelGroupSize = [GroupDim
group_dim],
      kernelName :: Name
kernelName = String -> Name
nameFromString (String
"copy_" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
rank String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"d_" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
forall a. Pretty a => a -> String
prettyString PrimType
pt),
      kernelFailureTolerant :: Bool
kernelFailureTolerant = Bool
True,
      kernelCheckLocalMemory :: Bool
kernelCheckLocalMemory = Bool
False
    }
  where
    shape :: Shape (TExp Int64)
shape = LMAD (TExp Int64) -> Shape (TExp Int64)
forall num. LMAD num -> Shape num
LMAD.shape LMAD (TExp Int64)
dest_lmad
    rank :: Int
rank = Shape (TExp Int64) -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (TExp Int64)
shape

    body :: Code KernelOp
body = (State VNameSource (Code KernelOp) -> VNameSource -> Code KernelOp)
-> VNameSource
-> State VNameSource (Code KernelOp)
-> Code KernelOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip State VNameSource (Code KernelOp) -> VNameSource -> Code KernelOp
forall s a. State s a -> s -> a
evalState (Int -> VNameSource
newNameSource Int
1000) (State VNameSource (Code KernelOp) -> Code KernelOp)
-> State VNameSource (Code KernelOp) -> Code KernelOp
forall a b. (a -> b) -> a -> b
$ do
      VName
group_id <- String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_id"
      VName
local_id <- String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_id"
      VName
local_size <- String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_size"
      VName
global_id <- String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"global_id"
      VName
group_iter <- String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_iter"
      let global_id_e :: TExp Int64
global_id_e =
            ((VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 VName
group_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 VName
group_iter TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_groups) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 VName
local_size)
              TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 VName
local_id

      [VName]
is <- Int
-> StateT VNameSource Identity VName
-> StateT VNameSource Identity [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
rank (StateT VNameSource Identity VName
 -> StateT VNameSource Identity [VName])
-> StateT VNameSource Identity VName
-> StateT VNameSource Identity [VName]
forall a b. (a -> b) -> a -> b
$ String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
      let is_e :: [Exp]
is_e = (TExp Int64 -> Exp) -> Shape (TExp Int64) -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (Shape (TExp Int64) -> TExp Int64 -> Shape (TExp Int64)
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex Shape (TExp Int64)
shape (VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 VName
global_id))
          in_bounds :: TPrimExp Bool VName
in_bounds = (TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName)
-> [TPrimExp Bool VName] -> TPrimExp Bool VName
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ((TExp Int64 -> TExp Int64 -> TPrimExp Bool VName)
-> Shape (TExp Int64)
-> Shape (TExp Int64)
-> [TPrimExp Bool VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) ((VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 [VName]
is) Shape (TExp Int64)
shape)

      VName
element <- String -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"element"

      let dec :: VName -> Code a
dec VName
v = VName -> Volatility -> PrimType -> Code a
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
v Volatility
Nonvolatile (PrimType -> Code a) -> PrimType -> Code a
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
          src_o :: Count u (TExp Int64)
src_o = TExp Int64 -> Count u (TExp Int64)
forall {k} (u :: k) e. e -> Count u e
Count (LMAD (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index LMAD (TExp Int64)
src_lmad ((VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 [VName]
is))
          dest_o :: Count u (TExp Int64)
dest_o = TExp Int64 -> Count u (TExp Int64)
forall {k} (u :: k) e. e -> Count u e
Count (LMAD (TExp Int64) -> Shape (TExp Int64) -> TExp Int64
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index LMAD (TExp Int64)
dest_lmad ((VName -> TExp Int64) -> [VName] -> Shape (TExp Int64)
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 [VName]
is))
          copy_elem :: Code a
copy_elem =
            VName -> Volatility -> PrimType -> Code a
forall a. VName -> Volatility -> PrimType -> Code a
DeclareScalar VName
element Volatility
Nonvolatile PrimType
pt
              Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
forall a.
VName
-> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Code a
Read VName
element VName
src_mem Count Elements (TExp Int64)
forall {k} {u :: k}. Count u (TExp Int64)
src_o PrimType
pt (String -> Space
Space String
"device") Volatility
Nonvolatile
              Code a -> Code a -> Code a
forall a. Semigroup a => a -> a -> a
<> VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Write VName
dest_mem Count Elements (TExp Int64)
forall {k} {u :: k}. Count u (TExp Int64)
dest_o PrimType
pt (String -> Space
Space String
"device") Volatility
Nonvolatile (VName -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
LeafExp VName
element PrimType
pt)
      Code KernelOp -> State VNameSource (Code KernelOp)
forall a. a -> StateT VNameSource Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Code KernelOp -> State VNameSource (Code KernelOp))
-> Code KernelOp -> State VNameSource (Code KernelOp)
forall a b. (a -> b) -> a -> b
$
        (VName -> Code KernelOp) -> [VName] -> Code KernelOp
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Code KernelOp
forall {a}. VName -> Code a
dec [VName
group_id, VName
local_id, VName
local_size, VName
global_id]
          Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> KernelOp -> Code KernelOp
forall a. a -> Code a
Op (VName -> Int -> KernelOp
GetLocalId VName
local_id Int
0)
          Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> KernelOp -> Code KernelOp
forall a. a -> Code a
Op (VName -> Int -> KernelOp
GetLocalSize VName
local_size Int
0)
          Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> KernelOp -> Code KernelOp
forall a. a -> Code a
Op (VName -> Int -> KernelOp
GetGroupId VName
group_id Int
0)
          Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> VName -> Exp -> Code KernelOp -> Code KernelOp
forall a. VName -> Exp -> Code a -> Code a
For
            VName
group_iter
            (TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (Shape (TExp Int64) -> TExp Int64
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product Shape (TExp Int64)
shape TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` (VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64 VName
local_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_groups)))
            ( VName -> Exp -> Code KernelOp
forall a. VName -> Exp -> Code a
SetScalar VName
global_id (TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
global_id_e)
                Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> (VName -> Code KernelOp) -> [VName] -> Code KernelOp
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap VName -> Code KernelOp
forall {a}. VName -> Code a
dec [VName]
is
                Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> [Code KernelOp] -> Code KernelOp
forall a. Monoid a => [a] -> a
mconcat ((VName -> Exp -> Code KernelOp)
-> [VName] -> [Exp] -> [Code KernelOp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Code KernelOp
forall a. VName -> Exp -> Code a
SetScalar [VName]
is [Exp]
is_e)
                Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. Semigroup a => a -> a -> a
<> TPrimExp Bool VName
-> Code KernelOp -> Code KernelOp -> Code KernelOp
forall a. TPrimExp Bool VName -> Code a -> Code a -> Code a
If TPrimExp Bool VName
in_bounds Code KernelOp
forall {a}. Code a
copy_elem Code KernelOp
forall a. Monoid a => a
mempty
            )