{-# LANGUAGE TypeFamilies #-}

-- | Generation of kernels with group-level bodies.
module Futhark.CodeGen.ImpGen.GPU.Group
  ( sKernelGroup,
    compileGroupResult,
    groupOperations,

    -- * Precomputation
    Precomputed,
    precomputeConstants,
    precomputedConstants,
    atomicUpdateLocking,
  )
where

import Control.Monad
import Data.Bifunctor
import Data.List (partition, zip4)
import Data.Map.Strict qualified as M
import Data.Maybe
import Data.Set qualified as S
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.GPU.Base
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (chunks, mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (divUp, rem)
import Prelude hiding (quot, rem)

-- | @flattenArray k flat arr@ flattens the outer @k@ dimensions of
-- @arr@ to @flat@.  (Make sure @flat@ is the sum of those dimensions
-- or you'll have a bad time.)
flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray :: forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray Int
k TV Int64
flat VName
arr = do
  ArrayEntry MemLoc
arr_loc PrimType
pt <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  let flat_shape :: ShapeBase SubExp
flat_shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
flat) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
k (MemLoc -> [SubExp]
memLocShape MemLoc
arr_loc)
  [Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_flat") PrimType
pt ShapeBase SubExp
flat_shape (MemLoc -> VName
memLocName MemLoc
arr_loc) (LMAD -> ImpM rep r op VName) -> LMAD -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    LMAD -> Maybe LMAD -> LMAD
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> LMAD
forall a. HasCallStack => [Char] -> a
error [Char]
"flattenArray") (Maybe LMAD -> LMAD) -> Maybe LMAD -> LMAD
forall a b. (a -> b) -> a -> b
$
      LMAD -> [TPrimExp Int64 VName] -> Maybe LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape (MemLoc -> LMAD
memLocLMAD MemLoc
arr_loc) ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
flat_shape)

sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray :: forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray TPrimExp Int64 VName
start TV Int64
size VName
arr = do
  MemLoc VName
mem [SubExp]
_ LMAD
ixfun <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase (ShapeBase SubExp) NoUniqueness
arr_t <- VName -> ImpM rep r op (TypeBase (ShapeBase SubExp) NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
  let slice :: Slice (TPrimExp Int64 VName)
slice =
        [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum
          ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
arr_t))
          [TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
start (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
size) TPrimExp Int64 VName
1]
  [Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
forall rep r op.
[Char]
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op VName
sArray
    (VName -> [Char]
baseString VName
arr [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_chunk")
    (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
arr_t)
    (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
arr_t ShapeBase SubExp -> SubExp -> ShapeBase SubExp
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` VName -> SubExp
Var (TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Int64
size))
    VName
mem
    (LMAD -> ImpM rep r op VName) -> LMAD -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ LMAD -> Slice (TPrimExp Int64 VName) -> LMAD
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD
ixfun Slice (TPrimExp Int64 VName)
slice

-- | @applyLambda lam dests args@ emits code that:
--
-- 1. Binds each parameter of @lam@ to the corresponding element of
--    @args@, interpreted as a (name,slice) pair (as in 'copyDWIM').
--    Use an empty list for a scalar.
--
-- 2. Executes the body of @lam@.
--
-- 3. Binds the t'SubExp's that are the 'Result' of @lam@ to the
-- provided @dest@s, again interpreted as the destination for a
-- 'copyDWIM'.
applyLambda ::
  (Mem rep inner) =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyLambda :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  [LParam rep] -> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  [(Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))]
-> ((Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args) (((Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param LetDecMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, (SubExp
arg, [DimIndex (TPrimExp Int64 VName)]
arg_slice)) ->
    VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
arg [DimIndex (TPrimExp Int64 VName)]
arg_slice
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    let res :: [SubExp]
res = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> [SubExpRes]) -> Body rep -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    [((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)]
-> (((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [SubExp]
-> [((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [SubExp]
res) ((((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> (((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((VName
dest, [DimIndex (TPrimExp Int64 VName)]
dest_slice), SubExp
se) ->
      VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TPrimExp Int64 VName)]
dest_slice SubExp
se []

-- | As applyLambda, but first rename the names in the lambda.  This
-- makes it safe to apply it in multiple places.  (It might be safe
-- anyway, but you have to be more careful - use this if you are in
-- doubt.)
applyRenamedLambda ::
  (Mem rep inner) =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyRenamedLambda :: forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  Lambda rep
lam_renamed <- Lambda rep -> ImpM rep r op (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
  Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam_renamed [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args

groupChunkLoop ::
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) ->
  InKernelGen ()
groupChunkLoop :: TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w TExp Int32 -> TV Int64 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let max_chunk_size :: TExp Int32
max_chunk_size = TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  TExp Int32
num_chunks <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
w TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
max_chunk_size
  [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
    TExp Int32
chunk_start <-
      [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_start" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
max_chunk_size
    TExp Int32
chunk_end <-
      [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"chunk_end" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMin32 TExp Int32
w (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
max_chunk_size)
    TV Int64
chunk_size <-
      [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"chunk_size" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_end TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
chunk_start
    TExp Int32 -> TV Int64 -> InKernelGen ()
m TExp Int32
chunk_start TV Int64
chunk_size

virtualisedGroupScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
virtualisedGroupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w ((TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        crosses_segment :: TExp Bool
crosses_segment =
          case Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag of
            Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
            Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true ->
              TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) (TExp Int32 -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int32
chunk_start)
    Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        TPrimExp Int64 VName
carry_idx <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"carry_idx" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
        Lambda GPUMem
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
          Lambda GPUMem
lam
          ((VName -> (VName, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map (,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) [VName]
arrs)
          ( (VName -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
carry_idx]) (SubExp -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> (VName -> SubExp)
-> VName
-> (SubExp, [DimIndex (TPrimExp Int64 VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
arrs
              [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a. [a] -> [a] -> [a]
++ (VName -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) (SubExp -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> (VName -> SubExp)
-> VName
-> (SubExp, [DimIndex (TPrimExp Int64 VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
arrs
          )

    [VName]
arrs_chunks <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TPrimExp Int64 VName
-> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
arrs

    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

    Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
w) (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size) Lambda GPUMem
lam [VName]
arrs_chunks

copyInGroup :: CopyCompiler GPUMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLoc
destloc MemLoc
srcloc = do
  Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem KernelEnv KernelOp MemEntry
-> ImpM GPUMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
destloc)
  Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem KernelEnv KernelOp MemEntry
-> ImpM GPUMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
srcloc)

  let src_lmad :: LMAD
src_lmad = MemLoc -> LMAD
memLocLMAD MemLoc
srcloc
      dims :: [TPrimExp Int64 VName]
dims = LMAD -> [TPrimExp Int64 VName]
forall num. LMAD num -> Shape num
LMAD.shape LMAD
src_lmad
      rank :: Int
rank = [TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims

  case (Space
dest_space, Space
src_space) of
    (ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
      let fullDim :: d -> DimIndex d
fullDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
          destslice' :: Slice (TPrimExp Int64 VName)
destslice' =
            [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              Int
-> DimIndex (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ Int
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
          srcslice' :: Slice (TPrimExp Int64 VName)
srcslice' =
            [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              Int
-> DimIndex (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ Int
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
      CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
lmadCopy
        PrimType
pt
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc Slice (TPrimExp Int64 VName)
destslice')
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc Slice (TPrimExp Int64 VName)
srcslice')
    (Space, Space)
_ -> do
      [TExp Int32] -> ([TExp Int32] -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace ((TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims) (([TExp Int32] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int32] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int32]
is ->
        CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
lmadCopy
          PrimType
pt
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32 -> DimIndex (TPrimExp Int64 VName))
-> [TExp Int32] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32 -> DimIndex (TPrimExp Int64 VName))
-> [TExp Int32] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims = do
  TPrimExp Int64 VName
ltid <- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> (KernelEnv -> TExp Int32) -> KernelEnv -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
  InKernelGen [TPrimExp Int64 VName]
-> ([TExp Int32] -> InKernelGen [TPrimExp Int64 VName])
-> Maybe [TExp Int32]
-> InKernelGen [TPrimExp Int64 VName]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid" [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
ltid) ([TPrimExp Int64 VName] -> InKernelGen [TPrimExp Int64 VName]
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TPrimExp Int64 VName] -> InKernelGen [TPrimExp Int64 VName])
-> ([TExp Int32] -> [TPrimExp Int64 VName])
-> [TExp Int32]
-> InKernelGen [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int32 -> TPrimExp Int64 VName)
-> [TExp Int32] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64)
    (Maybe [TExp Int32] -> InKernelGen [TPrimExp Int64 VName])
-> (KernelEnv -> Maybe [TExp Int32])
-> KernelEnv
-> InKernelGen [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Map [SubExp] [TExp Int32] -> Maybe [TExp Int32]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims
    (Map [SubExp] [TExp Int32] -> Maybe [TExp Int32])
-> (KernelEnv -> Map [SubExp] [TExp Int32])
-> KernelEnv
-> Maybe [TExp Int32]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap
    (KernelConstants -> Map [SubExp] [TExp Int32])
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] [TExp Int32]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants
    (KernelEnv -> InKernelGen [TPrimExp Int64 VName])
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> InKernelGen [TPrimExp Int64 VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims (SegSeqDims [Int]
seq_is) SegSpace
space =
  ([((VName, SubExp), Int)] -> [(VName, SubExp)])
-> ([((VName, SubExp), Int)] -> [(VName, SubExp)])
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)])
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ((((VName, SubExp), Int) -> (VName, SubExp))
-> [((VName, SubExp), Int)] -> [(VName, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, SubExp), Int) -> (VName, SubExp)
forall a b. (a, b) -> a
fst) ((((VName, SubExp), Int) -> (VName, SubExp))
-> [((VName, SubExp), Int)] -> [(VName, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, SubExp), Int) -> (VName, SubExp)
forall a b. (a, b) -> a
fst) (([((VName, SubExp), Int)], [((VName, SubExp), Int)])
 -> ([(VName, SubExp)], [(VName, SubExp)]))
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)])
forall a b. (a -> b) -> a -> b
$
    (((VName, SubExp), Int) -> Bool)
-> [((VName, SubExp), Int)]
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> [Int] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
seq_is) (Int -> Bool)
-> (((VName, SubExp), Int) -> Int)
-> ((VName, SubExp), Int)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, SubExp), Int) -> Int
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [Int] -> [((VName, SubExp), Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) [Int
0 ..])

compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId :: SegSpace -> InKernelGen ()
compileFlatId SegSpace
space = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  VName -> TExp Int32 -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) TExp Int32
ltid

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist ::
  Shape ->
  Count GroupSize SubExp ->
  [HistOp GPUMem] ->
  InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist :: ShapeBase SubExp
-> Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist ShapeBase SubExp
segments Count GroupSize SubExp
group_size =
  ((Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
 -> [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b.
(a -> b)
-> ImpM GPUMem KernelEnv KernelOp a
-> ImpM GPUMem KernelEnv KernelOp b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
   GPUMem
   KernelEnv
   KernelOp
   (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
 -> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ([HistOp GPUMem]
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()]))
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
 -> HistOp GPUMem
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp GPUMem]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
  where
    onOp :: Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
l HistOp GPUMem
op = do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

      let local_subhistos :: [VName]
local_subhistos = HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op

      case (Maybe Locking
l, AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp (Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv)
-> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) of
        (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)
        (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
          VName
locks <- [Char] -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"locks"

          let num_locks :: TPrimExp Int64 VName
num_locks = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size
              dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase SubExp
segments ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histOpShape HistOp GPUMem
op ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> ShapeBase SubExp
forall rep. HistOp rep -> ShapeBase SubExp
histShape HistOp GPUMem
op)
              l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
              locks_t :: TypeBase (ShapeBase SubExp) NoUniqueness
locks_t = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Count GroupSize SubExp -> SubExp
forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness

          VName
locks_mem <- [Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char]
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc [Char]
"locks_mem" (TypeBase (ShapeBase SubExp) NoUniqueness
-> Count Bytes (TPrimExp Int64 VName)
typeSize TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"
          VName
-> PrimType -> ShapeBase SubExp -> VName -> LMAD -> InKernelGen ()
forall rep r op.
VName
-> PrimType
-> ShapeBase SubExp
-> VName
-> LMAD
-> ImpM rep r op ()
dArray VName
locks PrimType
int32 (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) VName
locks_mem (LMAD -> InKernelGen ()) -> LMAD -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            TPrimExp Int64 VName -> [TPrimExp Int64 VName] -> LMAD
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota TPrimExp Int64 VName
0 ([TPrimExp Int64 VName] -> LMAD)
-> (TypeBase (ShapeBase SubExp) NoUniqueness
    -> [TPrimExp Int64 VName])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> LMAD
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> LMAD)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> LMAD
forall a b. (a -> b) -> a -> b
$
              TypeBase (ShapeBase SubExp) NoUniqueness
locks_t

          Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [TPrimExp Int64 VName]
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants] (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
              VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

          (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' ([Char] -> Space
Space [Char]
"local") [VName]
local_subhistos)

groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
virt SegSpace
space InKernelGen ()
m = do
  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  -- Maybe we can statically detect that this is actually a
  -- SegNoVirtFull and generate ever-so-slightly simpler code.
  let virt' :: SegVirt
virt' = if [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
group_size] then SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []) else SegVirt
virt
  case SegVirt
virt' of
    SegVirt
SegVirt -> do
      Maybe (TExp Int32)
iters <- [SubExp] -> Map [SubExp] (TExp Int32) -> Maybe (TExp Int32)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims (Map [SubExp] (TExp Int32) -> Maybe (TExp Int32))
-> (KernelEnv -> Map [SubExp] (TExp Int32))
-> KernelEnv
-> Maybe (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap (KernelConstants -> Map [SubExp] (TExp Int32))
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Maybe (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Maybe (TExp Int32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      case Maybe (TExp Int32)
iters of
        Maybe (TExp Int32)
Nothing -> do
          TExp Int32
iterations <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"iterations" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims'
          TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp Int32
iterations ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
            [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            InKernelGen ()
m
        Just TExp Int32
num_chunks -> Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
          [Char]
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"chunk_i" TExp Int32
num_chunks ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
            TExp Int32
i <- [Char] -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
ltid
            [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((VName -> DimIndex (TPrimExp Int64 VName))
-> [VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (VName -> TPrimExp Int64 VName)
-> VName
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64) [VName]
ltids)) [TPrimExp Int64 VName]
dims') InKernelGen ()
m
    SegVirt
SegNoVirt -> Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids ([TPrimExp Int64 VName] -> InKernelGen ())
-> InKernelGen [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [SubExp]
dims) InKernelGen ()
m
    SegNoVirtFull SegSeqDims
seq_dims -> do
      let (([VName]
ltids_seq, [SubExp]
dims_seq), ([VName]
ltids_par, [SubExp]
dims_par)) =
            ([(VName, SubExp)] -> ([VName], [SubExp]))
-> ([(VName, SubExp)] -> ([VName], [SubExp]))
-> ([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp]))
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip (([(VName, SubExp)], [(VName, SubExp)])
 -> (([VName], [SubExp]), ([VName], [SubExp])))
-> ([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp]))
forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims SegSpace
space
      ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims_seq) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_seq -> do
        (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_seq [TPrimExp Int64 VName]
is_seq
        Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall {k} (t :: k) rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_par ([TPrimExp Int64 VName] -> InKernelGen ())
-> InKernelGen [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims_par
          InKernelGen ()
m

compileGroupExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) [] SubExp
se []
-- The static arrays stuff does not work inside kernels.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
  [(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
dest) [Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileGroupExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Replicate ShapeBase SubExp
ds SubExp
se)) | ShapeBase SubExp
ds ShapeBase SubExp -> ShapeBase SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= ShapeBase SubExp
forall a. Monoid a => a
mempty = do
  VName
flat <- [Char] -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_flat"
  [VName]
is <- Int
-> ImpM GPUMem KernelEnv KernelOp VName
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
dest_t) ([Char] -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"rep_i")
  let is' :: [TPrimExp Int64 VName]
is' = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
SegVirt (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
dest_t) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
dest) [TPrimExp Int64 VName]
is' SubExp
se (Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
ds) [TPrimExp Int64 VName]
is')
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    dest_t :: TypeBase (ShapeBase SubExp) NoUniqueness
dest_t = PatElem LetDecMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
PatElem LetDecMem
dest
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
  Exp
n' <- SubExp -> ImpM GPUMem KernelEnv KernelOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
n
  Exp
e' <- SubExp -> ImpM GPUMem KernelEnv KernelOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
e
  Exp
s' <- SubExp -> ImpM GPUMem KernelEnv KernelOp Exp
forall a rep r op. ToExp a => a -> ImpM rep r op Exp
forall rep r op. SubExp -> ImpM rep r op Exp
toExp SubExp
s
  TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall {k} (t :: k).
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop (Exp -> TPrimExp Int64 VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp Exp
n') ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i' -> do
    TV Any
x <-
      [Char] -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"x" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$
        Exp -> TExp Any
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (Exp -> TExp Any) -> Exp -> TExp Any
forall a b. (a -> b) -> a -> b
$
          BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) Exp
e' (Exp -> Exp) -> Exp -> Exp
forall a b. (a -> b) -> a -> b
$
            BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (TPrimExp Int64 VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
i') Exp
s'
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
dest) [TPrimExp Int64 VName
i'] (VName -> SubExp
Var (TV Any -> VName
forall {k} (t :: k). TV t -> VName
tvVar TV Any
x)) []
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

-- When generating code for a scalar in-place update, we must make
-- sure that only one thread performs the write.  When writing an
-- array, the group-level copy code will take care of doing the right
-- thing.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se))
  | [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        case Safety
safety of
          Safety
Unsafe -> InKernelGen ()
write
          Safety
Safe -> TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
dims) InKernelGen ()
write
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    slice' :: Slice (TPrimExp Int64 VName)
slice' = (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> Slice a -> Slice b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
    dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem LetDecMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
PatElem LetDecMem
pe
    write :: InKernelGen ()
write = VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LetDecMem
pe) (Slice (TPrimExp Int64 VName) -> [DimIndex (TPrimExp Int64 VName)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
se []
compileGroupExp Pat (LetDec GPUMem)
dest Exp GPUMem
e = do
  -- It is a messy to jump into control flow for error handling.
  -- Avoid that by always doing an error sync here.  Potential
  -- improvement: only do this if any errors are pending (this could
  -- also be handled in later codegen).
  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Exp GPUMem -> Bool
forall {rep}. Exp rep -> Bool
doSync Exp GPUMem
e) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
  ExpCompiler GPUMem KernelEnv KernelOp
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e
  where
    doSync :: Exp rep -> Bool
doSync Loop {} = Bool
True
    doSync Match {} = Bool
True
    doSync Exp rep
_ = Bool
False

compileGroupOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LetDecMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
Pat LetDecMem
pat SubExp
size Space
space
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (PatElem LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LetDecMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space

  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LetDecMem
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
          VName
dest
          ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
          []

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays ([VName] -> InKernelGen Fence) -> [VName] -> InKernelGen Fence
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LetDecMem
pat
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

  let segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
      crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
        (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)

  -- groupScan needs to treat the scan output as a one-dimensional
  -- array of scan elements, so we invent some new flattened arrays
  -- here.
  TV Int64
dims_flat <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
  let scan :: SegBinOp GPUMem
scan = [SegBinOp GPUMem] -> SegBinOp GPUMem
forall a. HasCallStack => [a] -> a
head [SegBinOp GPUMem]
scans
      num_scan_results :: Int
num_scan_results = [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan
  [VName]
arrs_flat <-
    (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) ([VName] -> ImpM GPUMem KernelEnv KernelOp [VName])
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$
      Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take Int
num_scan_results ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$
        Pat LetDecMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LetDecMem
pat

  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
        ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
        (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
    SegVirt
_ ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
        ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space

  let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      mkTempArr :: TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr TypeBase (ShapeBase SubExp) NoUniqueness
t =
        [Char]
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
[Char]
-> PrimType -> ShapeBase SubExp -> Space -> ImpM rep r op VName
sAllocArray [Char]
"red_arr" (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
t) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ [Char] -> Space
Space [Char]
"local"

  [VName]
tmp_arrs <- (TypeBase (ShapeBase SubExp) NoUniqueness
 -> ImpM GPUMem KernelEnv KernelOp VName)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegBinOp GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [SegBinOp GPUMem] -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall rep.
Lambda rep -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> (SegBinOp GPUMem -> Lambda GPUMem)
-> SegBinOp GPUMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
ops
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) =
            Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
      [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
      (PatElem LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res

  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  let tmps_for_ops :: [[VName]]
tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
ops) [VName]
tmp_arrs
  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
    SegVirt
_ -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
  where
    ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
    ([PatElem LetDecMem]
red_pes, [PatElem LetDecMem]
map_pes) = Int
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) ([PatElem LetDecMem] -> ([PatElem LetDecMem], [PatElem LetDecMem]))
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LetDecMem
pat

    virtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') ((TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
        Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"possibly incorporate carry" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
              Lambda GPUMem
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
                (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
                ((VName -> (VName, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map (,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) [VName]
tmps)
                ( (PatElem LetDecMem -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> [PatElem LetDecMem]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[]) (SubExp -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> (PatElem LetDecMem -> SubExp)
-> PatElem LetDecMem
-> (SubExp, [DimIndex (TPrimExp Int64 VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp)
-> (PatElem LetDecMem -> VName) -> PatElem LetDecMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem LetDecMem]
red_pes
                    [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a. [a] -> [a] -> [a]
++ (VName -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> [VName] -> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> [a] -> [b]
map ((,[TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]) (SubExp -> (SubExp, [DimIndex (TPrimExp Int64 VName)]))
-> (VName -> SubExp)
-> VName
-> (SubExp, [DimIndex (TPrimExp Int64 VName)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
tmps
                )

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
          [VName]
tmps_chunks <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TPrimExp Int64 VName
-> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
tmps
          TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
chunk_size)) (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps_chunks

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [(PatElem LetDecMem, VName)]
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([VName] -> [(PatElem LetDecMem, VName)])
-> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]

    --
    virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      TV Int64
dims_flat <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
          ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp TV Int64
dims_flat)
          (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(PatElem LetDecMem, VName)]
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([VName] -> [(PatElem LetDecMem, VName)])
-> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
            (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
            []
            (VName -> SubExp
Var VName
arr)
            ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims') [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    -- Nonsegmented case (or rather, a single segment) - this we can
    -- handle directly with a group-level reduction.
    nonvirtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
        TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(PatElem LetDecMem, VName)]
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([VName] -> [(PatElem LetDecMem, VName)])
-> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
0]
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

    -- Segmented intra-group reductions are turned into (regular)
    -- segmented scans.  It is possible that this can be done
    -- better, but at least this approach is simple.
    nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      -- groupScan operates on flattened arrays.  This does not
      -- involve copying anything; merely playing with the index
      -- function.
      TV Int64
dims_flat <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
          ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"Save result of reduction." (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(PatElem LetDecMem, VName)]
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LetDecMem] -> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LetDecMem]
red_pes ([VName] -> [(PatElem LetDecMem, VName)])
-> [VName] -> [(PatElem LetDecMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LetDecMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LetDecMem
pe, VName
arr) ->
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
            (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe)
            []
            (VName -> SubExp
Var VName
arr)
            ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. HasCallStack => [a] -> [a]
init [TPrimExp Int64 VName]
dims') [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. HasCallStack => [a] -> a
last [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody GPUMem
kbody))) = do
  SegSpace -> InKernelGen ()
compileFlatId SegSpace
space
  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- We don't need the red_pes, because it is guaranteed by our type
  -- rules that they occupy the same memory as the destinations for
  -- the ops.
  let num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      ([PatElem LetDecMem]
_red_pes, [PatElem LetDecMem]
map_pes) =
        Int
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElem LetDecMem] -> ([PatElem LetDecMem], [PatElem LetDecMem]))
-> [PatElem LetDecMem]
-> ([PatElem LetDecMem], [PatElem LetDecMem])
forall a b. (a -> b) -> a -> b
$ Pat LetDecMem -> [PatElem LetDecMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LetDecMem
pat

  Count GroupSize SubExp
group_size <- KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount (KernelConstants -> Count GroupSize SubExp)
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Count GroupSize SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Count GroupSize SubExp)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Count GroupSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' <- ShapeBase SubExp
-> Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. HasCallStack => [a] -> [a]
init [SubExp]
dims) Count GroupSize SubExp
group_size [HistOp GPUMem]
ops

  -- Ensure that all locks have been initialised.
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
          ([SubExp]
red_is, [SubExp]
red_vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp GPUMem] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
      (PatElem LetDecMem -> KernelResult -> InKernelGen ())
-> [PatElem LetDecMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LetDecMem]
map_pes [KernelResult]
map_res

      let vs_per_op :: [[SubExp]]
vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [SubExp]
red_vs

      [(SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
  HistOp GPUMem)]
-> ((SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
     HistOp GPUMem)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [[SubExp]]
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> [HistOp GPUMem]
-> [(SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
     HistOp GPUMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' [HistOp GPUMem]
ops) (((SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
   HistOp GPUMem)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
     HistOp GPUMem)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        \(SubExp
bin, [SubExp]
op_vs, [TPrimExp Int64 VName] -> InKernelGen ()
do_op, HistOp ShapeBase SubExp
dest_shape SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda GPUMem
lam) -> do
          let bin' :: TPrimExp Int64 VName
bin' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
bin
              dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
dest_shape
              bin_in_bounds :: TExp Bool
bin_in_bounds = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
bin']) [TPrimExp Int64 VName]
dest_shape'
              bin_is :: [TPrimExp Int64 VName]
bin_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. HasCallStack => [a] -> [a]
init [VName]
ltids) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
bin']
              vs_params :: [Param LetDecMem]
vs_params = Int -> [Param LetDecMem] -> [Param LetDecMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param LetDecMem] -> [Param LetDecMem])
-> [Param LetDecMem] -> [Param LetDecMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

          Text -> InKernelGen () -> InKernelGen ()
forall rep r op. Text -> ImpM rep r op () -> ImpM rep r op ()
sComment Text
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [LParam GPUMem] -> InKernelGen ()
forall rep (inner :: * -> *) r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
              ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ShapeBase SubExp
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
                [(Param LetDecMem, SubExp)]
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LetDecMem] -> [SubExp] -> [(Param LetDecMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LetDecMem]
vs_params [SubExp]
op_vs) (((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LetDecMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LetDecMem
p, SubExp
v) ->
                  VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LetDecMem -> VName
forall dec. Param dec -> VName
paramName Param LetDecMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
                [TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bin_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)

  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  [Char] -> InKernelGen ()
forall a. [Char] -> a
compilerBugS ([Char] -> InKernelGen ()) -> [Char] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Char]
"compileGroupOp: cannot compile rhs of binding " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Pat LetDecMem -> [Char]
forall a. Pretty a => a -> [Char]
prettyString Pat (LetDec GPUMem)
Pat LetDecMem
pat

groupOperations :: Operations GPUMem KernelEnv Imp.KernelOp
groupOperations :: Operations GPUMem KernelEnv KernelOp
groupOperations =
  (OpCompiler GPUMem KernelEnv KernelOp
-> Operations GPUMem KernelEnv KernelOp
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp)
    { opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup,
      opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp,
      opsStmsCompiler :: Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep (inner :: * -> *) op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
        [(Space, AllocCompiler GPUMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [([Char] -> Space
Space [Char]
"local", AllocCompiler GPUMem KernelEnv KernelOp
forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
    }

arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
  VarEntry GPUMem
res <- VName -> ImpM GPUMem KernelEnv KernelOp (VarEntry GPUMem)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry GPUMem
res of
    ArrayVar Maybe (Exp GPUMem)
_ ArrayEntry
entry ->
      ([Char] -> Space
Space [Char]
"local" ==) (Space -> Bool) -> (MemEntry -> Space) -> MemEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
        (MemEntry -> Bool)
-> ImpM GPUMem KernelEnv KernelOp MemEntry -> InKernelGen Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
entry))
    VarEntry GPUMem
_ -> Bool -> InKernelGen Bool
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
arrayInLocalMemory Constant {} = Bool -> InKernelGen Bool
forall a. a -> ImpM GPUMem KernelEnv KernelOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False

sKernelGroup ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelGroup :: [Char]
-> VName -> KernelAttrs -> InKernelGen () -> CallKernelGen ()
sKernelGroup = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> [Char]
-> VName
-> KernelAttrs
-> InKernelGen ()
-> CallKernelGen ()
sKernel Operations GPUMem KernelEnv KernelOp
groupOperations KernelConstants -> TExp Int32
kernelGroupId

compileGroupResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileGroupResult :: SegSpace -> PatElem LetDecMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
_ PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp
w, SubExp
per_group_elems)] VName
what) = do
  TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (TypeBase (ShapeBase SubExp) NoUniqueness -> TPrimExp Int64 VName)
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (TypeBase (ShapeBase SubExp) NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (TypeBase (ShapeBase SubExp) NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      offset :: TPrimExp Int64 VName
offset =
        SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems
          TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)

  -- Avoid loop for the common case where each thread is statically
  -- known to write at most one element.
  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    if SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
      then
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
ltid]
      else [Char]
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall {k} (t :: k) rep r op.
[Char]
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor [Char]
"i" (TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
        TPrimExp Int64 VName
j <- [Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"j" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
j]
compileGroupResult SegSpace
space PatElem LetDecMem
pe (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
what) = do
  let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      out_tile_sizes :: [TPrimExp Int64 VName]
out_tile_sizes = ((SubExp, SubExp) -> TPrimExp Int64 VName)
-> [(SubExp, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((SubExp, SubExp) -> SubExp)
-> (SubExp, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
      group_is :: [TPrimExp Int64 VName]
group_is = (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids) [TPrimExp Int64 VName]
out_tile_sizes
  [TPrimExp Int64 VName]
local_is <- [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs ([SubExp] -> InKernelGen [TPrimExp Int64 VName])
-> [SubExp] -> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
  [TV Int64]
is_for_thread <-
    (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([Char]
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TV t)
dPrimV [Char]
"thread_out_index") ([TPrimExp Int64 VName]
 -> ImpM GPUMem KernelEnv KernelOp [TV Int64])
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TV Int64]
forall a b. (a -> b) -> a -> b
$
      (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(+) [TPrimExp Int64 VName]
group_is [TPrimExp Int64 VName]
local_is

  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((TV Int64 -> VName) -> [TV Int64] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> VName
forall {k} (t :: k). TV t -> VName
tvVar [TV Int64]
is_for_thread) ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) ((TV Int64 -> TPrimExp Int64 VName)
-> [TV Int64] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> TPrimExp Int64 VName
forall {k} (t :: k). TV t -> TExp t
tvExp [TV Int64]
is_for_thread) (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
local_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

  let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ([SubExp]
dims, [SubExp]
group_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
      group_tiles' :: [TPrimExp Int64 VName]
group_tiles' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
group_tiles
      reg_tiles' :: [TPrimExp Int64 VName]
reg_tiles' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
reg_tiles

  -- Which group tile is this group responsible for?
  let group_tile_is :: [TPrimExp Int64 VName]
group_tile_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids

  -- Within the group tile, which register tile is this thread
  -- responsible for?
  [TPrimExp Int64 VName]
reg_tile_is <-
    [Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"reg_tile_i" [TPrimExp Int64 VName]
group_tiles' (TPrimExp Int64 VName -> InKernelGen [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

  -- Compute output array slice for the register tile belonging to
  -- this thread.
  let regTileSliceDim :: (TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim (TExp t
group_tile, TExp t
group_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
        TExp t
tile_dim_start <-
          [Char] -> TExp t -> ImpM rep r op (TExp t)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"tile_dim_start" (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$
            TExp t
reg_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* (TExp t
group_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
group_tile_i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
        DimIndex (TExp t) -> ImpM rep r op (DimIndex (TExp t))
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimIndex (TExp t) -> ImpM rep r op (DimIndex (TExp t)))
-> DimIndex (TExp t) -> ImpM rep r op (DimIndex (TExp t))
forall a b. (a -> b) -> a -> b
$ TExp t -> TExp t -> TExp t -> DimIndex (TExp t)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp t
tile_dim_start TExp t
reg_tile TExp t
1
  Slice (TPrimExp Int64 VName)
reg_tile_slices <-
    [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice
      ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> ImpM GPUMem KernelEnv KernelOp [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp (Slice (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> (TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> ImpM
      GPUMem KernelEnv KernelOp (DimIndex (TPrimExp Int64 VName)))
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp [DimIndex (TPrimExp Int64 VName)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
        (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (DimIndex (TPrimExp Int64 VName))
forall {k} {t :: k} {rep} {r} {op}.
NumExp t =>
(TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim
        ([TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
group_tiles' [TPrimExp Int64 VName]
group_tile_is)
        ([TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
reg_tiles' [TPrimExp Int64 VName]
reg_tile_is)

  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
ShapeBase SubExp
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
reg_tiles) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_in_reg_tile -> do
      let dest_is :: [TPrimExp Int64 VName]
dest_is = Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice Slice (TPrimExp Int64 VName)
reg_tile_slices [TPrimExp Int64 VName]
is_in_reg_tile
          src_is :: [TPrimExp Int64 VName]
src_is = [TPrimExp Int64 VName]
reg_tile_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is_in_reg_tile
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ((TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v.
Eq v =>
TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TPrimExp Int64 VName]
dest_is ([TPrimExp Int64 VName] -> [TExp Bool])
-> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
dest_is (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
src_is
compileGroupResult SegSpace
space PatElem LetDecMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
  let gids :: [TPrimExp Int64 VName]
gids = ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  if Bool -> Bool
not Bool
in_local_memory
    then
      Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
    else -- If the result of the group is an array in local memory, we
    -- store it by collective copying among all the threads of the
    -- group.  TODO: also do this if the array is in global memory
    -- (but this is a bit more tricky, synchronisation-wise).
      VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LetDecMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LetDecMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElem LetDecMem
_ WriteReturns {} =
  [Char] -> InKernelGen ()
forall a. [Char] -> a
compilerLimitationS [Char]
"compileGroupResult: WriteReturns not handled yet."

-- | The sizes of nested iteration spaces in the kernel.
type SegOpSizes = S.Set [SubExp]

-- | Various useful precomputed information for group-level SegOps.
data Precomputed = Precomputed
  { Precomputed -> SegOpSizes
pcSegOpSizes :: SegOpSizes,
    Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

-- | Find the sizes of nested parallelism in a t'SegOp' body.
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes = Stms GPUMem -> SegOpSizes
onStms
  where
    onStms :: Stms GPUMem -> SegOpSizes
onStms = (Stm GPUMem -> SegOpSizes) -> Stms GPUMem -> SegOpSizes
forall m a. Monoid m => (a -> m) -> Seq a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm GPUMem -> SegOpSizes
onStm
    onStm :: Stm GPUMem -> SegOpSizes
onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Op (Inner (SegOp SegOp SegLevel GPUMem
op)))) =
      case SegLevel -> SegVirt
segVirt (SegLevel -> SegVirt) -> SegLevel -> SegVirt
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op of
        SegNoVirtFull SegSeqDims
seq_dims ->
          [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. (a, b) -> b
snd (([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)])
-> ([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims (SegSpace -> ([(VName, SubExp)], [(VName, SubExp)]))
-> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
        SegVirt
_ -> [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace (SegSpace -> [(VName, SubExp)]) -> SegSpace -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
    onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Replicate {}))) =
      [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec GPUMem) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Iota {}))) =
      [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec GPUMem) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    onStm (Let (Pat [PatElem (LetDec GPUMem)
pe]) StmAux (ExpDec GPUMem)
_ (BasicOp (Manifest {}))) =
      [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem (LetDec GPUMem) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElem dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElem (LetDec GPUMem)
pe
    onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_)) =
      (Case (Body GPUMem) -> SegOpSizes)
-> [Case (Body GPUMem)] -> SegOpSizes
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Stms GPUMem -> SegOpSizes
onStms (Stms GPUMem -> SegOpSizes)
-> (Case (Body GPUMem) -> Stms GPUMem)
-> Case (Body GPUMem)
-> SegOpSizes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem)
-> (Case (Body GPUMem) -> Body GPUMem)
-> Case (Body GPUMem)
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody) [Case (Body GPUMem)]
cases SegOpSizes -> SegOpSizes -> SegOpSizes
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> SegOpSizes
onStms (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
defbody)
    onStm (Let Pat (LetDec GPUMem)
_ StmAux (ExpDec GPUMem)
_ (Loop [(FParam GPUMem, SubExp)]
_ LoopForm
_ Body GPUMem
body)) =
      Stms GPUMem -> SegOpSizes
onStms (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
    onStm Stm GPUMem
_ = SegOpSizes
forall a. Monoid a => a
mempty

-- | Precompute various constants and useful information.
precomputeConstants :: Count GroupSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants :: Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size Stms GPUMem
stms = do
  let sizes :: SegOpSizes
sizes = Stms GPUMem -> SegOpSizes
segOpSizes Stms GPUMem
stms
  Map [SubExp] (TExp Int32)
iters_map <- [([SubExp], TExp Int32)] -> Map [SubExp] (TExp Int32)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], TExp Int32)] -> Map [SubExp] (TExp Int32))
-> ImpM GPUMem HostEnv HostOp [([SubExp], TExp Int32)]
-> ImpM GPUMem HostEnv HostOp (Map [SubExp] (TExp Int32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32))
-> [[SubExp]]
-> ImpM GPUMem HostEnv HostOp [([SubExp], TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap (SegOpSizes -> [[SubExp]]
forall a. Set a -> [a]
S.toList SegOpSizes
sizes)
  Precomputed -> CallKernelGen Precomputed
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Precomputed -> CallKernelGen Precomputed)
-> Precomputed -> CallKernelGen Precomputed
forall a b. (a -> b) -> a -> b
$ SegOpSizes -> Map [SubExp] (TExp Int32) -> Precomputed
Precomputed SegOpSizes
sizes Map [SubExp] (TExp Int32)
iters_map
  where
    mkMap :: [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap [SubExp]
dims = do
      let n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 [SubExp]
dims
      TExp Int32
num_chunks <- [Char] -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall {k} (t :: k) rep r op.
[Char] -> TExp t -> ImpM rep r op (TExp t)
dPrimVE [Char]
"num_chunks" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall {k} (u :: k) e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size
      ([SubExp], TExp Int32)
-> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
forall a. a -> ImpM GPUMem HostEnv HostOp a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, TExp Int32
num_chunks)

-- | Make use of various precomputed constants.
precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants :: forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pre InKernelGen a
m = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  Map [SubExp] [TExp Int32]
new_ids <- [([SubExp], [TExp Int32])] -> Map [SubExp] [TExp Int32]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], [TExp Int32])] -> Map [SubExp] [TExp Int32])
-> ImpM GPUMem KernelEnv KernelOp [([SubExp], [TExp Int32])]
-> ImpM GPUMem KernelEnv KernelOp (Map [SubExp] [TExp Int32])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp]
 -> ImpM GPUMem KernelEnv KernelOp ([SubExp], [TExp Int32]))
-> [[SubExp]]
-> ImpM GPUMem KernelEnv KernelOp [([SubExp], [TExp Int32])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (TExp Int32
-> [SubExp]
-> ImpM GPUMem KernelEnv KernelOp ([SubExp], [TExp Int32])
forall {k} {t :: k} {rep} {r} {op}.
IntExp t =>
TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TExp Int32
ltid) (SegOpSizes -> [[SubExp]]
forall a. Set a -> [a]
S.toList (Precomputed -> SegOpSizes
pcSegOpSizes Precomputed
pre))
  let f :: KernelEnv -> KernelEnv
f KernelEnv
env =
        KernelEnv
env
          { kernelConstants :: KernelConstants
kernelConstants =
              (KernelEnv -> KernelConstants
kernelConstants KernelEnv
env)
                { kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
new_ids,
                  kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap Precomputed
pre
                }
          }
  (KernelEnv -> KernelEnv) -> InKernelGen a -> InKernelGen a
forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
  where
    mkMap :: TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TPrimExp t VName
ltid [SubExp]
dims = do
      let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      [TPrimExp Int64 VName]
ids' <- [Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
forall rep r op.
[Char]
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' [Char]
"ltid_pre" [TPrimExp Int64 VName]
dims' (TPrimExp t VName -> TPrimExp Int64 VName
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp t VName
ltid)
      ([SubExp], [TExp Int32]) -> ImpM rep r op ([SubExp], [TExp Int32])
forall a. a -> ImpM rep r op a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, (TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall {k} (t :: k) v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
ids')