{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.CodeGen.ImpGen.GPU.Base
  ( KernelConstants (..),
    keyWithEntryPoint,
    CallKernelGen,
    InKernelGen,
    Locks (..),
    HostEnv (..),
    Target (..),
    KernelEnv (..),
    computeThreadChunkSize,
    groupReduce,
    groupScan,
    isActive,
    sKernelThread,
    sKernelGroup,
    KernelAttrs (..),
    defKernelAttrs,
    sReplicate,
    sIota,
    sRotateKernel,
    sCopy,
    compileThreadResult,
    compileGroupResult,
    virtualiseGroups,
    kernelLoop,
    groupCoverSpace,
    Precomputed,
    precomputeConstants,
    precomputedConstants,
    atomicUpdateLocking,
    AtomicBinOp,
    Locking (..),
    AtomicUpdate (..),
    DoAtomicUpdate,
  )
where

import Control.Monad.Except
import Data.Bifunctor
import Data.List (foldl', partition, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.CodeGen.ImpCode.GPU as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Construct (fullSliceNum)
import Futhark.Error
import Futhark.IR.GPUMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (chunks, dropLast, mapAccumLM, nubOrd, splitFromEnd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)

-- | Which target are we ultimately generating code for?  While most
-- of the kernels code is the same, there are some cases where we
-- generate special code based on the ultimate low-level API we are
-- targeting.
data Target = CUDA | OpenCL

-- | Information about the locks available for accumulators.
data Locks = Locks
  { Locks -> VName
locksArray :: VName,
    Locks -> Int
locksCount :: Int
  }

data HostEnv = HostEnv
  { HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp,
    HostEnv -> Target
hostTarget :: Target,
    HostEnv -> Map VName Locks
hostLocks :: M.Map VName Locks
  }

data KernelEnv = KernelEnv
  { KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp,
    KernelEnv -> KernelConstants
kernelConstants :: KernelConstants,
    KernelEnv -> Map VName Locks
kernelLocks :: M.Map VName Locks
  }

type CallKernelGen = ImpM GPUMem HostEnv Imp.HostOp

type InKernelGen = ImpM GPUMem KernelEnv Imp.KernelOp

data KernelConstants = KernelConstants
  { KernelConstants -> TExp Int32
kernelGlobalThreadId :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelLocalThreadId :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelGroupId :: Imp.TExp Int32,
    KernelConstants -> VName
kernelGlobalThreadIdVar :: VName,
    KernelConstants -> VName
kernelLocalThreadIdVar :: VName,
    KernelConstants -> VName
kernelGroupIdVar :: VName,
    KernelConstants -> Count NumGroups SubExp
kernelNumGroupsCount :: Count NumGroups SubExp,
    KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount :: Count GroupSize SubExp,
    KernelConstants -> TPrimExp Int64 VName
kernelNumGroups :: Imp.TExp Int64,
    KernelConstants -> TPrimExp Int64 VName
kernelGroupSize :: Imp.TExp Int64,
    KernelConstants -> TExp Int32
kernelNumThreads :: Imp.TExp Int32,
    KernelConstants -> TExp Int32
kernelWaveSize :: Imp.TExp Int32,
    -- | A mapping from dimensions of nested SegOps to already
    -- computed local thread IDs.  Only valid in non-virtualised case.
    KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap :: M.Map [SubExp] [Imp.TExp Int32],
    -- | Mapping from dimensions of nested SegOps to how many
    -- iterations the virtualisation loop needs.
    KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

-- | The sizes of nested iteration spaces in the kernel.
type SegOpSizes = S.Set [SubExp]

-- | Find the sizes of nested parallelism in a t'SegOp' body.
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes :: Stms GPUMem -> SegOpSizes
segOpSizes = Stms GPUMem -> SegOpSizes
onStms
  where
    onStms :: Stms GPUMem -> SegOpSizes
onStms = (Stm GPUMem -> SegOpSizes) -> Stms GPUMem -> SegOpSizes
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp GPUMem -> SegOpSizes
onExp (Exp GPUMem -> SegOpSizes)
-> (Stm GPUMem -> Exp GPUMem) -> Stm GPUMem -> SegOpSizes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPUMem -> Exp GPUMem
forall rep. Stm rep -> Exp rep
stmExp)
    onExp :: Exp GPUMem -> SegOpSizes
onExp (Op (Inner (SegOp SegOp SegLevel GPUMem
op))) =
      case SegLevel -> SegVirt
segVirt (SegLevel -> SegVirt) -> SegLevel -> SegVirt
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegLevel
forall lvl rep. SegOp lvl rep -> lvl
segLevel SegOp SegLevel GPUMem
op of
        SegNoVirtFull SegSeqDims
seq_dims ->
          [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. (a, b) -> b
snd (([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)])
-> ([(VName, SubExp)], [(VName, SubExp)]) -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims (SegSpace -> ([(VName, SubExp)], [(VName, SubExp)]))
-> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
        SegVirt
_ -> [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace (SegSpace -> [(VName, SubExp)]) -> SegSpace -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPUMem -> SegSpace
forall lvl rep. SegOp lvl rep -> SegSpace
segSpace SegOp SegLevel GPUMem
op
    onExp (BasicOp (Replicate Shape
shape SubExp
_)) =
      [SubExp] -> SegOpSizes
forall a. a -> Set a
S.singleton ([SubExp] -> SegOpSizes) -> [SubExp] -> SegOpSizes
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
    onExp (Match [SubExp]
_ [Case (Body GPUMem)]
cases Body GPUMem
defbody MatchDec (BranchType GPUMem)
_) =
      (Case (Body GPUMem) -> SegOpSizes)
-> [Case (Body GPUMem)] -> SegOpSizes
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Stms GPUMem -> SegOpSizes
onStms (Stms GPUMem -> SegOpSizes)
-> (Case (Body GPUMem) -> Stms GPUMem)
-> Case (Body GPUMem)
-> SegOpSizes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms (Body GPUMem -> Stms GPUMem)
-> (Case (Body GPUMem) -> Body GPUMem)
-> Case (Body GPUMem)
-> Stms GPUMem
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPUMem) -> Body GPUMem
forall body. Case body -> body
caseBody) [Case (Body GPUMem)]
cases SegOpSizes -> SegOpSizes -> SegOpSizes
forall a. Semigroup a => a -> a -> a
<> Stms GPUMem -> SegOpSizes
onStms (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
defbody)
    onExp (DoLoop [(FParam GPUMem, SubExp)]
_ LoopForm GPUMem
_ Body GPUMem
body) =
      Stms GPUMem -> SegOpSizes
onStms (Body GPUMem -> Stms GPUMem
forall rep. Body rep -> Stms rep
bodyStms Body GPUMem
body)
    onExp Exp GPUMem
_ = SegOpSizes
forall a. Monoid a => a
mempty

-- | Various useful precomputed information.
data Precomputed = Precomputed
  { Precomputed -> SegOpSizes
pcSegOpSizes :: SegOpSizes,
    Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap :: M.Map [SubExp] (Imp.TExp Int32)
  }

-- | Precompute various constants and useful information.
precomputeConstants :: Count GroupSize (Imp.TExp Int64) -> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants :: Count GroupSize (TPrimExp Int64 VName)
-> Stms GPUMem -> CallKernelGen Precomputed
precomputeConstants Count GroupSize (TPrimExp Int64 VName)
group_size Stms GPUMem
stms = do
  let sizes :: SegOpSizes
sizes = Stms GPUMem -> SegOpSizes
segOpSizes Stms GPUMem
stms
  Map [SubExp] (TExp Int32)
iters_map <- [([SubExp], TExp Int32)] -> Map [SubExp] (TExp Int32)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], TExp Int32)] -> Map [SubExp] (TExp Int32))
-> ImpM GPUMem HostEnv HostOp [([SubExp], TExp Int32)]
-> ImpM GPUMem HostEnv HostOp (Map [SubExp] (TExp Int32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32))
-> [[SubExp]]
-> ImpM GPUMem HostEnv HostOp [([SubExp], TExp Int32)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap (SegOpSizes -> [[SubExp]]
forall a. Set a -> [a]
S.toList SegOpSizes
sizes)
  Precomputed -> CallKernelGen Precomputed
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Precomputed -> CallKernelGen Precomputed)
-> Precomputed -> CallKernelGen Precomputed
forall a b. (a -> b) -> a -> b
$ SegOpSizes -> Map [SubExp] (TExp Int32) -> Precomputed
Precomputed SegOpSizes
sizes Map [SubExp] (TExp Int32)
iters_map
  where
    mkMap :: [SubExp] -> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
mkMap [SubExp]
dims = do
      let n :: TPrimExp Int64 VName
n = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 [SubExp]
dims
      TExp Int32
num_chunks <- String -> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_chunks" (TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem HostEnv HostOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` Count GroupSize (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
unCount Count GroupSize (TPrimExp Int64 VName)
group_size
      ([SubExp], TExp Int32)
-> ImpM GPUMem HostEnv HostOp ([SubExp], TExp Int32)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, TExp Int32
num_chunks)

-- | Make use of various precomputed constants.
precomputedConstants :: Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants :: forall a. Precomputed -> InKernelGen a -> InKernelGen a
precomputedConstants Precomputed
pre InKernelGen a
m = do
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  Map [SubExp] [TExp Int32]
new_ids <- [([SubExp], [TExp Int32])] -> Map [SubExp] [TExp Int32]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], [TExp Int32])] -> Map [SubExp] [TExp Int32])
-> ImpM GPUMem KernelEnv KernelOp [([SubExp], [TExp Int32])]
-> ImpM GPUMem KernelEnv KernelOp (Map [SubExp] [TExp Int32])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp]
 -> ImpM GPUMem KernelEnv KernelOp ([SubExp], [TExp Int32]))
-> [[SubExp]]
-> ImpM GPUMem KernelEnv KernelOp [([SubExp], [TExp Int32])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int32
-> [SubExp]
-> ImpM GPUMem KernelEnv KernelOp ([SubExp], [TExp Int32])
forall {t} {rep} {r} {op}.
IntExp t =>
TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TExp Int32
ltid) (SegOpSizes -> [[SubExp]]
forall a. Set a -> [a]
S.toList (Precomputed -> SegOpSizes
pcSegOpSizes Precomputed
pre))
  let f :: KernelEnv -> KernelEnv
f KernelEnv
env =
        KernelEnv
env
          { kernelConstants :: KernelConstants
kernelConstants =
              (KernelEnv -> KernelConstants
kernelConstants KernelEnv
env)
                { kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
new_ids,
                  kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Precomputed -> Map [SubExp] (TExp Int32)
pcChunkItersMap Precomputed
pre
                }
          }
  (KernelEnv -> KernelEnv) -> InKernelGen a -> InKernelGen a
forall r rep op a. (r -> r) -> ImpM rep r op a -> ImpM rep r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
  where
    mkMap :: TPrimExp t VName
-> [SubExp] -> ImpM rep r op ([SubExp], [TExp Int32])
mkMap TPrimExp t VName
ltid [SubExp]
dims = do
      let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      [TPrimExp Int64 VName]
ids' <- String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
forall rep r op.
String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' String
"ltid_pre" [TPrimExp Int64 VName]
dims' (TPrimExp t VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp t VName
ltid)
      ([SubExp], [TExp Int32]) -> ImpM rep r op ([SubExp], [TExp Int32])
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp]
dims, (TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
ids')

keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
  String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String -> (Name -> String) -> Maybe Name -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ((String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
".") (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString) Maybe Name
fname String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
key

allocLocal :: AllocCompiler GPUMem r Imp.KernelOp
allocLocal :: forall r. AllocCompiler GPUMem r KernelOp
allocLocal VName
mem Count Bytes (TPrimExp Int64 VName)
size =
  KernelOp -> ImpM GPUMem r KernelOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> ImpM GPUMem r KernelOp ())
-> KernelOp -> ImpM GPUMem r KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TPrimExp Int64 VName) -> KernelOp
Imp.LocalAlloc VName
mem Count Bytes (TPrimExp Int64 VName)
size

kernelAlloc ::
  Pat LetDecMem ->
  SubExp ->
  Space ->
  InKernelGen ()
kernelAlloc :: Pat LParamMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc (Pat [PatElem LParamMem
_]) SubExp
_ ScalarSpace {} =
  -- Handled by the declaration of the memory block, which is then
  -- translated to an actual scalar variable during C code generation.
  () -> InKernelGen ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
kernelAlloc (Pat [PatElem LParamMem
mem]) SubExp
size (Space String
"local") =
  AllocCompiler GPUMem KernelEnv KernelOp
forall r. AllocCompiler GPUMem r KernelOp
allocLocal (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
mem) (Count Bytes (TPrimExp Int64 VName) -> InKernelGen ())
-> Count Bytes (TPrimExp Int64 VName) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a. a -> Count Bytes a
Imp.bytes (TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> Count Bytes (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
size
kernelAlloc (Pat [PatElem LParamMem
mem]) SubExp
_ Space
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate memory block " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatElem LParamMem -> String
forall a. Pretty a => a -> String
pretty PatElem LParamMem
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in kernel."
kernelAlloc Pat LParamMem
dest SubExp
_ Space
_ =
  String -> InKernelGen ()
forall a. HasCallStack => String -> a
error (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for in-kernel allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat LParamMem -> String
forall a. Show a => a -> String
show Pat LParamMem
dest

splitSpace ::
  Pat LetDecMem ->
  SplitOrdering ->
  SubExp ->
  SubExp ->
  SubExp ->
  ImpM rep r op ()
splitSpace :: forall rep r op.
Pat LParamMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> ImpM rep r op ()
splitSpace (Pat [PatElem LParamMem
size]) SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread = do
  Count Elements (TPrimExp Int64 VName)
num_elements <- TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName))
-> (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName
-> Count Elements (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> Count Elements (TPrimExp Int64 VName))
-> ImpM rep r op (PrimExp VName)
-> ImpM rep r op (Count Elements (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ImpM rep r op (PrimExp VName)
forall a rep r op. ToExp a => a -> ImpM rep r op (PrimExp VName)
toExp SubExp
w
  let i' :: TPrimExp Int64 VName
i' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
i
  Count Elements (TPrimExp Int64 VName)
elems_per_thread' <- TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements (TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName))
-> (PrimExp VName -> TPrimExp Int64 VName)
-> PrimExp VName
-> Count Elements (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> Count Elements (TPrimExp Int64 VName))
-> ImpM rep r op (PrimExp VName)
-> ImpM rep r op (Count Elements (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> ImpM rep r op (PrimExp VName)
forall a rep r op. ToExp a => a -> ImpM rep r op (PrimExp VName)
toExp SubExp
elems_per_thread
  SplitOrdering
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> TV Int64
-> ImpM rep r op ()
forall rep r op.
SplitOrdering
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> TV Int64
-> ImpM rep r op ()
computeThreadChunkSize SplitOrdering
o TPrimExp Int64 VName
i' Count Elements (TPrimExp Int64 VName)
elems_per_thread' Count Elements (TPrimExp Int64 VName)
num_elements (VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
mkTV (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
size) PrimType
int64)
splitSpace Pat LParamMem
pat SplitOrdering
_ SubExp
_ SubExp
_ SubExp
_ =
  String -> ImpM rep r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM rep r op ()) -> String -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for splitSpace: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat LParamMem -> String
forall a. Pretty a => a -> String
pretty Pat LParamMem
pat

updateAcc :: VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc :: VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs = String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"UpdateAcc" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  -- See the ImpGen implementation of UpdateAcc for general notes.
  let is' :: [TPrimExp Int64 VName]
is' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
is
  (VName
c, Space
space, [VName]
arrs, [TPrimExp Int64 VName]
dims, Maybe (Lambda GPUMem)
op) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, [VName], [TPrimExp Int64 VName],
      Maybe (Lambda GPUMem))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep
     r
     op
     (VName, Space, [VName], [TPrimExp Int64 VName], Maybe (Lambda rep))
lookupAcc VName
acc [TPrimExp Int64 VName]
is'
  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName]
is')) [TPrimExp Int64 VName]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    case Maybe (Lambda GPUMem)
op of
      Maybe (Lambda GPUMem)
Nothing ->
        [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [SubExp]
vs) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
arr, SubExp
v) -> VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
is' SubExp
v []
      Just Lambda GPUMem
lam -> do
        [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
        let ([VName]
_x_params, [VName]
y_params) =
              Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
vs) ([VName] -> ([VName], [VName])) -> [VName] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> VName) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param LParamMem -> VName
forall dec. Param dec -> VName
paramName ([Param LParamMem] -> [VName]) -> [Param LParamMem] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
        [(VName, SubExp)]
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
y_params [SubExp]
vs) (((VName, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
yp, SubExp
v) -> VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
yp [] SubExp
v []
        AtomicBinOp
atomics <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
        case AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomics Lambda GPUMem
lam of
          AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f -> DoAtomicUpdate GPUMem KernelEnv
f Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
          AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f -> DoAtomicUpdate GPUMem KernelEnv
f Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
          AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f -> do
            Maybe Locks
c_locks <- VName -> Map VName Locks -> Maybe Locks
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
c (Map VName Locks -> Maybe Locks)
-> (KernelEnv -> Map VName Locks) -> KernelEnv -> Maybe Locks
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> Map VName Locks
kernelLocks (KernelEnv -> Maybe Locks)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Maybe Locks)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
            case Maybe Locks
c_locks of
              Just (Locks VName
locks Int
num_locks) -> do
                let locking :: Locking
locking =
                      VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]) -> Locking)
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]) -> Locking
forall a b. (a -> b) -> a -> b
$
                        TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` Int -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims
                Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
locking Space
space [VName]
arrs [TPrimExp Int64 VName]
is'
              Maybe Locks
Nothing ->
                String -> InKernelGen ()
forall a. HasCallStack => String -> a
error (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Missing locks for " String -> String -> String
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty VName
acc

compileThreadExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileThreadExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
pe) [] SubExp
se []
compileThreadExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase Shape NoUniqueness
_)) =
  [(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
dest) [Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileThreadExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) =
  VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
compileThreadExp Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  ExpCompiler GPUMem KernelEnv KernelOp
forall rep inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

-- | Assign iterations of a for-loop to all threads in the kernel.
-- The passed-in function is invoked with the (symbolic) iteration.
-- The body must contain thread-level code.  For multidimensional
-- loops, use 'groupCoverSpace'.
kernelLoop ::
  IntExp t =>
  Imp.TExp t ->
  Imp.TExp t ->
  Imp.TExp t ->
  (Imp.TExp t -> InKernelGen ()) ->
  InKernelGen ()
kernelLoop :: forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp t
tid TExp t
num_threads TExp t
n TExp t -> InKernelGen ()
f =
  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    if TExp t
n TExp t -> TExp t -> Bool
forall a. Eq a => a -> a -> Bool
== TExp t
num_threads
      then TExp t -> InKernelGen ()
f TExp t
tid
      else do
        TExp t
num_chunks <- String -> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_chunks" (TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t))
-> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
n TExp t -> TExp t -> TExp t
forall e. IntegralExp e => e -> e -> e
`divUp` TExp t
num_threads
        String -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"chunk_i" TExp t
num_chunks ((TExp t -> InKernelGen ()) -> InKernelGen ())
-> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp t
chunk_i -> do
          TExp t
i <- String -> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"i" (TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t))
-> TExp t -> ImpM GPUMem KernelEnv KernelOp (TExp t)
forall a b. (a -> b) -> a -> b
$ TExp t
chunk_i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
num_threads TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
tid
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp t
i TExp t -> TExp t -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp t
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp t -> InKernelGen ()
f TExp t
i

-- | Assign iterations of a for-loop to threads in the workgroup.  The
-- passed-in function is invoked with the (symbolic) iteration.  For
-- multidimensional loops, use 'groupCoverSpace'.
groupLoop ::
  IntExp t =>
  Imp.TExp t ->
  (Imp.TExp t -> InKernelGen ()) ->
  InKernelGen ()
groupLoop :: forall t.
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp t
n TExp t -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop
    (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp t -> TExp t
forall to from v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
n)
    (KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants TPrimExp Int64 VName -> TExp t -> TExp t
forall to from v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
n)
    TExp t
n
    TExp t -> InKernelGen ()
f

-- | Iterate collectively though a multidimensional space, such that
-- all threads in the group participate.  The passed-in function is
-- invoked with a (symbolic) point in the index space.
groupCoverSpace ::
  IntExp t =>
  [Imp.TExp t] ->
  ([Imp.TExp t] -> InKernelGen ()) ->
  InKernelGen ()
groupCoverSpace :: forall t.
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp t]
ds [TExp t] -> InKernelGen ()
f = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  case Int -> [TExp t] -> ([TExp t], [TExp t])
forall a. Int -> [a] -> ([a], [a])
splitFromEnd Int
1 [TExp t]
ds of
    -- Optimise the case where the inner dimension of the space is
    -- equal to the group size.
    ([TExp t]
ds', [TExp t
last_d])
      | TExp t
last_d TExp t -> TExp t -> Bool
forall a. Eq a => a -> a -> Bool
== (TPrimExp Int64 VName
group_size TPrimExp Int64 VName -> TExp t -> TExp t
forall to from v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
last_d) -> do
          let ltid :: TExp t
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp t -> TExp t
forall to from v.
(IntExp to, IntExp from) =>
TPrimExp from v -> TPrimExp to v -> TPrimExp to v
`sExtAs` TExp t
last_d
          [TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
[TExp t] -> ([TExp t] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopSpace [TExp t]
ds' (([TExp t] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp t]
ds_is ->
            [TExp t] -> InKernelGen ()
f ([TExp t] -> InKernelGen ()) -> [TExp t] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp t]
ds_is [TExp t] -> [TExp t] -> [TExp t]
forall a. [a] -> [a] -> [a]
++ [TExp t
ltid]
    ([TExp t], [TExp t])
_ ->
      TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop ([TExp t] -> TExp t
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp t]
ds) ((TExp t -> InKernelGen ()) -> InKernelGen ())
-> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp t] -> InKernelGen ()
f ([TExp t] -> InKernelGen ())
-> (TExp t -> [TExp t]) -> TExp t -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp t] -> TExp t -> [TExp t]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp t]
ds

localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims = do
  TPrimExp Int64 VName
ltid <- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> (KernelEnv -> TExp Int32) -> KernelEnv -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
  InKernelGen [TPrimExp Int64 VName]
-> ([TExp Int32] -> InKernelGen [TPrimExp Int64 VName])
-> Maybe [TExp Int32]
-> InKernelGen [TPrimExp Int64 VName]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' String
"ltid" [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
ltid) ([TPrimExp Int64 VName] -> InKernelGen [TPrimExp Int64 VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([TPrimExp Int64 VName] -> InKernelGen [TPrimExp Int64 VName])
-> ([TExp Int32] -> [TPrimExp Int64 VName])
-> [TExp Int32]
-> InKernelGen [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int32 -> TPrimExp Int64 VName)
-> [TExp Int32] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64)
    (Maybe [TExp Int32] -> InKernelGen [TPrimExp Int64 VName])
-> (KernelEnv -> Maybe [TExp Int32])
-> KernelEnv
-> InKernelGen [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Map [SubExp] [TExp Int32] -> Maybe [TExp Int32]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims
    (Map [SubExp] [TExp Int32] -> Maybe [TExp Int32])
-> (KernelEnv -> Map [SubExp] [TExp Int32])
-> KernelEnv
-> Maybe [TExp Int32]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap
    (KernelConstants -> Map [SubExp] [TExp Int32])
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] [TExp Int32]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants
    (KernelEnv -> InKernelGen [TPrimExp Int64 VName])
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> InKernelGen [TPrimExp Int64 VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims :: SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims (SegSeqDims [Int]
seq_is) SegSpace
space =
  ([((VName, SubExp), Int)] -> [(VName, SubExp)])
-> ([((VName, SubExp), Int)] -> [(VName, SubExp)])
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)])
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap ((((VName, SubExp), Int) -> (VName, SubExp))
-> [((VName, SubExp), Int)] -> [(VName, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, SubExp), Int) -> (VName, SubExp)
forall a b. (a, b) -> a
fst) ((((VName, SubExp), Int) -> (VName, SubExp))
-> [((VName, SubExp), Int)] -> [(VName, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName, SubExp), Int) -> (VName, SubExp)
forall a b. (a, b) -> a
fst) (([((VName, SubExp), Int)], [((VName, SubExp), Int)])
 -> ([(VName, SubExp)], [(VName, SubExp)]))
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
-> ([(VName, SubExp)], [(VName, SubExp)])
forall a b. (a -> b) -> a -> b
$
    (((VName, SubExp), Int) -> Bool)
-> [((VName, SubExp), Int)]
-> ([((VName, SubExp), Int)], [((VName, SubExp), Int)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
partition ((Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int]
seq_is) (Int -> Bool)
-> (((VName, SubExp), Int) -> Int)
-> ((VName, SubExp), Int)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((VName, SubExp), Int) -> Int
forall a b. (a, b) -> b
snd) ([(VName, SubExp)] -> [Int] -> [((VName, SubExp), Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) [Int
0 ..])

groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace :: SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
virt SegSpace
space InKernelGen ()
m = do
  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  -- Maybe we can statically detect that this is actually a
  -- SegNoVirtFull and generate ever-so-slightly simpler code.
  let virt' :: SegVirt
virt' = if [TPrimExp Int64 VName]
dims' [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
group_size] then SegSeqDims -> SegVirt
SegNoVirtFull ([Int] -> SegSeqDims
SegSeqDims []) else SegVirt
virt
  case SegVirt
virt' of
    SegVirt
SegVirt -> do
      Maybe (TExp Int32)
iters <- [SubExp] -> Map [SubExp] (TExp Int32) -> Maybe (TExp Int32)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims (Map [SubExp] (TExp Int32) -> Maybe (TExp Int32))
-> (KernelEnv -> Map [SubExp] (TExp Int32))
-> KernelEnv
-> Maybe (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] (TExp Int32)
kernelChunkItersMap (KernelConstants -> Map [SubExp] (TExp Int32))
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] (TExp Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Maybe (TExp Int32))
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (Maybe (TExp Int32))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      case Maybe (TExp Int32)
iters of
        Maybe (TExp Int32)
Nothing -> do
          TExp Int32
iterations <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"iterations" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int32] -> TExp Int32) -> [TExp Int32] -> TExp Int32
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims'
          TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp Int32
iterations ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
            [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            InKernelGen ()
m
        Just TExp Int32
num_chunks -> do
          let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
          String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"chunk_i" TExp Int32
num_chunks ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
            TExp Int32
i <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"i" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
ltid
            [(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> InKernelGen ()
forall rep r op.
[(VName, TPrimExp Int64 VName)]
-> TPrimExp Int64 VName -> ImpM rep r op ()
dIndexSpace ([VName]
-> [TPrimExp Int64 VName] -> [(VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [TPrimExp Int64 VName]
dims') (TPrimExp Int64 VName -> InKernelGen ())
-> TPrimExp Int64 VName -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((VName -> DimIndex (TPrimExp Int64 VName))
-> [VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (VName -> TPrimExp Int64 VName)
-> VName
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64) [VName]
ltids)) [TPrimExp Int64 VName]
dims') InKernelGen ()
m
    SegVirt
SegNoVirt -> Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids ([TPrimExp Int64 VName] -> InKernelGen ())
-> InKernelGen [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ltids [SubExp]
dims) InKernelGen ()
m
    SegNoVirtFull SegSeqDims
seq_dims -> do
      let (([VName]
ltids_seq, [SubExp]
dims_seq), ([VName]
ltids_par, [SubExp]
dims_par)) =
            ([(VName, SubExp)] -> ([VName], [SubExp]))
-> ([(VName, SubExp)] -> ([VName], [SubExp]))
-> ([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp]))
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip (([(VName, SubExp)], [(VName, SubExp)])
 -> (([VName], [SubExp]), ([VName], [SubExp])))
-> ([(VName, SubExp)], [(VName, SubExp)])
-> (([VName], [SubExp]), ([VName], [SubExp]))
forall a b. (a -> b) -> a -> b
$ SegSeqDims -> SegSpace -> ([(VName, SubExp)], [(VName, SubExp)])
partitionSeqDims SegSeqDims
seq_dims SegSpace
space
      Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims_seq) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_seq -> do
        (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_seq [TPrimExp Int64 VName]
is_seq
        Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          (VName -> TPrimExp Int64 VName -> InKernelGen ())
-> [VName] -> [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TPrimExp Int64 VName -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ [VName]
ltids_par ([TPrimExp Int64 VName] -> InKernelGen ())
-> InKernelGen [TPrimExp Int64 VName] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs [SubExp]
dims_par
          InKernelGen ()
m

compileGroupExp :: ExpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Opaque OpaqueOp
_ SubExp
se)) =
  -- Cannot print in GPU code.
  VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
pe) [] SubExp
se []
-- The static arrays stuff does not work inside kernels.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase Shape NoUniqueness
_)) =
  [(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
dest) [Int64 -> TPrimExp Int64 VName
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileGroupExp Pat (LetDec GPUMem)
_ (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
vs)) =
  VName -> [SubExp] -> [SubExp] -> InKernelGen ()
updateAcc VName
acc [SubExp]
is [SubExp]
vs
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Replicate Shape
ds SubExp
se)) = do
  VName
flat <- String -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"rep_flat"
  [VName]
is <- Int
-> ImpM GPUMem KernelEnv KernelOp VName
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
ds) (String -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"rep_i")
  let is' :: [TPrimExp Int64 VName]
is' = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace SegVirt
SegVirt (VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat ([(VName, SubExp)] -> SegSpace) -> [(VName, SubExp)] -> SegSpace
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
dest) [TPrimExp Int64 VName]
is' SubExp
se []
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Rotate [SubExp]
rs VName
arr)) = do
  [TPrimExp Int64 VName]
ds <- (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [TPrimExp Int64 VName])
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
-> InKernelGen [TPrimExp Int64 VName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
  [TPrimExp Int64 VName]
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TPrimExp Int64 VName]
ds (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
    [TPrimExp Int64 VName]
is' <- [ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
-> InKernelGen [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
 -> InKernelGen [TPrimExp Int64 VName])
-> [ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
-> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> SubExp
 -> TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName]
-> [SubExp]
-> [TPrimExp Int64 VName]
-> [ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {rep} {r} {op}.
TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate [TPrimExp Int64 VName]
ds [SubExp]
rs [TPrimExp Int64 VName]
is
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
dest) [TPrimExp Int64 VName]
is (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
is'
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    rotate :: TPrimExp Int64 VName
-> SubExp
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate TPrimExp Int64 VName
d SubExp
r TPrimExp Int64 VName
i = String
-> TPrimExp Int64 VName -> ImpM rep r op (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"rot_i" (TPrimExp Int64 VName -> ImpM rep r op (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> ImpM rep r op (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
rotateIndex TPrimExp Int64 VName
d (SubExp -> TPrimExp Int64 VName
pe64 SubExp
r) TPrimExp Int64 VName
i
compileGroupExp (Pat [PatElem (LetDec GPUMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
  PrimExp VName
n' <- SubExp -> ImpM GPUMem KernelEnv KernelOp (PrimExp VName)
forall a rep r op. ToExp a => a -> ImpM rep r op (PrimExp VName)
toExp SubExp
n
  PrimExp VName
e' <- SubExp -> ImpM GPUMem KernelEnv KernelOp (PrimExp VName)
forall a rep r op. ToExp a => a -> ImpM rep r op (PrimExp VName)
toExp SubExp
e
  PrimExp VName
s' <- SubExp -> ImpM GPUMem KernelEnv KernelOp (PrimExp VName)
forall a rep r op. ToExp a => a -> ImpM rep r op (PrimExp VName)
toExp SubExp
s
  TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
groupLoop (PrimExp VName -> TPrimExp Int64 VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp PrimExp VName
n') ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i' -> do
    TV Any
x <-
      String -> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"x" (TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$
        PrimExp VName -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TExp Any) -> PrimExp VName -> TExp Any
forall a b. (a -> b) -> a -> b
$
          BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) PrimExp VName
e' (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
            BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
i') PrimExp VName
s'
    VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
dest) [TPrimExp Int64 VName
i'] (VName -> SubExp
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

-- When generating code for a scalar in-place update, we must make
-- sure that only one thread performs the write.  When writing an
-- array, the group-level copy code will take care of doing the right
-- thing.
compileGroupExp (Pat [PatElem (LetDec GPUMem)
pe]) (BasicOp (Update Safety
safety VName
_ Slice SubExp
slice SubExp
se))
  | [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        case Safety
safety of
          Safety
Unsafe -> InKernelGen ()
write
          Safety
Safe -> TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
dims) InKernelGen ()
write
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
  where
    slice' :: Slice (TPrimExp Int64 VName)
slice' = (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
    dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatElem LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => PatElem dec -> TypeBase Shape NoUniqueness
patElemType PatElem (LetDec GPUMem)
PatElem LParamMem
pe
    write :: InKernelGen ()
write = VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem (LetDec GPUMem)
PatElem LParamMem
pe) (Slice (TPrimExp Int64 VName) -> [DimIndex (TPrimExp Int64 VName)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
se []
compileGroupExp Pat (LetDec GPUMem)
dest Exp GPUMem
e =
  ExpCompiler GPUMem KernelEnv KernelOp
forall rep inner r op.
Mem rep inner =>
Pat (LetDec rep) -> Exp rep -> ImpM rep r op ()
defCompileExp Pat (LetDec GPUMem)
dest Exp GPUMem
e

sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel SegThread {} = () -> InKernelGen ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
sanityCheckLevel SegGroup {} =
  String -> InKernelGen ()
forall a. HasCallStack => String -> a
error String
"compileGroupOp: unexpected group-level SegOp."

compileFlatId :: SegLevel -> SegSpace -> InKernelGen ()
compileFlatId :: SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space = do
  SegLevel -> InKernelGen ()
sanityCheckLevel SegLevel
lvl
  TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  VName -> TExp Int32 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) TExp Int32
ltid

-- Construct the necessary lock arrays for an intra-group histogram.
prepareIntraGroupSegHist ::
  Count GroupSize SubExp ->
  [HistOp GPUMem] ->
  InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
  ((Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
 -> [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
   GPUMem
   KernelEnv
   KernelOp
   (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
 -> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()])
-> ([HistOp GPUMem]
    -> ImpM
         GPUMem
         KernelEnv
         KernelOp
         (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()]))
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
 -> HistOp GPUMem
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp GPUMem]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [[TPrimExp Int64 VName] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
  where
    onOp :: Maybe Locking
-> HistOp GPUMem
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
onOp Maybe Locking
l HistOp GPUMem
op = do
      KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

      let local_subhistos :: [VName]
local_subhistos = HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest HistOp GPUMem
op

      case (Maybe Locking
l, AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp (Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv)
-> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp GPUMem -> Lambda GPUMem
forall rep. HistOp rep -> Lambda rep
histOp HistOp GPUMem
op) of
        (Maybe Locking
_, AtomicPrim DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Maybe Locking
_, AtomicCAS DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, DoAtomicUpdate GPUMem KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe Locking
l, Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
        (Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate GPUMem KernelEnv
f) -> do
          VName
locks <- String -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"locks"

          let num_locks :: TPrimExp Int64 VName
num_locks = SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
              dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histOpShape HistOp GPUMem
op Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> HistOp GPUMem -> Shape
forall rep. HistOp rep -> Shape
histShape HistOp GPUMem
op)
              l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TPrimExp Int64 VName -> [TPrimExp Int64 VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> [TPrimExp Int64 VName])
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
num_locks) (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TPrimExp Int64 VName]
dims)
              locks_t :: TypeBase Shape NoUniqueness
locks_t = PrimType -> Shape -> NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness

          VName
locks_mem <- String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String
-> Count Bytes (TPrimExp Int64 VName)
-> Space
-> ImpM rep r op VName
sAlloc String
"locks_mem" (TypeBase Shape NoUniqueness -> Count Bytes (TPrimExp Int64 VName)
typeSize TypeBase Shape NoUniqueness
locks_t) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
          VName -> PrimType -> Shape -> VName -> IxFun -> InKernelGen ()
forall rep r op.
VName -> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op ()
dArray VName
locks PrimType
int32 (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
locks_t) VName
locks_mem (IxFun -> InKernelGen ()) -> IxFun -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> (TypeBase Shape NoUniqueness -> [TPrimExp Int64 VName])
-> TypeBase Shape NoUniqueness
-> IxFun
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness
-> [TPrimExp Int64 VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> IxFun)
-> TypeBase Shape NoUniqueness -> IxFun
forall a b. (a -> b) -> a -> b
$
              TypeBase Shape NoUniqueness
locks_t

          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [TPrimExp Int64 VName]
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants] (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is ->
              VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
locks [TPrimExp Int64 VName]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []

          (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (Maybe Locking, [TPrimExp Int64 VName] -> InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate GPUMem KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)

-- Which fence do we need to protect shared access to this memory space?
fenceForSpace :: Space -> Imp.Fence
fenceForSpace :: Space -> Fence
fenceForSpace (Space String
"local") = Fence
Imp.FenceLocal
fenceForSpace Space
_ = Fence
Imp.FenceGlobal

-- If we are touching these arrays, which kind of fence do we need?
fenceForArrays :: [VName] -> InKernelGen Imp.Fence
fenceForArrays :: [VName] -> InKernelGen Fence
fenceForArrays = ([Fence] -> Fence)
-> ImpM GPUMem KernelEnv KernelOp [Fence] -> InKernelGen Fence
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Fence -> Fence -> Fence) -> Fence -> [Fence] -> Fence
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Fence -> Fence -> Fence
forall a. Ord a => a -> a -> a
max Fence
Imp.FenceLocal) (ImpM GPUMem KernelEnv KernelOp [Fence] -> InKernelGen Fence)
-> ([VName] -> ImpM GPUMem KernelEnv KernelOp [Fence])
-> [VName]
-> InKernelGen Fence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> InKernelGen Fence)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [Fence]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> InKernelGen Fence
forall {rep} {r} {op}. VName -> ImpM rep r op Fence
need
  where
    need :: VName -> ImpM rep r op Fence
need VName
arr =
      (MemEntry -> Fence)
-> ImpM rep r op MemEntry -> ImpM rep r op Fence
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Space -> Fence
fenceForSpace (Space -> Fence) -> (MemEntry -> Space) -> MemEntry -> Fence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace)
        (ImpM rep r op MemEntry -> ImpM rep r op Fence)
-> (ArrayEntry -> ImpM rep r op MemEntry)
-> ArrayEntry
-> ImpM rep r op Fence
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ImpM rep r op MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory
        (VName -> ImpM rep r op MemEntry)
-> (ArrayEntry -> VName) -> ArrayEntry -> ImpM rep r op MemEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemLoc -> VName
memLocName
        (MemLoc -> VName) -> (ArrayEntry -> MemLoc) -> ArrayEntry -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArrayEntry -> MemLoc
entryArrayLoc
        (ArrayEntry -> ImpM rep r op Fence)
-> ImpM rep r op ArrayEntry -> ImpM rep r op Fence
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr

groupChunkLoop ::
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> TV Int64 -> InKernelGen ()) ->
  InKernelGen ()
groupChunkLoop :: TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w TExp Int32 -> TV Int64 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let max_chunk_size :: TExp Int32
max_chunk_size = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
  TExp Int32
num_chunks <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"num_chunks" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
w TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int32
max_chunk_size
  String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"chunk_i" TExp Int32
num_chunks ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_i -> do
    TExp Int32
chunk_start <-
      String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"chunk_start" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
max_chunk_size
    TExp Int32
chunk_end <-
      String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"chunk_end" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Int32
forall v. TPrimExp Int32 v -> TPrimExp Int32 v -> TPrimExp Int32 v
sMin32 TExp Int32
w (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
max_chunk_size)
    TV Int64
chunk_size <-
      String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"chunk_size" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TExp Int32
chunk_end TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
chunk_start
    TExp Int32 -> TV Int64 -> InKernelGen ()
m TExp Int32
chunk_start TV Int64
chunk_size

sliceArray :: Imp.TExp Int64 -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray :: forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray TPrimExp Int64 VName
start TV Int64
size VName
arr = do
  MemLoc VName
mem [SubExp]
_ IxFun
ixfun <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM rep r op ArrayEntry -> ImpM rep r op MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase Shape NoUniqueness
arr_t <- VName -> ImpM rep r op (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr
  let slice :: Slice (TPrimExp Int64 VName)
slice =
        [TPrimExp Int64 VName]
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall d. Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum
          ((SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
Imp.pe64 (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
arr_t))
          [TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
start (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
size) TPrimExp Int64 VName
1]
  String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
forall rep r op.
String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray
    (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_chunk")
    (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
arr_t)
    (TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
arr_t Shape -> SubExp -> Shape
forall d. ShapeBase d -> d -> ShapeBase d
`setOuterDim` VName -> SubExp
Var (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
size))
    VName
mem
    (IxFun -> ImpM rep r op VName) -> IxFun -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$ IxFun -> Slice (TPrimExp Int64 VName) -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
IxFun.slice IxFun
ixfun Slice (TPrimExp Int64 VName)
slice

-- | @flattenArray k flat arr@ flattens the outer @k@ dimensions of
-- @arr@ to @flat@.  (Make sure @flat@ is the sum of those dimensions
-- or you'll have a bad time.)
flattenArray :: Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray :: forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray Int
k TV Int64
flat VName
arr = do
  ArrayEntry MemLoc
arr_loc PrimType
pt <- VName -> ImpM rep r op ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  let flat_shape :: Shape
flat_shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> [SubExp] -> Shape
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
flat) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
k (MemLoc -> [SubExp]
memLocShape MemLoc
arr_loc)
  String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
forall rep r op.
String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_flat") PrimType
pt Shape
flat_shape (MemLoc -> VName
memLocName MemLoc
arr_loc) (IxFun -> ImpM rep r op VName) -> IxFun -> ImpM rep r op VName
forall a b. (a -> b) -> a -> b
$
    IxFun -> [TPrimExp Int64 VName] -> IxFun
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
IxFun.reshape (MemLoc -> IxFun
memLocIxFun MemLoc
arr_loc) ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$
      (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
        Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
flat_shape

-- | @applyLambda lam dests args@ emits code that:
--
-- 1. Binds each parameter of @lam@ to the corresponding element of
--    @args@, interpreted as a (name,slice) pair (as in 'copyDWIM').
--    Use an empty list for a scalar.
--
-- 2. Executes the body of @lam@.
--
-- 3. Binds the t'SubExp's that are the 'Result' of @lam@ to the
-- provided @dest@s, again interpreted as the destination for a
-- 'copyDWIM'.
applyLambda ::
  Mem rep inner =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyLambda :: forall rep inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  [LParam rep] -> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam rep] -> ImpM rep r op ())
-> [LParam rep] -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
  [(Param LParamMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))]
-> ((Param LParamMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(Param LParamMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda rep -> [LParam rep]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args) (((Param LParamMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> ((Param LParamMem, (SubExp, [DimIndex (TPrimExp Int64 VName)]))
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, (SubExp
arg, [DimIndex (TPrimExp Int64 VName)]
arg_slice)) ->
    VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
arg [DimIndex (TPrimExp Int64 VName)]
arg_slice
  Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (Body rep -> Stms rep
forall rep. Body rep -> Stms rep
bodyStms (Body rep -> Stms rep) -> Body rep -> Stms rep
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam) (ImpM rep r op () -> ImpM rep r op ())
-> ImpM rep r op () -> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ do
    let res :: [SubExp]
res = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp]) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body rep -> [SubExpRes]) -> Body rep -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda rep -> Body rep
forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam
    [((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)]
-> (((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [SubExp]
-> [((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [SubExp]
res) ((((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
  -> ImpM rep r op ())
 -> ImpM rep r op ())
-> (((VName, [DimIndex (TPrimExp Int64 VName)]), SubExp)
    -> ImpM rep r op ())
-> ImpM rep r op ()
forall a b. (a -> b) -> a -> b
$ \((VName
dest, [DimIndex (TPrimExp Int64 VName)]
dest_slice), SubExp
se) ->
      VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
dest [DimIndex (TPrimExp Int64 VName)]
dest_slice SubExp
se []

-- | As applyLambda, but first rename the names in the lambda.  This
-- makes it safe to apply it in multiple places.  (It might be safe
-- anyway, but you have to be more careful - use this if you are in
-- doubt.)
applyRenamedLambda ::
  Mem rep inner =>
  Lambda rep ->
  [(VName, [DimIndex (Imp.TExp Int64)])] ->
  [(SubExp, [DimIndex (Imp.TExp Int64)])] ->
  ImpM rep r op ()
applyRenamedLambda :: forall rep inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda Lambda rep
lam [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args = do
  Lambda rep
lam_renamed <- Lambda rep -> ImpM rep r op (Lambda rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
lam
  Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
forall rep inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyLambda Lambda rep
lam_renamed [(VName, [DimIndex (TPrimExp Int64 VName)])]
dests [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
args

virtualisedGroupScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
virtualisedGroupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop TExp Int32
w ((TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
    KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
    let ltid :: TExp Int32
ltid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
        crosses_segment :: TExp Bool
crosses_segment =
          case Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag of
            Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing -> TExp Bool
forall v. TPrimExp Bool v
false
            Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true ->
              TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) (TExp Int32 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int32
chunk_start)
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"possibly incorporate carry" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
crosses_segment) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        TPrimExp Int64 VName
carry_idx <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"carry_idx" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1
        Lambda GPUMem
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
          Lambda GPUMem
lam
          ([VName]
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs ([[DimIndex (TPrimExp Int64 VName)]]
 -> [(VName, [DimIndex (TPrimExp Int64 VName)])])
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)]
-> [[DimIndex (TPrimExp Int64 VName)]]
forall a. a -> [a]
repeat [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
          ( [SubExp]
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs) ([DimIndex (TPrimExp Int64 VName)]
-> [[DimIndex (TPrimExp Int64 VName)]]
forall a. a -> [a]
repeat [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
carry_idx])
              [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a. [a] -> [a] -> [a]
++ [SubExp]
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
arrs) ([DimIndex (TPrimExp Int64 VName)]
-> [[DimIndex (TPrimExp Int64 VName)]]
forall a. a -> [a]
repeat [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
          )

    [VName]
arrs_chunks <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int64 VName
-> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
arrs

    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

    Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag
      (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
w)
      (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_size)
      Lambda GPUMem
lam
      [VName]
arrs_chunks

compileGroupOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LParamMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
Pat LParamMem
pat SubExp
size Space
space
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread))) =
  Pat LParamMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall rep r op.
Pat LParamMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> ImpM rep r op ()
splitSpace Pat (LetDec GPUMem)
Pat LParamMem
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (PatElem LParamMem -> KernelResult -> InKernelGen ())
-> [PatElem LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LParamMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
scans [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space

  let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat LParamMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LParamMem
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix
          VName
dest
          ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids)
          (KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
          []

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays ([VName] -> InKernelGen Fence) -> [VName] -> InKernelGen Fence
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LParamMem
pat
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
fence

  let segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
      crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
        (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName
segment_size)

  -- groupScan needs to treat the scan output as a one-dimensional
  -- array of scan elements, so we invent some new flattened arrays
  -- here.
  TV Int64
dims_flat <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
  let scan :: SegBinOp GPUMem
scan = [SegBinOp GPUMem] -> SegBinOp GPUMem
forall a. [a] -> a
head [SegBinOp GPUMem]
scans
      num_scan_results :: Int
num_scan_results = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral SegBinOp GPUMem
scan
  [VName]
arrs_flat <-
    (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) ([VName] -> ImpM GPUMem KernelEnv KernelOp [VName])
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$
      Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take Int
num_scan_results ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$
        Pat LParamMem -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat (LetDec GPUMem)
Pat LParamMem
pat

  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
        ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dims_flat)
        (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
    SegVirt
_ ->
      Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
        ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
        ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
        (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
scan)
        [VName]
arrs_flat
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp GPUMem]
ops [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
body))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space

  let dims' :: [TPrimExp Int64 VName]
dims' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims
      mkTempArr :: TypeBase Shape NoUniqueness -> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr TypeBase Shape NoUniqueness
t =
        String
-> PrimType
-> Shape
-> Space
-> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
String -> PrimType -> Shape -> Space -> ImpM rep r op VName
sAllocArray String
"red_arr" (TypeBase Shape NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase Shape NoUniqueness
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> TypeBase Shape NoUniqueness -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase Shape NoUniqueness
t) (Space -> ImpM GPUMem KernelEnv KernelOp VName)
-> Space -> ImpM GPUMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"

  [VName]
tmp_arrs <- (TypeBase Shape NoUniqueness
 -> ImpM GPUMem KernelEnv KernelOp VName)
-> [TypeBase Shape NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase Shape NoUniqueness -> ImpM GPUMem KernelEnv KernelOp VName
mkTempArr ([TypeBase Shape NoUniqueness]
 -> ImpM GPUMem KernelEnv KernelOp [VName])
-> [TypeBase Shape NoUniqueness]
-> ImpM GPUMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegBinOp GPUMem -> [TypeBase Shape NoUniqueness])
-> [SegBinOp GPUMem] -> [TypeBase Shape NoUniqueness]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType (Lambda GPUMem -> [TypeBase Shape NoUniqueness])
-> (SegBinOp GPUMem -> Lambda GPUMem)
-> SegBinOp GPUMem
-> [TypeBase Shape NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda) [SegBinOp GPUMem]
ops
  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) =
            Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
body
      [(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
      (PatElem LParamMem -> KernelResult -> InKernelGen ())
-> [PatElem LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
map_pes [KernelResult]
map_res

  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

  let tmps_for_ops :: [[VName]]
tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp GPUMem -> Int) -> [SegBinOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp GPUMem -> [SubExp]) -> SegBinOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp GPUMem -> [SubExp]
forall rep. SegBinOp rep -> [SubExp]
segBinOpNeutral) [SegBinOp GPUMem]
ops) [VName]
tmp_arrs
  case SegLevel -> SegVirt
segVirt SegLevel
lvl of
    SegVirt
SegVirt -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
    SegVirt
_ -> [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops
  where
    ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
    ([PatElem LParamMem]
red_pes, [PatElem LParamMem]
map_pes) = Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp GPUMem] -> Int
forall rep. [SegBinOp rep] -> Int
segBinOpResults [SegBinOp GPUMem]
ops) ([PatElem LParamMem] -> ([PatElem LParamMem], [PatElem LParamMem]))
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LParamMem
pat

    virtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
virtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
      TExp Int32
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
groupChunkLoop (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') ((TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> TV Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
chunk_start TV Int64
chunk_size -> do
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"possibly incorporate carry" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
chunk_start TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. TExp Int32
0 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
              Lambda GPUMem
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
Lambda rep
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> ImpM rep r op ()
applyRenamedLambda
                (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
                ([VName]
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmps ([[DimIndex (TPrimExp Int64 VName)]]
 -> [(VName, [DimIndex (TPrimExp Int64 VName)])])
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(VName, [DimIndex (TPrimExp Int64 VName)])]
forall a b. (a -> b) -> a -> b
$ [DimIndex (TPrimExp Int64 VName)]
-> [[DimIndex (TPrimExp Int64 VName)]]
forall a. a -> [a]
repeat [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
                ( [SubExp]
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((PatElem LParamMem -> SubExp) -> [PatElem LParamMem] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem LParamMem -> VName) -> PatElem LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName) [PatElem LParamMem]
red_pes) ([DimIndex (TPrimExp Int64 VName)]
-> [[DimIndex (TPrimExp Int64 VName)]]
forall a. a -> [a]
repeat [])
                    [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a. [a] -> [a] -> [a]
++ [SubExp]
-> [[DimIndex (TPrimExp Int64 VName)]]
-> [(SubExp, [DimIndex (TPrimExp Int64 VName)])]
forall a b. [a] -> [b] -> [(a, b)]
zip ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
tmps) ([DimIndex (TPrimExp Int64 VName)]
-> [[DimIndex (TPrimExp Int64 VName)]]
forall a. a -> [a]
repeat [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start])
                )

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
          [VName]
tmps_chunks <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TPrimExp Int64 VName
-> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op.
TPrimExp Int64 VName -> TV Int64 -> VName -> ImpM rep r op VName
sliceArray (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start) TV Int64
chunk_size) [VName]
tmps
          TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
chunk_size)) (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps_chunks

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

        [(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes ([VName] -> [(PatElem LParamMem, VName)])
-> [VName] -> [(PatElem LParamMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
arr) ->
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
chunk_start]
    virtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      TV Int64
dims_flat <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
virtualisedGroupScan
          ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName -> TExp Int32)
-> TPrimExp Int64 VName -> TExp Int32
forall a b. (a -> b) -> a -> b
$ TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
dims_flat)
          (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      [(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes ([VName] -> [(PatElem LParamMem, VName)])
-> [VName] -> [(PatElem LParamMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
arr) ->
        VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
          (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
          []
          (VName -> SubExp
Var VName
arr)
          ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

    nonvirtCase :: [TPrimExp Int64 VName] -> [[VName]] -> InKernelGen ()
nonvirtCase [TPrimExp Int64 VName
dim'] [[VName]]
tmps_for_ops = do
      -- Nonsegmented case (or rather, a single segment) - this we can
      -- handle directly with a group-level reduction.
      [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) ->
        TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
dim') (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op) [VName]
tmps
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
      [(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes ([VName] -> [(PatElem LParamMem, VName)])
-> [VName] -> [(PatElem LParamMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
arr) ->
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName
0]
    --
    nonvirtCase [TPrimExp Int64 VName]
dims' [[VName]]
tmps_for_ops = do
      -- Segmented intra-group reductions are turned into (regular)
      -- segmented scans.  It is possible that this can be done
      -- better, but at least this approach is simple.

      -- groupScan operates on flattened arrays.  This does not
      -- involve copying anything; merely playing with the index
      -- function.
      TV Int64
dims_flat <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"dims_flat" (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims'
      let segment_size :: TPrimExp Int64 VName
segment_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims'
          crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
            (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`rem` TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
segment_size)

      [(SegBinOp GPUMem, [VName])]
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp GPUMem] -> [[VName]] -> [(SegBinOp GPUMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp GPUMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ())
-> ((SegBinOp GPUMem, [VName]) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp GPUMem
op, [VName]
tmps) -> do
        [VName]
tmps_flat <- (VName -> ImpM GPUMem KernelEnv KernelOp VName)
-> [VName] -> ImpM GPUMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> TV Int64 -> VName -> ImpM GPUMem KernelEnv KernelOp VName
forall rep r op. Int -> TV Int64 -> VName -> ImpM rep r op VName
flattenArray ([TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims') TV Int64
dims_flat) [VName]
tmps
        Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan
          ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
          ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          ([TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
dims')
          (SegBinOp GPUMem -> Lambda GPUMem
forall rep. SegBinOp rep -> Lambda rep
segBinOpLambda SegBinOp GPUMem
op)
          [VName]
tmps_flat

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal

      [(PatElem LParamMem, VName)]
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElem LParamMem] -> [VName] -> [(PatElem LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem LParamMem]
red_pes ([VName] -> [(PatElem LParamMem, VName)])
-> [VName] -> [(PatElem LParamMem, VName)]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
tmps_for_ops) (((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElem LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElem LParamMem
pe, VName
arr) ->
        VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
          (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe)
          []
          (VName -> SubExp
Var VName
arr)
          ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. Num d => d -> d -> DimIndex d
unitSlice TPrimExp Int64 VName
0) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a]
init [TPrimExp Int64 VName]
dims') [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a. [a] -> a
last [TPrimExp Int64 VName]
dims' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1])

      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp GPUMem]
ops [TypeBase Shape NoUniqueness]
_ KernelBody GPUMem
kbody))) = do
  SegLevel -> SegSpace -> InKernelGen ()
compileFlatId SegLevel
lvl SegSpace
space
  let ([VName]
ltids, [SubExp]
_dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  -- We don't need the red_pes, because it is guaranteed by our type
  -- rules that they occupy the same memory as the destinations for
  -- the ops.
  let num_red_res :: Int
num_red_res = [HistOp GPUMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp GPUMem -> [SubExp]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [SubExp]
forall rep. HistOp rep -> [SubExp]
histNeutral) [HistOp GPUMem]
ops)
      ([PatElem LParamMem]
_red_pes, [PatElem LParamMem]
map_pes) =
        Int
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElem LParamMem] -> ([PatElem LParamMem], [PatElem LParamMem]))
-> [PatElem LParamMem]
-> ([PatElem LParamMem], [PatElem LParamMem])
forall a b. (a -> b) -> a -> b
$ Pat LParamMem -> [PatElem LParamMem]
forall dec. Pat dec -> [PatElem dec]
patElems Pat (LetDec GPUMem)
Pat LParamMem
pat

  [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp GPUMem]
-> InKernelGen [[TPrimExp Int64 VName] -> InKernelGen ()]
prepareIntraGroupSegHist (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) [HistOp GPUMem]
ops

  -- Ensure that all locks have been initialised.
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

  SegVirt -> SegSpace -> InKernelGen () -> InKernelGen ()
groupCoverSegSpace (SegLevel -> SegVirt
segVirt SegLevel
lvl) SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep r op.
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody GPUMem -> Stms GPUMem
forall rep. KernelBody rep -> Stms rep
kernelBodyStms KernelBody GPUMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody GPUMem -> [KernelResult]
forall rep. KernelBody rep -> [KernelResult]
kernelBodyResult KernelBody GPUMem
kbody
          ([SubExp]
red_is, [SubExp]
red_vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp GPUMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp GPUMem]
ops) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
      (PatElem LParamMem -> KernelResult -> InKernelGen ())
-> [PatElem LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElem LParamMem]
map_pes [KernelResult]
map_res

      let vs_per_op :: [[SubExp]]
vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp GPUMem -> Int) -> [HistOp GPUMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp GPUMem -> [VName]) -> HistOp GPUMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp GPUMem -> [VName]
forall rep. HistOp rep -> [VName]
histDest) [HistOp GPUMem]
ops) [SubExp]
red_vs

      [(SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
  HistOp GPUMem)]
-> ((SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
     HistOp GPUMem)
    -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [[SubExp]]
-> [[TPrimExp Int64 VName] -> InKernelGen ()]
-> [HistOp GPUMem]
-> [(SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
     HistOp GPUMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[TPrimExp Int64 VName] -> InKernelGen ()]
ops' [HistOp GPUMem]
ops) (((SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
   HistOp GPUMem)
  -> InKernelGen ())
 -> InKernelGen ())
-> ((SubExp, [SubExp], [TPrimExp Int64 VName] -> InKernelGen (),
     HistOp GPUMem)
    -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        \(SubExp
bin, [SubExp]
op_vs, [TPrimExp Int64 VName] -> InKernelGen ()
do_op, HistOp Shape
dest_shape SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda GPUMem
lam) -> do
          let bin' :: TPrimExp Int64 VName
bin' = SubExp -> TPrimExp Int64 VName
pe64 SubExp
bin
              dest_shape' :: [TPrimExp Int64 VName]
dest_shape' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape
              bin_in_bounds :: TExp Bool
bin_in_bounds = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix [TPrimExp Int64 VName
bin'])) [TPrimExp Int64 VName]
dest_shape'
              bin_is :: [TPrimExp Int64 VName]
bin_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids) [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName
bin']
              vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
              [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam
              Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest Shape
shape (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is -> do
                [(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
op_vs) (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
                  VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TPrimExp Int64 VName]
is
                [TPrimExp Int64 VName] -> InKernelGen ()
do_op ([TPrimExp Int64 VName]
bin_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is)

  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileGroupOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat LParamMem -> String
forall a. Pretty a => a -> String
pretty Pat (LetDec GPUMem)
Pat LParamMem
pat

compileThreadOp :: OpCompiler GPUMem KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler GPUMem KernelEnv KernelOp
compileThreadOp Pat (LetDec GPUMem)
pat (Alloc SubExp
size Space
space) =
  Pat LParamMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pat (LetDec GPUMem)
Pat LParamMem
pat SubExp
size Space
space
compileThreadOp Pat (LetDec GPUMem)
pat (Inner (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread))) =
  Pat LParamMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall rep r op.
Pat LParamMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> ImpM rep r op ()
splitSpace Pat (LetDec GPUMem)
Pat LParamMem
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileThreadOp Pat (LetDec GPUMem)
pat Op GPUMem
_ =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileThreadOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Pat LParamMem -> String
forall a. Pretty a => a -> String
pretty Pat (LetDec GPUMem)
Pat LParamMem
pat

-- | Locking strategy used for an atomic update.
data Locking = Locking
  { -- | Array containing the lock.
    Locking -> VName
lockingArray :: VName,
    -- | Value for us to consider the lock free.
    Locking -> TExp Int32
lockingIsUnlocked :: Imp.TExp Int32,
    -- | What to write when we lock it.
    Locking -> TExp Int32
lockingToLock :: Imp.TExp Int32,
    -- | What to write when we unlock it.
    Locking -> TExp Int32
lockingToUnlock :: Imp.TExp Int32,
    -- | A transformation from the logical lock index to the
    -- physical position in the array.  This can also be used
    -- to make the lock array smaller.
    Locking -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64]
  }

-- | A function for generating code for an atomic update.  Assumes
-- that the bucket is in-bounds.
type DoAtomicUpdate rep r =
  Space -> [VName] -> [Imp.TExp Int64] -> ImpM rep r Imp.KernelOp ()

-- | The mechanism that will be used for performing the atomic update.
-- Approximates how efficient it will be.  Ordered from most to least
-- efficient.
data AtomicUpdate rep r
  = -- | Supported directly by primitive.
    AtomicPrim (DoAtomicUpdate rep r)
  | -- | Can be done by efficient swaps.
    AtomicCAS (DoAtomicUpdate rep r)
  | -- | Requires explicit locking.
    AtomicLocking (Locking -> DoAtomicUpdate rep r)

-- | Is there an atomic t'BinOp' corresponding to this t'BinOp'?
type AtomicBinOp =
  BinOp ->
  Maybe (VName -> VName -> Count Imp.Elements (Imp.TExp Int64) -> Imp.Exp -> Imp.AtomicOp)

-- | Do an atomic update corresponding to a binary operator lambda.
atomicUpdateLocking ::
  AtomicBinOp ->
  Lambda GPUMem ->
  AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking :: AtomicBinOp -> Lambda GPUMem -> AtomicUpdate GPUMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda GPUMem
lam
  | Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda GPUMem -> Maybe [(BinOp, PrimType, VName, VName)]
forall rep.
ASTRep rep =>
Lambda rep -> Maybe [(BinOp, PrimType, VName, VName)]
lamIsBinOp Lambda GPUMem
lam,
    ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64]) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
      [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv)
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [TPrimExp Int64 VName]
bucket ->
        -- If the operator is a vectorised binary operator on 32/64-bit
        -- values, we can use a particularly efficient
        -- implementation. If the operator has an atomic implementation
        -- we use that, otherwise it is still a binary operator which
        -- can be implemented by atomic compare-and-swap if 32/64 bits.
        [(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
 -> InKernelGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
          -- Common variables.
          TV Any
old <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old" PrimType
t

          (VName
arr', Space
_a_space, Count Elements (TPrimExp Int64 VName)
bucket_offset) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
a [TPrimExp Int64 VName]
bucket

          case Space
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> BinOp
-> Maybe (PrimExp VName -> KernelOp)
opHasAtomicSupport Space
space (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) VName
arr' Count Elements (TPrimExp Int64 VName)
bucket_offset BinOp
op of
            Just PrimExp VName -> KernelOp
f -> KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ PrimExp VName -> KernelOp
f (PrimExp VName -> KernelOp) -> PrimExp VName -> KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
Imp.var VName
y PrimType
t
            Maybe (PrimExp VName -> KernelOp)
Nothing ->
              Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TPrimExp Int64 VName]
bucket VName
x (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                VName
x VName -> PrimExp VName -> InKernelGen ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> PrimExp VName
Imp.var VName
x PrimType
t) (VName -> PrimType -> PrimExp VName
Imp.var VName
y PrimType
t)
  where
    opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> BinOp
-> Maybe (PrimExp VName -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket' BinOp
bop = do
      let atomic :: (VName
 -> VName
 -> Count Elements (TPrimExp Int64 VName)
 -> PrimExp VName
 -> AtomicOp)
-> PrimExp VName -> KernelOp
atomic VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimExp VName
-> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp)
-> (PrimExp VName -> AtomicOp) -> PrimExp VName -> KernelOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimExp VName
-> AtomicOp
f VName
old VName
arr' Count Elements (TPrimExp Int64 VName)
bucket'
      (VName
 -> VName
 -> Count Elements (TPrimExp Int64 VName)
 -> PrimExp VName
 -> AtomicOp)
-> PrimExp VName -> KernelOp
atomic ((VName
  -> VName
  -> Count Elements (TPrimExp Int64 VName)
  -> PrimExp VName
  -> AtomicOp)
 -> PrimExp VName -> KernelOp)
-> Maybe
     (VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> PrimExp VName
      -> AtomicOp)
-> Maybe (PrimExp VName -> KernelOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop

    primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
      | ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall rep r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicPrim
      | Bool
otherwise = DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall rep r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicCAS

    isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = Maybe
  (VName
   -> VName
   -> Count Elements (TPrimExp Int64 VName)
   -> PrimExp VName
   -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe
   (VName
    -> VName
    -> Count Elements (TPrimExp Int64 VName)
    -> PrimExp VName
    -> AtomicOp)
 -> Bool)
-> Maybe
     (VName
      -> VName
      -> Count Elements (TPrimExp Int64 VName)
      -> PrimExp VName
      -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op

-- If the operator functions purely on single 32/64-bit values, we can
-- use an implementation based on CAS, no matter what the operator
-- does.
atomicUpdateLocking AtomicBinOp
_ Lambda GPUMem
op
  | [Prim PrimType
t] <- Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
op,
    [LParam GPUMem
xp, LParam GPUMem
_] <- Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op,
    PrimType -> Int
primBitSize PrimType
t Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64] = DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall rep r. DoAtomicUpdate rep r -> AtomicUpdate rep r
AtomicCAS (DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv)
-> DoAtomicUpdate GPUMem KernelEnv -> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [TPrimExp Int64 VName]
bucket -> do
      TV Any
old <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old" PrimType
t
      Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TPrimExp Int64 VName]
bucket (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName LParam GPUMem
Param LParamMem
xp) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [LParam GPUMem
Param LParamMem
xp] (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
op
atomicUpdateLocking AtomicBinOp
_ Lambda GPUMem
op = (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> AtomicUpdate GPUMem KernelEnv
forall rep r.
(Locking -> DoAtomicUpdate rep r) -> AtomicUpdate rep r
AtomicLocking ((Locking -> DoAtomicUpdate GPUMem KernelEnv)
 -> AtomicUpdate GPUMem KernelEnv)
-> (Locking -> DoAtomicUpdate GPUMem KernelEnv)
-> AtomicUpdate GPUMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [TPrimExp Int64 VName]
bucket -> do
  TV Int32
old <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"old" PrimType
int32
  TV Bool
continue <- String
-> PrimType
-> TExp Bool
-> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall t rep r op.
String -> PrimType -> TExp t -> ImpM rep r op (TV t)
dPrimVol String
"continue" PrimType
Bool TExp Bool
forall v. TPrimExp Bool v
true

  -- Correctly index into locks.
  (VName
locks', Space
_locks_space, Count Elements (TPrimExp Int64 VName)
locks_offset) <-
    VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([TPrimExp Int64 VName]
 -> ImpM
      GPUMem
      KernelEnv
      KernelOp
      (VName, Space, Count Elements (TPrimExp Int64 VName)))
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall a b. (a -> b) -> a -> b
$ Locking -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
lockingMapping Locking
locking [TPrimExp Int64 VName]
bucket

  -- Critical section
  let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimExp VName
-> PrimExp VName
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              Count Elements (TPrimExp Int64 VName)
locks_offset
              (TExp Int32 -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp VName) -> TExp Int32 -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingIsUnlocked Locking
locking)
              (TExp Int32 -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp VName) -> TExp Int32 -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
      lock_acquired :: TExp Bool
lock_acquired = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
old TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Locking -> TExp Int32
lockingIsUnlocked Locking
locking
      -- Even the releasing is done with an atomic rather than a
      -- simple write, for memory coherency reasons.
      release_lock :: InKernelGen ()
release_lock =
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
            PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimExp VName
-> PrimExp VName
-> AtomicOp
Imp.AtomicCmpXchg
              PrimType
int32
              (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
              VName
locks'
              Count Elements (TPrimExp Int64 VName)
locks_offset
              (TExp Int32 -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp VName) -> TExp Int32 -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
              (TExp Int32 -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp VName) -> TExp Int32 -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToUnlock Locking
locking)
      break_loop :: InKernelGen ()
break_loop = TV Bool
continue TV Bool -> TExp Bool -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Bool
forall v. TPrimExp Bool v
false

  -- Preparing parameters. It is assumed that the caller has already
  -- filled the arr_params. We copy the current value to the
  -- accumulator parameters.
  --
  -- Note the use of 'everythingVolatile' when reading and writing the
  -- buckets.  This was necessary to ensure correct execution on a
  -- newer NVIDIA GPU (RTX 2080).  The 'volatile' modifiers likely
  -- make the writes pass through the (SM-local) L1 cache, which is
  -- necessary here, because we are really doing device-wide
  -- synchronisation without atomics (naughty!).
  let ([Param LParamMem]
acc_params, [Param LParamMem]
_arr_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
op
      bind_acc_params :: InKernelGen ()
bind_acc_params =
        InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"bind lhs" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, VName
arr) ->
              VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

  let op_body :: InKernelGen ()
op_body =
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"execute operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
acc_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
op

      do_hist :: InKernelGen ()
do_hist =
        InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"update global result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            (VName -> SubExp -> InKernelGen ())
-> [VName] -> [SubExp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([TPrimExp Int64 VName] -> VName -> SubExp -> InKernelGen ()
forall {rep} {r} {op}.
[TPrimExp Int64 VName] -> VName -> SubExp -> ImpM rep r op ()
writeArray [TPrimExp Int64 VName]
bucket) [VName]
arrs ([SubExp] -> InKernelGen ()) -> [SubExp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              (Param LParamMem -> SubExp) -> [Param LParamMem] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
acc_params

      fence :: InKernelGen ()
fence = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence (Fence -> KernelOp) -> Fence -> KernelOp
forall a b. (a -> b) -> a -> b
$ Space -> Fence
fenceForSpace Space
space

  -- While-loop: Try to insert your value
  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Bool -> TExp Bool
forall t. TV t -> TExp t
tvExp TV Bool
continue) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
try_acquire_lock
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
lock_acquired (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams [LParam GPUMem]
[Param LParamMem]
acc_params
      InKernelGen ()
bind_acc_params
      InKernelGen ()
op_body
      InKernelGen ()
do_hist
      InKernelGen ()
fence
      InKernelGen ()
release_lock
      InKernelGen ()
break_loop
    InKernelGen ()
fence
  where
    writeArray :: [TPrimExp Int64 VName] -> VName -> SubExp -> ImpM rep r op ()
writeArray [TPrimExp Int64 VName]
bucket VName
arr SubExp
val = VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
bucket SubExp
val []

atomicUpdateCAS ::
  Space ->
  PrimType ->
  VName ->
  VName ->
  [Imp.TExp Int64] ->
  VName ->
  InKernelGen () ->
  InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [TPrimExp Int64 VName]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [TPrimExp Int64 VName]
bucket VName
x InKernelGen ()
do_op = do
  -- Code generation target:
  --
  -- old = d_his[idx];
  -- do {
  --   assumed = old;
  --   x = do_op(assumed, y);
  --   old = atomicCAS(&d_his[idx], assumed, tmp);
  -- } while(assumed != old);
  VName
assumed <- TV Any -> VName
forall t. TV t -> VName
tvVar (TV Any -> VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Any)
-> ImpM GPUMem KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"assumed" PrimType
t
  TV Bool
run_loop <- String -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TV Bool)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"run_loop" TExp Bool
forall v. TPrimExp Bool v
true

  -- XXX: CUDA may generate really bad code if this is not a volatile
  -- read.  Unclear why.  The later reads are volatile, so maybe
  -- that's it.
  InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName]
bucket

  (VName
arr', Space
_a_space, Count Elements (TPrimExp Int64 VName)
bucket_offset) <- VName
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray VName
arr [TPrimExp Int64 VName]
bucket

  -- While-loop: Try to insert your value
  let (PrimExp VName -> PrimExp VName
toBits, PrimExp VName -> PrimExp VName
fromBits) =
        case PrimType
t of
          FloatType FloatType
Float16 ->
            ( \PrimExp VName
v -> String -> [PrimExp VName] -> PrimType -> PrimExp VName
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits16" [PrimExp VName
v] PrimType
int16,
              \PrimExp VName
v -> String -> [PrimExp VName] -> PrimType -> PrimExp VName
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits16" [PrimExp VName
v] PrimType
t
            )
          FloatType FloatType
Float32 ->
            ( \PrimExp VName
v -> String -> [PrimExp VName] -> PrimType -> PrimExp VName
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits32" [PrimExp VName
v] PrimType
int32,
              \PrimExp VName
v -> String -> [PrimExp VName] -> PrimType -> PrimExp VName
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits32" [PrimExp VName
v] PrimType
t
            )
          FloatType FloatType
Float64 ->
            ( \PrimExp VName
v -> String -> [PrimExp VName] -> PrimType -> PrimExp VName
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits64" [PrimExp VName
v] PrimType
int64,
              \PrimExp VName
v -> String -> [PrimExp VName] -> PrimType -> PrimExp VName
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits64" [PrimExp VName
v] PrimType
t
            )
          PrimType
_ -> (PrimExp VName -> PrimExp VName
forall a. a -> a
id, PrimExp VName -> PrimExp VName
forall a. a -> a
id)

      int :: PrimType
int
        | PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
16 = PrimType
int16
        | PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = PrimType
int32
        | Bool
otherwise = PrimType
int64

  TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Bool -> TExp Bool
forall t. TV t -> TExp t
tvExp TV Bool
run_loop) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    VName
assumed VName -> PrimExp VName -> InKernelGen ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ VName -> PrimType -> PrimExp VName
Imp.var VName
old PrimType
t
    VName
x VName -> PrimExp VName -> InKernelGen ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ VName -> PrimType -> PrimExp VName
Imp.var VName
assumed PrimType
t
    InKernelGen ()
do_op
    VName
old_bits_v <- String -> ImpM GPUMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"old_bits"
    VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
old_bits_v PrimType
int
    let old_bits :: PrimExp VName
old_bits = VName -> PrimType -> PrimExp VName
Imp.var VName
old_bits_v PrimType
int
    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ())
-> (AtomicOp -> KernelOp) -> AtomicOp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> InKernelGen ()) -> AtomicOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      PrimType
-> VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimExp VName
-> PrimExp VName
-> AtomicOp
Imp.AtomicCmpXchg
        PrimType
int
        VName
old_bits_v
        VName
arr'
        Count Elements (TPrimExp Int64 VName)
bucket_offset
        (PrimExp VName -> PrimExp VName
toBits (VName -> PrimType -> PrimExp VName
Imp.var VName
assumed PrimType
t))
        (PrimExp VName -> PrimExp VName
toBits (VName -> PrimType -> PrimExp VName
Imp.var VName
x PrimType
t))
    VName
old VName -> PrimExp VName -> InKernelGen ()
forall rep r op. VName -> PrimExp VName -> ImpM rep r op ()
<~~ PrimExp VName -> PrimExp VName
fromBits PrimExp VName
old_bits
    let won :: PrimExp VName
won = CmpOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq PrimType
int) (PrimExp VName -> PrimExp VName
toBits (VName -> PrimType -> PrimExp VName
Imp.var VName
assumed PrimType
t)) PrimExp VName
old_bits
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (PrimExp VName -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool PrimExp VName
won) (TV Bool
run_loop TV Bool -> TExp Bool -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Bool
forall v. TPrimExp Bool v
false)

computeKernelUses ::
  FreeIn a =>
  a ->
  [VName] ->
  CallKernelGen [Imp.KernelUse]
computeKernelUses :: forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
  let actually_free :: Names
actually_free = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
  -- Compute the variables that we need to pass to the kernel.
  [KernelUse] -> [KernelUse]
forall a. Ord a => [a] -> [a]
nubOrd ([KernelUse] -> [KernelUse])
-> CallKernelGen [KernelUse] -> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free

readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet = ([Maybe KernelUse] -> [KernelUse])
-> ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelUse] -> [KernelUse]
forall a. [Maybe a] -> [a]
catMaybes (ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
 -> CallKernelGen [KernelUse])
-> (Names -> ImpM GPUMem HostEnv HostOp [Maybe KernelUse])
-> Names
-> CallKernelGen [KernelUse]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ImpM GPUMem HostEnv HostOp (Maybe KernelUse))
-> [VName] -> ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM GPUMem HostEnv HostOp (Maybe KernelUse)
forall {r} {op}. VName -> ImpM GPUMem r op (Maybe KernelUse)
f ([VName] -> ImpM GPUMem HostEnv HostOp [Maybe KernelUse])
-> (Names -> [VName])
-> Names
-> ImpM GPUMem HostEnv HostOp [Maybe KernelUse]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
  where
    f :: VName -> ImpM GPUMem r op (Maybe KernelUse)
f VName
var = do
      TypeBase Shape NoUniqueness
t <- VName -> ImpM GPUMem r op (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
var
      VTable GPUMem
vtable <- ImpM GPUMem r op (VTable GPUMem)
forall rep r op. ImpM rep r op (VTable rep)
getVTable
      case TypeBase Shape NoUniqueness
t of
        Array {} -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe KernelUse
forall a. Maybe a
Nothing
        Acc {} -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe KernelUse
forall a. Maybe a
Nothing
        Mem (Space String
"local") -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe KernelUse
forall a. Maybe a
Nothing
        Mem {} -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse))
-> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
        Prim PrimType
bt ->
          VTable GPUMem
-> PrimExp VName -> ImpM GPUMem r op (Maybe KernelConstExp)
forall rep r op.
VTable GPUMem
-> PrimExp VName -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable (VName -> PrimType -> PrimExp VName
Imp.var VName
var PrimType
bt) ImpM GPUMem r op (Maybe KernelConstExp)
-> (Maybe KernelConstExp -> ImpM GPUMem r op (Maybe KernelUse))
-> ImpM GPUMem r op (Maybe KernelUse)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
            Just KernelConstExp
ce -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse))
-> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
            Maybe KernelConstExp
Nothing -> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse))
-> Maybe KernelUse -> ImpM GPUMem r op (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt

isConstExp ::
  VTable GPUMem ->
  Imp.Exp ->
  ImpM rep r op (Maybe Imp.KernelConstExp)
isConstExp :: forall rep r op.
VTable GPUMem
-> PrimExp VName -> ImpM rep r op (Maybe KernelConstExp)
isConstExp VTable GPUMem
vtable PrimExp VName
size = do
  Maybe Name
fname <- ImpM rep r op (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let onLeaf :: VName -> PrimType -> Maybe KernelConstExp
onLeaf VName
name PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
      lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
        Exp GPUMem -> Maybe KernelConstExp
constExp (Exp GPUMem -> Maybe KernelConstExp)
-> Maybe (Exp GPUMem) -> Maybe KernelConstExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarEntry GPUMem -> Maybe (Exp GPUMem)
forall {rep}. VarEntry rep -> Maybe (Exp rep)
hasExp (VarEntry GPUMem -> Maybe (Exp GPUMem))
-> Maybe (VarEntry GPUMem) -> Maybe (Exp GPUMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VTable GPUMem -> Maybe (VarEntry GPUMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable GPUMem
vtable
      constExp :: Exp GPUMem -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize Name
key SizeClass
_)))) =
        KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> KernelConst
Imp.SizeConst (Name -> KernelConst) -> Name -> KernelConst
forall a b. (a -> b) -> a -> b
$ Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) PrimType
int32
      constExp Exp GPUMem
e = (VName -> Maybe KernelConstExp)
-> Exp GPUMem -> Maybe KernelConstExp
forall (m :: * -> *) rep v.
(MonadFail m, RepTypes rep) =>
(VName -> m (PrimExp v)) -> Exp rep -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp Exp GPUMem
e
  Maybe KernelConstExp -> ImpM rep r op (Maybe KernelConstExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe KernelConstExp -> ImpM rep r op (Maybe KernelConstExp))
-> Maybe KernelConstExp -> ImpM rep r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ (VName -> PrimType -> Maybe KernelConstExp)
-> PrimExp VName -> Maybe KernelConstExp
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM VName -> PrimType -> Maybe KernelConstExp
onLeaf PrimExp VName
size
  where
    hasExp :: VarEntry rep -> Maybe (Exp rep)
hasExp (ArrayVar Maybe (Exp rep)
e ArrayEntry
_) = Maybe (Exp rep)
e
    hasExp (AccVar Maybe (Exp rep)
e (VName, Shape, [TypeBase Shape NoUniqueness])
_) = Maybe (Exp rep)
e
    hasExp (ScalarVar Maybe (Exp rep)
e ScalarEntry
_) = Maybe (Exp rep)
e
    hasExp (MemVar Maybe (Exp rep)
e MemEntry
_) = Maybe (Exp rep)
e

computeThreadChunkSize ::
  SplitOrdering ->
  Imp.TExp Int64 ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  Imp.Count Imp.Elements (Imp.TExp Int64) ->
  TV Int64 ->
  ImpM rep r op ()
computeThreadChunkSize :: forall rep r op.
SplitOrdering
-> TPrimExp Int64 VName
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> TV Int64
-> ImpM rep r op ()
computeThreadChunkSize (SplitStrided SubExp
stride) TPrimExp Int64 VName
thread_index Count Elements (TPrimExp Int64 VName)
elements_per_thread Count Elements (TPrimExp Int64 VName)
num_elements TV Int64
chunk_var =
  TV Int64
chunk_var
    TV Int64 -> TPrimExp Int64 VName -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
      (Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread)
      ((Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
thread_index) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TPrimExp Int64 VName
pe64 SubExp
stride)
computeThreadChunkSize SplitOrdering
SplitContiguous TPrimExp Int64 VName
thread_index Count Elements (TPrimExp Int64 VName)
elements_per_thread Count Elements (TPrimExp Int64 VName)
num_elements TV Int64
chunk_var = do
  TV Int64
starting_point <-
    String -> TPrimExp Int64 VName -> ImpM rep r op (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"starting_point" (TPrimExp Int64 VName -> ImpM rep r op (TV Int64))
-> TPrimExp Int64 VName -> ImpM rep r op (TV Int64)
forall a b. (a -> b) -> a -> b
$
      TPrimExp Int64 VName
thread_index TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread
  TV Int64
remaining_elements <-
    String -> TPrimExp Int64 VName -> ImpM rep r op (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"remaining_elements" (TPrimExp Int64 VName -> ImpM rep r op (TV Int64))
-> TPrimExp Int64 VName -> ImpM rep r op (TV Int64)
forall a b. (a -> b) -> a -> b
$
      Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
starting_point

  let no_remaining_elements :: TExp Bool
no_remaining_elements = TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
remaining_elements TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp Int64 VName
0
      beyond_bounds :: TExp Bool
beyond_bounds = Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
starting_point

  TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
    (TExp Bool
no_remaining_elements TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
beyond_bounds)
    (TV Int64
chunk_var TV Int64 -> TPrimExp Int64 VName -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TPrimExp Int64 VName
0)
    ( TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
forall rep r op.
TExp Bool
-> ImpM rep r op () -> ImpM rep r op () -> ImpM rep r op ()
sIf
        TExp Bool
is_last_thread
        (TV Int64
chunk_var TV Int64 -> TPrimExp Int64 VName -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
last_thread_elements)
        (TV Int64
chunk_var TV Int64 -> TPrimExp Int64 VName -> ImpM rep r op ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread)
    )
  where
    last_thread_elements :: Count Elements (TPrimExp Int64 VName)
last_thread_elements =
      Count Elements (TPrimExp Int64 VName)
num_elements Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName -> Count Elements (TPrimExp Int64 VName)
forall a. a -> Count Elements a
Imp.elements TPrimExp Int64 VName
thread_index Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
-> Count Elements (TPrimExp Int64 VName)
forall a. Num a => a -> a -> a
* Count Elements (TPrimExp Int64 VName)
elements_per_thread
    is_last_thread :: TExp Bool
is_last_thread =
      Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
num_elements
        TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TPrimExp Int64 VName
thread_index TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1)
          TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* Count Elements (TPrimExp Int64 VName) -> TPrimExp Int64 VName
forall u e. Count u e -> e
Imp.unCount Count Elements (TPrimExp Int64 VName)
elements_per_thread

kernelInitialisationSimple ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumGroups SubExp
-> Count GroupSize SubExp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size = do
  VName
global_tid <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"global_tid"
  VName
local_tid <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_tid"
  VName
group_id <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_tid"
  VName
wave_size <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"wave_size"
  VName
inner_group_size <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_size"
  let num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups)
      group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size)
      constants :: KernelConstants
constants =
        KernelConstants :: TExp Int32
-> TExp Int32
-> TExp Int32
-> VName
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> Map [SubExp] [TExp Int32]
-> Map [SubExp] (TExp Int32)
-> KernelConstants
KernelConstants
          { kernelGlobalThreadId :: TExp Int32
kernelGlobalThreadId = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
global_tid,
            kernelLocalThreadId :: TExp Int32
kernelLocalThreadId = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
local_tid,
            kernelGroupId :: TExp Int32
kernelGroupId = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
group_id,
            kernelGlobalThreadIdVar :: VName
kernelGlobalThreadIdVar = VName
global_tid,
            kernelLocalThreadIdVar :: VName
kernelLocalThreadIdVar = VName
local_tid,
            kernelNumGroupsCount :: Count NumGroups SubExp
kernelNumGroupsCount = Count NumGroups SubExp
num_groups,
            kernelGroupSizeCount :: Count GroupSize SubExp
kernelGroupSizeCount = Count GroupSize SubExp
group_size,
            kernelGroupIdVar :: VName
kernelGroupIdVar = VName
group_id,
            kernelNumGroups :: TPrimExp Int64 VName
kernelNumGroups = TPrimExp Int64 VName
num_groups',
            kernelGroupSize :: TPrimExp Int64 VName
kernelGroupSize = TPrimExp Int64 VName
group_size',
            kernelNumThreads :: TExp Int32
kernelNumThreads = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
group_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_groups'),
            kernelWaveSize :: TExp Int32
kernelWaveSize = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
wave_size,
            kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
forall a. Monoid a => a
mempty,
            kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Map [SubExp] (TExp Int32)
forall a. Monoid a => a
mempty
          }

  let set_constants :: InKernelGen ()
set_constants = do
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
local_tid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
inner_group_size PrimType
int64
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
wave_size PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
group_id PrimType
int32

        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
        VName -> TExp Int32 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
global_tid (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
inner_group_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
local_tid

  (KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelConstants
constants, InKernelGen ()
set_constants)

isActive :: [(VName, SubExp)] -> Imp.TExp Bool
isActive :: [(VName, SubExp)] -> TExp Bool
isActive [(VName, SubExp)]
limit = case [TExp Bool]
actives of
  [] -> TExp Bool
forall v. TPrimExp Bool v
true
  TExp Bool
x : [TExp Bool]
xs -> (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
x [TExp Bool]
xs
  where
    ([VName]
is, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
    actives :: [TExp Bool]
actives = (VName -> TPrimExp Int64 VName -> TExp Bool)
-> [VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TPrimExp Int64 VName -> TExp Bool
forall {v}. v -> TPrimExp Int64 v -> TPrimExp Bool v
active [VName]
is ([TPrimExp Int64 VName] -> [TExp Bool])
-> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws
    active :: v -> TPrimExp Int64 v -> TPrimExp Bool v
active v
i = (v -> TPrimExp Int64 v
forall a. a -> TPrimExp Int64 a
Imp.le64 v
i TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<.)

-- | Change every memory block to be in the global address space,
-- except those who are in the local memory space.  This only affects
-- generated code - we still need to make sure that the memory is
-- actually present on the device (and declared as variables in the
-- kernel).
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
  Space
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a
forall rep r op a. Space -> ImpM rep r op a -> ImpM rep r op a
localDefaultSpace (String -> Space
Imp.Space String
"global") (ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a)
-> (ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a)
-> ImpM GPUMem HostEnv HostOp a
-> ImpM GPUMem HostEnv HostOp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VTable GPUMem -> VTable GPUMem)
-> ImpM GPUMem HostEnv HostOp a -> ImpM GPUMem HostEnv HostOp a
forall rep r op a.
(VTable rep -> VTable rep) -> ImpM rep r op a -> ImpM rep r op a
localVTable ((VarEntry GPUMem -> VarEntry GPUMem)
-> VTable GPUMem -> VTable GPUMem
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry GPUMem -> VarEntry GPUMem
forall {rep}. VarEntry rep -> VarEntry rep
globalMemory)
  where
    globalMemory :: VarEntry rep -> VarEntry rep
globalMemory (MemVar Maybe (Exp rep)
_ MemEntry
entry)
      | MemEntry -> Space
entryMemSpace MemEntry
entry Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= String -> Space
Space String
"local" =
          Maybe (Exp rep) -> MemEntry -> VarEntry rep
forall rep. Maybe (Exp rep) -> MemEntry -> VarEntry rep
MemVar Maybe (Exp rep)
forall a. Maybe a
Nothing MemEntry
entry {entryMemSpace :: Space
entryMemSpace = String -> Space
Imp.Space String
"global"}
    globalMemory VarEntry rep
entry =
      VarEntry rep
entry

groupReduce ::
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
groupReduce :: TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduce TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  TV Int32
offset <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"offset" PrimType
int32
  TV Int32
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduceWithOffset TV Int32
offset TExp Int32
w Lambda GPUMem
lam [VName]
arrs

groupReduceWithOffset ::
  TV Int32 ->
  Imp.TExp Int32 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
groupReduceWithOffset :: TV Int32
-> TExp Int32 -> Lambda GPUMem -> [VName] -> InKernelGen ()
groupReduceWithOffset TV Int32
offset TExp Int32
w Lambda GPUMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

  let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      global_tid :: TExp Int32
global_tid = KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants

      barrier :: InKernelGen ()
barrier
        | (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
        | Bool
otherwise = KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal

      readReduceArgument :: Param LParamMem -> VName -> InKernelGen ()
readReduceArgument Param LParamMem
param VName
arr
        | Prim PrimType
_ <- Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
param = do
            let i :: TExp Int32
i = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
param) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
        | Bool
otherwise = do
            let i :: TExp Int32
i = TExp Int32
global_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
param) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]

      writeReduceOpResult :: Param LParamMem -> VName -> InKernelGen ()
writeReduceOpResult Param LParamMem
param VName
arr
        | Prim PrimType
_ <- Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
param =
            VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
param) []
        | Bool
otherwise =
            () -> InKernelGen ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  let ([Param LParamMem]
reduce_acc_params, [Param LParamMem]
reduce_arr_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  TV Int32
skip_waves <- String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"skip_waves" (TExp Int32
1 :: Imp.TExp Int32)
  [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams ([LParam GPUMem] -> InKernelGen ())
-> [LParam GPUMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
0 :: Imp.TExp Int32)

  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"participating threads read initial accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readReduceArgument [Param LParamMem]
reduce_acc_params [VName]
arrs

  let do_reduce :: InKernelGen ()
do_reduce = do
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"read array element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readReduceArgument [Param LParamMem]
reduce_arr_params [VName]
arrs
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"apply reduction operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
reduce_acc_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
        String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment String
"write result of operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
writeReduceOpResult [Param LParamMem]
reduce_acc_params [VName]
arrs
      in_wave_reduce :: InKernelGen ()
in_wave_reduce = InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile InKernelGen ()
do_reduce

      wave_size :: TExp Int32
wave_size = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      group_size :: TPrimExp Int64 VName
group_size = KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
      wave_id :: TExp Int32
wave_id = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
      in_wave_id :: TExp Int32
in_wave_id = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
wave_size
      num_waves :: TExp Int32
num_waves = (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
group_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
wave_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
      arg_in_bounds :: TExp Bool
arg_in_bounds = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w

      doing_in_wave_reductions :: TExp Bool
doing_in_wave_reductions =
        TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
wave_size
      apply_in_in_wave_iteration :: TExp Bool
apply_in_in_wave_iteration =
        (TExp Int32
in_wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
        TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32)
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
doing_in_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
            (TExp Bool
arg_in_bounds TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
apply_in_in_wave_iteration)
            InKernelGen ()
in_wave_reduce
          TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2

      doing_cross_wave_reductions :: TExp Bool
doing_cross_wave_reductions =
        TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
num_waves
      is_first_thread_in_wave :: TExp Bool
is_first_thread_in_wave =
        TExp Int32
in_wave_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      wave_not_skipped :: TExp Bool
wave_not_skipped =
        (TExp Int32
wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
      apply_in_cross_wave_iteration :: TExp Bool
apply_in_cross_wave_iteration =
        TExp Bool
arg_in_bounds TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
is_first_thread_in_wave TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
wave_not_skipped
      cross_wave_reductions :: InKernelGen ()
cross_wave_reductions =
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile TExp Bool
doing_cross_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
          InKernelGen ()
barrier
          TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
wave_size
          TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
            TExp Bool
apply_in_cross_wave_iteration
            InKernelGen ()
do_reduce
          TV Int32
skip_waves TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2

  InKernelGen ()
in_wave_reductions
  InKernelGen ()
cross_wave_reductions

groupScan ::
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  Lambda GPUMem ->
  [VName] ->
  InKernelGen ()
groupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> Lambda GPUMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
w Lambda GPUMem
lam [VName]
arrs = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  Lambda GPUMem
renamed_lam <- Lambda GPUMem -> ImpM GPUMem KernelEnv KernelOp (Lambda GPUMem)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda GPUMem
lam

  let ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
      ([Param LParamMem]
x_params, [Param LParamMem]
y_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam

  [LParam GPUMem] -> InKernelGen ()
forall rep inner r op.
Mem rep inner =>
[LParam rep] -> ImpM rep r op ()
dLParams (Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
lam [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
renamed_lam)

  TExp Bool
ltid_in_bounds <- String -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"ltid_in_bounds" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
w

  Fence
fence <- [VName] -> InKernelGen Fence
fenceForArrays [VName]
arrs

  -- The scan works by splitting the group into blocks, which are
  -- scanned separately.  Typically, these blocks are smaller than
  -- the lockstep width, which enables barrier-free execution inside
  -- them.
  --
  -- We hardcode the block size here.  The only requirement is that
  -- it should not be less than the square root of the group size.
  -- With 32, we will work on groups of size 1024 or smaller, which
  -- fits every device Troels has seen.  Still, it would be nicer if
  -- it were a runtime parameter.  Some day.
  let block_size :: TExp Int32
block_size = TExp Int32
32
      simd_width :: TExp Int32
simd_width = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
      block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
      in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
      doInBlockScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag' TExp Bool
active =
        KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda GPUMem
-> InKernelGen ()
inBlockScan
          KernelConstants
constants
          Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag'
          TPrimExp Int64 VName
arrs_full_size
          TExp Int32
simd_width
          TExp Int32
block_size
          TExp Bool
active
          [VName]
arrs
          InKernelGen ()
barrier
      array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
lam
      barrier :: InKernelGen ()
barrier
        | Bool
array_scan =
            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
        | Bool
otherwise =
            KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
fence

      group_offset :: TPrimExp Int64 VName
group_offset = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants

      writeBlockResult :: Param LParamMem -> VName -> InKernelGen ()
writeBlockResult Param LParamMem
p VName
arr
        | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
        | Bool
otherwise =
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []

      readPrevBlockResult :: Param LParamMem -> VName -> InKernelGen ()
readPrevBlockResult Param LParamMem
p VName
arr
        | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]
        | Bool
otherwise =
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
group_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1]

  Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Bool
ltid_in_bounds Lambda GPUMem
lam
  InKernelGen ()
barrier

  let is_first_block :: TExp Bool
is_first_block = TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"save correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []

    InKernelGen ()
barrier

  let last_in_block :: TExp Bool
last_in_block = TExp Int32
in_block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"last thread of block 'i' writes its result to offset 'i'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
last_in_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
writeBlockResult [Param LParamMem]
x_params [VName]
arrs

  InKernelGen ()
barrier

  let first_block_seg_flag :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag = do
        TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag
        (TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TExp Bool)
 -> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool))
-> (TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
          TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)
  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
comment
    String
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'"
    (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda GPUMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag (TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) Lambda GPUMem
renamed_lam

  InKernelGen ()
barrier

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"move correct values for first block back a block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM
              VName
arr
              [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]
              (VName -> SubExp
Var VName
arr)
              [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]

    InKernelGen ()
barrier

  TExp Bool
no_carry_in <- String -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"no_carry_in" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
ltid_in_bounds

  let read_carry_in :: InKernelGen ()
read_carry_in = TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        [(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x)) []
        (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readPrevBlockResult [Param LParamMem]
x_params [VName]
arrs

      op_to_x :: InKernelGen ()
op_to_x
        | Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam
        | Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
            TExp Bool
inactive <-
              String -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"inactive" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) TExp Int32
ltid32
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> (((Param LParamMem, Param LParamMem) -> InKernelGen ())
    -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> (((Param LParamMem, Param LParamMem) -> InKernelGen ())
    -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
              VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y)) []
            -- The convoluted control flow is to ensure all threads
            -- hit this barrier (if applicable).
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
lam

      write_final_result :: InKernelGen ()
write_final_result =
        [(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
arr) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []

  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"carry-in for every block except the first" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"read operands" InKernelGen ()
read_carry_in
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
    String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"write final result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
no_carry_in InKernelGen ()
write_final_result

  InKernelGen ()
barrier

  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"restore correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      [(Param LParamMem, Param LParamMem, VName)]
-> ((Param LParamMem, Param LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem]
-> [VName]
-> [(Param LParamMem, Param LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LParamMem]
x_params [Param LParamMem]
y_params [VName]
arrs) (((Param LParamMem, Param LParamMem, VName) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y, VName
arr) ->
        if TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
y)
          then VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) []
          else VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
arrs_full_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
group_offset TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid]

  InKernelGen ()
barrier

inBlockScan ::
  KernelConstants ->
  Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
  Imp.TExp Int64 ->
  Imp.TExp Int32 ->
  Imp.TExp Int32 ->
  Imp.TExp Bool ->
  [VName] ->
  InKernelGen () ->
  Lambda GPUMem ->
  InKernelGen ()
inBlockScan :: KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda GPUMem
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TPrimExp Int64 VName
arrs_full_size TExp Int32
lockstep_width TExp Int32
block_size TExp Bool
active [VName]
arrs InKernelGen ()
barrier Lambda GPUMem
scan_lam = InKernelGen () -> InKernelGen ()
forall rep r op a. ImpM rep r op a -> ImpM rep r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
  TV Int32
skip_threads <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"skip_threads" PrimType
int32
  let actual_params :: [LParam GPUMem]
actual_params = Lambda GPUMem -> [LParam GPUMem]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPUMem
scan_lam
      ([Param LParamMem]
x_params, [Param LParamMem]
y_params) =
        Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam GPUMem]
[Param LParamMem]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam GPUMem]
[Param LParamMem]
actual_params
      y_to_x :: InKernelGen ()
y_to_x =
        [(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
          Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
            VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y)) []

  -- Set initial y values
  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"read input for in-block scan" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readInitial [Param LParamMem]
y_params [VName]
arrs
      -- Since the final result is expected to be in x_params, we may
      -- need to copy it there for the first thread in the block.
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Int32
in_block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) InKernelGen ()
y_to_x

  Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier

  let op_to_x :: TExp Bool -> InKernelGen ()
op_to_x TExp Bool
in_block_thread_active
        | Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_block_thread_active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam
        | Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
            TExp Bool
inactive <-
              String -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"inactive" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads) TExp Int32
ltid32
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TExp Bool
in_block_thread_active TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
inactive) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
 -> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
                VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y)) []
            -- The convoluted control flow is to ensure all threads
            -- hit this barrier (if applicable).
            Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
            TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
in_block_thread_active (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sUnless TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
              [Param LParamMem] -> Body GPUMem -> InKernelGen ()
forall dec rep r op. [Param dec] -> Body rep -> ImpM rep r op ()
compileBody' [Param LParamMem]
x_params (Body GPUMem -> InKernelGen ()) -> Body GPUMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
                Lambda GPUMem -> Body GPUMem
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPUMem
scan_lam

      maybeBarrier :: InKernelGen ()
maybeBarrier =
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen
          (TExp Int32
lockstep_width TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads)
          InKernelGen ()
barrier

  String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"in-block scan (hopefully no barriers needed)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
    TV Int32
skip_threads TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TExp Int32
1
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      TExp Bool
thread_active <-
        String -> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"thread_active" (TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM GPUMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
active

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
thread_active (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"read operands" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        (Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (TPrimExp Int64 VName -> Param LParamMem -> VName -> InKernelGen ()
readParam (TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads)) [Param LParamMem]
x_params [VName]
arrs
      String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"perform operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Bool -> InKernelGen ()
op_to_x TExp Bool
thread_active

      InKernelGen ()
maybeBarrier

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
thread_active (InKernelGen () -> InKernelGen ())
-> (InKernelGen () -> InKernelGen ())
-> InKernelGen ()
-> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> InKernelGen () -> InKernelGen ()
forall rep r op. String -> ImpM rep r op () -> ImpM rep r op ()
sComment String
"write result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        [InKernelGen ()] -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([InKernelGen ()] -> InKernelGen ())
-> [InKernelGen ()] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          (Param LParamMem -> Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem]
-> [Param LParamMem]
-> [VName]
-> [InKernelGen ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param LParamMem -> Param LParamMem -> VName -> InKernelGen ()
writeResult [Param LParamMem]
x_params [Param LParamMem]
y_params [VName]
arrs

      InKernelGen ()
maybeBarrier

      TV Int32
skip_threads TV Int32 -> TExp Int32 -> InKernelGen ()
forall t rep r op. TV t -> TExp t -> ImpM rep r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
  where
    block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
    in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
    ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
    ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
    gtid :: TPrimExp Int64 VName
gtid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
    array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase Shape NoUniqueness -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase Shape NoUniqueness] -> Bool)
-> [TypeBase Shape NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda GPUMem -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda GPUMem
scan_lam

    readInitial :: Param LParamMem -> VName -> InKernelGen ()
readInitial Param LParamMem
p VName
arr
      | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid]
      | Bool
otherwise =
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
gtid]

    readParam :: TPrimExp Int64 VName -> Param LParamMem -> VName -> InKernelGen ()
readParam TPrimExp Int64 VName
behind Param LParamMem
p VName
arr
      | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
p =
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind]
      | Bool
otherwise =
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
gtid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
behind TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
arrs_full_size]

    writeResult :: Param LParamMem -> Param LParamMem -> VName -> InKernelGen ()
writeResult Param LParamMem
x Param LParamMem
y VName
arr
      | TypeBase Shape NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase Shape NoUniqueness -> Bool)
-> TypeBase Shape NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param LParamMem
x = do
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM VName
arr [TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []
      | Bool
otherwise =
          VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []

simpleKernelGroups ::
  Imp.TExp Int64 ->
  Imp.TExp Int64 ->
  CallKernelGen (Imp.TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
simpleKernelGroups :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> CallKernelGen
     (TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
simpleKernelGroups TPrimExp Int64 VName
max_num_groups TPrimExp Int64 VName
kernel_size = do
  TV Int64
group_size <- String -> PrimType -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"group_size" PrimType
int64
  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let group_size_key :: Name
group_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
group_size
  HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. op -> ImpM rep r op ()
sOp (HostOp -> ImpM GPUMem HostEnv HostOp ())
-> HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
group_size) Name
group_size_key SizeClass
Imp.SizeGroup
  TPrimExp Int64 VName
virt_num_groups <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"virt_num_groups" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
kernel_size TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp TV Int64
group_size
  TV Int64
num_groups <- String
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"num_groups" (TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64))
-> TPrimExp Int64 VName -> ImpM GPUMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
virt_num_groups TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
`sMin64` TPrimExp Int64 VName
max_num_groups
  (TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
-> CallKernelGen
     (TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TPrimExp Int64 VName
virt_num_groups, SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> SubExp -> Count NumGroups SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
num_groups, SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ TV Int64 -> SubExp
forall t. TV t -> SubExp
tvSize TV Int64
group_size)

simpleKernelConstants ::
  Imp.TExp Int64 ->
  String ->
  CallKernelGen
    ( (Imp.TExp Int64 -> InKernelGen ()) -> InKernelGen (),
      KernelConstants
    )
simpleKernelConstants :: TPrimExp Int64 VName
-> String
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
kernel_size String
desc = do
  -- For performance reasons, codegen assumes that the thread count is
  -- never more than will fit in an i32.  This means we need to cap
  -- the number of groups here.  The cap is set much higher than any
  -- GPU will possibly need.  Feel free to come back and laugh at me
  -- in the future.
  let max_num_groups :: TPrimExp Int64 VName
max_num_groups = TPrimExp Int64 VName
1024 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
1024
  VName
thread_gtid <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM GPUMem HostEnv HostOp VName)
-> String -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gtid"
  VName
thread_ltid <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM GPUMem HostEnv HostOp VName)
-> String -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_ltid"
  VName
group_id <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM GPUMem HostEnv HostOp VName)
-> String -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gid"
  VName
inner_group_size <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_size"
  (TExp Int32
virt_num_groups, Count NumGroups SubExp
num_groups, Count GroupSize SubExp
group_size) <-
    TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> CallKernelGen
     (TExp Int32, Count NumGroups SubExp, Count GroupSize SubExp)
simpleKernelGroups TPrimExp Int64 VName
max_num_groups TPrimExp Int64 VName
kernel_size
  let group_size' :: TPrimExp Int64 VName
group_size' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
      num_groups' :: TPrimExp Int64 VName
num_groups' = SubExp -> TPrimExp Int64 VName
Imp.pe64 (SubExp -> TPrimExp Int64 VName) -> SubExp -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount Count NumGroups SubExp
num_groups

      constants :: KernelConstants
constants =
        KernelConstants :: TExp Int32
-> TExp Int32
-> TExp Int32
-> VName
-> VName
-> VName
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TExp Int32
-> TExp Int32
-> Map [SubExp] [TExp Int32]
-> Map [SubExp] (TExp Int32)
-> KernelConstants
KernelConstants
          { kernelGlobalThreadId :: TExp Int32
kernelGlobalThreadId = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
thread_gtid,
            kernelLocalThreadId :: TExp Int32
kernelLocalThreadId = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
thread_ltid,
            kernelGroupId :: TExp Int32
kernelGroupId = VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
group_id,
            kernelGlobalThreadIdVar :: VName
kernelGlobalThreadIdVar = VName
thread_gtid,
            kernelLocalThreadIdVar :: VName
kernelLocalThreadIdVar = VName
thread_ltid,
            kernelGroupIdVar :: VName
kernelGroupIdVar = VName
group_id,
            kernelNumGroupsCount :: Count NumGroups SubExp
kernelNumGroupsCount = Count NumGroups SubExp
num_groups,
            kernelGroupSizeCount :: Count GroupSize SubExp
kernelGroupSizeCount = Count GroupSize SubExp
group_size,
            kernelNumGroups :: TPrimExp Int64 VName
kernelNumGroups = TPrimExp Int64 VName
num_groups',
            kernelGroupSize :: TPrimExp Int64 VName
kernelGroupSize = TPrimExp Int64 VName
group_size',
            kernelNumThreads :: TExp Int32
kernelNumThreads = TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TPrimExp Int64 VName
group_size' TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
num_groups'),
            kernelWaveSize :: TExp Int32
kernelWaveSize = TExp Int32
0,
            kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
forall a. Monoid a => a
mempty,
            kernelChunkItersMap :: Map [SubExp] (TExp Int32)
kernelChunkItersMap = Map [SubExp] (TExp Int32)
forall a. Monoid a => a
mempty
          }

      wrapKernel :: (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
wrapKernel TPrimExp Int64 VName -> InKernelGen ()
m = do
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
thread_ltid PrimType
int32
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
inner_group_size PrimType
int64
        VName -> PrimType -> InKernelGen ()
forall rep r op. VName -> PrimType -> ImpM rep r op ()
dPrim_ VName
group_id PrimType
int32
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
        KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
        VName -> TExp Int32 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
thread_gtid (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
inner_group_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
thread_ltid
        SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
virt_num_groups ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
virt_group_id -> do
          TPrimExp Int64 VName
global_tid <-
            String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"global_tid" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
virt_group_id TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
le32 VName
inner_group_size)
                TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
          TPrimExp Int64 VName -> InKernelGen ()
m TPrimExp Int64 VName
global_tid

  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
 KernelConstants)
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
wrapKernel, KernelConstants
constants)

-- | For many kernels, we may not have enough physical groups to cover
-- the logical iteration space.  Some groups thus have to perform
-- double duty; we put an outer loop to accomplish this.  The
-- advantage over just launching a bazillion threads is that the cost
-- of memory expansion should be proportional to the number of
-- *physical* threads (hardware parallelism), not the amount of
-- application parallelism.
virtualiseGroups ::
  SegVirt ->
  Imp.TExp Int32 ->
  (Imp.TExp Int32 -> InKernelGen ()) ->
  InKernelGen ()
virtualiseGroups :: SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
required_groups TExp Int32 -> InKernelGen ()
m = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TV Int32
phys_group_id <- String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"phys_group_id" PrimType
int32
  KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
phys_group_id) Int
0
  TExp Int32
iterations <-
    String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"iterations" (TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32))
-> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TExp Int32)
forall a b. (a -> b) -> a -> b
$
      (TExp Int32
required_groups TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
phys_group_id) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants)

  String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" TExp Int32
iterations ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
    TExp Int32 -> InKernelGen ()
m (TExp Int32 -> InKernelGen ())
-> (TV Int32 -> TExp Int32) -> TV Int32 -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp
      (TV Int32 -> InKernelGen ())
-> ImpM GPUMem KernelEnv KernelOp (TV Int32) -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> TExp Int32 -> ImpM GPUMem KernelEnv KernelOp (TV Int32)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV
        String
"virt_group_id"
        (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
phys_group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants))
    -- Make sure the virtual group is actually done before we let
    -- another virtual group have its way with it.
    KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
virtualiseGroups SegVirt
_ TExp Int32
_ TExp Int32 -> InKernelGen ()
m = do
  VName
gid <- KernelConstants -> VName
kernelGroupIdVar (KernelConstants -> VName)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TExp Int32 -> InKernelGen ()
m (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
forall a. a -> TPrimExp Int32 a
Imp.le32 VName
gid

-- | Various extra configuration of the kernel being generated.
data KernelAttrs = KernelAttrs
  { -- | Can this kernel execute correctly even if previous kernels failed?
    KernelAttrs -> Bool
kAttrFailureTolerant :: Bool,
    -- | Does whatever launch this kernel check for local memory capacity itself?
    KernelAttrs -> Bool
kAttrCheckLocalMemory :: Bool,
    -- | Number of groups.
    KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups :: Count NumGroups SubExp,
    -- | Group size.
    KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize :: Count GroupSize SubExp
  }

-- | The default kernel attributes.
defKernelAttrs ::
  Count NumGroups SubExp ->
  Count GroupSize SubExp ->
  KernelAttrs
defKernelAttrs :: Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs Count NumGroups SubExp
num_groups Count GroupSize SubExp
group_size =
  KernelAttrs :: Bool
-> Bool
-> Count NumGroups SubExp
-> Count GroupSize SubExp
-> KernelAttrs
KernelAttrs
    { kAttrFailureTolerant :: Bool
kAttrFailureTolerant = Bool
False,
      kAttrCheckLocalMemory :: Bool
kAttrCheckLocalMemory = Bool
True,
      kAttrNumGroups :: Count NumGroups SubExp
kAttrNumGroups = Count NumGroups SubExp
num_groups,
      kAttrGroupSize :: Count GroupSize SubExp
kAttrGroupSize = Count GroupSize SubExp
group_size
    }

sKernel ::
  Operations GPUMem KernelEnv Imp.KernelOp ->
  (KernelConstants -> Imp.TExp Int32) ->
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernel :: Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> String
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernel Operations GPUMem KernelEnv KernelOp
ops KernelConstants -> TExp Int32
flatf String
name VName
v KernelAttrs
attrs InKernelGen ()
f = do
  (KernelConstants
constants, InKernelGen ()
set_constants) <-
    Count NumGroups SubExp
-> Count GroupSize SubExp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (KernelAttrs -> Count NumGroups SubExp
kAttrNumGroups KernelAttrs
attrs) (KernelAttrs -> Count GroupSize SubExp
kAttrGroupSize KernelAttrs
attrs)
  Name
name' <- String -> ImpM GPUMem HostEnv HostOp Name
forall rep r op. String -> ImpM rep r op Name
nameForFun (String -> ImpM GPUMem HostEnv HostOp Name)
-> String -> ImpM GPUMem HostEnv HostOp Name
forall a b. (a -> b) -> a -> b
$ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag VName
v)
  KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name' (InKernelGen () -> ImpM GPUMem HostEnv HostOp ())
-> InKernelGen () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    InKernelGen ()
set_constants
    VName -> TExp Int32 -> InKernelGen ()
forall t rep r op. VName -> TExp t -> ImpM rep r op ()
dPrimV_ VName
v (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
flatf KernelConstants
constants
    InKernelGen ()
f

sKernelThread ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelThread :: String
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelThread = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> String
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernel Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants -> TExp Int32
kernelGlobalThreadId

sKernelGroup ::
  String ->
  VName ->
  KernelAttrs ->
  InKernelGen () ->
  CallKernelGen ()
sKernelGroup :: String
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelGroup = Operations GPUMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> String
-> VName
-> KernelAttrs
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernel Operations GPUMem KernelEnv KernelOp
groupOperations KernelConstants -> TExp Int32
kernelGroupId

sKernelOp ::
  KernelAttrs ->
  KernelConstants ->
  Operations GPUMem KernelEnv Imp.KernelOp ->
  Name ->
  InKernelGen () ->
  CallKernelGen ()
sKernelOp :: KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name InKernelGen ()
m = do
  HostEnv AtomicBinOp
atomics Target
_ Map VName Locks
locks <- ImpM GPUMem HostEnv HostOp HostEnv
forall rep r op. ImpM rep r op r
askEnv
  Code KernelOp
body <- CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal (CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp))
-> CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a b. (a -> b) -> a -> b
$ KernelEnv
-> Operations GPUMem KernelEnv KernelOp
-> InKernelGen ()
-> CallKernelGen (Code KernelOp)
forall r' rep op' a r op.
r'
-> Operations rep r' op'
-> ImpM rep r' op' a
-> ImpM rep r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> Map VName Locks -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants Map VName Locks
locks) Operations GPUMem KernelEnv KernelOp
ops InKernelGen ()
m
  [KernelUse]
uses <- Code KernelOp -> [VName] -> CallKernelGen [KernelUse]
forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body [VName]
forall a. Monoid a => a
mempty
  Code HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> ImpM GPUMem HostEnv HostOp ())
-> (Kernel -> Code HostOp)
-> Kernel
-> ImpM GPUMem HostEnv HostOp ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp)
-> (Kernel -> HostOp) -> Kernel -> Code HostOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Kernel -> HostOp
Imp.CallKernel (Kernel -> ImpM GPUMem HostEnv HostOp ())
-> Kernel -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    Kernel :: Code KernelOp
-> [KernelUse]
-> [PrimExp VName]
-> [PrimExp VName]
-> Name
-> Bool
-> Bool
-> Kernel
Imp.Kernel
      { kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body,
        kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses,
        kernelNumGroups :: [PrimExp VName]
Imp.kernelNumGroups = [TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelNumGroups KernelConstants
constants],
        kernelGroupSize :: [PrimExp VName]
Imp.kernelGroupSize = [TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants],
        kernelName :: Name
Imp.kernelName = Name
name,
        kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = KernelAttrs -> Bool
kAttrFailureTolerant KernelAttrs
attrs,
        kernelCheckLocalMemory :: Bool
Imp.kernelCheckLocalMemory = KernelAttrs -> Bool
kAttrCheckLocalMemory KernelAttrs
attrs
      }

sKernelFailureTolerant ::
  Bool ->
  Operations GPUMem KernelEnv Imp.KernelOp ->
  KernelConstants ->
  Name ->
  InKernelGen () ->
  CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
tol Operations GPUMem KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
  KernelAttrs
-> KernelConstants
-> Operations GPUMem KernelEnv KernelOp
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelOp KernelAttrs
attrs KernelConstants
constants Operations GPUMem KernelEnv KernelOp
ops Name
name InKernelGen ()
m
  where
    attrs :: KernelAttrs
attrs =
      ( Count NumGroups SubExp -> Count GroupSize SubExp -> KernelAttrs
defKernelAttrs
          (KernelConstants -> Count NumGroups SubExp
kernelNumGroupsCount KernelConstants
constants)
          (KernelConstants -> Count GroupSize SubExp
kernelGroupSizeCount KernelConstants
constants)
      )
        { kAttrFailureTolerant :: Bool
kAttrFailureTolerant = Bool
tol
        }

copyInGroup :: CopyCompiler GPUMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLoc
destloc MemLoc
srcloc = do
  Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem KernelEnv KernelOp MemEntry
-> ImpM GPUMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
destloc)
  Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM GPUMem KernelEnv KernelOp MemEntry
-> ImpM GPUMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName MemLoc
srcloc)

  let src_ixfun :: IxFun
src_ixfun = MemLoc -> IxFun
memLocIxFun MemLoc
srcloc
      dims :: [TPrimExp Int64 VName]
dims = IxFun -> [TPrimExp Int64 VName]
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
IxFun.shape IxFun
src_ixfun
      rank :: Int
rank = [TPrimExp Int64 VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp Int64 VName]
dims

  case (Space
dest_space, Space
src_space) of
    (ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
      let fullDim :: d -> DimIndex d
fullDim d
d = d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1
          destslice' :: Slice (TPrimExp Int64 VName)
destslice' =
            [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              Int
-> DimIndex (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ Int
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
          srcslice' :: Slice (TPrimExp Int64 VName)
srcslice' =
            [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$
              Int
-> DimIndex (TPrimExp Int64 VName)
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix TPrimExp Int64 VName
0)
                [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. [a] -> [a] -> [a]
++ Int
-> [DimIndex (TPrimExp Int64 VName)]
-> [DimIndex (TPrimExp Int64 VName)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) ((TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall {d}. Num d => d -> DimIndex d
fullDim [TPrimExp Int64 VName]
dims)
      CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
copyElementWise
        PrimType
pt
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc Slice (TPrimExp Int64 VName)
destslice')
        (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc Slice (TPrimExp Int64 VName)
srcslice')
    (Space, Space)
_ -> do
      [TExp Int32] -> ([TExp Int32] -> InKernelGen ()) -> InKernelGen ()
forall t.
IntExp t =>
[TExp t] -> ([TExp t] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace ((TPrimExp Int64 VName -> TExp Int32)
-> [TPrimExp Int64 VName] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 [TPrimExp Int64 VName]
dims) (([TExp Int32] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int32] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int32]
is ->
        CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
copyElementWise
          PrimType
pt
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
destloc ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32 -> DimIndex (TPrimExp Int64 VName))
-> [TExp Int32] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
          (MemLoc -> Slice (TPrimExp Int64 VName) -> MemLoc
sliceMemLoc MemLoc
srcloc ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> [DimIndex (TPrimExp Int64 VName)]
-> Slice (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ (TExp Int32 -> DimIndex (TPrimExp Int64 VName))
-> [TExp Int32] -> [DimIndex (TPrimExp Int64 VName)]
forall a b. (a -> b) -> [a] -> [b]
map (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall d. d -> DimIndex d
DimFix (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32
-> DimIndex (TPrimExp Int64 VName)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64) [TExp Int32]
is))
      KernelOp -> InKernelGen ()
forall op rep r. op -> ImpM rep r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal

threadOperations, groupOperations :: Operations GPUMem KernelEnv Imp.KernelOp
threadOperations :: Operations GPUMem KernelEnv KernelOp
threadOperations =
  (OpCompiler GPUMem KernelEnv KernelOp
-> Operations GPUMem KernelEnv KernelOp
forall rep inner op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileThreadOp)
    { opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler GPUMem KernelEnv KernelOp
forall rep r op. CopyCompiler rep r op
copyElementWise,
      opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileThreadExp,
      opsStmsCompiler :: Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
        [(Space, AllocCompiler GPUMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String -> Space
Space String
"local", AllocCompiler GPUMem KernelEnv KernelOp
forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
    }
groupOperations :: Operations GPUMem KernelEnv KernelOp
groupOperations =
  (OpCompiler GPUMem KernelEnv KernelOp
-> Operations GPUMem KernelEnv KernelOp
forall rep inner op r.
(Mem rep inner, FreeIn op) =>
OpCompiler rep r op -> Operations rep r op
defaultOperations OpCompiler GPUMem KernelEnv KernelOp
compileGroupOp)
    { opsCopyCompiler :: CopyCompiler GPUMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler GPUMem KernelEnv KernelOp
copyInGroup,
      opsExpCompiler :: ExpCompiler GPUMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler GPUMem KernelEnv KernelOp
compileGroupExp,
      opsStmsCompiler :: Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms GPUMem -> InKernelGen () -> InKernelGen ()
forall rep inner op r.
(Mem rep inner, FreeIn op) =>
Names -> Stms rep -> ImpM rep r op () -> ImpM rep r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
      opsAllocCompilers :: Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
opsAllocCompilers =
        [(Space, AllocCompiler GPUMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler GPUMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String -> Space
Space String
"local", AllocCompiler GPUMem KernelEnv KernelOp
forall r. AllocCompiler GPUMem r KernelOp
allocLocal)]
    }

-- | Perform a Replicate with a kernel.
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se = do
  TypeBase Shape NoUniqueness
t <- SubExp -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall t (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase Shape NoUniqueness)
subExpType SubExp
se
  [SubExp]
ds <- Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
dropLast (TypeBase Shape NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase Shape NoUniqueness
t) ([SubExp] -> [SubExp])
-> (TypeBase Shape NoUniqueness -> [SubExp])
-> TypeBase Shape NoUniqueness
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (TypeBase Shape NoUniqueness -> [SubExp])
-> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
-> ImpM GPUMem HostEnv HostOp [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
arr

  let dims :: [TPrimExp Int64 VName]
dims = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ds [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t
  TPrimExp Int64 VName
n <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"replicate_n" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 [TPrimExp Int64 VName]
dims
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> String
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n String
"replicate"

  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
          String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
            String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM GPUMem HostEnv HostOp ())
-> InKernelGen () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> do
      [TPrimExp Int64 VName]
is' <- String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' String
"rep_i" [TPrimExp Int64 VName]
dims TPrimExp Int64 VName
gtid
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
arr [TPrimExp Int64 VName]
is' SubExp
se ([TPrimExp Int64 VName] -> InKernelGen ())
-> [TPrimExp Int64 VName] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          Int -> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [TPrimExp Int64 VName]
is'

replicateName :: PrimType -> String
replicateName :: PrimType -> String
replicateName PrimType
bt = String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt

replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> ImpM GPUMem HostEnv HostOp Name
replicateForType PrimType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
replicateName PrimType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
    VName
num_elems <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"num_elems"
    VName
val <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"val"

    let params :: [Param]
params =
          [ VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
            VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int64,
            VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt
          ]
        shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
    Name
-> [Param]
-> [Param]
-> ImpM GPUMem HostEnv HostOp ()
-> ImpM GPUMem HostEnv HostOp ()
forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [] [Param]
params (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
      VName
arr <-
        String
-> PrimType
-> Shape
-> VName
-> IxFun
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray String
"arr" PrimType
bt Shape
shape VName
mem (IxFun -> ImpM GPUMem HostEnv HostOp VName)
-> IxFun -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicateKernel VName
arr (SubExp -> ImpM GPUMem HostEnv HostOp ())
-> SubExp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val

  Name -> ImpM GPUMem HostEnv HostOp Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName
-> SubExp -> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
v = do
  ArrayEntry (MemLoc VName
arr_mem [SubExp]
arr_shape IxFun
arr_ixfun) PrimType
_ <- VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  TypeBase Shape NoUniqueness
v_t <- SubExp -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall t (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase Shape NoUniqueness)
subExpType SubExp
v
  case TypeBase Shape NoUniqueness
v_t of
    Prim PrimType
v_t'
      | IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
arr_ixfun -> Maybe (ImpM GPUMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (ImpM GPUMem HostEnv HostOp ())
 -> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ())))
-> Maybe (ImpM GPUMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
forall a b. (a -> b) -> a -> b
$
          ImpM GPUMem HostEnv HostOp ()
-> Maybe (ImpM GPUMem HostEnv HostOp ())
forall a. a -> Maybe a
Just (ImpM GPUMem HostEnv HostOp ()
 -> Maybe (ImpM GPUMem HostEnv HostOp ()))
-> ImpM GPUMem HostEnv HostOp ()
-> Maybe (ImpM GPUMem HostEnv HostOp ())
forall a b. (a -> b) -> a -> b
$ do
            Name
fname <- PrimType -> ImpM GPUMem HostEnv HostOp Name
replicateForType PrimType
v_t'
            Code HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> ImpM GPUMem HostEnv HostOp ())
-> Code HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
              [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
                []
                Name
fname
                [ VName -> Arg
Imp.MemArg VName
arr_mem,
                  PrimExp VName -> Arg
Imp.ExpArg (PrimExp VName -> Arg) -> PrimExp VName -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
arr_shape,
                  PrimExp VName -> Arg
Imp.ExpArg (PrimExp VName -> Arg) -> PrimExp VName -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp VName
forall a. ToExp a => PrimType -> a -> PrimExp VName
toExp' PrimType
v_t' SubExp
v
                ]
    TypeBase Shape NoUniqueness
_ -> Maybe (ImpM GPUMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (ImpM GPUMem HostEnv HostOp ())
forall a. Maybe a
Nothing

-- | Perform a Replicate with a kernel.
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicate VName
arr SubExp
se = do
  -- If the replicate is of a particularly common and simple form
  -- (morally a memset()/fill), then we use a common function.
  Maybe (ImpM GPUMem HostEnv HostOp ())
is_fill <- VName
-> SubExp -> CallKernelGen (Maybe (ImpM GPUMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
se

  case Maybe (ImpM GPUMem HostEnv HostOp ())
is_fill of
    Just ImpM GPUMem HostEnv HostOp ()
m -> ImpM GPUMem HostEnv HostOp ()
m
    Maybe (ImpM GPUMem HostEnv HostOp ())
Nothing -> VName -> SubExp -> ImpM GPUMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se

-- | Perform an Iota with a kernel.
sIotaKernel ::
  VName ->
  Imp.TExp Int64 ->
  Imp.Exp ->
  Imp.Exp ->
  IntType ->
  CallKernelGen ()
sIotaKernel :: VName
-> TPrimExp Int64 VName
-> PrimExp VName
-> PrimExp VName
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIotaKernel VName
arr TPrimExp Int64 VName
n PrimExp VName
x PrimExp VName
s IntType
et = do
  MemLoc
destloc <- ArrayEntry -> MemLoc
entryArrayLoc (ArrayEntry -> MemLoc)
-> ImpM GPUMem HostEnv HostOp ArrayEntry
-> ImpM GPUMem HostEnv HostOp MemLoc
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> String
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n String
"iota"

  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
          String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
            String
"iota_"
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
et
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_"
              String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM GPUMem HostEnv HostOp ())
-> InKernelGen () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid ->
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        (VName
destmem, Space
destspace, Count Elements (TPrimExp Int64 VName)
destidx) <- MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
destloc [TPrimExp Int64 VName
gtid]

        Code KernelOp -> InKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> PrimExp VName
-> Code KernelOp
forall a.
VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> PrimExp VName
-> Code a
Imp.Write VName
destmem Count Elements (TPrimExp Int64 VName)
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile (PrimExp VName -> Code KernelOp) -> PrimExp VName -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
            BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp
              (IntType -> Overflow -> BinOp
Add IntType
et Overflow
OverflowWrap)
              (BinOp -> PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
et Overflow
OverflowWrap) (IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
Imp.sExt IntType
et (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
gtid) PrimExp VName
s)
              PrimExp VName
x

iotaName :: IntType -> String
iotaName :: IntType -> String
iotaName IntType
bt = String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
bt

iotaForType :: IntType -> CallKernelGen Name
iotaForType :: IntType -> ImpM GPUMem HostEnv HostOp Name
iotaForType IntType
bt = do
  let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> IntType -> String
iotaName IntType
bt

  Bool
exists <- Name -> ImpM GPUMem HostEnv HostOp Bool
forall rep r op. Name -> ImpM rep r op Bool
hasFunction Name
fname
  Bool
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
    VName
mem <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
    VName
n <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"n"
    VName
x <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
    VName
s <- String -> ImpM GPUMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"s"

    let params :: [Param]
params =
          [ VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
            VName -> PrimType -> Param
Imp.ScalarParam VName
n PrimType
int32,
            VName -> PrimType -> Param
Imp.ScalarParam VName
x (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt,
            VName -> PrimType -> Param
Imp.ScalarParam VName
s (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
          ]
        shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
n]
        n' :: TPrimExp Int64 VName
n' = VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 VName
n
        x' :: PrimExp VName
x' = VName -> PrimType -> PrimExp VName
Imp.var VName
x (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
        s' :: PrimExp VName
s' = VName -> PrimType -> PrimExp VName
Imp.var VName
s (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt

    Name
-> [Param]
-> [Param]
-> ImpM GPUMem HostEnv HostOp ()
-> ImpM GPUMem HostEnv HostOp ()
forall rep r op.
Name -> [Param] -> [Param] -> ImpM rep r op () -> ImpM rep r op ()
function Name
fname [] [Param]
params (ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ())
-> ImpM GPUMem HostEnv HostOp () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
      VName
arr <-
        String
-> PrimType
-> Shape
-> VName
-> IxFun
-> ImpM GPUMem HostEnv HostOp VName
forall rep r op.
String
-> PrimType -> Shape -> VName -> IxFun -> ImpM rep r op VName
sArray String
"arr" (IntType -> PrimType
IntType IntType
bt) Shape
shape VName
mem (IxFun -> ImpM GPUMem HostEnv HostOp VName)
-> IxFun -> ImpM GPUMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
          [TPrimExp Int64 VName] -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota ([TPrimExp Int64 VName] -> IxFun)
-> [TPrimExp Int64 VName] -> IxFun
forall a b. (a -> b) -> a -> b
$
            (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
              Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
      VName
-> TPrimExp Int64 VName
-> PrimExp VName
-> PrimExp VName
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIotaKernel VName
arr (TPrimExp Int64 VName -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TPrimExp Int64 VName
n') PrimExp VName
x' PrimExp VName
s' IntType
bt

  Name -> ImpM GPUMem HostEnv HostOp Name
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname

-- | Perform an Iota with a kernel.
sIota ::
  VName ->
  Imp.TExp Int64 ->
  Imp.Exp ->
  Imp.Exp ->
  IntType ->
  CallKernelGen ()
sIota :: VName
-> TPrimExp Int64 VName
-> PrimExp VName
-> PrimExp VName
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIota VName
arr TPrimExp Int64 VName
n PrimExp VName
x PrimExp VName
s IntType
et = do
  ArrayEntry (MemLoc VName
arr_mem [SubExp]
_ IxFun
arr_ixfun) PrimType
_ <- VName -> ImpM GPUMem HostEnv HostOp ArrayEntry
forall rep r op. VName -> ImpM rep r op ArrayEntry
lookupArray VName
arr
  if IxFun -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun
arr_ixfun
    then do
      Name
fname <- IntType -> ImpM GPUMem HostEnv HostOp Name
iotaForType IntType
et
      Code HostOp -> ImpM GPUMem HostEnv HostOp ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code HostOp -> ImpM GPUMem HostEnv HostOp ())
-> Code HostOp -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
        [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
          []
          Name
fname
          [VName -> Arg
Imp.MemArg VName
arr_mem, PrimExp VName -> Arg
Imp.ExpArg (PrimExp VName -> Arg) -> PrimExp VName -> Arg
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp Int64 VName
n, PrimExp VName -> Arg
Imp.ExpArg PrimExp VName
x, PrimExp VName -> Arg
Imp.ExpArg PrimExp VName
s]
    else VName
-> TPrimExp Int64 VName
-> PrimExp VName
-> PrimExp VName
-> IntType
-> ImpM GPUMem HostEnv HostOp ()
sIotaKernel VName
arr TPrimExp Int64 VName
n PrimExp VName
x PrimExp VName
s IntType
et

sCopy :: CopyCompiler GPUMem HostEnv Imp.HostOp
sCopy :: CopyCompiler GPUMem HostEnv HostOp
sCopy PrimType
pt destloc :: MemLoc
destloc@(MemLoc VName
destmem [SubExp]
_ IxFun
_) srcloc :: MemLoc
srcloc@(MemLoc VName
srcmem [SubExp]
srcdims IxFun
_) = do
  -- Note that the shape of the destination and the source are
  -- necessarily the same.
  let shape :: [TPrimExp Int64 VName]
shape = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
srcdims
      kernel_size :: TPrimExp Int64 VName
kernel_size = [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
shape

  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> String
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
kernel_size String
"copy"

  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
          String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
            String
"copy_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM GPUMem HostEnv HostOp ())
-> InKernelGen () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> do
      [TPrimExp Int64 VName]
is <- String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' String
"copy_i" [TPrimExp Int64 VName]
shape TPrimExp Int64 VName
gtid

      (VName
_, Space
destspace, Count Elements (TPrimExp Int64 VName)
destidx) <- MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
destloc [TPrimExp Int64 VName]
is
      (VName
_, Space
srcspace, Count Elements (TPrimExp Int64 VName)
srcidx) <- MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     GPUMem
     KernelEnv
     KernelOp
     (VName, Space, Count Elements (TPrimExp Int64 VName))
forall rep r op.
MemLoc
-> [TPrimExp Int64 VName]
-> ImpM
     rep r op (VName, Space, Count Elements (TPrimExp Int64 VName))
fullyIndexArray' MemLoc
srcloc [TPrimExp Int64 VName]
is

      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
kernel_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
        VName
tmp <- TV Any -> VName
forall t. TV t -> VName
tvVar (TV Any -> VName)
-> ImpM GPUMem KernelEnv KernelOp (TV Any)
-> ImpM GPUMem KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> ImpM GPUMem KernelEnv KernelOp (TV Any)
forall rep r op t. String -> PrimType -> ImpM rep r op (TV t)
dPrim String
"tmp" PrimType
pt
        Code KernelOp -> InKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Code KernelOp
forall a.
VName
-> VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> Code a
Imp.Read VName
tmp VName
srcmem Count Elements (TPrimExp Int64 VName)
srcidx PrimType
pt Space
srcspace Volatility
Imp.Nonvolatile
        Code KernelOp -> InKernelGen ()
forall op rep r. Code op -> ImpM rep r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> PrimExp VName
-> Code KernelOp
forall a.
VName
-> Count Elements (TPrimExp Int64 VName)
-> PrimType
-> Space
-> Volatility
-> PrimExp VName
-> Code a
Imp.Write VName
destmem Count Elements (TPrimExp Int64 VName)
destidx PrimType
pt Space
destspace Volatility
Imp.Nonvolatile (PrimExp VName -> Code KernelOp) -> PrimExp VName -> Code KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
Imp.var VName
tmp PrimType
pt

-- | Perform a Rotate with a kernel.
sRotateKernel :: VName -> [Imp.TExp Int64] -> VName -> CallKernelGen ()
sRotateKernel :: VName
-> [TPrimExp Int64 VName] -> VName -> ImpM GPUMem HostEnv HostOp ()
sRotateKernel VName
dest [TPrimExp Int64 VName]
rs VName
src = do
  TypeBase Shape NoUniqueness
t <- VName -> ImpM GPUMem HostEnv HostOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
src
  let ds :: [TPrimExp Int64 VName]
ds = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> [TPrimExp Int64 VName])
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
t
  TPrimExp Int64 VName
n <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"rotate_n" (TPrimExp Int64 VName
 -> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem HostEnv HostOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TPrimExp Int64 VName]
ds
  ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise, KernelConstants
constants) <- TPrimExp Int64 VName
-> String
-> CallKernelGen
     ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen (),
      KernelConstants)
simpleKernelConstants TPrimExp Int64 VName
n String
"rotate"

  Maybe Name
fname <- ImpM GPUMem HostEnv HostOp (Maybe Name)
forall rep r op. ImpM rep r op (Maybe Name)
askFunction
  let name :: Name
name =
        Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
          String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
            String
"rotate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)

  Bool
-> Operations GPUMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM GPUMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations GPUMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM GPUMem HostEnv HostOp ())
-> InKernelGen () -> ImpM GPUMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
    (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
virtualise ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
gtid -> TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
gtid TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp Int64 VName
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
      [TPrimExp Int64 VName]
is' <- String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' String
"rep_i" [TPrimExp Int64 VName]
ds TPrimExp Int64 VName
gtid
      [TPrimExp Int64 VName]
is'' <- [ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
-> InKernelGen [TPrimExp Int64 VName]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence ([ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
 -> InKernelGen [TPrimExp Int64 VName])
-> [ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
-> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName
 -> TPrimExp Int64 VName
 -> TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall {rep} {r} {op}.
TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate [TPrimExp Int64 VName]
ds [TPrimExp Int64 VName]
rs [TPrimExp Int64 VName]
is'
      VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix VName
dest [TPrimExp Int64 VName]
is' (VName -> SubExp
Var VName
src) [TPrimExp Int64 VName]
is''
  where
    rotate :: TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> ImpM rep r op (TPrimExp Int64 VName)
rotate TPrimExp Int64 VName
d TPrimExp Int64 VName
r TPrimExp Int64 VName
i = String
-> TPrimExp Int64 VName -> ImpM rep r op (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"rot_i" (TPrimExp Int64 VName -> ImpM rep r op (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> ImpM rep r op (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
rotateIndex TPrimExp Int64 VName
d TPrimExp Int64 VName
r TPrimExp Int64 VName
i

compileGroupResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileGroupResult :: SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
_ PatElem LParamMem
pe (TileReturns Certs
_ [(SubExp
w, SubExp
per_group_elems)] VName
what) = do
  TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (TypeBase Shape NoUniqueness -> SubExp)
-> TypeBase Shape NoUniqueness
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
what

  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let ltid :: TPrimExp Int64 VName
ltid = TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
      offset :: TPrimExp Int64 VName
offset =
        SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems
          TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)

  -- Avoid loop for the common case where each thread is statically
  -- known to write at most one element.
  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    if SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_group_elems TPrimExp Int64 VName -> TPrimExp Int64 VName -> Bool
forall a. Eq a => a -> a -> Bool
== KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants
      then
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName
ltid TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
ltid]
      else String
-> TPrimExp Int64 VName
-> (TPrimExp Int64 VName -> InKernelGen ())
-> InKernelGen ()
forall t rep r op.
String
-> TExp t -> (TExp t -> ImpM rep r op ()) -> ImpM rep r op ()
sFor String
"i" (TPrimExp Int64 VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants) ((TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ())
-> (TPrimExp Int64 VName -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TPrimExp Int64 VName
i -> do
        TPrimExp Int64 VName
j <- String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"j" (TPrimExp Int64 VName
 -> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName))
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TPrimExp Int64 VName
kernelGroupSize KernelConstants
constants TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
ltid
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (TPrimExp Int64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName
j TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
offset] (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName
j]
compileGroupResult SegSpace
space PatElem LParamMem
pe (TileReturns Certs
_ [(SubExp, SubExp)]
dims VName
what) = do
  let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      out_tile_sizes :: [TPrimExp Int64 VName]
out_tile_sizes = ((SubExp, SubExp) -> TPrimExp Int64 VName)
-> [(SubExp, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> ((SubExp, SubExp) -> SubExp)
-> (SubExp, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
      group_is :: [TPrimExp Int64 VName]
group_is = (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(*) ((VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids) [TPrimExp Int64 VName]
out_tile_sizes
  [TPrimExp Int64 VName]
local_is <- [SubExp] -> InKernelGen [TPrimExp Int64 VName]
localThreadIDs ([SubExp] -> InKernelGen [TPrimExp Int64 VName])
-> [SubExp] -> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
  [TV Int64]
is_for_thread <-
    (TPrimExp Int64 VName -> ImpM GPUMem KernelEnv KernelOp (TV Int64))
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> TPrimExp Int64 VName
-> ImpM GPUMem KernelEnv KernelOp (TV Int64)
forall t rep r op. String -> TExp t -> ImpM rep r op (TV t)
dPrimV String
"thread_out_index") ([TPrimExp Int64 VName]
 -> ImpM GPUMem KernelEnv KernelOp [TV Int64])
-> [TPrimExp Int64 VName]
-> ImpM GPUMem KernelEnv KernelOp [TV Int64]
forall a b. (a -> b) -> a -> b
$
      (TPrimExp Int64 VName
 -> TPrimExp Int64 VName -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
(+) [TPrimExp Int64 VName]
group_is [TPrimExp Int64 VName]
local_is

  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((TV Int64 -> VName) -> [TV Int64] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> VName
forall t. TV t -> VName
tvVar [TV Int64]
is_for_thread) ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
      VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) ((TV Int64 -> TPrimExp Int64 VName)
-> [TV Int64] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> TPrimExp Int64 VName
forall t. TV t -> TExp t
tvExp [TV Int64]
is_for_thread) (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
local_is
compileGroupResult SegSpace
space PatElem LParamMem
pe (RegTileReturns Certs
_ [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv

  let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
      ([SubExp]
dims, [SubExp]
group_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
      group_tiles' :: [TPrimExp Int64 VName]
group_tiles' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
group_tiles
      reg_tiles' :: [TPrimExp Int64 VName]
reg_tiles' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
reg_tiles

  -- Which group tile is this group responsible for?
  let group_tile_is :: [TPrimExp Int64 VName]
group_tile_is = (VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 [VName]
gids

  -- Within the group tile, which register tile is this thread
  -- responsible for?
  [TPrimExp Int64 VName]
reg_tile_is <-
    String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> InKernelGen [TPrimExp Int64 VName]
forall rep r op.
String
-> [TPrimExp Int64 VName]
-> TPrimExp Int64 VName
-> ImpM rep r op [TPrimExp Int64 VName]
dIndexSpace' String
"reg_tile_i" [TPrimExp Int64 VName]
group_tiles' (TPrimExp Int64 VName -> InKernelGen [TPrimExp Int64 VName])
-> TPrimExp Int64 VName -> InKernelGen [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> TExp Int32 -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants

  -- Compute output array slice for the register tile belonging to
  -- this thread.
  let regTileSliceDim :: (TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim (TExp t
group_tile, TExp t
group_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
        TExp t
tile_dim_start <-
          String -> TExp t -> ImpM rep r op (TExp t)
forall t rep r op. String -> TExp t -> ImpM rep r op (TExp t)
dPrimVE String
"tile_dim_start" (TExp t -> ImpM rep r op (TExp t))
-> TExp t -> ImpM rep r op (TExp t)
forall a b. (a -> b) -> a -> b
$
            TExp t
reg_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* (TExp t
group_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
group_tile_i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
        DimIndex (TExp t) -> ImpM rep r op (DimIndex (TExp t))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DimIndex (TExp t) -> ImpM rep r op (DimIndex (TExp t)))
-> DimIndex (TExp t) -> ImpM rep r op (DimIndex (TExp t))
forall a b. (a -> b) -> a -> b
$ TExp t -> TExp t -> TExp t -> DimIndex (TExp t)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp t
tile_dim_start TExp t
reg_tile TExp t
1
  Slice (TPrimExp Int64 VName)
reg_tile_slices <-
    [DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName)
forall d. [DimIndex d] -> Slice d
Slice
      ([DimIndex (TPrimExp Int64 VName)] -> Slice (TPrimExp Int64 VName))
-> ImpM GPUMem KernelEnv KernelOp [DimIndex (TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp (Slice (TPrimExp Int64 VName))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> (TPrimExp Int64 VName, TPrimExp Int64 VName)
 -> ImpM
      GPUMem KernelEnv KernelOp (DimIndex (TPrimExp Int64 VName)))
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
-> ImpM GPUMem KernelEnv KernelOp [DimIndex (TPrimExp Int64 VName)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
        (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> (TPrimExp Int64 VName, TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (DimIndex (TPrimExp Int64 VName))
forall {t} {rep} {r} {op}.
NumExp t =>
(TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM rep r op (DimIndex (TExp t))
regTileSliceDim
        ([TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
group_tiles' [TPrimExp Int64 VName]
group_tile_is)
        ([TPrimExp Int64 VName]
-> [TPrimExp Int64 VName]
-> [(TPrimExp Int64 VName, TPrimExp Int64 VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TPrimExp Int64 VName]
reg_tiles' [TPrimExp Int64 VName]
reg_tile_is)

  Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
    Shape
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall rep r op.
Shape
-> ([TPrimExp Int64 VName] -> ImpM rep r op ()) -> ImpM rep r op ()
sLoopNest ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
reg_tiles) (([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ())
-> ([TPrimExp Int64 VName] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TPrimExp Int64 VName]
is_in_reg_tile -> do
      let dest_is :: [TPrimExp Int64 VName]
dest_is = Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice Slice (TPrimExp Int64 VName)
reg_tile_slices [TPrimExp Int64 VName]
is_in_reg_tile
          src_is :: [TPrimExp Int64 VName]
src_is = [TPrimExp Int64 VName]
reg_tile_is [TPrimExp Int64 VName]
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a. [a] -> [a] -> [a]
++ [TPrimExp Int64 VName]
is_in_reg_tile
      TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen ((TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TPrimExp Int64 VName -> TPrimExp Int64 VName -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TPrimExp Int64 VName]
dest_is ([TPrimExp Int64 VName] -> [TExp Bool])
-> [TPrimExp Int64 VName] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName]
dest_is (VName -> SubExp
Var VName
what) [TPrimExp Int64 VName]
src_is
compileGroupResult SegSpace
space PatElem LParamMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
  let gids :: [TPrimExp Int64 VName]
gids = ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space

  if Bool -> Bool
not Bool
in_local_memory
    then
      Operations GPUMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall rep r op a.
Operations rep r op -> ImpM rep r op a -> ImpM rep r op a
localOps Operations GPUMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
        TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
          VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
    else -- If the result of the group is an array in local memory, we
    -- store it by collective copying among all the threads of the
    -- group.  TODO: also do this if the array is in global memory
    -- (but this is a bit more tricky, synchronisation-wise).
      VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElem LParamMem
_ WriteReturns {} =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: WriteReturns not handled yet."
compileGroupResult SegSpace
_ PatElem LParamMem
_ ConcatReturns {} =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: ConcatReturns not handled yet."

compileThreadResult ::
  SegSpace ->
  PatElem LetDecMem ->
  KernelResult ->
  InKernelGen ()
compileThreadResult :: SegSpace -> PatElem LParamMem -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
_ PatElem LParamMem
_ RegTileReturns {} =
  String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileThreadResult: RegTileReturns not yet handled."
compileThreadResult SegSpace
space PatElem LParamMem
pe (Returns ResultManifest
_ Certs
_ SubExp
what) = do
  let is :: [TPrimExp Int64 VName]
is = ((VName, SubExp) -> TPrimExp Int64 VName)
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
Imp.le64 (VName -> TPrimExp Int64 VName)
-> ((VName, SubExp) -> VName)
-> (VName, SubExp)
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TPrimExp Int64 VName])
-> [(VName, SubExp)] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
  VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> InKernelGen ()
forall rep r op.
VName
-> [TPrimExp Int64 VName]
-> SubExp
-> [TPrimExp Int64 VName]
-> ImpM rep r op ()
copyDWIMFix (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName]
is SubExp
what []
compileThreadResult SegSpace
_ PatElem LParamMem
pe (ConcatReturns Certs
_ SplitOrdering
SplitContiguous SubExp
_ SubExp
per_thread_elems VName
what) = do
  KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  let offset :: TPrimExp Int64 VName
offset =
        SubExp -> TPrimExp Int64 VName
pe64 SubExp
per_thread_elems
          TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants)
  TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (TypeBase Shape NoUniqueness -> SubExp)
-> TypeBase Shape NoUniqueness
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
what
  VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
offset TPrimExp Int64 VName
n TPrimExp Int64 VName
1] (VName -> SubExp
Var VName
what) []
compileThreadResult SegSpace
_ PatElem LParamMem
pe (ConcatReturns Certs
_ (SplitStrided SubExp
stride) SubExp
_ SubExp
_ VName
what) = do
  TPrimExp Int64 VName
offset <- TExp Int32 -> TPrimExp Int64 VName
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TPrimExp Int64 VName)
-> (KernelEnv -> TExp Int32) -> KernelEnv -> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelGlobalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp KernelEnv
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM GPUMem KernelEnv KernelOp KernelEnv
forall rep r op. ImpM rep r op r
askEnv
  TPrimExp Int64 VName
n <- SubExp -> TPrimExp Int64 VName
pe64 (SubExp -> TPrimExp Int64 VName)
-> (TypeBase Shape NoUniqueness -> SubExp)
-> TypeBase Shape NoUniqueness
-> TPrimExp Int64 VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase Shape NoUniqueness -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (TypeBase Shape NoUniqueness -> TPrimExp Int64 VName)
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
-> ImpM GPUMem KernelEnv KernelOp (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM GPUMem KernelEnv KernelOp (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
what
  VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) [TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> TPrimExp Int64 VName
-> DimIndex (TPrimExp Int64 VName)
forall d. d -> d -> d -> DimIndex d
DimSlice TPrimExp Int64 VName
offset TPrimExp Int64 VName
n (TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName))
-> TPrimExp Int64 VName -> DimIndex (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int64 VName
pe64 SubExp
stride] (VName -> SubExp
Var VName
what) []
compileThreadResult SegSpace
_ PatElem LParamMem
pe (WriteReturns Certs
_ (Shape [SubExp]
rws) VName
_arr [(Slice SubExp, SubExp)]
dests) = do
  let rws' :: [TPrimExp Int64 VName]
rws' = (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
rws
  [(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
dests (((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
    let slice' :: Slice (TPrimExp Int64 VName)
slice' = (SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice
        write :: TExp Bool
write = Slice (TPrimExp Int64 VName) -> [TPrimExp Int64 VName] -> TExp Bool
inBounds Slice (TPrimExp Int64 VName)
slice' [TPrimExp Int64 VName]
rws'
    TExp Bool -> InKernelGen () -> InKernelGen ()
forall rep r op. TExp Bool -> ImpM rep r op () -> ImpM rep r op ()
sWhen TExp Bool
write (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> InKernelGen ()
forall rep r op.
VName
-> [DimIndex (TPrimExp Int64 VName)]
-> SubExp
-> [DimIndex (TPrimExp Int64 VName)]
-> ImpM rep r op ()
copyDWIM (PatElem LParamMem -> VName
forall dec. PatElem dec -> VName
patElemName PatElem LParamMem
pe) (Slice (TPrimExp Int64 VName) -> [DimIndex (TPrimExp Int64 VName)]
forall d. Slice d -> [DimIndex d]
unSlice Slice (TPrimExp Int64 VName)
slice') SubExp
e []
compileThreadResult SegSpace
_ PatElem LParamMem
_ TileReturns {} =
  String -> InKernelGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: TileReturns unhandled."

arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
  VarEntry GPUMem
res <- VName -> ImpM GPUMem KernelEnv KernelOp (VarEntry GPUMem)
forall rep r op. VName -> ImpM rep r op (VarEntry rep)
lookupVar VName
name
  case VarEntry GPUMem
res of
    ArrayVar Maybe (Exp GPUMem)
_ ArrayEntry
entry ->
      (String -> Space
Space String
"local" Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
==) (Space -> Bool) -> (MemEntry -> Space) -> MemEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
        (MemEntry -> Bool)
-> ImpM GPUMem KernelEnv KernelOp MemEntry -> InKernelGen Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM GPUMem KernelEnv KernelOp MemEntry
forall rep r op. VName -> ImpM rep r op MemEntry
lookupMemory (MemLoc -> VName
memLocName (ArrayEntry -> MemLoc
entryArrayLoc ArrayEntry
entry))
    VarEntry GPUMem
_ -> Bool -> InKernelGen Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
arrayInLocalMemory Constant {} = Bool -> InKernelGen Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False