{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}

-- | Compile a 'KernelsMem' program to imperative code with kernels.
-- This is mostly (but not entirely) the same process no matter if we
-- are targeting OpenCL or CUDA.  The important distinctions (the host
-- level code) are introduced later.
module Futhark.CodeGen.ImpGen.Kernels
  ( compileProgOpenCL,
    compileProgCUDA,
    Warnings,
  )
where

import Control.Monad.Except
import Data.Bifunctor (second)
import Data.List (foldl')
import qualified Data.Map as M
import Data.Maybe
import Futhark.CodeGen.ImpCode.Kernels (bytes)
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen hiding (compileProg)
import qualified Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Kernels.Base
import Futhark.CodeGen.ImpGen.Kernels.SegHist
import Futhark.CodeGen.ImpGen.Kernels.SegMap
import Futhark.CodeGen.ImpGen.Kernels.SegRed
import Futhark.CodeGen.ImpGen.Kernels.SegScan
import Futhark.CodeGen.ImpGen.Kernels.Transpose
import Futhark.CodeGen.SetDefaultSpace
import Futhark.Error
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Util.IntegralExp (IntegralExp, divUp, quot, rem)
import Prelude hiding (quot, rem)

callKernelOperations :: Operations KernelsMem HostEnv Imp.HostOp
callKernelOperations :: Operations KernelsMem HostEnv HostOp
callKernelOperations =
  Operations :: forall lore r op.
ExpCompiler lore r op
-> OpCompiler lore r op
-> StmsCompiler lore r op
-> CopyCompiler lore r op
-> Map Space (AllocCompiler lore r op)
-> Operations lore r op
Operations
    { opsExpCompiler :: ExpCompiler KernelsMem HostEnv HostOp
opsExpCompiler = ExpCompiler KernelsMem HostEnv HostOp
expCompiler,
      opsCopyCompiler :: CopyCompiler KernelsMem HostEnv HostOp
opsCopyCompiler = CopyCompiler KernelsMem HostEnv HostOp
callKernelCopy,
      opsOpCompiler :: OpCompiler KernelsMem HostEnv HostOp
opsOpCompiler = OpCompiler KernelsMem HostEnv HostOp
opCompiler,
      opsStmsCompiler :: StmsCompiler KernelsMem HostEnv HostOp
opsStmsCompiler = StmsCompiler KernelsMem HostEnv HostOp
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms,
      opsAllocCompilers :: Map Space (AllocCompiler KernelsMem HostEnv HostOp)
opsAllocCompilers = Map Space (AllocCompiler KernelsMem HostEnv HostOp)
forall a. Monoid a => a
mempty
    }

openclAtomics, cudaAtomics :: AtomicBinOp
(AtomicBinOp
openclAtomics, AtomicBinOp
cudaAtomics) = ((BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
 -> Maybe
      (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> Maybe
     (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl, (BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
 -> Maybe
      (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> Maybe
     (VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
cuda)
  where
    opencl64 :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl64 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int64),
        (IntType -> BinOp
SMax IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMax IntType
Int64),
        (IntType -> BinOp
SMin IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMin IntType
Int64),
        (IntType -> BinOp
UMax IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMax IntType
Int64),
        (IntType -> BinOp
UMin IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMin IntType
Int64),
        (IntType -> BinOp
And IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int64),
        (IntType -> BinOp
Or IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int64),
        (IntType -> BinOp
Xor IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int64)
      ]
    opencl32 :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl32 =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32),
        (IntType -> BinOp
SMax IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMax IntType
Int32),
        (IntType -> BinOp
SMin IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicSMin IntType
Int32),
        (IntType -> BinOp
UMax IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMax IntType
Int32),
        (IntType -> BinOp
UMin IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicUMin IntType
Int32),
        (IntType -> BinOp
And IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int32),
        (IntType -> BinOp
Or IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int32),
        (IntType -> BinOp
Xor IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int32)
      ]
    opencl :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl = [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl32 [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl64
    cuda :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
cuda =
      [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
opencl
        [(BinOp,
  VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp)]
forall a. [a] -> [a] -> [a]
++ [ (FloatType -> BinOp
FAdd FloatType
Float32, FloatType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicFAdd FloatType
Float32),
             (FloatType -> BinOp
FAdd FloatType
Float64, FloatType
-> VName -> VName -> Count Elements (TExp Int64) -> Exp -> AtomicOp
Imp.AtomicFAdd FloatType
Float64)
           ]

compileProg ::
  MonadFreshNames m =>
  HostEnv ->
  Prog KernelsMem ->
  m (Warnings, Imp.Program)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog KernelsMem -> m (Warnings, Program)
compileProg HostEnv
env Prog KernelsMem
prog =
  (Program -> Program) -> (Warnings, Program) -> (Warnings, Program)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Space -> Program -> Program
forall op. Space -> Definitions op -> Definitions op
setDefaultSpace (String -> Space
Imp.Space String
"device"))
    ((Warnings, Program) -> (Warnings, Program))
-> m (Warnings, Program) -> m (Warnings, Program)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> HostEnv
-> Operations KernelsMem HostEnv HostOp
-> Space
-> Prog KernelsMem
-> m (Warnings, Program)
forall lore op (m :: * -> *) r.
(Mem lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
Futhark.CodeGen.ImpGen.compileProg HostEnv
env Operations KernelsMem HostEnv HostOp
callKernelOperations (String -> Space
Imp.Space String
"device") Prog KernelsMem
prog

-- | Compile a 'KernelsMem' program to low-level parallel code, with
-- either CUDA or OpenCL characteristics.
compileProgOpenCL,
  compileProgCUDA ::
    MonadFreshNames m => Prog KernelsMem -> m (Warnings, Imp.Program)
compileProgOpenCL :: forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m (Warnings, Program)
compileProgOpenCL = HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog KernelsMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog KernelsMem -> m (Warnings, Program))
-> HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> HostEnv
HostEnv AtomicBinOp
openclAtomics Target
OpenCL
compileProgCUDA :: forall (m :: * -> *).
MonadFreshNames m =>
Prog KernelsMem -> m (Warnings, Program)
compileProgCUDA = HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall (m :: * -> *).
MonadFreshNames m =>
HostEnv -> Prog KernelsMem -> m (Warnings, Program)
compileProg (HostEnv -> Prog KernelsMem -> m (Warnings, Program))
-> HostEnv -> Prog KernelsMem -> m (Warnings, Program)
forall a b. (a -> b) -> a -> b
$ AtomicBinOp -> Target -> HostEnv
HostEnv AtomicBinOp
cudaAtomics Target
CUDA

opCompiler ::
  Pattern KernelsMem ->
  Op KernelsMem ->
  CallKernelGen ()
opCompiler :: OpCompiler KernelsMem HostEnv HostOp
opCompiler Pattern KernelsMem
dest (Alloc DimSize
e Space
space) =
  Pattern KernelsMem
-> DimSize -> Space -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Mem lore =>
Pattern lore -> DimSize -> Space -> ImpM lore r op ()
compileAlloc Pattern KernelsMem
dest DimSize
e Space
space
opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (GetSize Name
key SizeClass
size_class))) = do
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (CmpSizeLe Name
key SizeClass
size_class DimSize
x))) = do
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  let size_class' :: SizeClass
size_class' = Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
size_class
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> (Exp -> HostOp) -> Exp -> ImpM KernelsMem HostEnv HostOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Name -> SizeClass -> Exp -> HostOp
Imp.CmpSizeLe (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) SizeClass
size_class'
    (Exp -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp Exp
-> ImpM KernelsMem HostEnv HostOp ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DimSize -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (GetSizeMax SizeClass
size_class))) =
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) SizeClass
size_class
opCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (Inner (SizeOp (CalcNumGroups DimSize
w64 Name
max_num_groups_key DimSize
group_size))) = do
  Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
  TV Int32
max_num_groups :: TV Int32 <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"max_num_groups" PrimType
int32
  HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
max_num_groups) (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
max_num_groups_key) (SizeClass -> HostOp) -> SizeClass -> HostOp
forall a b. (a -> b) -> a -> b
$
      Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname SizeClass
SizeNumGroups

  -- If 'w' is small, we launch fewer groups than we normally would.
  -- We don't want any idle groups.
  --
  -- The calculations are done with 64-bit integers to avoid overflow
  -- issues.
  let num_groups_maybe_zero :: TExp Int64
num_groups_maybe_zero =
        TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64 (DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
w64 TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
group_size) (TExp Int64 -> TExp Int64) -> TExp Int64 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
          TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
max_num_groups)
  -- We also don't want zero groups.
  let num_groups :: TExp Int64
num_groups = TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMax64 TExp Int64
1 TExp Int64
num_groups_maybe_zero
  VName -> PrimType -> TV Int32
forall t. VName -> PrimType -> TV t
mkTV (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) PrimType
int32 TV Int32 -> TExp Int32 -> ImpM KernelsMem HostEnv HostOp ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
num_groups
opCompiler Pattern KernelsMem
dest (Inner (SegOp SegOp SegLevel KernelsMem
op)) =
  Pattern KernelsMem
-> SegOp SegLevel KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
segOpCompiler Pattern KernelsMem
dest SegOp SegLevel KernelsMem
op
opCompiler Pattern KernelsMem
pat Op KernelsMem
e =
  String -> ImpM KernelsMem HostEnv HostOp ()
forall a. String -> a
compilerBugS (String -> ImpM KernelsMem HostEnv HostOp ())
-> String -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    String
"opCompiler: Invalid pattern\n  "
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LetDecMem
pat
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\nfor expression\n  "
      String -> String -> String
forall a. [a] -> [a] -> [a]
++ MemOp (HostOp KernelsMem ()) -> String
forall a. Pretty a => a -> String
pretty Op KernelsMem
MemOp (HostOp KernelsMem ())
e

sizeClassWithEntryPoint :: Maybe Name -> Imp.SizeClass -> Imp.SizeClass
sizeClassWithEntryPoint :: Maybe Name -> SizeClass -> SizeClass
sizeClassWithEntryPoint Maybe Name
fname (Imp.SizeThreshold KernelPath
path Maybe Int64
def) =
  KernelPath -> Maybe Int64 -> SizeClass
Imp.SizeThreshold (((Name, Bool) -> (Name, Bool)) -> KernelPath -> KernelPath
forall a b. (a -> b) -> [a] -> [b]
map (Name, Bool) -> (Name, Bool)
f KernelPath
path) Maybe Int64
def
  where
    f :: (Name, Bool) -> (Name, Bool)
f (Name
name, Bool
x) = (Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
name, Bool
x)
sizeClassWithEntryPoint Maybe Name
_ SizeClass
size_class = SizeClass
size_class

segOpCompiler ::
  Pattern KernelsMem ->
  SegOp SegLevel KernelsMem ->
  CallKernelGen ()
segOpCompiler :: Pattern KernelsMem
-> SegOp SegLevel KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
segOpCompiler Pattern KernelsMem
pat (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase DimSize) NoUniqueness]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegMap Pattern KernelsMem
pat SegLevel
lvl SegSpace
space KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat (SegRed lvl :: SegLevel
lvl@SegThread {} SegSpace
space [SegBinOp KernelsMem]
reds [TypeBase (ShapeBase DimSize) NoUniqueness]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegRed Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
reds KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat (SegScan lvl :: SegLevel
lvl@SegThread {} SegSpace
space [SegBinOp KernelsMem]
scans [TypeBase (ShapeBase DimSize) NoUniqueness]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> SegLevel
-> SegSpace
-> [SegBinOp KernelsMem]
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegScan Pattern KernelsMem
pat SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat (SegHist (SegThread Count NumGroups DimSize
num_groups Count GroupSize DimSize
group_size SegVirt
_) SegSpace
space [HistOp KernelsMem]
ops [TypeBase (ShapeBase DimSize) NoUniqueness]
_ KernelBody KernelsMem
kbody) =
  Pattern KernelsMem
-> Count NumGroups DimSize
-> Count GroupSize DimSize
-> SegSpace
-> [HistOp KernelsMem]
-> KernelBody KernelsMem
-> ImpM KernelsMem HostEnv HostOp ()
compileSegHist Pattern KernelsMem
pat Count NumGroups DimSize
num_groups Count GroupSize DimSize
group_size SegSpace
space [HistOp KernelsMem]
ops KernelBody KernelsMem
kbody
segOpCompiler Pattern KernelsMem
pat SegOp SegLevel KernelsMem
segop =
  String -> ImpM KernelsMem HostEnv HostOp ()
forall a. String -> a
compilerBugS (String -> ImpM KernelsMem HostEnv HostOp ())
-> String -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ String
"segOpCompiler: unexpected " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SegLevel -> String
forall a. Pretty a => a -> String
pretty (SegOp SegLevel KernelsMem -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel KernelsMem
segop) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" for rhs of pattern " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LetDecMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LetDecMem
pat

-- Create boolean expression that checks whether all kernels in the
-- enclosed code do not use more local memory than we have available.
-- We look at *all* the kernels here, even those that might be
-- otherwise protected by their own multi-versioning branches deeper
-- down.  Currently the compiler will not generate multi-versioning
-- that makes this a problem, but it might in the future.
checkLocalMemoryReqs :: Imp.Code -> CallKernelGen (Maybe (Imp.TExp Bool))
checkLocalMemoryReqs :: Code HostOp -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code HostOp
code = do
  Scope SOACS
scope <- ImpM KernelsMem HostEnv HostOp (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let alloc_sizes :: [Count Bytes (TExp Int64)]
alloc_sizes = (Kernel -> Count Bytes (TExp Int64))
-> [Kernel] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64)
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Count Bytes (TExp Int64)] -> Count Bytes (TExp Int64))
-> (Kernel -> [Count Bytes (TExp Int64)])
-> Kernel
-> Count Bytes (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Count Bytes (TExp Int64) -> Count Bytes (TExp Int64))
-> [Count Bytes (TExp Int64)] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TExp Int64) -> Count Bytes (TExp Int64)
forall {e}. IntegralExp e => e -> e
alignedSize ([Count Bytes (TExp Int64)] -> [Count Bytes (TExp Int64)])
-> (Kernel -> [Count Bytes (TExp Int64)])
-> Kernel
-> [Count Bytes (TExp Int64)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Code KernelOp -> [Count Bytes (TExp Int64)]
localAllocSizes (Code KernelOp -> [Count Bytes (TExp Int64)])
-> (Kernel -> Code KernelOp)
-> Kernel
-> [Count Bytes (TExp Int64)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> Code KernelOp
Imp.kernelBody) ([Kernel] -> [Count Bytes (TExp Int64)])
-> [Kernel] -> [Count Bytes (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ Code HostOp -> [Kernel]
getKernels Code HostOp
code

  -- If any of the sizes involve a variable that is not known at this
  -- point, then we cannot check the requirements.
  if (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Scope SOACS -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.notMember` Scope SOACS
scope) (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Count Bytes (TExp Int64)] -> Names
forall a. FreeIn a => a -> Names
freeIn [Count Bytes (TExp Int64)]
alloc_sizes)
    then Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (TExp Bool)
forall a. Maybe a
Nothing
    else do
      TV Int32
local_memory_capacity :: TV Int32 <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"local_memory_capacity" PrimType
int32
      HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SizeClass -> HostOp
Imp.GetSizeMax (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
local_memory_capacity) SizeClass
SizeLocalMemory

      let local_memory_capacity_64 :: TExp Int64
local_memory_capacity_64 =
            TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
local_memory_capacity
          fits :: Count Bytes (TExp Int64) -> TExp Bool
fits Count Bytes (TExp Int64)
size =
            Count Bytes (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
unCount Count Bytes (TExp Int64)
size TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
local_memory_capacity_64
      Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool)))
-> Maybe (TExp Bool) -> CallKernelGen (Maybe (TExp Bool))
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Maybe (TExp Bool)
forall a. a -> Maybe a
Just (TExp Bool -> Maybe (TExp Bool)) -> TExp Bool -> Maybe (TExp Bool)
forall a b. (a -> b) -> a -> b
$ (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
forall v. TPrimExp Bool v
true ((Count Bytes (TExp Int64) -> TExp Bool)
-> [Count Bytes (TExp Int64)] -> [TExp Bool]
forall a b. (a -> b) -> [a] -> [b]
map Count Bytes (TExp Int64) -> TExp Bool
fits [Count Bytes (TExp Int64)]
alloc_sizes)
  where
    getKernels :: Code HostOp -> [Kernel]
getKernels = (HostOp -> [Kernel]) -> Code HostOp -> [Kernel]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap HostOp -> [Kernel]
getKernel
    getKernel :: HostOp -> [Kernel]
getKernel (Imp.CallKernel Kernel
k) = [Kernel
k]
    getKernel HostOp
_ = []

    localAllocSizes :: Code KernelOp -> [Count Bytes (TExp Int64)]
localAllocSizes = (KernelOp -> [Count Bytes (TExp Int64)])
-> Code KernelOp -> [Count Bytes (TExp Int64)]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap KernelOp -> [Count Bytes (TExp Int64)]
localAllocSize
    localAllocSize :: KernelOp -> [Count Bytes (TExp Int64)]
localAllocSize (Imp.LocalAlloc VName
_ Count Bytes (TExp Int64)
size) = [Count Bytes (TExp Int64)
size]
    localAllocSize KernelOp
_ = []

    -- These allocations will actually be padded to an 8-byte aligned
    -- size, so we should take that into account when checking whether
    -- they fit.
    alignedSize :: e -> e
alignedSize e
x = e
x e -> e -> e
forall a. Num a => a -> a -> a
+ ((e
8 e -> e -> e
forall a. Num a => a -> a -> a
- (e
x e -> e -> e
forall e. IntegralExp e => e -> e -> e
`rem` e
8)) e -> e -> e
forall e. IntegralExp e => e -> e -> e
`rem` e
8)

expCompiler :: ExpCompiler KernelsMem HostEnv Imp.HostOp
-- We generate a simple kernel for itoa and replicate.
expCompiler :: ExpCompiler KernelsMem HostEnv HostOp
expCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (BasicOp (Iota DimSize
n DimSize
x DimSize
s IntType
et)) = do
  Exp
x' <- DimSize -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
x
  Exp
s' <- DimSize -> ImpM KernelsMem HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp DimSize
s

  VName
-> TExp Int64
-> Exp
-> Exp
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIota (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) (DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp DimSize
n) Exp
x' Exp
s' IntType
et
expCompiler (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (BasicOp (Replicate ShapeBase DimSize
_ DimSize
se)) =
  VName -> DimSize -> ImpM KernelsMem HostEnv HostOp ()
sReplicate (PatElemT LetDecMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LetDecMem
pe) DimSize
se
-- Allocation in the "local" space is just a placeholder.
expCompiler Pattern KernelsMem
_ (Op (Alloc DimSize
_ (Space String
"local"))) =
  () -> ImpM KernelsMem HostEnv HostOp ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
-- This is a multi-versioning If created by incremental flattening.
-- We need to augment the conditional with a check that any local
-- memory requirements in tbranch are compatible with the hardware.
-- We do not check anything for fbranch, as we assume that it will
-- always be safe (and what would we do if none of the branches would
-- work?).
expCompiler Pattern KernelsMem
dest (If DimSize
cond BodyT KernelsMem
tbranch BodyT KernelsMem
fbranch (IfDec [BranchType KernelsMem]
_ IfSort
IfEquiv)) = do
  Code HostOp
tcode <- ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp (Code HostOp)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp (Code HostOp))
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp (Code HostOp)
forall a b. (a -> b) -> a -> b
$ Pattern KernelsMem
-> BodyT KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern KernelsMem
dest BodyT KernelsMem
tbranch
  Code HostOp
fcode <- ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp (Code HostOp)
forall lore r op. ImpM lore r op () -> ImpM lore r op (Code op)
collect (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp (Code HostOp))
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp (Code HostOp)
forall a b. (a -> b) -> a -> b
$ Pattern KernelsMem
-> BodyT KernelsMem -> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Mem lore =>
Pattern lore -> Body lore -> ImpM lore r op ()
compileBody Pattern KernelsMem
dest BodyT KernelsMem
fbranch
  Maybe (TExp Bool)
check <- Code HostOp -> CallKernelGen (Maybe (TExp Bool))
checkLocalMemoryReqs Code HostOp
tcode
  Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ case Maybe (TExp Bool)
check of
    Maybe (TExp Bool)
Nothing -> Code HostOp
fcode
    Just TExp Bool
ok -> TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If (TExp Bool
ok TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. DimSize -> TExp Bool
forall a. ToExp a => a -> TExp Bool
toBoolExp DimSize
cond) Code HostOp
tcode Code HostOp
fcode
expCompiler Pattern KernelsMem
dest ExpT KernelsMem
e =
  ExpCompiler KernelsMem HostEnv HostOp
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern KernelsMem
dest ExpT KernelsMem
e

callKernelCopy :: CopyCompiler KernelsMem HostEnv Imp.HostOp
callKernelCopy :: CopyCompiler KernelsMem HostEnv HostOp
callKernelCopy
  PrimType
bt
  destloc :: MemLocation
destloc@(MemLocation VName
destmem [DimSize]
_ IxFun (TExp Int64)
destIxFun)
  Slice (TExp Int64)
destslice
  srcloc :: MemLocation
srcloc@(MemLocation VName
srcmem [DimSize]
srcshape IxFun (TExp Int64)
srcIxFun)
  Slice (TExp Int64)
srcslice
    | Just
        ( TExp Int64
destoffset,
          TExp Int64
srcoffset,
          TExp Int64
num_arrays,
          TExp Int64
size_x,
          TExp Int64
size_y
          ) <-
        PrimType
-> MemLocation
-> Slice (TExp Int64)
-> MemLocation
-> Slice (TExp Int64)
-> Maybe
     (TExp Int64, TExp Int64, TExp Int64, TExp Int64, TExp Int64)
isMapTransposeCopy PrimType
bt MemLocation
destloc Slice (TExp Int64)
destslice MemLocation
srcloc Slice (TExp Int64)
srcslice = do
      Name
fname <- PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt
      Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
        [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
          []
          Name
fname
          [ VName -> Arg
Imp.MemArg VName
destmem,
            Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
destoffset,
            VName -> Arg
Imp.MemArg VName
srcmem,
            Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
srcoffset,
            Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
num_arrays,
            Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size_x,
            Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
size_y
          ]
    | TExp Int64
bt_size <- PrimType -> TExp Int64
forall a. Num a => PrimType -> a
primByteSize PrimType
bt,
      Just TExp Int64
destoffset <-
        IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> Slice (TExp Int64) -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
destIxFun Slice (TExp Int64)
destslice) TExp Int64
bt_size,
      Just TExp Int64
srcoffset <-
        IxFun (TExp Int64) -> TExp Int64 -> Maybe (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
IxFun.linearWithOffset (IxFun (TExp Int64) -> Slice (TExp Int64) -> IxFun (TExp Int64)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun (TExp Int64)
srcIxFun Slice (TExp Int64)
srcslice) TExp Int64
bt_size = do
      let num_elems :: Count Elements (TExp Int64)
num_elems = TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> TExp Int64 -> Count Elements (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (DimSize -> TExp Int64) -> [DimSize] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map DimSize -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [DimSize]
srcshape
      Space
srcspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem HostEnv HostOp MemEntry
-> ImpM KernelsMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
srcmem
      Space
destspace <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem HostEnv HostOp MemEntry
-> ImpM KernelsMem HostEnv HostOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory VName
destmem
      Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
        VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code HostOp
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
          VName
destmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
destoffset)
          Space
destspace
          VName
srcmem
          (TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
srcoffset)
          Space
srcspace
          (Count Bytes (TExp Int64) -> Code HostOp)
-> Count Bytes (TExp Int64) -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Count Elements (TExp Int64)
num_elems Count Elements (TExp Int64) -> PrimType -> Count Bytes (TExp Int64)
`Imp.withElemType` PrimType
bt
    | Bool
otherwise = CopyCompiler KernelsMem HostEnv HostOp
sCopy PrimType
bt MemLocation
destloc Slice (TExp Int64)
destslice MemLocation
srcloc Slice (TExp Int64)
srcslice

mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType :: PrimType -> CallKernelGen Name
mapTransposeForType PrimType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
mapTransposeName PrimType
bt

  Bool
exists <- Name -> ImpM KernelsMem HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
  Bool
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM KernelsMem HostEnv HostOp ()
 -> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ Name -> Function HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Function HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ PrimType -> Function HostOp
mapTransposeFunction PrimType
bt

  Name -> CallKernelGen Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname

mapTransposeName :: PrimType -> String
mapTransposeName :: PrimType -> String
mapTransposeName PrimType
bt = String
"gpu_map_transpose_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt

mapTransposeFunction :: PrimType -> Imp.Function
mapTransposeFunction :: PrimType -> Function HostOp
mapTransposeFunction PrimType
bt =
  Bool
-> [Param]
-> [Param]
-> Code HostOp
-> [ExternalValue]
-> [ExternalValue]
-> Function HostOp
forall a.
Bool
-> [Param]
-> [Param]
-> Code a
-> [ExternalValue]
-> [ExternalValue]
-> FunctionT a
Imp.Function Bool
False [] [Param]
params Code HostOp
transpose_code [] []
  where
    params :: [Param]
params =
      [ VName -> Param
memparam VName
destmem,
        VName -> Param
intparam VName
destoffset,
        VName -> Param
memparam VName
srcmem,
        VName -> Param
intparam VName
srcoffset,
        VName -> Param
intparam VName
num_arrays,
        VName -> Param
intparam VName
x,
        VName -> Param
intparam VName
y
      ]

    space :: Space
space = String -> Space
Space String
"device"
    memparam :: VName -> Param
memparam VName
v = VName -> Space -> Param
Imp.MemParam VName
v Space
space
    intparam :: VName -> Param
intparam VName
v = VName -> PrimType -> Param
Imp.ScalarParam VName
v (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32

    [ VName
destmem,
      VName
destoffset,
      VName
srcmem,
      VName
srcoffset,
      VName
num_arrays,
      VName
x,
      VName
y,
      VName
mulx,
      VName
muly,
      VName
block
      ] =
        (String -> Int -> VName) -> [String] -> [Int] -> [VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (Name -> Int -> VName
VName (Name -> Int -> VName)
-> (String -> Name) -> String -> Int -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Name
nameFromString)
          [ String
"destmem",
            String
"destoffset",
            String
"srcmem",
            String
"srcoffset",
            String
"num_arrays",
            String
"x_elems",
            String
"y_elems",
            -- The following is only used for low width/height
            -- transpose kernels
            String
"mulx",
            String
"muly",
            String
"block"
          ]
          [Int
0 ..]

    block_dim_int :: Integer
block_dim_int = Integer
16

    block_dim :: IntegralExp a => a
    block_dim :: forall a. IntegralExp a => a
block_dim = a
16

    -- When an input array has either width==1 or height==1, performing a
    -- transpose will be the same as performing a copy.
    can_use_copy :: TExp Bool
can_use_copy =
      let onearr :: TExp Bool
onearr = VName -> TExp Int32
Imp.vi32 VName
num_arrays TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1
          height_is_one :: TExp Bool
height_is_one = VName -> TExp Int32
Imp.vi32 VName
y TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1
          width_is_one :: TExp Bool
width_is_one = VName -> TExp Int32
Imp.vi32 VName
x TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
1
       in TExp Bool
onearr TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (TExp Bool
width_is_one TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
height_is_one)

    transpose_code :: Code HostOp
transpose_code =
      TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
input_is_empty Code HostOp
forall a. Monoid a => a
mempty (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
        [Code HostOp] -> Code HostOp
forall a. Monoid a => [a] -> a
mconcat
          [ VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
muly Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32),
            VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
muly (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` VName -> TExp Int32
Imp.vi32 VName
x,
            VName -> Volatility -> PrimType -> Code HostOp
forall a. VName -> Volatility -> PrimType -> Code a
Imp.DeclareScalar VName
mulx Volatility
Imp.Nonvolatile (IntType -> PrimType
IntType IntType
Int32),
            VName -> Exp -> Code HostOp
forall a. VName -> Exp -> Code a
Imp.SetScalar VName
mulx (Exp -> Code HostOp) -> Exp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> Exp) -> TExp Int32 -> Exp
forall a b. (a -> b) -> a -> b
$ TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` VName -> TExp Int32
Imp.vi32 VName
y,
            TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
can_use_copy Code HostOp
copy_code (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
              TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowwidth (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowWidth) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_lowheight (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeLowHeight) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                  TExp Bool -> Code HostOp -> Code HostOp -> Code HostOp
forall a. TExp Bool -> Code a -> Code a -> Code a
Imp.If TExp Bool
should_use_small (TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeSmall) (Code HostOp -> Code HostOp) -> Code HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
                    TransposeType -> Code HostOp
callTransposeKernel TransposeType
TransposeNormal
          ]

    input_is_empty :: TExp Bool
input_is_empty =
      VName -> TExp Int32
Imp.vi32 VName
num_arrays TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. VName -> TExp Int32
Imp.vi32 VName
x TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. VName -> TExp Int32
Imp.vi32 VName
y TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0

    should_use_small :: TExp Bool
should_use_small =
      VName -> TExp Int32
Imp.vi32 VName
x TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. VName -> TExp Int32
Imp.vi32 VName
y TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
2)

    should_use_lowwidth :: TExp Bool
should_use_lowwidth =
      VName -> TExp Int32
Imp.vi32 VName
x TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TExp Int32
Imp.vi32 VName
y

    should_use_lowheight :: TExp Bool
should_use_lowheight =
      VName -> TExp Int32
Imp.vi32 VName
y TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. (TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
2)
        TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
forall a. IntegralExp a => a
block_dim TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TExp Int32
Imp.vi32 VName
x

    copy_code :: Code HostOp
copy_code =
      let num_bytes :: TExp Int64
num_bytes =
            TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$
              VName -> TExp Int32
Imp.vi32 VName
x TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* VName -> TExp Int32
Imp.vi32 VName
y TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* Exp -> TExp Int32
forall v. PrimExp v -> TPrimExp Int32 v
isInt32 (ExpLeaf -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
Imp.LeafExp (PrimType -> ExpLeaf
Imp.SizeOf PrimType
bt) (IntType -> PrimType
IntType IntType
Int32))
       in VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code HostOp
forall a.
VName
-> Count Bytes (TExp Int64)
-> Space
-> VName
-> Count Bytes (TExp Int64)
-> Space
-> Count Bytes (TExp Int64)
-> Code a
Imp.Copy
            VName
destmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall u e. e -> Count u e
Imp.Count (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
Imp.vi32 VName
destoffset)
            Space
space
            VName
srcmem
            (TExp Int64 -> Count Bytes (TExp Int64)
forall u e. e -> Count u e
Imp.Count (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
Imp.vi32 VName
srcoffset)
            Space
space
            (TExp Int64 -> Count Bytes (TExp Int64)
forall u e. e -> Count u e
Imp.Count TExp Int64
num_bytes)

    callTransposeKernel :: TransposeType -> Code HostOp
callTransposeKernel =
      HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp)
-> (TransposeType -> HostOp) -> TransposeType -> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel
        (Kernel -> HostOp)
-> (TransposeType -> Kernel) -> TransposeType -> HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Integer -> TransposeArgs -> PrimType -> TransposeType -> Kernel
mapTransposeKernel
          (PrimType -> String
mapTransposeName PrimType
bt)
          Integer
block_dim_int
          ( VName
destmem,
            VName -> TExp Int32
Imp.vi32 VName
destoffset,
            VName
srcmem,
            VName -> TExp Int32
Imp.vi32 VName
srcoffset,
            VName -> TExp Int32
Imp.vi32 VName
x,
            VName -> TExp Int32
Imp.vi32 VName
y,
            VName -> TExp Int32
Imp.vi32 VName
mulx,
            VName -> TExp Int32
Imp.vi32 VName
muly,
            VName -> TExp Int32
Imp.vi32 VName
num_arrays,
            VName
block
          )
          PrimType
bt