{-# LANGUAGE TypeFamilies #-}

-- | Compile a 'GPUMem' program to imperative code with kernels.
-- This is mostly (but not entirely) the same process no matter if we
-- are targeting OpenCL or CUDA.  The important distinctions (the host
-- level code) are introduced later.
module Futhark.CodeGen.ImpGen.GPU
  ( compileProgOpenCL,
    compileProgCUDA,
    Warnings,
  )
where

import Control.Monad
import Control.Monad.State
import Data.Foldable (toList)
import Data.List (foldl')
import Data.Map qualified as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import Futhark.CodeGen.ImpGen qualified
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.CodeGen.ImpGen.GPU.Copy
import Futhark.CodeGen.ImpGen.GPU.SegHist
import Futhark.CodeGen.ImpGen.GPU.SegMap
import Futhark.CodeGen.ImpGen.GPU.SegRed
import Futhark.CodeGen.ImpGen.GPU.SegScan
import Futhark.CodeGen.ImpGen.GPU.Transpose
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.IxFun qualified as IxFun
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Util.IntegralExp (IntegralExp, divUp, quot, rem)
import Prelude hiding (quot, rem)

callKernelOperations :: Operations GPUMem HostEnv Imp.HostOp
callKernelOperations :: Operations GPUMem HostEnv HostOp
callKernelOperations =
  Operations
    { opsExpCompiler :: ExpCompiler GPUMem HostEnv HostOp
opsExpCompiler = ExpCompiler GPUMem HostEnv HostOp
expCompiler,
      opsCopyCompiler :: CopyCompiler GPUMem HostEnv HostOp
opsCopyCompiler = CopyCompiler GPUMem HostEnv HostOp
callKernelCopy,
      opsOpCompiler :: OpCompiler GPUMem HostEnv HostOp
opsOpCompiler = OpCompiler GPUMem HostEnv HostOp
Pat LetDecMem -> Op GPUMem -> CallKernelGen ()
opCompiler,
      opsStmsCompiler :: StmsCompiler GPUMem HostEnv HostOp
opsStmsCompiler = StmsCompiler GPUMem HostEnv HostOp
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem HostEnv HostOp)
opsAllocCompilers = Map Space (AllocCompiler GPUMem HostEnv HostOp)
forall a. Monoid a => a
mempty
    }

openclAtomics, cudaAtomics :: AtomicBinOp
(AtomicBinOp
openclAtomics, AtomicBinOp
cudaAtomics) = ((BinOp
 -> [(BinOp,
      VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> Exp
      -> AtomicOp)]
 -> Maybe
      (VName
       -> VName
       -> Count Elements (TPrimExp Int64 VName)
       -> Exp
       -> AtomicOp))
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
-> Maybe
     (VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> Exp
      -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl, (BinOp
 -> [(BinOp,
      VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> Exp
      -> AtomicOp)]
 -> Maybe
      (VName
       -> VName
       -> Count Elements (TPrimExp Int64 VName)
       -> Exp
       -> AtomicOp))
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
-> Maybe
     (VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> Exp
      -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
cuda)
  where
    opencl64 :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl64 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int64),
        (IntType -> BinOp
SMax IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMax IntType
Int64),
        (IntType -> BinOp
SMin IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMin IntType
Int64),
        (IntType -> BinOp
UMax IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMax IntType
Int64),
        (IntType -> BinOp
UMin IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMin IntType
Int64),
        (IntType -> BinOp
And IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAnd IntType
Int64),
        (IntType -> BinOp
Or IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicOr IntType
Int64),
        (IntType -> BinOp
Xor IntType
Int64, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicXor IntType
Int64)
      ]
    opencl32 :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl32 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowUndef, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAdd IntType
Int32),
        (IntType -> BinOp
SMax IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMax IntType
Int32),
        (IntType -> BinOp
SMin IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicSMin IntType
Int32),
        (IntType -> BinOp
UMax IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMax IntType
Int32),
        (IntType -> BinOp
UMin IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicUMin IntType
Int32),
        (IntType -> BinOp
And IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicAnd IntType
Int32),
        (IntType -> BinOp
Or IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicOr IntType
Int32),
        (IntType -> BinOp
Xor IntType
Int32, IntType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicXor IntType
Int32)
      ]
    opencl :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl = [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl32 [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl64
    cuda :: [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
cuda =
      [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
opencl
        [(BinOp,
  VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> Exp
  -> AtomicOp)]
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
-> [(BinOp,
     VName
     -> VName
     -> Count Elements (TPrimExp Int64 VName)
     -> Exp
     -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [ (FloatType -> BinOp
FAdd FloatType
Float32, FloatType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicFAdd FloatType
Float32),
             (FloatType -> BinOp
FAdd FloatType
Float64, FloatType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> Exp
-> AtomicOp
Imp.AtomicFAdd FloatType
Float64)
           ]

compileProg ::
  MonadFreshNames m =>
  HostEnv ->
  Prog GPUMem ->
  m (Warnings, Imp.Program)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg HostEnv
env =
  HostEnv
-> Operations GPUMem HostEnv HostOp
-> Space
-> Prog GPUMem
-> m (Warnings, Program)
forall rep (inner :: * -> *) op (m :: * -> *) r.
(Mem rep inner, FreeIn op, MonadFreshNames m) =>
r
-> Operations rep r op
-> Space
-> Prog rep
-> m (Warnings, Definitions op)
Futhark.CodeGen.ImpGen.compileProg HostEnv
env Operations GPUMem HostEnv HostOp
callKernelOperations Space
device_space
  where
    device_space :: Space
device_space = [Char] -> Space
Imp.Space [Char]
"device"

-- | Compile a 'GPUMem' program to low-level parallel code, with
-- either CUDA or OpenCL characteristics.
compileProgOpenCL,
  compileProgCUDA ::
    MonadFreshNames m => Prog GPUMem -> m (Warnings, Imp.Program)
compileProgOpenCL :: forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
compileProgOpenCL = HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog GPUMem -> m (Warnings, Program))
-> HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> Map VName Locks -> HostEnv
HostEnv AtomicBinOp
openclAtomics Target
OpenCL Map VName Locks
forall a. Monoid a => a
mempty
compileProgCUDA :: forall (m :: * -> *).
MonadFreshNames m =>
Prog GPUMem -> m (Warnings, Program)
compileProgCUDA = HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog GPUMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog GPUMem -> m (Warnings, Program))
-> HostEnv -> Prog GPUMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> Map VName Locks -> HostEnv
HostEnv AtomicBinOp
cudaAtomics Target
CUDA Map VName Locks
forall a. Monoid a => a
mempty

opCompiler ::
  Pat LetDecMem ->
  Op GPUMem ->
  CallKernelGen ()
opCompiler :: Pat LetDecMem -> Op GPUMem -> CallKernelGen ()
opCompiler Pat LetDecMem
dest (Alloc SubExp
e Space
space) =
  Pat (LetDec GPUMem) -> SubExp -> Space -> CallKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> SubExp -> Space -> ImpM rep r op ()
compileAlloc Pat (LetDec GPUMem)
Pat LetDecMem
dest SubExp
e Space
space
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (GetSize Name
key SizeClass
size_class))) = do
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x))) = do
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let size_class' :: SizeClass
size_class' = Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ())
-> (Exp -> HostOp) -> Exp -> CallKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Name -> SizeClass -> Exp -> HostOp
Imp.CmpSizeLe (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
size_class'
    (Exp -> CallKernelGen ())
-> ImpM GPUMem HostEnv HostOp Exp -> CallKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ImpM GPUMem HostEnv HostOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
x
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (GetSizeMax SizeClass
size_class))) =
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) SizeClass
size_class
opCompiler (Pat [PatElem LetDecMem
pe]) (Inner (SizeOp (CalcNumGroups SubExp
w64 Name
max_num_groups_key SubExp
group_size))) = do
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  TV Int32
max_num_groups :: TV Int32 <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"max_num_groups" PrimType
int32
  HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
max_num_groups) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
max_num_groups_key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
SizeNumGroups

  -- If 'w' is small, we launch fewer groups than we normally would.
  -- We don't want any idle groups.
  --
  -- The calculations are done with 64-bit integers to avoid overflow
  -- issues.
  let num_groups_maybe_zero :: TPrimExp Int64 VName
num_groups_maybe_zero =
        TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w64 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 SubExp
group_size) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
          TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TPrimExp Int32 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
max_num_groups)
  -- We also don't want zero groups.
  let num_groups :: TPrimExp Int64 VName
num_groups = TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TPrimExp Int64 VName
1 TPrimExp Int64 VName
num_groups_maybe_zero
  VName -> PrimType -> TV Int32
forall {k} (t :: k). VName -> PrimType -> TV t
mkTV (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) PrimType
int32 TV Int32 -> TPrimExp Int32 VName -> CallKernelGen ()
forall {k} (t :: k) rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
num_groups
opCompiler Pat LetDecMem
dest (Inner (SegOp SegOp SegLevel GPUMem
op)) =
  Pat LetDecMem -> SegOp SegLevel GPUMem -> CallKernelGen ()
segOpCompiler Pat LetDecMem
dest SegOp SegLevel GPUMem
op
opCompiler (Pat [PatElem LetDecMem]
pes) (Inner (GPUBody [TypeBase Shape NoUniqueness]
_ (Body BodyDec GPUMem
_ Stms GPUMem
stms Result
res))) = do
  VName
tid <- [Char] -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"tid"
  let one :: Count u SubExp
one = SubExp -> Count u SubExp
forall {k} (u :: k) e. e -> Count u e
Count (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
  [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelThread [Char]
"gpuseq" VName
tid (Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
forall {k} {u :: k}. Count u SubExp
one Count GroupSize SubExp
forall {k} {u :: k}. Count u SubExp
one) (InKernelGen () -> CallKernelGen ())
-> InKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
res) Stms GPUMem
stms (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [(PatElem LetDecMem, SubExpRes)]
-> ((PatElem LetDecMem, SubExpRes) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> Result -> [(PatElem LetDecMem, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
pes Result
res) (((PatElem LetDecMem, SubExpRes) -> InKernelGen ())
 -> InKernelGen ())
-> ((PatElem LetDecMem, SubExpRes) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, SubExpRes Certs
_ SubExp
se) ->
        VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0] SubExp
se []
opCompiler Pat LetDecMem
pat Op GPUMem
e =
  [Char] -> CallKernelGen ()
forall a. [Char] -> a
compilerBugS ([Char] -> CallKernelGen ()) -> [Char] -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
    [Char]
"opCompiler: Invalid pattern\n  "
      [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat LetDecMem
pat
      [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"\nfor expression\n  "
      [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ MemOp (HostOp NoOp) GPUMem -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Op GPUMem
MemOp (HostOp NoOp) GPUMem
e

sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass
sizeClassWithEntryPoint :: Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname (Imp.SizeThreshold KernelPath
path Maybe Int64
def) =
  KernelPath -> Maybe Int64 -> SizeClass
Imp.SizeThreshold (((Name, Bool) -> (Name, Bool)) -> KernelPath -> KernelPath
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> (Name, Bool)
f KernelPath
path) Maybe Int64
def
  where
    f :: (Name, Bool) -> (Name, Bool)
f (Name
name, Bool
x) = (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
name, Bool
x)
sizeClassWithEntryPoint Maybe Name
_ SizeClass
size_class = SizeClass
size_class

segOpCompiler ::
  Pat LetDecMem ->
  SegOp SegLevel GPUMem ->
  CallKernelGen ()
segOpCompiler :: Pat LetDecMem -> SegOp SegLevel GPUMem -> CallKernelGen ()
segOpCompiler Pat LetDecMem
pat (SegMap SegLevel
lvl SegSpace
space [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel -> SegSpace -> KernelBody GPUMem -> CallKernelGen ()
compileSegMap Pat LetDecMem
pat SegLevel
lvl SegSpace
space KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat (SegRed lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [SegBinOp GPUMem]
reds [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegRed Pat LetDecMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
reds KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat (SegScan lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [SegBinOp GPUMem]
scans [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel
-> SegSpace
-> [SegBinOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegScan Pat LetDecMem
pat SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat (SegHist lvl :: SegLevel
lvl@(SegThread SegVirt
_ Maybe KernelGrid
_) SegSpace
space [HistOp GPUMem]
ops [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
kbody) =
  Pat LetDecMem
-> SegLevel
-> SegSpace
-> [HistOp GPUMem]
-> KernelBody GPUMem
-> CallKernelGen ()
compileSegHist Pat LetDecMem
pat SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops KernelBody GPUMem
kbody
segOpCompiler Pat LetDecMem
pat SegOp SegLevel GPUMem
segop =
  [Char] -> CallKernelGen ()
forall a. [Char] -> a
compilerBugS ([Char] -> CallKernelGen ()) -> [Char] -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"segOpCompiler: unexpected " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ SegLevel -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
segop) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" for rhs of pattern " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat LetDecMem
pat

-- Create boolean expression that checks whether all kernels in the
-- enclosed code do not use more local memory than we have available.
-- We look at *all* the kernels here, even those that might be
-- otherwise protected by their own multi-versioning branches deeper
-- down.  Currently the compiler will not generate multi-versioning
-- that makes this a problem, but it might in the future.
checkLocalMemoryReqs :: Imp.HostCode -> CallKernelGen (Maybe (Imp.TExp Bool))
checkLocalMemoryReqs :: Code HostOp -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code HostOp
code = do
  Scope SOACS
scope <- ImpM GPUMem HostEnv HostOp (Scope SOACS)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let alloc_sizes :: [Count Bytes (TPrimExp Int64 VName)]
alloc_sizes = (Kernel -> Count Bytes (TPrimExp Int64 VName))
-> [Kernel] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map ([Count Bytes (TPrimExp Int64 VName)]
-> Count Bytes (TPrimExp Int64 VName)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TPrimExp Int64 VName)]
 -> Count Bytes (TPrimExp Int64 VName))
-> (Kernel -> [Count Bytes (TPrimExp Int64 VName)])
-> Kernel
-> Count Bytes (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Count Bytes (TPrimExp Int64 VName)
 -> Count Bytes (TPrimExp Int64 VName))
-> [Count Bytes (TPrimExp Int64 VName)]
-> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TPrimExp Int64 VName)
-> Count Bytes (TPrimExp Int64 VName)
forall {a}. IntegralExp a => a -> a
alignedSize ([Count Bytes (TPrimExp Int64 VName)]
 -> [Count Bytes (TPrimExp Int64 VName)])
-> (Kernel -> [Count Bytes (TPrimExp Int64 VName)])
-> Kernel
-> [Count Bytes (TPrimExp Int64 VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Code KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSizes (Code KernelOp -> [Count Bytes (TPrimExp Int64 VName)])
-> (Kernel -> Code KernelOp)
-> Kernel
-> [Count Bytes (TPrimExp Int64 VName)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> Code KernelOp
Imp.kernelBody) ([Kernel] -> [Count Bytes (TPrimExp Int64 VName)])
-> [Kernel] -> [Count Bytes (TPrimExp Int64 VName)]
forall a b. (a -> b) -> a -> b
$ Code HostOp -> [Kernel]
getGPU Code HostOp
code

  -- If any of the sizes involve a variable that is not known at this
  -- point, then we cannot check the requirements.
  if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Scope SOACS -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` Scope SOACS
scope) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TPrimExp Int64 VName)] -> Names
forall a. FreeIn a => a -> Names
freeIn [Count Bytes (TPrimExp Int64 VName)]
alloc_sizes)
    then Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (TExp Bool)
forall a. Maybe a
Nothing
    else do
      TV Int32
local_memory_capacity :: TV Int32 <- [Char] -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int32)
forall {k} rep r op (t :: k).
[Char] -> PrimType -> ImpM rep r op (TV t)
dPrim [Char]
"local_memory_capacity" PrimType
int32
      HostOp -> CallKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> CallKernelGen ()) -> HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int32 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int32
local_memory_capacity) SizeClass
SizeLocalMemory

      let local_memory_capacity_64 :: TPrimExp Int64 VName
local_memory_capacity_64 =
            TPrimExp Int32 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int32 VName -> TPrimExp Int64 VName)
-> TPrimExp Int32 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TPrimExp Int32 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int32
local_memory_capacity
          fits :: Count Bytes (TPrimExp Int64 VName) -> TExp Bool
fits Count Bytes (TPrimExp Int64 VName)
size =
            Count Bytes (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count Bytes (TPrimExp Int64 VName)
size TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
local_memory_capacity_64
      Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool)))
-> Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Maybe (TExp Bool)
forall a. a -> Maybe a
Just (TExp Bool -> Maybe (TExp Bool)) -> TExp Bool -> Maybe (TExp Bool)
forall a b. (a -> b) -> a -> b
$ (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
forall v. TPrimExp Bool v
true ((Count Bytes (TPrimExp Int64 VName) -> TExp Bool)
-> [Count Bytes (TPrimExp Int64 VName)] -> [TExp Bool]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TPrimExp Int64 VName) -> TExp Bool
fits [Count Bytes (TPrimExp Int64 VName)]
alloc_sizes)
  where
    getGPU :: Code HostOp -> [Kernel]
getGPU = (HostOp -> [Kernel]) -> Code HostOp -> [Kernel]
forall m a. Monoid m => (a -> m) -> Code a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HostOp -> [Kernel]
getKernel
    getKernel :: HostOp -> [Kernel]
getKernel (Imp.CallKernel Kernel
k) | Kernel -> Bool
Imp.kernelCheckLocalMemory Kernel
k = [Kernel
k]
    getKernel HostOp
_ = []

    localAllocSizes :: Code KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSizes = (KernelOp -> [Count Bytes (TPrimExp Int64 VName)])
-> Code KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
forall m a. Monoid m => (a -> m) -> Code a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSize
    localAllocSize :: KernelOp -> [Count Bytes (TPrimExp Int64 VName)]
localAllocSize (Imp.LocalAlloc VName
_ Count Bytes (TPrimExp Int64 VName)
size) = [Count Bytes (TPrimExp Int64 VName)
size]
    localAllocSize KernelOp
_ = []

    -- These allocations will actually be padded to an 8-byte aligned
    -- size, so we should take that into account when checking whether
    -- they fit.
    alignedSize :: a -> a
alignedSize a
x = a
x a -> a -> a
forall a. Num a => a -> a -> a
+ ((a
8 a -> a -> a
forall a. Num a => a -> a -> a
- (a
x a -> a -> a
forall e. IntegralExp e => e -> e -> e
`rem` a
8)) a -> a -> a
forall e. IntegralExp e => e -> e -> e
`rem` a
8)

withAcc ::
  Pat LetDecMem ->
  [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))] ->
  Lambda GPUMem ->
  CallKernelGen ()
withAcc :: Pat LetDecMem
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem
-> CallKernelGen ()
withAcc Pat LetDecMem
pat [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam = do
  AtomicBinOp
atomics <- HostEnv -> AtomicBinOp
hostAtomics (HostEnv -> AtomicBinOp)
-> ImpM GPUMem HostEnv HostOp HostEnv
-> ImpM GPUMem HostEnv HostOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
atomics ([(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
 -> CallKernelGen ())
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [VName]
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accs [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs
  where
    accs :: [VName]
accs = (Param LetDecMem -> VName) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName ([Param LetDecMem] -> [VName]) -> [Param LetDecMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
    locksForInputs :: AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
_ [] =
      ExpCompiler GPUMem HostEnv HostOp
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
Pat LetDecMem
pat (Exp GPUMem -> CallKernelGen ()) -> Exp GPUMem -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem -> Exp GPUMem
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam
    locksForInputs AtomicBinOp
atomics ((VName
c, (Shape
_, [VName]
_, Maybe (Lambda GPUMem, [SubExp])
op)) : [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs')
      | Just (Lambda GPUMem
op_lam, [SubExp]
_) <- Maybe (Lambda GPUMem, [SubExp])
op,
        AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
_ <- AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics Lambda GPUMem
op_lam = do
          let num_locks :: Int
num_locks = Int
100151
          VName
locks_arr <- [Char] -> Int -> ImpM GPUMem HostEnv HostOp VName
genZeroes [Char]
"withacc_locks" Int
num_locks
          let locks :: Locks
locks = VName -> Int -> Locks
Locks VName
locks_arr Int
num_locks
              extend :: HostEnv -> HostEnv
extend HostEnv
env = HostEnv
env {hostLocks :: Map VName Locks
hostLocks = VName -> Locks -> Map VName Locks -> Map VName Locks
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
c Locks
locks (Map VName Locks -> Map VName Locks)
-> Map VName Locks -> Map VName Locks
forall a b. (a -> b) -> a -> b
$ HostEnv -> Map VName Locks
hostLocks HostEnv
env}
          (HostEnv -> HostEnv) -> CallKernelGen () -> CallKernelGen ()
forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv HostEnv -> HostEnv
extend (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
atomics [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs'
      | Bool
otherwise =
          AtomicBinOp
-> [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
-> CallKernelGen ()
locksForInputs AtomicBinOp
atomics [(VName, (Shape, [VName], Maybe (Lambda GPUMem, [SubExp])))]
inputs'

expCompiler :: ExpCompiler GPUMem HostEnv Imp.HostOp
-- We generate a simple kernel for itoa and replicate.
expCompiler :: ExpCompiler GPUMem HostEnv HostOp
expCompiler (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Iota SubExp
n SubExp
x SubExp
s IntType
et)) = do
  Exp
x' <- SubExp -> ImpM GPUMem HostEnv HostOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
x
  Exp
s' <- SubExp -> ImpM GPUMem HostEnv HostOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
s

  VName
-> TPrimExp Int64 VName
-> Exp
-> Exp
-> IntType
-> CallKernelGen ()
sIota (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n) Exp
x' Exp
s' IntType
et
expCompiler (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Replicate Shape
shape SubExp
se))
  | Acc {} <- PatElem LetDecMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec GPUMem)
PatElem LetDecMem
pe = () -> CallKernelGen ()
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
otherwise =
      if Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
        then VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> CallKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) [] SubExp
se []
        else VName -> SubExp -> CallKernelGen ()
sReplicate (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) SubExp
se
-- Allocation in the "local" space is just a placeholder.
expCompiler Pat (LetDec GPUMem)
_ (Op (Alloc SubExp
_ (Space [Char]
"local"))) =
  () -> CallKernelGen ()
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
expCompiler Pat (LetDec GPUMem)
pat (WithAcc [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam) =
  Pat LetDecMem
-> [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
-> Lambda GPUMem
-> CallKernelGen ()
withAcc Pat (LetDec GPUMem)
Pat LetDecMem
pat [(Shape, [VName], Maybe (Lambda GPUMem, [SubExp]))]
inputs Lambda GPUMem
lam
-- This is a multi-versioning Match created by incremental flattening.
-- We need to augment the conditional with a check that any local
-- memory requirements in tbranch are compatible with the hardware.
-- We do not check anything for defbody, as we assume that it will
-- always be safe (and what would we do if none of the branches would
-- work?).
expCompiler Pat (LetDec GPUMem)
dest (Match [SubExp]
cond (Case (Body GPUMem)
first_case : [Case (Body GPUMem)]
cases) Body GPUMem
defbranch sort :: MatchDec (BranchType GPUMem)
sort@(MatchDec [BranchType GPUMem]
_ MatchSort
MatchEquiv)) = do
  Code HostOp
tcode <- CallKernelGen () -> ImpM GPUMem HostEnv HostOp (Code HostOp)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (CallKernelGen () -> ImpM GPUMem HostEnv HostOp (Code HostOp))
-> CallKernelGen () -> ImpM GPUMem HostEnv HostOp (Code HostOp)
forall a b. (a -> b) -> a -> b
$ Pat (LetDec GPUMem) -> Body GPUMem -> CallKernelGen ()
forall rep r op. Pat (LetDec rep) -> Body rep -> ImpM rep r op ()
compileBody Pat (LetDec GPUMem)
dest (Body GPUMem -> CallKernelGen ())
-> Body GPUMem -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody Case (Body GPUMem)
first_case
  Code HostOp
fcode <- CallKernelGen () -> ImpM GPUMem HostEnv HostOp (Code HostOp)
forall rep r op. ImpM rep r op () -> ImpM rep r op (Code op)
collect (CallKernelGen () -> ImpM GPUMem HostEnv HostOp (Code HostOp))
-> CallKernelGen () -> ImpM GPUMem HostEnv HostOp (Code HostOp)
forall a b. (a -> b) -> a -> b
$ ExpCompiler GPUMem HostEnv HostOp
expCompiler Pat (LetDec GPUMem)
dest (Exp GPUMem -> CallKernelGen ()) -> Exp GPUMem -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> [Case (Body GPUMem)]
-> Body GPUMem
-> MatchDec (BranchType GPUMem)
-> Exp GPUMem
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPUMem)]
cases Body GPUMem
defbranch MatchDec (BranchType GPUMem)
sort
  Maybe (TExp Bool)
check <- Code HostOp -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code HostOp
tcode
  let matches :: TExp Bool
matches = [SubExp] -> [Maybe PrimValue] -> TExp Bool
caseMatch [SubExp]
cond (Case (Body GPUMem) -> [Maybe PrimValue]
forall body. Case body -> [Maybe PrimValue]
casePat Case (Body GPUMem)
first_case)
  Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> Code HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ case Maybe (TExp Bool)
check of
    Maybe (TExp Bool)
Nothing -> Code HostOp
fcode
    Just TExp Bool
ok -> TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If (TExp Bool
matches TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ok) Code HostOp
tcode Code HostOp
fcode
expCompiler Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  ExpCompiler GPUMem HostEnv HostOp
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

gpuCopyForType :: Rank -> PrimType -> CallKernelGen Name
gpuCopyForType :: Rank -> PrimType -> CallKernelGen Name
gpuCopyForType Rank
r PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Rank -> PrimType -> [Char]
gpuCopyName Rank
r PrimType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ Name -> Function HostOp -> CallKernelGen ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function HostOp -> CallKernelGen ())
-> Function HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ Rank -> PrimType -> Function HostOp
gpuCopyFunction Rank
r PrimType
bt

  Name -> CallKernelGen Name
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

gpuCopyName :: Rank -> PrimType -> String
gpuCopyName :: Rank -> PrimType -> [Char]
gpuCopyName (Rank Int
r) PrimType
bt = [Char]
"gpu_copy_" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Int -> [Char]
forall a. Show a => a -> [Char]
show Int
r [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"d_" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
bt

gpuCopyFunction :: Rank -> PrimType -> Imp.Function Imp.HostOp
gpuCopyFunction :: Rank -> PrimType -> Function HostOp
gpuCopyFunction (Rank Int
r) PrimType
pt = do
  let tdesc :: [Char]
tdesc = [[Char]] -> [Char]
forall a. Monoid a => [a] -> a
mconcat (Int -> [Char] -> [[Char]]
forall a. Int -> a -> [a]
replicate Int
r [Char]
"[]") [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
pt
  Maybe EntryPoint
-> [Param] -> [Param] -> Code HostOp -> Function HostOp
forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function Maybe EntryPoint
forall a. Maybe a
Nothing [] [Param]
params (Code HostOp -> Function HostOp) -> Code HostOp -> Function HostOp
forall a b. (a -> b) -> a -> b
$
    [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint ([Char]
"\n# Copy " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
tdesc) Maybe Exp
forall a. Maybe a
Nothing
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> Code HostOp
copy_code
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" Maybe Exp
forall a. Maybe a
Nothing
  where
    space :: Space
space = [Char] -> Space
Space [Char]
"device"
    memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
    intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64

    mkIxFun :: [Char] -> StateT VNameSource Identity (LMAD VName)
mkIxFun [Char]
desc = do
      let new :: [Char] -> StateT VNameSource Identity VName
new [Char]
x = [Char] -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> StateT VNameSource Identity VName)
-> [Char] -> StateT VNameSource Identity VName
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
x
          newDim :: Int -> StateT VNameSource Identity (LMADDim VName)
newDim Int
i = VName -> VName -> Int -> LMADDim VName
forall num. num -> num -> Int -> LMADDim num
LMAD.LMADDim (VName -> VName -> Int -> LMADDim VName)
-> StateT VNameSource Identity VName
-> StateT VNameSource Identity (VName -> Int -> LMADDim VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> StateT VNameSource Identity VName
new [Char]
"stride" StateT VNameSource Identity (VName -> Int -> LMADDim VName)
-> StateT VNameSource Identity VName
-> StateT VNameSource Identity (Int -> LMADDim VName)
forall a b.
StateT VNameSource Identity (a -> b)
-> StateT VNameSource Identity a -> StateT VNameSource Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> StateT VNameSource Identity VName
new [Char]
"shape" StateT VNameSource Identity (Int -> LMADDim VName)
-> StateT VNameSource Identity Int
-> StateT VNameSource Identity (LMADDim VName)
forall a b.
StateT VNameSource Identity (a -> b)
-> StateT VNameSource Identity a -> StateT VNameSource Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> StateT VNameSource Identity Int
forall a. a -> StateT VNameSource Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
i
      VName -> [LMADDim VName] -> LMAD VName
forall num. num -> [LMADDim num] -> LMAD num
LMAD.LMAD (VName -> [LMADDim VName] -> LMAD VName)
-> StateT VNameSource Identity VName
-> StateT VNameSource Identity ([LMADDim VName] -> LMAD VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> StateT VNameSource Identity VName
new [Char]
"offset" StateT VNameSource Identity ([LMADDim VName] -> LMAD VName)
-> StateT VNameSource Identity [LMADDim VName]
-> StateT VNameSource Identity (LMAD VName)
forall a b.
StateT VNameSource Identity (a -> b)
-> StateT VNameSource Identity a -> StateT VNameSource Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Int -> StateT VNameSource Identity (LMADDim VName))
-> [Int] -> StateT VNameSource Identity [LMADDim VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Int -> StateT VNameSource Identity (LMADDim VName)
newDim [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    ([Param]
params, Code HostOp
copy_code) = do
      (State VNameSource ([Param], Code HostOp)
 -> VNameSource -> ([Param], Code HostOp))
-> VNameSource
-> State VNameSource ([Param], Code HostOp)
-> ([Param], Code HostOp)
forall a b c. (a -> b -> c) -> b -> a -> c
flip State VNameSource ([Param], Code HostOp)
-> VNameSource -> ([Param], Code HostOp)
forall s a. State s a -> s -> a
evalState VNameSource
blankNameSource (State VNameSource ([Param], Code HostOp)
 -> ([Param], Code HostOp))
-> State VNameSource ([Param], Code HostOp)
-> ([Param], Code HostOp)
forall a b. (a -> b) -> a -> b
$ do
        VName
dest_mem <- [Char] -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"destmem"
        LMAD VName
dest_lmad <- [Char] -> StateT VNameSource Identity (LMAD VName)
mkIxFun [Char]
"dest"

        VName
src_mem <- [Char] -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"srcmem"
        LMAD VName
src_lmad <- [Char] -> StateT VNameSource Identity (LMAD VName)
mkIxFun [Char]
"src"

        VName
group_size <- [Char] -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"group_size"
        VName
num_groups <- [Char] -> StateT VNameSource Identity VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"num_groups"

        let kernel :: Kernel
kernel =
              PrimType
-> (TPrimExp Int64 VName, GroupDim)
-> (VName, LMAD (TPrimExp Int64 VName))
-> (VName, LMAD (TPrimExp Int64 VName))
-> Kernel
copyKernel
                PrimType
pt
                (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
num_groups, Exp -> GroupDim
forall a b. a -> Either a b
Left (Exp -> GroupDim) -> Exp -> GroupDim
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
group_size)
                (VName
dest_mem, VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> LMAD VName -> LMAD (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD VName
dest_lmad)
                (VName
src_mem, VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> LMAD VName -> LMAD (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD VName
src_lmad)

            dest_offset :: Count Bytes (TPrimExp Int64 VName)
dest_offset =
              TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (LMAD VName -> VName
forall num. LMAD num -> num
LMAD.offset LMAD VName
dest_lmad)) Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` PrimType
pt

            src_offset :: Count Bytes (TPrimExp Int64 VName)
src_offset =
              TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (LMAD VName -> VName
forall num. LMAD num -> num
LMAD.offset LMAD VName
src_lmad)) Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` PrimType
pt

            num_bytes :: Count Bytes (TPrimExp Int64 VName)
num_bytes =
              TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD VName -> [VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD VName
src_lmad)) Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` PrimType
pt

            do_copy :: Code HostOp
do_copy =
              PrimType
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Bytes (TPrimExp Int64 VName)
-> Code HostOp
forall a.
PrimType
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Bytes (TPrimExp Int64 VName)
-> Code a
Imp.Copy
                PrimType
pt
                VName
dest_mem
                Count Bytes (TPrimExp Int64 VName)
dest_offset
                ([Char] -> Space
Space [Char]
"device")
                VName
src_mem
                Count Bytes (TPrimExp Int64 VName)
src_offset
                ([Char] -> Space
Space [Char]
"device")
                Count Bytes (TPrimExp Int64 VName)
num_bytes

        ([Param], Code HostOp) -> State VNameSource ([Param], Code HostOp)
forall a. a -> StateT VNameSource Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
          ( [VName -> Param
memparam VName
dest_mem]
              [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ (VName -> Param) -> [VName] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Param
intparam (LMAD VName -> [VName]
forall num. LMAD num -> Shape num
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LMAD VName
dest_lmad)
              [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ [VName -> Param
memparam VName
src_mem]
              [Param] -> [Param] -> [Param]
forall a. [a] -> [a] -> [a]
++ (VName -> Param) -> [VName] -> [Param]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Param
intparam (LMAD VName -> [VName]
forall num. LMAD num -> Shape num
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LMAD VName
src_lmad),
            VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
group_size Volatility
Imp.Nonvolatile PrimType
int64
              Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
num_groups Volatility
Imp.Nonvolatile PrimType
int64
              Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
group_size Name
"copy_group_size" SizeClass
Imp.SizeGroup)
              Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
num_groups Name
"copy_num_groups" SizeClass
Imp.SizeNumGroups)
              Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If
                (LMAD (TPrimExp Int64 VName)
-> LMAD (TPrimExp Int64 VName) -> TExp Bool
forall num.
(Pretty num, Eq num) =>
LMAD (TPrimExp Int64 num)
-> LMAD (TPrimExp Int64 num) -> TPrimExp Bool num
LMAD.memcpyable (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> LMAD VName -> LMAD (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD VName
dest_lmad) (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (VName -> TPrimExp Int64 VName)
-> LMAD VName -> LMAD (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD VName
src_lmad))
                ( [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Simple copy" Maybe Exp
forall a. Maybe a
Nothing
                    Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> Code HostOp
do_copy
                )
                ( [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"## Kernel copy" Maybe Exp
forall a. Maybe a
Nothing
                    Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (Kernel -> HostOp
Imp.CallKernel Kernel
kernel)
                )
          )

mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"builtin#" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool -> CallKernelGen () -> CallKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (CallKernelGen () -> CallKernelGen ())
-> CallKernelGen () -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ Name -> Function HostOp -> CallKernelGen ()
forall op rep r. Name -> Function op -> ImpM rep r op ()
emitFunction Name
fname (Function HostOp -> CallKernelGen ())
-> Function HostOp -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$ PrimType -> Function HostOp
mapTransposeFunction PrimType
bt

  Name -> CallKernelGen Name
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> [Char]
mapTransposeName PrimType
bt = [Char]
"gpu_map_transpose_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
bt

mapTransposeFunction :: PrimType -> Imp.Function Imp.HostOp
mapTransposeFunction :: PrimType -> Function HostOp
mapTransposeFunction PrimType
bt =
  Maybe EntryPoint
-> [Param] -> [Param] -> Code HostOp -> Function HostOp
forall a.
Maybe EntryPoint -> [Param] -> [Param] -> Code a -> FunctionT a
Imp.Function Maybe EntryPoint
forall a. Maybe a
Nothing [] [Param]
params (Code HostOp -> Function HostOp) -> Code HostOp -> Function HostOp
forall a b. (a -> b) -> a -> b
$
    [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint ([Char]
"\n# Transpose " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> PrimType -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PrimType
bt) Maybe Exp
forall a. Maybe a
Nothing
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Number of arrays  " (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
num_arrays)
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"X elements        " (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x)
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Y elements        " (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y)
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Source      offset" (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
srcoffset)
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Destination offset" (Exp -> Maybe Exp
forall a. a -> Maybe a
Just (Exp -> Maybe Exp) -> Exp -> Maybe Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
destoffset)
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> Code HostOp
transpose_code
      Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"" Maybe Exp
forall a. Maybe a
Nothing
  where
    params :: [Param]
params =
      [ VName -> Param
memparam VName
destmem,
        VName -> Param
intparam VName
destoffset,
        VName -> Param
memparam VName
srcmem,
        VName -> Param
intparam VName
srcoffset,
        VName -> Param
intparam VName
num_arrays,
        VName -> Param
intparam VName
x,
        VName -> Param
intparam VName
y
      ]

    space :: Space
space = [Char] -> Space
Space [Char]
"device"
    memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
    intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64

    [ VName
destmem,
      VName
destoffset,
      VName
srcmem,
      VName
srcoffset,
      VName
num_arrays,
      VName
x,
      VName
y,
      VName
mulx,
      VName
muly,
      VName
block,
      VName
use_32b
      ] =
        ([Char] -> Int -> VName) -> [[Char]] -> [Int] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (Name -> Int -> VName
VName (Name -> Int -> VName)
-> ([Char] -> Name) -> [Char] -> Int -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Name
nameFromString)
          [ [Char]
"destmem",
            [Char]
"destoffset",
            [Char]
"srcmem",
            [Char]
"srcoffset",
            [Char]
"num_arrays",
            [Char]
"x_elems",
            [Char]
"y_elems",
            -- The following is only used for low width/height
            -- transpose kernels
            [Char]
"mulx",
            [Char]
"muly",
            [Char]
"block",
            [Char]
"use_32b"
          ]
          [Int
0 ..]

    block_dim_int :: Integer
block_dim_int = Integer
16

    block_dim :: IntegralExp a => a
    block_dim :: forall a. IntegralExp a => a
block_dim = a
16

    -- When an input array has either width==1 or height==1, performing a
    -- transpose will be the same as performing a copy.
    can_use_copy :: TExp Bool
can_use_copy =
      let onearr :: TExp Bool
onearr = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
num_arrays TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1
          height_is_one :: TExp Bool
height_is_one = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1
          width_is_one :: TExp Bool
width_is_one = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
1
       in TExp Bool
onearr TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Bool
width_is_one TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
height_is_one)

    transpose_code :: Code HostOp
transpose_code =
      TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
input_is_empty Code HostOp
forall a. Monoid a => a
mempty (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
        [Code HostOp] -> Code HostOp
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
muly Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int64),
            VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
muly (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x,
            VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
mulx Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int64),
            VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
mulx (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Exp
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y,
            VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
use_32b Volatility
Imp.Nonvolatile PrimType
Bool,
            VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
use_32b (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              TExp Bool -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Bool -> Exp) -> TExp Bool -> Exp
forall a b. (a -> b) -> a -> b
$
                (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
destoffset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
y) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
2 TPrimExp Int64 VName -> Int -> TPrimExp Int64 VName
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
31 :: Int) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
                  TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
y) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
2 TPrimExp Int64 VName -> Int -> TPrimExp Int64 VName
forall a b. (Num a, Integral b) => a -> b -> a
^ (Int
31 :: Int) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1,
            TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
can_use_copy Code HostOp
copy_code (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowwidth (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowWidth) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowheight (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowHeight) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_small (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeSmall) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                    TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeNormal
          ]

    input_is_empty :: TExp Bool
input_is_empty =
      VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
num_arrays TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0

    should_use_small :: TExp Bool
should_use_small =
      VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)

    should_use_lowwidth :: TExp Bool
should_use_lowwidth =
      VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y

    should_use_lowheight :: TExp Bool
should_use_lowheight =
      VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`quot` TPrimExp Int64 VName
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp Int64 VName
forall a. IntegralExp a => a
block_dim TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x

    copy_code :: Code HostOp
copy_code =
      let num_bytes :: TPrimExp Int64 VName
num_bytes = TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
x TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
y TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* PrimType -> TPrimExp Int64 VName
forall a. Num a => PrimType -> a
primByteSize PrimType
bt
       in PrimType
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Bytes (TPrimExp Int64 VName)
-> Code HostOp
forall a.
PrimType
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Bytes (TPrimExp Int64 VName)
-> Code a
Imp.Copy
            PrimType
bt
            VName
destmem
            (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall {k} (u :: k) e. e -> Count u e
Imp.Count (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
destoffset)
            Space
space
            VName
srcmem
            (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall {k} (u :: k) e. e -> Count u e
Imp.Count (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
srcoffset)
            Space
space
            (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall {k} (u :: k) e. e -> Count u e
Imp.Count TPrimExp Int64 VName
num_bytes)

    callTransposeKernel :: TransposeType -> Code HostOp
callTransposeKernel TransposeType
which =
      TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If
        (Exp -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool (VName -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
LeafExp VName
use_32b PrimType
Bool))
        ( [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Using 32-bit indexing" Maybe Exp
forall a. Maybe a
Nothing
            Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> TransposeType -> Code HostOp
callTransposeKernel32 TransposeType
which
        )
        ( [Char] -> Maybe Exp -> Code HostOp
forall a. [Char] -> Maybe Exp -> Code a
Imp.DebugPrint [Char]
"Using 64-bit indexing" Maybe Exp
forall a. Maybe a
Nothing
            Code HostOp -> Code HostOp -> Code HostOp
forall a. Semigroup a => a -> a -> a
<> TransposeType -> Code HostOp
callTransposeKernel64 TransposeType
which
        )

    callTransposeKernel64 :: TransposeType -> Code HostOp
callTransposeKernel64 =
      HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op
        (HostOp -> Code HostOp)
-> (TransposeType -> HostOp) -> TransposeType -> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel
        (Kernel -> HostOp)
-> (TransposeType -> Kernel) -> TransposeType -> HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimType, VName -> TPrimExp Int64 VName)
-> [Char]
-> Integer
-> TransposeArgs Int64
-> PrimType
-> TransposeType
-> Kernel
forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> [Char]
-> Integer
-> TransposeArgs int
-> PrimType
-> TransposeType
-> Kernel
mapTransposeKernel
          (PrimType
int64, VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64)
          (PrimType -> [Char]
mapTransposeName PrimType
bt)
          Integer
block_dim_int
          ( VName
destmem,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
destoffset,
            VName
srcmem,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
x,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
y,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
mulx,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
muly,
            VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays,
            VName
block
          )
          PrimType
bt

    callTransposeKernel32 :: TransposeType -> Code HostOp
callTransposeKernel32 =
      HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op
        (HostOp -> Code HostOp)
-> (TransposeType -> HostOp) -> TransposeType -> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel
        (Kernel -> HostOp)
-> (TransposeType -> Kernel) -> TransposeType -> HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PrimType, VName -> TPrimExp Int32 VName)
-> [Char]
-> Integer
-> TransposeArgs Int32
-> PrimType
-> TransposeType
-> Kernel
forall {k} (int :: k).
IntExp int =>
(PrimType, VName -> TExp int)
-> [Char]
-> Integer
-> TransposeArgs int
-> PrimType
-> TransposeType
-> Kernel
mapTransposeKernel
          (PrimType
int32, VName -> TPrimExp Int32 VName
forall a. a -> TPrimExp Int32 a
le32)
          (PrimType -> [Char]
mapTransposeName PrimType
bt)
          Integer
block_dim_int
          ( VName
destmem,
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
destoffset),
            VName
srcmem,
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
srcoffset),
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
x),
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
y),
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
mulx),
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
muly),
            TPrimExp Int64 VName -> TPrimExp Int32 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
num_arrays),
            VName
block
          )
          PrimType
bt

-- Note [32-bit transpositions]
--
-- Transposition kernels are much slower when they have to use 64-bit
-- arithmetic.  I observed about 0.67x slowdown on an A100 GPU when
-- transposing four-byte elements (much less when transposing 8-byte
-- elements).  Unfortunately, 64-bit arithmetic is a requirement for
-- large arrays (see #1953 for what happens otherwise).  We generate
-- both 32- and 64-bit index arithmetic versions of transpositions,
-- and dynamically pick between them at runtime.  This is an
-- unfortunate code bloat, and it would be preferable if we could
-- simply optimise the 64-bit version to make this distinction
-- unnecessary.  Fortunately these kernels are quite small.

callKernelCopy :: CopyCompiler GPUMem HostEnv Imp.HostOp
callKernelCopy :: CopyCompiler GPUMem HostEnv HostOp
callKernelCopy PrimType
pt destloc :: MemLoc
destloc@(MemLoc VName
destmem [SubExp]
_ IxFun (TPrimExp Int64 VName)
dest_ixfun) srcloc :: MemLoc
srcloc@(MemLoc VName
srcmem [SubExp]
_ IxFun (TPrimExp Int64 VName)
src_ixfun)
  | Just (TExp Bool
is_transpose, (TPrimExp Int64 VName
destoffset, TPrimExp Int64 VName
srcoffset, TPrimExp Int64 VName
num_arrays, TPrimExp Int64 VName
size_x, TPrimExp Int64 VName
size_y)) <-
      PrimType
-> MemLoc
-> MemLoc
-> Maybe
     (TExp Bool,
      (TPrimExp Int64 VName, TPrimExp Int64 VName, TPrimExp Int64 VName,
       TPrimExp Int64 VName, TPrimExp Int64 VName))
isMapTransposeCopy PrimType
pt MemLoc
destloc MemLoc
srcloc = do
      Name
fname <- PrimType -> CallKernelGen Name
mapTransposeForType PrimType
pt
      TExp Bool
-> CallKernelGen () -> CallKernelGen () -> CallKernelGen ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        TExp Bool
is_transpose
        ( Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> ([Arg] -> Code HostOp) -> [Arg] -> CallKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname ([Arg] -> CallKernelGen ()) -> [Arg] -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [ VName -> Arg
Imp.MemArg VName
destmem,
              Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
destoffset,
              VName -> Arg
Imp.MemArg VName
srcmem,
              Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
srcoffset,
              Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
num_arrays,
              Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
size_x,
              Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
size_y
            ]
        )
        CallKernelGen ()
nontranspose
  | Bool
otherwise = CallKernelGen ()
nontranspose
  where
    nontranspose :: CallKernelGen ()
nontranspose = do
      Name
fname <- Rank -> PrimType -> CallKernelGen Name
gpuCopyForType (Int -> Rank
Rank (IxFun (TPrimExp Int64 VName) -> Int
forall num. IntegralExp num => IxFun num -> Int
IxFun.rank IxFun (TPrimExp Int64 VName)
dest_ixfun)) PrimType
pt
      Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem HostEnv HostOp MemEntry
-> ImpM GPUMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
destmem
      Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem HostEnv HostOp MemEntry
-> ImpM GPUMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory VName
srcmem
      let dest_lmad :: LMAD (TPrimExp Int64 VName)
dest_lmad = LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall t. LMAD t -> LMAD t
LMAD.noPermutation (LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName))
-> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall num. IxFun num -> LMAD num
IxFun.ixfunLMAD IxFun (TPrimExp Int64 VName)
dest_ixfun
          src_lmad :: LMAD (TPrimExp Int64 VName)
src_lmad = LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall t. LMAD t -> LMAD t
LMAD.noPermutation (LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName))
-> LMAD (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp Int64 VName) -> LMAD (TPrimExp Int64 VName)
forall num. IxFun num -> LMAD num
IxFun.ixfunLMAD IxFun (TPrimExp Int64 VName)
src_ixfun
          num_elems :: Count Elements (TPrimExp Int64 VName)
num_elems = TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD (TPrimExp Int64 VName)
dest_lmad
      if Space
dest_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== [Char] -> Space
Space [Char]
"device" Bool -> Bool -> Bool
&& Space
src_space Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== [Char] -> Space
Space [Char]
"device"
        then
          Code HostOp -> CallKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> CallKernelGen ())
-> ([Arg] -> Code HostOp) -> [Arg] -> CallKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname ([Arg] -> CallKernelGen ()) -> [Arg] -> CallKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [VName -> Arg
Imp.MemArg VName
destmem]
              [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> Arg) -> [TPrimExp Int64 VName] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Arg
Imp.ExpArg (Exp -> Arg)
-> (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) (LMAD (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LMAD (TPrimExp Int64 VName)
dest_lmad)
              [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ [VName -> Arg
Imp.MemArg VName
srcmem]
              [Arg] -> [Arg] -> [Arg]
forall a. [a] -> [a] -> [a]
++ (TPrimExp Int64 VName -> Arg) -> [TPrimExp Int64 VName] -> [Arg]
forall a b. (a -> b) -> [a] -> [b]
map (Exp -> Arg
Imp.ExpArg (Exp -> Arg)
-> (TPrimExp Int64 VName -> Exp) -> TPrimExp Int64 VName -> Arg
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) (LMAD (TPrimExp Int64 VName) -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LMAD (TPrimExp Int64 VName)
src_lmad)
        else -- FIXME: this assumes a linear representation!
        -- Currently we never generate code where this is not the
        -- case, but we might in the future.

          VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> CallKernelGen ()
forall rep r op.
VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> VName
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> ImpM rep r op ()
sCopy
            VName
destmem
            (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMAD num -> num
LMAD.offset LMAD (TPrimExp Int64 VName)
dest_lmad) Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` PrimType
pt)
            Space
dest_space
            VName
srcmem
            (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (LMAD (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall num. LMAD num -> num
LMAD.offset LMAD (TPrimExp Int64 VName)
src_lmad) Count Elements (TPrimExp Int64 VName)
-> PrimType -> Count Bytes (TPrimExp Int64 VName)
`Imp.withElemType` PrimType
pt)
            Space
src_space
            Count Elements (TPrimExp Int64 VName)
num_elems
            PrimType
pt