{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.Base
( KernelConstants (..),
keyWithEntryPoint,
CallKernelGen,
InKernelGen,
HostEnv (..),
Target (..),
KernelEnv (..),
computeThreadChunkSize,
groupReduce,
groupScan,
isActive,
sKernelThread,
sKernelGroup,
sReplicate,
sIota,
sCopy,
compileThreadResult,
compileGroupResult,
virtualiseGroups,
groupLoop,
kernelLoop,
groupCoverSpace,
precomputeSegOpIDs,
atomicUpdateLocking,
AtomicBinOp,
Locking (..),
AtomicUpdate (..),
DoAtomicUpdate,
)
where
import Control.Monad.Except
import Data.List (elemIndex, find, zip4)
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Error
import Futhark.IR.KernelsMem
import qualified Futhark.IR.Mem.IxFun as IxFun
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Util (chunks, dropLast, mapAccumLM, maybeNth, nubOrd, takeLast)
import Futhark.Util.IntegralExp (divUp, quot, rem)
import Prelude hiding (quot, rem)
data Target = CUDA | OpenCL
data HostEnv = HostEnv
{ HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp,
HostEnv -> Target
hostTarget :: Target
}
data KernelEnv = KernelEnv
{ KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp,
KernelEnv -> KernelConstants
kernelConstants :: KernelConstants
}
type CallKernelGen = ImpM KernelsMem HostEnv Imp.HostOp
type InKernelGen = ImpM KernelsMem KernelEnv Imp.KernelOp
data KernelConstants = KernelConstants
{ KernelConstants -> TExp Int32
kernelGlobalThreadId :: Imp.TExp Int32,
KernelConstants -> TExp Int32
kernelLocalThreadId :: Imp.TExp Int32,
KernelConstants -> TExp Int32
kernelGroupId :: Imp.TExp Int32,
KernelConstants -> VName
kernelGlobalThreadIdVar :: VName,
KernelConstants -> VName
kernelLocalThreadIdVar :: VName,
KernelConstants -> VName
kernelGroupIdVar :: VName,
KernelConstants -> TExp Int64
kernelNumGroups :: Imp.TExp Int64,
KernelConstants -> TExp Int64
kernelGroupSize :: Imp.TExp Int64,
KernelConstants -> TExp Int32
kernelNumThreads :: Imp.TExp Int32,
KernelConstants -> TExp Int32
kernelWaveSize :: Imp.TExp Int32,
KernelConstants -> TExp Bool
kernelThreadActive :: Imp.TExp Bool,
KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap :: M.Map [SubExp] [Imp.TExp Int32]
}
segOpSizes :: Stms KernelsMem -> S.Set [SubExp]
segOpSizes :: Stms KernelsMem -> Set [SubExp]
segOpSizes = Stms KernelsMem -> Set [SubExp]
onStms
where
onStms :: Stms KernelsMem -> Set [SubExp]
onStms = (Stm KernelsMem -> Set [SubExp]) -> Stms KernelsMem -> Set [SubExp]
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (Exp KernelsMem -> Set [SubExp]
onExp (Exp KernelsMem -> Set [SubExp])
-> (Stm KernelsMem -> Exp KernelsMem)
-> Stm KernelsMem
-> Set [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm KernelsMem -> Exp KernelsMem
forall lore. Stm lore -> Exp lore
stmExp)
onExp :: Exp KernelsMem -> Set [SubExp]
onExp (Op (Inner (SegOp SegOp SegLevel KernelsMem
op))) =
[SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton ([SubExp] -> Set [SubExp]) -> [SubExp] -> Set [SubExp]
forall a b. (a -> b) -> a -> b
$ ((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(VName, SubExp)] -> [SubExp]) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace (SegSpace -> [(VName, SubExp)]) -> SegSpace -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel KernelsMem -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel KernelsMem
op
onExp (If SubExp
_ BodyT KernelsMem
tbranch BodyT KernelsMem
fbranch IfDec (BranchType KernelsMem)
_) =
Stms KernelsMem -> Set [SubExp]
onStms (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms BodyT KernelsMem
tbranch) Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Stms KernelsMem -> Set [SubExp]
onStms (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms BodyT KernelsMem
fbranch)
onExp (DoLoop [(FParam KernelsMem, SubExp)]
_ [(FParam KernelsMem, SubExp)]
_ LoopForm KernelsMem
_ BodyT KernelsMem
body) =
Stms KernelsMem -> Set [SubExp]
onStms (BodyT KernelsMem -> Stms KernelsMem
forall lore. BodyT lore -> Stms lore
bodyStms BodyT KernelsMem
body)
onExp Exp KernelsMem
_ = Set [SubExp]
forall a. Monoid a => a
mempty
precomputeSegOpIDs :: Stms KernelsMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs :: forall a. Stms KernelsMem -> InKernelGen a -> InKernelGen a
precomputeSegOpIDs Stms KernelsMem
stms InKernelGen a
m = do
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Map [SubExp] [TExp Int32]
new_ids <- [([SubExp], [TExp Int32])] -> Map [SubExp] [TExp Int32]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([([SubExp], [TExp Int32])] -> Map [SubExp] [TExp Int32])
-> ImpM KernelsMem KernelEnv KernelOp [([SubExp], [TExp Int32])]
-> ImpM KernelsMem KernelEnv KernelOp (Map [SubExp] [TExp Int32])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp]
-> ImpM KernelsMem KernelEnv KernelOp ([SubExp], [TExp Int32]))
-> [[SubExp]]
-> ImpM KernelsMem KernelEnv KernelOp [([SubExp], [TExp Int32])]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (TExp Int32
-> [SubExp]
-> ImpM KernelsMem KernelEnv KernelOp ([SubExp], [TExp Int32])
forall {a} {lore} {r} {op}.
ToExp a =>
TExp Int32 -> [a] -> ImpM lore r op ([a], [TExp Int32])
mkMap TExp Int32
ltid) (Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList (Stms KernelsMem -> Set [SubExp]
segOpSizes Stms KernelsMem
stms))
let f :: KernelEnv -> KernelEnv
f KernelEnv
env =
KernelEnv
env
{ kernelConstants :: KernelConstants
kernelConstants =
(KernelEnv -> KernelConstants
kernelConstants KernelEnv
env) {kernelLocalIdMap :: Map [SubExp] [TExp Int32]
kernelLocalIdMap = Map [SubExp] [TExp Int32]
new_ids}
}
(KernelEnv -> KernelEnv) -> InKernelGen a -> InKernelGen a
forall r lore op a.
(r -> r) -> ImpM lore r op a -> ImpM lore r op a
localEnv KernelEnv -> KernelEnv
f InKernelGen a
m
where
mkMap :: TExp Int32 -> [a] -> ImpM lore r op ([a], [TExp Int32])
mkMap TExp Int32
ltid [a]
dims = do
let dims' :: [TExp Int32]
dims' = (a -> TExp Int32) -> [a] -> [TExp Int32]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64 -> TExp Int32) -> (a -> TExp Int64) -> a -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) [a]
dims
[TExp Int32]
ids' <- (TExp Int32 -> ImpM lore r op (TExp Int32))
-> [TExp Int32] -> ImpM lore r op [TExp Int32]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TExp Int32 -> ImpM lore r op (TExp Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"ltid_pre") ([TExp Int32] -> ImpM lore r op [TExp Int32])
-> [TExp Int32] -> ImpM lore r op [TExp Int32]
forall a b. (a -> b) -> a -> b
$ [TExp Int32] -> TExp Int32 -> [TExp Int32]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int32]
dims' TExp Int32
ltid
([a], [TExp Int32]) -> ImpM lore r op ([a], [TExp Int32])
forall (m :: * -> *) a. Monad m => a -> m a
return ([a]
dims, [TExp Int32]
ids')
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String -> (Name -> String) -> Maybe Name -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ((String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
".") (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString) Maybe Name
fname String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
key
allocLocal :: AllocCompiler KernelsMem r Imp.KernelOp
allocLocal :: forall r. AllocCompiler KernelsMem r KernelOp
allocLocal VName
mem Count Bytes (TExp Int64)
size =
KernelOp -> ImpM KernelsMem r KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM KernelsMem r KernelOp ())
-> KernelOp -> ImpM KernelsMem r KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes (TExp Int64) -> KernelOp
Imp.LocalAlloc VName
mem Count Bytes (TExp Int64)
size
kernelAlloc ::
Pattern KernelsMem ->
SubExp ->
Space ->
InKernelGen ()
kernelAlloc :: Pattern KernelsMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
_]) SubExp
_ ScalarSpace {} =
() -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
kernelAlloc (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
mem]) SubExp
size (Space String
"local") =
AllocCompiler KernelsMem KernelEnv KernelOp
forall r. AllocCompiler KernelsMem r KernelOp
allocLocal (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
mem) (Count Bytes (TExp Int64) -> InKernelGen ())
-> Count Bytes (TExp Int64) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Count Bytes (TExp Int64)
forall a. a -> Count Bytes a
Imp.bytes (TExp Int64 -> Count Bytes (TExp Int64))
-> TExp Int64 -> Count Bytes (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
size
kernelAlloc (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
mem]) SubExp
_ Space
_ =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate memory block " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatElemT LParamMem -> String
forall a. Pretty a => a -> String
pretty PatElemT (LetDec KernelsMem)
PatElemT LParamMem
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in kernel."
kernelAlloc Pattern KernelsMem
dest SubExp
_ Space
_ =
String -> InKernelGen ()
forall a. HasCallStack => String -> a
error (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for in-kernel allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> String
forall a. Show a => a -> String
show Pattern KernelsMem
PatternT LParamMem
dest
splitSpace ::
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern KernelsMem ->
SplitOrdering ->
w ->
i ->
elems_per_thread ->
ImpM lore r op ()
splitSpace :: forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern KernelsMem
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace (Pattern [] [PatElemT (LetDec KernelsMem)
size]) SplitOrdering
o w
w i
i elems_per_thread
elems_per_thread = do
Count Elements (TExp Int64)
num_elements <- TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> (PrimExp ExpLeaf -> TExp Int64)
-> PrimExp ExpLeaf
-> Count Elements (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp ExpLeaf -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp ExpLeaf -> Count Elements (TExp Int64))
-> ImpM lore r op (PrimExp ExpLeaf)
-> ImpM lore r op (Count Elements (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> w -> ImpM lore r op (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp w
w
let i' :: TExp Int64
i' = i -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp i
i
Count Elements (TExp Int64)
elems_per_thread' <- TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements (TExp Int64 -> Count Elements (TExp Int64))
-> (PrimExp ExpLeaf -> TExp Int64)
-> PrimExp ExpLeaf
-> Count Elements (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimExp ExpLeaf -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp ExpLeaf -> Count Elements (TExp Int64))
-> ImpM lore r op (PrimExp ExpLeaf)
-> ImpM lore r op (Count Elements (TExp Int64))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> elems_per_thread -> ImpM lore r op (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp elems_per_thread
elems_per_thread
SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> ImpM lore r op ()
forall lore r op.
SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
o TExp Int64
i' Count Elements (TExp Int64)
elems_per_thread' Count Elements (TExp Int64)
num_elements (VName -> PrimType -> TV Int64
forall t. VName -> PrimType -> TV t
mkTV (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
size) PrimType
int64)
splitSpace Pattern KernelsMem
pat SplitOrdering
_ w
_ i
_ elems_per_thread
_ =
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for splitSpace: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LParamMem
pat
compileThreadExp :: ExpCompiler KernelsMem KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler KernelsMem KernelEnv KernelOp
compileThreadExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
[(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
dest) [Int64 -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileThreadExp Pattern KernelsMem
dest Exp KernelsMem
e =
ExpCompiler KernelsMem KernelEnv KernelOp
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern KernelsMem
dest Exp KernelsMem
e
kernelLoop ::
IntExp t =>
Imp.TExp t ->
Imp.TExp t ->
Imp.TExp t ->
(Imp.TExp t -> InKernelGen ()) ->
InKernelGen ()
kernelLoop :: forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop TExp t
tid TExp t
num_threads TExp t
n TExp t -> InKernelGen ()
f =
Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
if TExp t
n TExp t -> TExp t -> Bool
forall a. Eq a => a -> a -> Bool
== TExp t
num_threads
then TExp t -> InKernelGen ()
f TExp t
tid
else do
let elems_for_this :: TExp t
elems_for_this = (TExp t
n TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
- TExp t
tid) TExp t -> TExp t -> TExp t
forall e. IntegralExp e => e -> e -> e
`divUp` TExp t
num_threads
String -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp t
elems_for_this ((TExp t -> InKernelGen ()) -> InKernelGen ())
-> (TExp t -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp t
i -> TExp t -> InKernelGen ()
f (TExp t -> InKernelGen ()) -> TExp t -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ TExp t
i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
num_threads TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
tid
groupLoop ::
Imp.TExp Int64 ->
(Imp.TExp Int64 -> InKernelGen ()) ->
InKernelGen ()
groupLoop :: TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
groupLoop TExp Int64
n TExp Int64 -> InKernelGen ()
f = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
TExp Int64
-> TExp Int64
-> TExp Int64
-> (TExp Int64 -> InKernelGen ())
-> InKernelGen ()
forall t.
IntExp t =>
TExp t
-> TExp t -> TExp t -> (TExp t -> InKernelGen ()) -> InKernelGen ()
kernelLoop
(TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants)
(KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants)
TExp Int64
n
TExp Int64 -> InKernelGen ()
f
groupCoverSpace ::
[Imp.TExp Int64] ->
([Imp.TExp Int64] -> InKernelGen ()) ->
InKernelGen ()
groupCoverSpace :: [TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp Int64]
ds [TExp Int64] -> InKernelGen ()
f =
TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
groupLoop ([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
ds) ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> InKernelGen ()
f ([TExp Int64] -> InKernelGen ())
-> (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
ds
compileGroupExp :: ExpCompiler KernelsMem KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler KernelsMem KernelEnv KernelOp
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (ArrayLit [SubExp]
es TypeBase (ShapeBase SubExp) NoUniqueness
_)) =
[(Int64, SubExp)]
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int64] -> [SubExp] -> [(Int64, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int64
0 ..] [SubExp]
es) (((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int64, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
e) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
dest) [Int64 -> TExp Int64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int64
i :: Int64)] SubExp
e []
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (Replicate ShapeBase SubExp
ds SubExp
se)) = do
let ds' :: [TExp Int64]
ds' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
ds
[TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [TExp Int64]
ds' (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
dest) [TExp Int64]
is SubExp
se (Int -> [TExp Int64] -> [TExp Int64]
forall a. Int -> [a] -> [a]
drop (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
ds) [TExp Int64]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
it)) = do
PrimExp ExpLeaf
n' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp SubExp
n
PrimExp ExpLeaf
e' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp SubExp
e
PrimExp ExpLeaf
s' <- SubExp -> ImpM KernelsMem KernelEnv KernelOp (PrimExp ExpLeaf)
forall a lore r op.
ToExp a =>
a -> ImpM lore r op (PrimExp ExpLeaf)
toExp SubExp
s
TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
groupLoop (PrimExp ExpLeaf -> TExp Int64
forall t v. PrimExp v -> TPrimExp t v
TPrimExp PrimExp ExpLeaf
n') ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i' -> do
TV Any
x <-
String -> TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"x" (TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any))
-> TExp Any -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall a b. (a -> b) -> a -> b
$
PrimExp ExpLeaf -> TExp Any
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp ExpLeaf -> TExp Any) -> PrimExp ExpLeaf -> TExp Any
forall a b. (a -> b) -> a -> b
$
BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) PrimExp ExpLeaf
e' (PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$
BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) (TExp Int64 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
i') PrimExp ExpLeaf
s'
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
dest) [TExp Int64
i'] (VName -> SubExp
Var (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
x)) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetDec KernelsMem)]
_ [PatElemT (LetDec KernelsMem)
pe]) (BasicOp (Update VName
_ Slice SubExp
slice SubExp
se))
| [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([SubExp] -> Bool) -> [SubExp] -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice = do
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
ltid TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) ((DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice) SubExp
se []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp Pattern KernelsMem
dest Exp KernelsMem
e =
ExpCompiler KernelsMem KernelEnv KernelOp
forall lore r op.
Mem lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern KernelsMem
dest Exp KernelsMem
e
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel SegThread {} = () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sanityCheckLevel SegGroup {} =
String -> InKernelGen ()
forall a. HasCallStack => String -> a
error String
"compileGroupOp: unexpected group-level SegOp."
localThreadIDs :: [SubExp] -> InKernelGen [Imp.TExp Int64]
localThreadIDs :: [SubExp] -> InKernelGen [TExp Int64]
localThreadIDs [SubExp]
dims = do
TExp Int64
ltid <- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64)
-> (KernelEnv -> TExp Int32) -> KernelEnv -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int64)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
[TExp Int64]
-> ([TExp Int32] -> [TExp Int64])
-> Maybe [TExp Int32]
-> [TExp Int64]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe ([TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims' TExp Int64
ltid) ((TExp Int32 -> TExp Int64) -> [TExp Int32] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64)
(Maybe [TExp Int32] -> [TExp Int64])
-> (KernelEnv -> Maybe [TExp Int32]) -> KernelEnv -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Map [SubExp] [TExp Int32] -> Maybe [TExp Int32]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [SubExp]
dims
(Map [SubExp] [TExp Int32] -> Maybe [TExp Int32])
-> (KernelEnv -> Map [SubExp] [TExp Int32])
-> KernelEnv
-> Maybe [TExp Int32]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> Map [SubExp] [TExp Int32]
kernelLocalIdMap
(KernelConstants -> Map [SubExp] [TExp Int32])
-> (KernelEnv -> KernelConstants)
-> KernelEnv
-> Map [SubExp] [TExp Int32]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants
(KernelEnv -> [TExp Int64])
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> InKernelGen [TExp Int64]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space = do
SegLevel -> InKernelGen ()
sanityCheckLevel SegLevel
lvl
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
(VName -> TExp Int64 -> InKernelGen ())
-> [VName] -> [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> TExp Int64 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ [VName]
ltids ([TExp Int64] -> InKernelGen ())
-> InKernelGen [TExp Int64] -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [SubExp] -> InKernelGen [TExp Int64]
localThreadIDs [SubExp]
dims
TExp Int32
ltid <- KernelConstants -> TExp Int32
kernelLocalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int32)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int32)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
VName -> TExp Int32 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) TExp Int32
ltid
prepareIntraGroupSegHist ::
Count GroupSize SubExp ->
[HistOp KernelsMem] ->
InKernelGen [[Imp.TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp KernelsMem]
-> InKernelGen [[TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
((Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()])
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> InKernelGen [[TExp Int64] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> [[TExp Int64] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
-> InKernelGen [[TExp Int64] -> InKernelGen ()])
-> ([HistOp KernelsMem]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()]))
-> [HistOp KernelsMem]
-> InKernelGen [[TExp Int64] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
-> HistOp KernelsMem
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp KernelsMem]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [[TExp Int64] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp KernelsMem
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
where
onOp :: Maybe Locking
-> HistOp KernelsMem
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ())
onOp Maybe Locking
l HistOp KernelsMem
op = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let local_subhistos :: [VName]
local_subhistos = HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp KernelsMem
op
case (Maybe Locking
l, AtomicBinOp
-> Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp (Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv)
-> Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp KernelsMem -> Lambda KernelsMem
forall lore. HistOp lore -> Lambda lore
histOp HistOp KernelsMem
op) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate KernelsMem KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> (Maybe Locking, [TExp Int64] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate KernelsMem KernelEnv
f) -> do
VName
locks <- String -> ImpM KernelsMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"locks"
let num_locks :: TExp Int64
num_locks = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64) -> SubExp -> TExp Int64
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp KernelsMem -> ShapeBase SubExp
forall lore. HistOp lore -> ShapeBase SubExp
histShape HistOp KernelsMem
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [HistOp KernelsMem -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp KernelsMem
op]
l' :: Locking
l' = VName
-> TExp Int32
-> TExp Int32
-> TExp Int32
-> ([TExp Int64] -> [TExp Int64])
-> Locking
Locking VName
locks TExp Int32
0 TExp Int32
1 TExp Int32
0 (TExp Int64 -> [TExp Int64]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TExp Int64 -> [TExp Int64])
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
num_locks) (TExp Int64 -> TExp Int64)
-> ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [TExp Int64] -> [TExp Int64] -> TExp Int64
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [TExp Int64]
dims)
locks_t :: TypeBase (ShapeBase SubExp) NoUniqueness
locks_t = PrimType
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness
VName
locks_mem <- String
-> Count Bytes (TExp Int64)
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String -> Count Bytes (TExp Int64) -> Space -> ImpM lore r op VName
sAlloc String
"locks_mem" (TypeBase (ShapeBase SubExp) NoUniqueness
-> Count Bytes (TExp Int64)
typeSize TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
VName -> PrimType -> ShapeBase SubExp -> MemBind -> InKernelGen ()
forall lore r op.
VName
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op ()
dArray VName
locks PrimType
int32 (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
locks_t) (MemBind -> InKernelGen ()) -> MemBind -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
locks_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
locks_t
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants] (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
locks [TExp Int64]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
(Maybe Locking, [TExp Int64] -> InKernelGen ())
-> ImpM
KernelsMem
KernelEnv
KernelOp
(Maybe Locking, [TExp Int64] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate KernelsMem KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
whenActive :: SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive :: SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space InKernelGen ()
m
| SegVirt
SegNoVirtFull <- SegLevel -> SegVirt
segVirt SegLevel
lvl = InKernelGen ()
m
| Bool
otherwise = do
TExp Int64
group_size <- KernelConstants -> TExp Int64
kernelGroupSize (KernelConstants -> TExp Int64)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int64)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
if [TExp Int64
group_size] [TExp Int64] -> [TExp Int64] -> Bool
forall a. Eq a => a -> a -> Bool
== ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((VName, SubExp) -> SubExp) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space)
then InKernelGen ()
m
else TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) InKernelGen ()
m
compileGroupOp :: OpCompiler KernelsMem KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler KernelsMem KernelEnv KernelOp
compileGroupOp Pattern KernelsMem
pat (Alloc SubExp
size Space
space) =
Pattern KernelsMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern KernelsMem
pat SubExp
size Space
space
compileGroupOp Pattern KernelsMem
pat (Inner (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread))) =
Pattern KernelsMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern KernelsMem
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern KernelsMem
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegMap SegLevel
lvl SegSpace
space [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody KernelsMem
body))) = do
InKernelGen () -> InKernelGen ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(PatElemT LParamMem -> KernelResult -> InKernelGen ())
-> [PatElemT LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) (PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegScan SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
scans [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody KernelsMem
body))) = do
SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT LParamMem -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern KernelsMem
PatternT LParamMem
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix
VName
dest
((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
ltids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res)
[]
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
(TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64
segment_size)
TV Int64
dims_flat <- String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"dims_flat" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims'
let flattened :: PatElemT LParamMem -> ImpM KernelsMem KernelEnv KernelOp VName
flattened PatElemT LParamMem
pe = do
MemLocation VName
mem [SubExp]
_ IxFun (TExp Int64)
_ <-
ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem KernelEnv KernelOp ArrayEntry
-> ImpM KernelsMem KernelEnv KernelOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
let pe_t :: TypeBase (ShapeBase SubExp) NoUniqueness
pe_t = PatElemT LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall t. Typed t => t -> TypeBase (ShapeBase SubExp) NoUniqueness
typeOf PatElemT LParamMem
pe
arr_dims :: [SubExp]
arr_dims = VName -> SubExp
Var (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dims_flat) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop ([TExp Int64] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TExp Int64]
dims') (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
pe_t)
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray
(VName -> String
baseString (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_flat")
(TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
pe_t)
([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
arr_dims)
(MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
arr_dims
num_scan_results :: Int
num_scan_results = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
scans
[VName]
arrs_flat <- (PatElemT LParamMem -> ImpM KernelsMem KernelEnv KernelOp VName)
-> [PatElemT LParamMem]
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM PatElemT LParamMem -> ImpM KernelsMem KernelEnv KernelOp VName
flattened ([PatElemT LParamMem]
-> ImpM KernelsMem KernelEnv KernelOp [VName])
-> [PatElemT LParamMem]
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a. Int -> [a] -> [a]
take Int
num_scan_results ([PatElemT LParamMem] -> [PatElemT LParamMem])
-> [PatElemT LParamMem] -> [PatElemT LParamMem]
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
pat
[SegBinOp KernelsMem]
-> (SegBinOp KernelsMem -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [SegBinOp KernelsMem]
scans ((SegBinOp KernelsMem -> InKernelGen ()) -> InKernelGen ())
-> (SegBinOp KernelsMem -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \SegBinOp KernelsMem
scan -> do
let scan_op :: Lambda KernelsMem
scan_op = SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
scan
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda KernelsMem
-> [VName]
-> InKernelGen ()
groupScan ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment) ([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims') ([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims') Lambda KernelsMem
scan_op [VName]
arrs_flat
compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegRed SegLevel
lvl SegSpace
space [SegBinOp KernelsMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody KernelsMem
body))) = do
SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
([PatElemT LParamMem]
red_pes, [PatElemT LParamMem]
map_pes) =
Int
-> [PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
ops) ([PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem]))
-> [PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem])
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
pat
dims' :: [TExp Int64]
dims' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims
mkTempArr :: TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM KernelsMem KernelEnv KernelOp VName
mkTempArr TypeBase (ShapeBase SubExp) NoUniqueness
t =
String
-> PrimType
-> ShapeBase SubExp
-> Space
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" (TypeBase (ShapeBase SubExp) NoUniqueness -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType TypeBase (ShapeBase SubExp) NoUniqueness
t) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase SubExp) NoUniqueness
t) (Space -> ImpM KernelsMem KernelEnv KernelOp VName)
-> Space -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
[VName]
tmp_arrs <- (TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM KernelsMem KernelEnv KernelOp VName)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase (ShapeBase SubExp) NoUniqueness
-> ImpM KernelsMem KernelEnv KernelOp VName
mkTempArr ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM KernelsMem KernelEnv KernelOp [VName])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> ImpM KernelsMem KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegBinOp KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [SegBinOp KernelsMem]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType (Lambda KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> (SegBinOp KernelsMem -> Lambda KernelsMem)
-> SegBinOp KernelsMem
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda) [SegBinOp KernelsMem]
ops
let tmps_for_ops :: [[VName]]
tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegBinOp KernelsMem -> Int) -> [SegBinOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegBinOp KernelsMem -> [SubExp]) -> SegBinOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegBinOp KernelsMem -> [SubExp]
forall lore. SegBinOp lore -> [SubExp]
segBinOpNeutral) [SegBinOp KernelsMem]
ops) [VName]
tmp_arrs
SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) =
Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegBinOp KernelsMem] -> Int
forall lore. [SegBinOp lore] -> Int
segBinOpResults [SegBinOp KernelsMem]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
body
[(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
dest ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
(PatElemT LParamMem -> KernelResult -> InKernelGen ())
-> [PatElemT LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LParamMem]
map_pes [KernelResult]
map_res
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
case [TExp Int64]
dims' of
[TExp Int64
dim'] -> do
[(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp KernelsMem
op, [VName]
tmps) ->
TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
dim') (SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op) [VName]
tmps
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
[(PatElemT LParamMem, VName)]
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [VName] -> [(PatElemT LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
red_pes [VName]
tmp_arrs) (((PatElemT LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, VName
arr) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe) [] (VName -> SubExp
Var VName
arr) [TExp Int64
0]
[TExp Int64]
_ -> do
TV Int64
dims_flat <- String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"dims_flat" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims'
let flatten :: VName -> ImpM KernelsMem KernelEnv KernelOp VName
flatten VName
arr = do
ArrayEntry MemLocation
arr_loc PrimType
pt <- VName -> ImpM KernelsMem KernelEnv KernelOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
let flat_shape :: ShapeBase SubExp
flat_shape =
[SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp) -> [SubExp] -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$
VName -> SubExp
Var (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
dims_flat) SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
:
Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ltids) (MemLocation -> [SubExp]
memLocationShape MemLocation
arr_loc)
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM KernelsMem KernelEnv KernelOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray String
"red_arr_flat" PrimType
pt ShapeBase SubExp
flat_shape (MemBind -> ImpM KernelsMem KernelEnv KernelOp VName)
-> MemBind -> ImpM KernelsMem KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn (MemLocation -> VName
memLocationName MemLocation
arr_loc) (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$ (SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
flat_shape
let segment_size :: TExp Int64
segment_size = [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims'
crossesSegment :: TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment TExp Int32
from TExp Int32
to =
(TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
from) TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.>. (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
to TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`rem` TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
segment_size)
[(SegBinOp KernelsMem, [VName])]
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegBinOp KernelsMem]
-> [[VName]] -> [(SegBinOp KernelsMem, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegBinOp KernelsMem]
ops [[VName]]
tmps_for_ops) (((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegBinOp KernelsMem, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegBinOp KernelsMem
op, [VName]
tmps) -> do
[VName]
tmps_flat <- (VName -> ImpM KernelsMem KernelEnv KernelOp VName)
-> [VName] -> ImpM KernelsMem KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ImpM KernelsMem KernelEnv KernelOp VName
flatten [VName]
tmps
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda KernelsMem
-> [VName]
-> InKernelGen ()
groupScan
((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just TExp Int32 -> TExp Int32 -> TExp Bool
crossesSegment)
([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims')
([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
dims')
(SegBinOp KernelsMem -> Lambda KernelsMem
forall lore. SegBinOp lore -> Lambda lore
segBinOpLambda SegBinOp KernelsMem
op)
[VName]
tmps_flat
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
[(PatElemT LParamMem, VName)]
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT LParamMem] -> [VName] -> [(PatElemT LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT LParamMem]
red_pes [VName]
tmp_arrs) (((PatElemT LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((PatElemT LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT LParamMem
pe, VName
arr) ->
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM
(PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT LParamMem
pe)
[]
(VName -> SubExp
Var VName
arr)
((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. Num d => d -> d -> DimIndex d
unitSlice TExp Int64
0) ([TExp Int64] -> [TExp Int64]
forall a. [a] -> [a]
init [TExp Int64]
dims') [DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall a. [a] -> a
last [TExp Int64]
dims' TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
-TExp Int64
1])
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pattern KernelsMem
pat (Inner (SegOp (SegHist SegLevel
lvl SegSpace
space [HistOp KernelsMem]
ops [TypeBase (ShapeBase SubExp) NoUniqueness]
_ KernelBody KernelsMem
kbody))) = do
SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
let ltids :: [VName]
ltids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
let num_red_res :: Int
num_red_res = [HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp KernelsMem -> [SubExp]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp KernelsMem]
ops)
([PatElemT LParamMem]
_red_pes, [PatElemT LParamMem]
map_pes) =
Int
-> [PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem]))
-> [PatElemT LParamMem]
-> ([PatElemT LParamMem], [PatElemT LParamMem])
forall a b. (a -> b) -> a -> b
$ PatternT LParamMem -> [PatElemT LParamMem]
forall dec. PatternT dec -> [PatElemT dec]
patternElements Pattern KernelsMem
PatternT LParamMem
pat
[[TExp Int64] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp KernelsMem]
-> InKernelGen [[TExp Int64] -> InKernelGen ()]
prepareIntraGroupSegHist (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) [HistOp KernelsMem]
ops
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
SegLevel -> SegSpace -> InKernelGen () -> InKernelGen ()
whenActive SegLevel
lvl SegSpace
space (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody KernelsMem -> Stms KernelsMem
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody KernelsMem
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody KernelsMem -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody KernelsMem
kbody
([SubExp]
red_is, [SubExp]
red_vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp KernelsMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp KernelsMem]
ops) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
(PatElemT LParamMem -> KernelResult -> InKernelGen ())
-> [PatElemT LParamMem] -> [KernelResult] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT LParamMem]
map_pes [KernelResult]
map_res
let vs_per_op :: [[SubExp]]
vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp KernelsMem -> Int) -> [HistOp KernelsMem] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp KernelsMem -> [VName]) -> HistOp KernelsMem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp KernelsMem -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp KernelsMem]
ops) [SubExp]
red_vs
[(SubExp, [SubExp], [TExp Int64] -> InKernelGen (),
HistOp KernelsMem)]
-> ((SubExp, [SubExp], [TExp Int64] -> InKernelGen (),
HistOp KernelsMem)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [[SubExp]]
-> [[TExp Int64] -> InKernelGen ()]
-> [HistOp KernelsMem]
-> [(SubExp, [SubExp], [TExp Int64] -> InKernelGen (),
HistOp KernelsMem)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[TExp Int64] -> InKernelGen ()]
ops' [HistOp KernelsMem]
ops) (((SubExp, [SubExp], [TExp Int64] -> InKernelGen (),
HistOp KernelsMem)
-> InKernelGen ())
-> InKernelGen ())
-> ((SubExp, [SubExp], [TExp Int64] -> InKernelGen (),
HistOp KernelsMem)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SubExp
bin, [SubExp]
op_vs, [TExp Int64] -> InKernelGen ()
do_op, HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ ShapeBase SubExp
shape Lambda KernelsMem
lam) -> do
let bin' :: TExp Int64
bin' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
bin
dest_w' :: TExp Int64
dest_w' = SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
dest_w
bin_in_bounds :: TExp Bool
bin_in_bounds = TExp Int64
0 TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
bin' TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Int64
bin' TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
dest_w'
bin_is :: [TExp Int64]
bin_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids) [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64
bin']
vs_params :: [Param LParamMem]
vs_params = Int -> [Param LParamMem] -> [Param LParamMem]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param LParamMem] -> [Param LParamMem])
-> [Param LParamMem] -> [Param LParamMem]
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam
ShapeBase SubExp
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest ShapeBase SubExp
shape (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is -> do
[(Param LParamMem, SubExp)]
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [SubExp] -> [(Param LParamMem, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
vs_params [SubExp]
op_vs) (((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, SubExp
v) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] SubExp
v [TExp Int64]
is
[TExp Int64] -> InKernelGen ()
do_op ([TExp Int64]
bin_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pattern KernelsMem
pat Op KernelsMem
_ =
String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileGroupOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LParamMem
pat
compileThreadOp :: OpCompiler KernelsMem KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler KernelsMem KernelEnv KernelOp
compileThreadOp Pattern KernelsMem
pat (Alloc SubExp
size Space
space) =
Pattern KernelsMem -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern KernelsMem
pat SubExp
size Space
space
compileThreadOp Pattern KernelsMem
pat (Inner (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread))) =
Pattern KernelsMem
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern KernelsMem
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern KernelsMem
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileThreadOp Pattern KernelsMem
pat Op KernelsMem
_ =
String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileThreadOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT LParamMem -> String
forall a. Pretty a => a -> String
pretty Pattern KernelsMem
PatternT LParamMem
pat
data Locking = Locking
{
Locking -> VName
lockingArray :: VName,
Locking -> TExp Int32
lockingIsUnlocked :: Imp.TExp Int32,
Locking -> TExp Int32
lockingToLock :: Imp.TExp Int32,
Locking -> TExp Int32
lockingToUnlock :: Imp.TExp Int32,
Locking -> [TExp Int64] -> [TExp Int64]
lockingMapping :: [Imp.TExp Int64] -> [Imp.TExp Int64]
}
type DoAtomicUpdate lore r =
Space -> [VName] -> [Imp.TExp Int64] -> ImpM lore r Imp.KernelOp ()
data AtomicUpdate lore r
=
AtomicPrim (DoAtomicUpdate lore r)
|
AtomicCAS (DoAtomicUpdate lore r)
|
AtomicLocking (Locking -> DoAtomicUpdate lore r)
type AtomicBinOp =
BinOp ->
Maybe (VName -> VName -> Count Imp.Elements (Imp.TExp Int64) -> Imp.Exp -> Imp.AtomicOp)
atomicUpdateLocking ::
AtomicBinOp ->
Lambda KernelsMem ->
AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking :: AtomicBinOp
-> Lambda KernelsMem -> AtomicUpdate KernelsMem KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda KernelsMem
lam
| Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda KernelsMem -> Maybe [(BinOp, PrimType, VName, VName)]
forall lore.
ASTLore lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda KernelsMem
lam,
((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64]) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
[(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv)
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [TExp Int64]
bucket ->
[(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
TV Any
old <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
t
(VName
arr', Space
_a_space, Count Elements (TExp Int64)
bucket_offset) <- VName
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
a [TExp Int64]
bucket
case Space
-> VName
-> VName
-> Count Elements (TExp Int64)
-> BinOp
-> Maybe (PrimExp ExpLeaf -> KernelOp)
opHasAtomicSupport Space
space (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) VName
arr' Count Elements (TExp Int64)
bucket_offset BinOp
op of
Just PrimExp ExpLeaf -> KernelOp
f -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ PrimExp ExpLeaf -> KernelOp
f (PrimExp ExpLeaf -> KernelOp) -> PrimExp ExpLeaf -> KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
y PrimType
t
Maybe (PrimExp ExpLeaf -> KernelOp)
Nothing ->
Space
-> PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TExp Int64]
bucket VName
x (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
x VName -> PrimExp ExpLeaf -> InKernelGen ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
<~~ BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
x PrimType
t) (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
y PrimType
t)
where
opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements (TExp Int64)
-> BinOp
-> Maybe (PrimExp ExpLeaf -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements (TExp Int64)
bucket' BinOp
bop = do
let atomic :: (VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> PrimExp ExpLeaf -> KernelOp
atomic VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp)
-> (PrimExp ExpLeaf -> AtomicOp) -> PrimExp ExpLeaf -> KernelOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp
f VName
old VName
arr' Count Elements (TExp Int64)
bucket'
(VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> PrimExp ExpLeaf -> KernelOp
atomic ((VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> PrimExp ExpLeaf -> KernelOp)
-> Maybe
(VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> Maybe (PrimExp ExpLeaf -> KernelOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop
primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
| ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicPrim
| Bool
otherwise = DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS
isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = Maybe
(VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe
(VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> Bool)
-> Maybe
(VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op
atomicUpdateLocking AtomicBinOp
_ Lambda KernelsMem
op
| [Prim PrimType
t] <- Lambda KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda KernelsMem
op,
[LParam KernelsMem
xp, LParam KernelsMem
_] <- Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
op,
PrimType -> Int
primBitSize PrimType
t Int -> [Int] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Int
32, Int
64] = DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS (DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv)
-> DoAtomicUpdate KernelsMem KernelEnv
-> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [TExp Int64]
bucket -> do
TV Any
old <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
t
Space
-> PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr (TV Any -> VName
forall t. TV t -> VName
tvVar TV Any
old) [TExp Int64]
bucket (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName LParam KernelsMem
Param LParamMem
xp) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [LParam KernelsMem
Param LParamMem
xp] (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
op
atomicUpdateLocking AtomicBinOp
_ Lambda KernelsMem
op = (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> AtomicUpdate KernelsMem KernelEnv
forall lore r.
(Locking -> DoAtomicUpdate lore r) -> AtomicUpdate lore r
AtomicLocking ((Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> AtomicUpdate KernelsMem KernelEnv)
-> (Locking -> DoAtomicUpdate KernelsMem KernelEnv)
-> AtomicUpdate KernelsMem KernelEnv
forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [TExp Int64]
bucket -> do
TV Int32
old <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"old" PrimType
int32
TV Bool
continue <- String
-> PrimType
-> TExp Bool
-> ImpM KernelsMem KernelEnv KernelOp (TV Bool)
forall t lore r op.
String -> PrimType -> TExp t -> ImpM lore r op (TV t)
dPrimVol String
"continue" PrimType
Bool TExp Bool
forall v. TPrimExp Bool v
true
(VName
locks', Space
_locks_space, Count Elements (TExp Int64)
locks_offset) <-
VName
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64)))
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ Locking -> [TExp Int64] -> [TExp Int64]
lockingMapping Locking
locking [TExp Int64]
bucket
let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> AtomicOp
Imp.AtomicCmpXchg
PrimType
int32
(TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
VName
locks'
Count Elements (TExp Int64)
locks_offset
(TExp Int32 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp ExpLeaf) -> TExp Int32 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingIsUnlocked Locking
locking)
(TExp Int32 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp ExpLeaf) -> TExp Int32 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
lock_acquired :: TExp Bool
lock_acquired = TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
old TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. Locking -> TExp Int32
lockingIsUnlocked Locking
locking
release_lock :: InKernelGen ()
release_lock =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> AtomicOp
Imp.AtomicCmpXchg
PrimType
int32
(TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
old)
VName
locks'
Count Elements (TExp Int64)
locks_offset
(TExp Int32 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp ExpLeaf) -> TExp Int32 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToLock Locking
locking)
(TExp Int32 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int32 -> PrimExp ExpLeaf) -> TExp Int32 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ Locking -> TExp Int32
lockingToUnlock Locking
locking)
break_loop :: InKernelGen ()
break_loop = TV Bool
continue TV Bool -> TExp Bool -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Bool
forall v. TPrimExp Bool v
false
let ([Param LParamMem]
acc_params, [Param LParamMem]
_arr_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
op
bind_acc_params :: InKernelGen ()
bind_acc_params =
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"bind lhs" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
acc_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
acc_p, VName
arr) ->
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
acc_p) [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
let op_body :: InKernelGen ()
op_body =
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"execute operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
acc_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
op
do_hist :: InKernelGen ()
do_hist =
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"update global result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(VName -> SubExp -> InKernelGen ())
-> [VName] -> [SubExp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([TExp Int64] -> VName -> SubExp -> InKernelGen ()
forall {lore} {r} {op}.
[TExp Int64] -> VName -> SubExp -> ImpM lore r op ()
writeArray [TExp Int64]
bucket) [VName]
arrs ([SubExp] -> InKernelGen ()) -> [SubExp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> SubExp) -> [Param LParamMem] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param LParamMem -> VName) -> Param LParamMem -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param LParamMem -> VName
forall dec. Param dec -> VName
paramName) [Param LParamMem]
acc_params
fence :: InKernelGen ()
fence = case Space
space of
Space String
"local" -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
Space
_ -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Bool -> TExp Bool
forall t. TV t -> TExp t
tvExp TV Bool
continue) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
try_acquire_lock
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
lock_acquired (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams [LParam KernelsMem]
[Param LParamMem]
acc_params
InKernelGen ()
bind_acc_params
InKernelGen ()
op_body
InKernelGen ()
do_hist
InKernelGen ()
fence
InKernelGen ()
release_lock
InKernelGen ()
break_loop
InKernelGen ()
fence
where
writeArray :: [TExp Int64] -> VName -> SubExp -> ImpM lore r op ()
writeArray [TExp Int64]
bucket VName
arr SubExp
val = VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64]
bucket SubExp
val []
atomicUpdateCAS ::
Space ->
PrimType ->
VName ->
VName ->
[Imp.TExp Int64] ->
VName ->
InKernelGen () ->
InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [TExp Int64]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [TExp Int64]
bucket VName
x InKernelGen ()
do_op = do
VName
assumed <- TV Any -> VName
forall t. TV t -> VName
tvVar (TV Any -> VName)
-> ImpM KernelsMem KernelEnv KernelOp (TV Any)
-> ImpM KernelsMem KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Any)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"assumed" PrimType
t
TV Bool
run_loop <- String -> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TV Bool)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"run_loop" TExp Bool
forall v. TPrimExp Bool v
true
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [TExp Int64]
bucket
(VName
arr', Space
_a_space, Count Elements (TExp Int64)
bucket_offset) <- VName
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall lore r op.
VName
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray VName
arr [TExp Int64]
bucket
let (PrimExp ExpLeaf -> PrimExp ExpLeaf
toBits, PrimExp ExpLeaf -> PrimExp ExpLeaf
fromBits) =
case PrimType
t of
FloatType FloatType
Float32 ->
( \PrimExp ExpLeaf
v -> String -> [PrimExp ExpLeaf] -> PrimType -> PrimExp ExpLeaf
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits32" [PrimExp ExpLeaf
v] PrimType
int32,
\PrimExp ExpLeaf
v -> String -> [PrimExp ExpLeaf] -> PrimType -> PrimExp ExpLeaf
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits32" [PrimExp ExpLeaf
v] PrimType
t
)
FloatType FloatType
Float64 ->
( \PrimExp ExpLeaf
v -> String -> [PrimExp ExpLeaf] -> PrimType -> PrimExp ExpLeaf
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits64" [PrimExp ExpLeaf
v] PrimType
int64,
\PrimExp ExpLeaf
v -> String -> [PrimExp ExpLeaf] -> PrimType -> PrimExp ExpLeaf
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits64" [PrimExp ExpLeaf
v] PrimType
t
)
PrimType
_ -> (PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. a -> a
id, PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a. a -> a
id)
int :: PrimType
int
| PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = PrimType
int32
| Bool
otherwise = PrimType
int64
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Bool -> TExp Bool
forall t. TV t -> TExp t
tvExp TV Bool
run_loop) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
VName
assumed VName -> PrimExp ExpLeaf -> InKernelGen ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
<~~ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
old PrimType
t
VName
x VName -> PrimExp ExpLeaf -> InKernelGen ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
<~~ VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
assumed PrimType
t
InKernelGen ()
do_op
VName
old_bits_v <- String -> ImpM KernelsMem KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"old_bits"
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
old_bits_v PrimType
int
let old_bits :: PrimExp ExpLeaf
old_bits = VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
old_bits_v PrimType
int
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
PrimType
-> VName
-> VName
-> Count Elements (TExp Int64)
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> AtomicOp
Imp.AtomicCmpXchg
PrimType
int
VName
old_bits_v
VName
arr'
Count Elements (TExp Int64)
bucket_offset
(PrimExp ExpLeaf -> PrimExp ExpLeaf
toBits (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
assumed PrimType
t))
(PrimExp ExpLeaf -> PrimExp ExpLeaf
toBits (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
x PrimType
t))
VName
old VName -> PrimExp ExpLeaf -> InKernelGen ()
forall lore r op. VName -> PrimExp ExpLeaf -> ImpM lore r op ()
<~~ PrimExp ExpLeaf -> PrimExp ExpLeaf
fromBits PrimExp ExpLeaf
old_bits
let won :: PrimExp ExpLeaf
won = CmpOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp (PrimType -> CmpOp
CmpEq PrimType
int) (PrimExp ExpLeaf -> PrimExp ExpLeaf
toBits (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
assumed PrimType
t)) PrimExp ExpLeaf
old_bits
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (PrimExp ExpLeaf -> TExp Bool
forall v. PrimExp v -> TPrimExp Bool v
isBool PrimExp ExpLeaf
won) (TV Bool
run_loop TV Bool -> TExp Bool -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Bool
forall v. TPrimExp Bool v
false)
splitOp :: ASTLore lore => Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp :: forall lore.
ASTLore lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda lore
lam = (SubExp -> Maybe (BinOp, PrimType, VName, VName))
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([SubExp] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
where
n :: Int
n = [TypeBase (ShapeBase SubExp) NoUniqueness] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Int)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda lore
lam
splitStm :: SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm (Var VName
res) = do
Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
(Stm lore -> Bool) -> [Stm lore] -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res] [VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetDec lore) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT (LetDec lore) -> [VName])
-> (Stm lore -> PatternT (LetDec lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetDec lore)
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm lore] -> Maybe (Stm lore)) -> [Stm lore] -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$
Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
Int
i <- VName -> SubExp
Var VName
res SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
Param (LParamInfo lore)
xp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
Param (LParamInfo lore)
yp <- Int -> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) ([Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore)))
-> [Param (LParamInfo lore)] -> Maybe (Param (LParamInfo lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
Prim PrimType
t <- TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. a -> Maybe a
Just (TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness))
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> Maybe (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ PatElemT (LetDec lore) -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
PatElemT dec -> TypeBase (ShapeBase SubExp) NoUniqueness
patElemType PatElemT (LetDec lore)
pe
(BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
xp, Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo lore)
yp)
splitStm SubExp
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing
computeKernelUses ::
FreeIn a =>
a ->
[VName] ->
CallKernelGen [Imp.KernelUse]
computeKernelUses :: forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
let actually_free :: Names
actually_free = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
[KernelUse] -> [KernelUse]
forall a. Ord a => [a] -> [a]
nubOrd ([KernelUse] -> [KernelUse])
-> CallKernelGen [KernelUse] -> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free
readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet Names
free =
([Maybe KernelUse] -> [KernelUse])
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelUse] -> [KernelUse]
forall a. [Maybe a] -> [a]
catMaybes (ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse])
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall a b. (a -> b) -> a -> b
$
[VName]
-> (VName -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
free) ((VName -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse])
-> (VName -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp [Maybe KernelUse]
forall a b. (a -> b) -> a -> b
$ \VName
var -> do
TypeBase (ShapeBase SubExp) NoUniqueness
t <- VName
-> ImpM
KernelsMem
HostEnv
HostOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
var
VTable KernelsMem
vtable <- ImpM KernelsMem HostEnv HostOp (VTable KernelsMem)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
case TypeBase (ShapeBase SubExp) NoUniqueness
t of
Array {} -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
Mem (Space String
"local") -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
Mem {} -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
Prim PrimType
bt ->
VTable KernelsMem
-> PrimExp ExpLeaf
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelConstExp)
forall lore r op.
VTable KernelsMem
-> PrimExp ExpLeaf -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable KernelsMem
vtable (VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
var PrimType
bt) ImpM KernelsMem HostEnv HostOp (Maybe KernelConstExp)
-> (Maybe KernelConstExp
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just KernelConstExp
ce -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
Maybe KernelConstExp
Nothing
| PrimType
bt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
Cert -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
| Bool
otherwise -> Maybe KernelUse -> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM KernelsMem HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt
isConstExp ::
VTable KernelsMem ->
Imp.Exp ->
ImpM lore r op (Maybe Imp.KernelConstExp)
isConstExp :: forall lore r op.
VTable KernelsMem
-> PrimExp ExpLeaf -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable KernelsMem
vtable PrimExp ExpLeaf
size = do
Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let onLeaf :: ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf (Imp.ScalarVar VName
name) PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
onLeaf (Imp.SizeOf PrimType
pt) PrimType
_ = KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> KernelConstExp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> KernelConstExp) -> PrimValue -> KernelConstExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value (Int32 -> IntValue) -> Int32 -> IntValue
forall a b. (a -> b) -> a -> b
$ PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
onLeaf Imp.Index {} PrimType
_ = Maybe KernelConstExp
forall a. Maybe a
Nothing
lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
Exp KernelsMem -> Maybe KernelConstExp
constExp (Exp KernelsMem -> Maybe KernelConstExp)
-> Maybe (Exp KernelsMem) -> Maybe KernelConstExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarEntry KernelsMem -> Maybe (Exp KernelsMem)
forall {lore}. VarEntry lore -> Maybe (Exp lore)
hasExp (VarEntry KernelsMem -> Maybe (Exp KernelsMem))
-> Maybe (VarEntry KernelsMem) -> Maybe (Exp KernelsMem)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VTable KernelsMem -> Maybe (VarEntry KernelsMem)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable KernelsMem
vtable
constExp :: Exp KernelsMem -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize Name
key SizeClass
_)))) =
KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> KernelConst
Imp.SizeConst (Name -> KernelConst) -> Name -> KernelConst
forall a b. (a -> b) -> a -> b
$ Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) PrimType
int32
constExp Exp KernelsMem
e = (VName -> Maybe KernelConstExp)
-> Exp KernelsMem -> Maybe KernelConstExp
forall (m :: * -> *) lore v.
(MonadFail m, Decorations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp Exp KernelsMem
e
Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp))
-> Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ (ExpLeaf -> PrimType -> Maybe KernelConstExp)
-> PrimExp ExpLeaf -> Maybe KernelConstExp
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf PrimExp ExpLeaf
size
where
hasExp :: VarEntry lore -> Maybe (Exp lore)
hasExp (ArrayVar Maybe (Exp lore)
e ArrayEntry
_) = Maybe (Exp lore)
e
hasExp (ScalarVar Maybe (Exp lore)
e ScalarEntry
_) = Maybe (Exp lore)
e
hasExp (MemVar Maybe (Exp lore)
e MemEntry
_) = Maybe (Exp lore)
e
computeThreadChunkSize ::
SplitOrdering ->
Imp.TExp Int64 ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
Imp.Count Imp.Elements (Imp.TExp Int64) ->
TV Int64 ->
ImpM lore r op ()
computeThreadChunkSize :: forall lore r op.
SplitOrdering
-> TExp Int64
-> Count Elements (TExp Int64)
-> Count Elements (TExp Int64)
-> TV Int64
-> ImpM lore r op ()
computeThreadChunkSize (SplitStrided SubExp
stride) TExp Int64
thread_index Count Elements (TExp Int64)
elements_per_thread Count Elements (TExp Int64)
num_elements TV Int64
chunk_var =
TV Int64
chunk_var
TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64 -> TExp Int64 -> TExp Int64
forall v. TPrimExp Int64 v -> TPrimExp Int64 v -> TPrimExp Int64 v
sMin64
(Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elements_per_thread)
((Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
num_elements TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
thread_index) TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
stride)
computeThreadChunkSize SplitOrdering
SplitContiguous TExp Int64
thread_index Count Elements (TExp Int64)
elements_per_thread Count Elements (TExp Int64)
num_elements TV Int64
chunk_var = do
TV Int64
starting_point <-
String -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"starting_point" (TExp Int64 -> ImpM lore r op (TV Int64))
-> TExp Int64 -> ImpM lore r op (TV Int64)
forall a b. (a -> b) -> a -> b
$
TExp Int64
thread_index TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elements_per_thread
TV Int64
remaining_elements <-
String -> TExp Int64 -> ImpM lore r op (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"remaining_elements" (TExp Int64 -> ImpM lore r op (TV Int64))
-> TExp Int64 -> ImpM lore r op (TV Int64)
forall a b. (a -> b) -> a -> b
$
Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
num_elements TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
starting_point
let no_remaining_elements :: TExp Bool
no_remaining_elements = TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
remaining_elements TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int64
0
beyond_bounds :: TExp Bool
beyond_bounds = Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
num_elements TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
starting_point
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
(TExp Bool
no_remaining_elements TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool
beyond_bounds)
(TV Int64
chunk_var TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int64
0)
( TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
TExp Bool
-> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf
TExp Bool
is_last_thread
(TV Int64
chunk_var TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
last_thread_elements)
(TV Int64
chunk_var TV Int64 -> TExp Int64 -> ImpM lore r op ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elements_per_thread)
)
where
last_thread_elements :: Count Elements (TExp Int64)
last_thread_elements =
Count Elements (TExp Int64)
num_elements Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
- TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
Imp.elements TExp Int64
thread_index Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
* Count Elements (TExp Int64)
elements_per_thread
is_last_thread :: TExp Bool
is_last_thread =
Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
num_elements
TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. (TExp Int64
thread_index TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
1) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* Count Elements (TExp Int64) -> TExp Int64
forall u e. Count u e -> e
Imp.unCount Count Elements (TExp Int64)
elements_per_thread
kernelInitialisationSimple ::
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (Count TExp Int64
num_groups) (Count TExp Int64
group_size) = do
VName
global_tid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"global_tid"
VName
local_tid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_tid"
VName
group_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_tid"
VName
wave_size <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"wave_size"
VName
inner_group_size <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_size"
let constants :: KernelConstants
constants =
TExp Int32
-> TExp Int32
-> TExp Int32
-> VName
-> VName
-> VName
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> Map [SubExp] [TExp Int32]
-> KernelConstants
KernelConstants
(VName -> TExp Int32
Imp.vi32 VName
global_tid)
(VName -> TExp Int32
Imp.vi32 VName
local_tid)
(VName -> TExp Int32
Imp.vi32 VName
group_id)
VName
global_tid
VName
local_tid
VName
group_id
TExp Int64
num_groups
TExp Int64
group_size
(TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_groups))
(VName -> TExp Int32
Imp.vi32 VName
wave_size)
TExp Bool
forall v. TPrimExp Bool v
true
Map [SubExp] [TExp Int32]
forall a. Monoid a => a
mempty
let set_constants :: InKernelGen ()
set_constants = do
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
global_tid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
local_tid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
inner_group_size PrimType
int64
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
wave_size PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
global_tid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
(KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelConstants
constants, InKernelGen ()
set_constants)
isActive :: [(VName, SubExp)] -> Imp.TExp Bool
isActive :: [(VName, SubExp)] -> TExp Bool
isActive [(VName, SubExp)]
limit = case [TExp Bool]
actives of
[] -> TExp Bool
forall v. TPrimExp Bool v
true
TExp Bool
x : [TExp Bool]
xs -> (TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) TExp Bool
x [TExp Bool]
xs
where
([VName]
is, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
actives :: [TExp Bool]
actives = (VName -> TExp Int64 -> TExp Bool)
-> [VName] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TExp Int64 -> TExp Bool
active [VName]
is ([TExp Int64] -> [TExp Bool]) -> [TExp Int64] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
ws
active :: VName -> TExp Int64 -> TExp Bool
active VName
i = (VName -> TExp Int64
Imp.vi64 VName
i TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<.)
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
Space
-> ImpM KernelsMem HostEnv HostOp a
-> ImpM KernelsMem HostEnv HostOp a
forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace (String -> Space
Imp.Space String
"global") (ImpM KernelsMem HostEnv HostOp a
-> ImpM KernelsMem HostEnv HostOp a)
-> (ImpM KernelsMem HostEnv HostOp a
-> ImpM KernelsMem HostEnv HostOp a)
-> ImpM KernelsMem HostEnv HostOp a
-> ImpM KernelsMem HostEnv HostOp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VTable KernelsMem -> VTable KernelsMem)
-> ImpM KernelsMem HostEnv HostOp a
-> ImpM KernelsMem HostEnv HostOp a
forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable ((VarEntry KernelsMem -> VarEntry KernelsMem)
-> VTable KernelsMem -> VTable KernelsMem
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry KernelsMem -> VarEntry KernelsMem
forall {lore}. VarEntry lore -> VarEntry lore
globalMemory)
where
globalMemory :: VarEntry lore -> VarEntry lore
globalMemory (MemVar Maybe (Exp lore)
_ MemEntry
entry)
| MemEntry -> Space
entryMemSpace MemEntry
entry Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= String -> Space
Space String
"local" =
Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing MemEntry
entry {entryMemSpace :: Space
entryMemSpace = String -> Space
Imp.Space String
"global"}
globalMemory VarEntry lore
entry =
VarEntry lore
entry
groupReduce ::
Imp.TExp Int32 ->
Lambda KernelsMem ->
[VName] ->
InKernelGen ()
groupReduce :: TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduce TExp Int32
w Lambda KernelsMem
lam [VName]
arrs = do
TV Int32
offset <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"offset" PrimType
int32
TV Int32
-> TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduceWithOffset TV Int32
offset TExp Int32
w Lambda KernelsMem
lam [VName]
arrs
groupReduceWithOffset ::
TV Int32 ->
Imp.TExp Int32 ->
Lambda KernelsMem ->
[VName] ->
InKernelGen ()
groupReduceWithOffset :: TV Int32
-> TExp Int32 -> Lambda KernelsMem -> [VName] -> InKernelGen ()
groupReduceWithOffset TV Int32
offset TExp Int32
w Lambda KernelsMem
lam [VName]
arrs = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let local_tid :: TExp Int32
local_tid = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
global_tid :: TExp Int32
global_tid = KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
barrier :: InKernelGen ()
barrier
| (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda KernelsMem
lam = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
| Bool
otherwise = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
readReduceArgument :: Param LParamMem -> VName -> InKernelGen ()
readReduceArgument Param LParamMem
param VName
arr
| Prim PrimType
_ <- Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
param = do
let i :: TExp Int32
i = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
param) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
| Bool
otherwise = do
let i :: TExp Int32
i = TExp Int32
global_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
param) [] (VName -> SubExp
Var VName
arr) [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
i]
writeReduceOpResult :: Param LParamMem -> VName -> InKernelGen ()
writeReduceOpResult Param LParamMem
param VName
arr
| Prim PrimType
_ <- Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
param =
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
param) []
| Bool
otherwise =
() -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
let ([Param LParamMem]
reduce_acc_params, [Param LParamMem]
reduce_arr_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam
TV Int32
skip_waves <- String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"skip_waves" (TExp Int32
1 :: Imp.TExp Int32)
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams ([LParam KernelsMem] -> InKernelGen ())
-> [LParam KernelsMem] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam
TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- (TExp Int32
0 :: Imp.TExp Int32)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"participating threads read initial accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readReduceArgument [Param LParamMem]
reduce_acc_params [VName]
arrs
let do_reduce :: InKernelGen ()
do_reduce = do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read array element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readReduceArgument [Param LParamMem]
reduce_arr_params [VName]
arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"apply reduction operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
reduce_acc_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"write result of operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
writeReduceOpResult [Param LParamMem]
reduce_acc_params [VName]
arrs
in_wave_reduce :: InKernelGen ()
in_wave_reduce = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile InKernelGen ()
do_reduce
wave_size :: TExp Int32
wave_size = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
group_size :: TExp Int64
group_size = KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
wave_id :: TExp Int32
wave_id = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
in_wave_id :: TExp Int32
in_wave_id = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
wave_size
num_waves :: TExp Int32
num_waves = (TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 TExp Int64
group_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
wave_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1) TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
wave_size
arg_in_bounds :: TExp Bool
arg_in_bounds = TExp Int32
local_tid TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
w
doing_in_wave_reductions :: TExp Bool
doing_in_wave_reductions =
TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
wave_size
apply_in_in_wave_iteration :: TExp Bool
apply_in_in_wave_iteration =
(TExp Int32
in_wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- (TExp Int32
1 :: Imp.TExp Int32)
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
doing_in_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen
(TExp Bool
arg_in_bounds TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
apply_in_in_wave_iteration)
InKernelGen ()
in_wave_reduce
TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
offset TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
doing_cross_wave_reductions :: TExp Bool
doing_cross_wave_reductions =
TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
num_waves
is_first_thread_in_wave :: TExp Bool
is_first_thread_in_wave =
TExp Int32
in_wave_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
wave_not_skipped :: TExp Bool
wave_not_skipped =
(TExp Int32
wave_id TExp Int32 -> TExp Int32 -> TExp Int32
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp t v
.&. (TExp Int32
2 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1)) TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
apply_in_cross_wave_iteration :: TExp Bool
apply_in_cross_wave_iteration =
TExp Bool
arg_in_bounds TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
is_first_thread_in_wave TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
wave_not_skipped
cross_wave_reductions :: InKernelGen ()
cross_wave_reductions =
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile TExp Bool
doing_cross_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
barrier
TV Int32
offset TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
wave_size
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen
TExp Bool
apply_in_cross_wave_iteration
InKernelGen ()
do_reduce
TV Int32
skip_waves TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_waves TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
InKernelGen ()
in_wave_reductions
InKernelGen ()
cross_wave_reductions
groupScan ::
Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
Imp.TExp Int64 ->
Imp.TExp Int64 ->
Lambda KernelsMem ->
[VName] ->
InKernelGen ()
groupScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int64
-> TExp Int64
-> Lambda KernelsMem
-> [VName]
-> InKernelGen ()
groupScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int64
arrs_full_size TExp Int64
w Lambda KernelsMem
lam [VName]
arrs = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Lambda KernelsMem
renamed_lam <- Lambda KernelsMem
-> ImpM KernelsMem KernelEnv KernelOp (Lambda KernelsMem)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda KernelsMem
lam
let ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid :: TExp Int64
ltid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
([Param LParamMem]
x_params, [Param LParamMem]
y_params) = Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param LParamMem] -> ([Param LParamMem], [Param LParamMem]))
-> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam
[LParam KernelsMem] -> InKernelGen ()
forall lore r op. Mem lore => [LParam lore] -> ImpM lore r op ()
dLParams (Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
lam [Param LParamMem] -> [Param LParamMem] -> [Param LParamMem]
forall a. [a] -> [a] -> [a]
++ Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
renamed_lam)
TExp Bool
ltid_in_bounds <- String
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"ltid_in_bounds" (TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
w
let block_size :: TExp Int32
block_size = TExp Int32
32
simd_width :: TExp Int32
simd_width = KernelConstants -> TExp Int32
kernelWaveSize KernelConstants
constants
block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
doInBlockScan :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda KernelsMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag' TExp Bool
active =
KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda KernelsMem
-> InKernelGen ()
inBlockScan
KernelConstants
constants
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag'
TExp Int64
arrs_full_size
TExp Int32
simd_width
TExp Int32
block_size
TExp Bool
active
[VName]
arrs
InKernelGen ()
barrier
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda KernelsMem
lam
barrier :: InKernelGen ()
barrier
| Bool
array_scan =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
| Bool
otherwise =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
group_offset :: TExp Int64
group_offset = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants) TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
writeBlockResult :: Param LParamMem -> VName -> InKernelGen ()
writeBlockResult Param LParamMem
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
p =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
| Bool
otherwise =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
group_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
readPrevBlockResult :: Param LParamMem -> VName -> InKernelGen ()
readPrevBlockResult Param LParamMem
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
p =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
| Bool
otherwise =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
group_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_id TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
1]
Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda KernelsMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Bool
ltid_in_bounds Lambda KernelsMem
lam
InKernelGen ()
barrier
let is_first_block :: TExp Bool
is_first_block = TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
arrs_full_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
group_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []
InKernelGen ()
barrier
let last_in_block :: TExp Bool
last_in_block = TExp Int32
in_block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
1
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"last thread of block 'i' writes its result to offset 'i'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Bool
last_in_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
writeBlockResult [Param LParamMem]
x_params [VName]
arrs
InKernelGen ()
barrier
let first_block_seg_flag :: Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag = do
TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag
(TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a. a -> Maybe a
Just ((TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool))
-> (TExp Int32 -> TExp Int32 -> TExp Bool)
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
forall a b. (a -> b) -> a -> b
$ \TExp Int32
from TExp Int32
to ->
TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
from TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
-TExp Int32
1) (TExp Int32
to TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
-TExp Int32
1)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment
String
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'"
(InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Bool -> Lambda KernelsMem -> InKernelGen ()
doInBlockScan Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
first_block_seg_flag (TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
ltid_in_bounds) Lambda KernelsMem
renamed_lam
InKernelGen ()
barrier
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"move correct values for first block back a block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM
VName
arr
[TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
arrs_full_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
group_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
ltid]
(VName -> SubExp
Var VName
arr)
[TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
arrs_full_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
group_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
block_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
ltid]
InKernelGen ()
barrier
let read_carry_in :: InKernelGen ()
read_carry_in = do
[(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x)) []
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readPrevBlockResult [Param LParamMem]
x_params [VName]
arrs
y_to_x :: InKernelGen ()
y_to_x = [(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y)) []
op_to_x :: InKernelGen ()
op_to_x
| Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
[Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
| Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
TExp Bool
inactive <-
String
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"inactive" (TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
-TExp Int32
1) TExp Int32
ltid32
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
inactive InKernelGen ()
y_to_x
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
lam
write_final_result :: InKernelGen ()
write_final_result =
[(Param LParamMem, VName)]
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem] -> [VName] -> [(Param LParamMem, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [VName]
arrs) (((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ())
-> ((Param LParamMem, VName) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
p, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"carry-in for every block except the first" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless (TExp Bool
is_first_block TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.||. TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v
bNot TExp Bool
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" InKernelGen ()
read_carry_in
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write final result" InKernelGen ()
write_final_result
InKernelGen ()
barrier
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"restore correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param LParamMem, Param LParamMem, VName)]
-> ((Param LParamMem, Param LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem]
-> [VName]
-> [(Param LParamMem, Param LParamMem, VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param LParamMem]
x_params [Param LParamMem]
y_params [VName]
arrs) (((Param LParamMem, Param LParamMem, VName) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, Param LParamMem, VName) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y, VName
arr) ->
if TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
y)
then VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) []
else VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
arrs_full_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
group_offset TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
ltid]
InKernelGen ()
barrier
inBlockScan ::
KernelConstants ->
Maybe (Imp.TExp Int32 -> Imp.TExp Int32 -> Imp.TExp Bool) ->
Imp.TExp Int64 ->
Imp.TExp Int32 ->
Imp.TExp Int32 ->
Imp.TExp Bool ->
[VName] ->
InKernelGen () ->
Lambda KernelsMem ->
InKernelGen ()
inBlockScan :: KernelConstants
-> Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
-> TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> [VName]
-> InKernelGen ()
-> Lambda KernelsMem
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag TExp Int64
arrs_full_size TExp Int32
lockstep_width TExp Int32
block_size TExp Bool
active [VName]
arrs InKernelGen ()
barrier Lambda KernelsMem
scan_lam = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TV Int32
skip_threads <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"skip_threads" PrimType
int32
let in_block_thread_active :: TExp Bool
in_block_thread_active =
TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TExp Int32
in_block_id
actual_params :: [LParam KernelsMem]
actual_params = Lambda KernelsMem -> [LParam KernelsMem]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda KernelsMem
scan_lam
([Param LParamMem]
x_params, [Param LParamMem]
y_params) =
Int -> [Param LParamMem] -> ([Param LParamMem], [Param LParamMem])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param LParamMem] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam KernelsMem]
[Param LParamMem]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam KernelsMem]
[Param LParamMem]
actual_params
y_to_x :: InKernelGen ()
y_to_x =
[(Param LParamMem, Param LParamMem)]
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param LParamMem]
-> [Param LParamMem] -> [(Param LParamMem, Param LParamMem)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param LParamMem]
x_params [Param LParamMem]
y_params) (((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ())
-> ((Param LParamMem, Param LParamMem) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param LParamMem
x, Param LParamMem
y) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) [] (VName -> SubExp
Var (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y)) []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read input for in-block scan" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param LParamMem -> VName -> InKernelGen ()
readInitial [Param LParamMem]
y_params [VName]
arrs
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int32
in_block_id TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) InKernelGen ()
y_to_x
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
let op_to_x :: InKernelGen ()
op_to_x
| Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
Nothing <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag =
[Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_lam
| Just TExp Int32 -> TExp Int32 -> TExp Bool
flag_true <- Maybe (TExp Int32 -> TExp Int32 -> TExp Bool)
seg_flag = do
TExp Bool
inactive <-
String
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"inactive" (TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool))
-> TExp Bool -> ImpM KernelsMem KernelEnv KernelOp (TExp Bool)
forall a b. (a -> b) -> a -> b
$
TExp Int32 -> TExp Int32 -> TExp Bool
flag_true (TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads) TExp Int32
ltid32
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
inactive InKernelGen ()
y_to_x
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sUnless TExp Bool
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param LParamMem] -> BodyT KernelsMem -> InKernelGen ()
forall dec lore r op. [Param dec] -> Body lore -> ImpM lore r op ()
compileBody' [Param LParamMem]
x_params (BodyT KernelsMem -> InKernelGen ())
-> BodyT KernelsMem -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> BodyT KernelsMem
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda KernelsMem
scan_lam
maybeBarrier :: InKernelGen ()
maybeBarrier =
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen
(TExp Int32
lockstep_width TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads)
InKernelGen ()
barrier
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"in-block scan (hopefully no barriers needed)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TV Int32
skip_threads TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TExp Int32
1
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhile (TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int32
block_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Bool
in_block_thread_active TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem] -> [VName] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (TExp Int64 -> Param LParamMem -> VName -> InKernelGen ()
readParam (TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads)) [Param LParamMem]
x_params [VName]
arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
InKernelGen ()
maybeBarrier
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Bool
in_block_thread_active TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TExp Bool
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[InKernelGen ()] -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([InKernelGen ()] -> InKernelGen ())
-> [InKernelGen ()] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param LParamMem -> Param LParamMem -> VName -> InKernelGen ())
-> [Param LParamMem]
-> [Param LParamMem]
-> [VName]
-> [InKernelGen ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param LParamMem -> Param LParamMem -> VName -> InKernelGen ()
writeResult [Param LParamMem]
x_params [Param LParamMem]
y_params [VName]
arrs
InKernelGen ()
maybeBarrier
TV Int32
skip_threads TV Int32 -> TExp Int32 -> InKernelGen ()
forall t lore r op. TV t -> TExp t -> ImpM lore r op ()
<-- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
skip_threads TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
2
where
block_id :: TExp Int32
block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`quot` TExp Int32
block_size
in_block_id :: TExp Int32
in_block_id = TExp Int32
ltid32 TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TExp Int32
block_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int32
block_size
ltid32 :: TExp Int32
ltid32 = KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
ltid :: TExp Int64
ltid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int32
ltid32
gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda KernelsMem -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda KernelsMem
scan_lam
readInitial :: Param LParamMem -> VName -> InKernelGen ()
readInitial Param LParamMem
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
p =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
ltid]
| Bool
otherwise =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
gtid]
readParam :: TExp Int64 -> Param LParamMem -> VName -> InKernelGen ()
readParam TExp Int64
behind Param LParamMem
p VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
p =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
behind]
| Bool
otherwise =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
p) [] (VName -> SubExp
Var VName
arr) [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
gtid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
- TExp Int64
behind TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
arrs_full_size]
writeResult :: Param LParamMem -> Param LParamMem -> VName -> InKernelGen ()
writeResult Param LParamMem
x Param LParamMem
y VName
arr
| TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param LParamMem
x = do
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM VName
arr [TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []
| Bool
otherwise =
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param LParamMem -> VName
forall dec. Param dec -> VName
paramName Param LParamMem
x) []
computeMapKernelGroups :: Imp.TExp Int64 -> CallKernelGen (Imp.TExp Int64, Imp.TExp Int64)
computeMapKernelGroups :: TExp Int64 -> CallKernelGen (TExp Int64, TExp Int64)
computeMapKernelGroups TExp Int64
kernel_size = do
TV Int64
group_size <- String -> PrimType -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"group_size" PrimType
int64
Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let group_size_key :: Name
group_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
group_size
HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize (TV Int64 -> VName
forall t. TV t -> VName
tvVar TV Int64
group_size) Name
group_size_key SizeClass
Imp.SizeGroup
TV Int64
num_groups <- String -> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"num_groups" (TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64))
-> TExp Int64 -> ImpM KernelsMem HostEnv HostOp (TV Int64)
forall a b. (a -> b) -> a -> b
$ TExp Int64
kernel_size TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
group_size
(TExp Int64, TExp Int64) -> CallKernelGen (TExp Int64, TExp Int64)
forall (m :: * -> *) a. Monad m => a -> m a
return (TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
num_groups, TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp TV Int64
group_size)
simpleKernelConstants ::
Imp.TExp Int64 ->
String ->
CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants :: TExp Int64
-> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants TExp Int64
kernel_size String
desc = do
VName
thread_gtid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM KernelsMem HostEnv HostOp VName)
-> String -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gtid"
VName
thread_ltid <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM KernelsMem HostEnv HostOp VName)
-> String -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_ltid"
VName
group_id <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM KernelsMem HostEnv HostOp VName)
-> String -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gid"
(TExp Int64
num_groups, TExp Int64
group_size) <- TExp Int64 -> CallKernelGen (TExp Int64, TExp Int64)
computeMapKernelGroups TExp Int64
kernel_size
let set_constants :: InKernelGen ()
set_constants = do
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_gtid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_ltid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
thread_gtid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
(KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return
( TExp Int32
-> TExp Int32
-> TExp Int32
-> VName
-> VName
-> VName
-> TExp Int64
-> TExp Int64
-> TExp Int32
-> TExp Int32
-> TExp Bool
-> Map [SubExp] [TExp Int32]
-> KernelConstants
KernelConstants
(VName -> TExp Int32
Imp.vi32 VName
thread_gtid)
(VName -> TExp Int32
Imp.vi32 VName
thread_ltid)
(VName -> TExp Int32
Imp.vi32 VName
group_id)
VName
thread_gtid
VName
thread_ltid
VName
group_id
TExp Int64
num_groups
TExp Int64
group_size
(TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (TExp Int64
group_size TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
num_groups))
TExp Int32
0
(VName -> TExp Int64
Imp.vi64 VName
thread_gtid TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
kernel_size)
Map [SubExp] [TExp Int32]
forall a. Monoid a => a
mempty,
InKernelGen ()
set_constants
)
virtualiseGroups ::
SegVirt ->
Imp.TExp Int32 ->
(Imp.TExp Int32 -> InKernelGen ()) ->
InKernelGen ()
virtualiseGroups :: SegVirt
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegVirt TExp Int32
required_groups TExp Int32 -> InKernelGen ()
m = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
TV Int32
phys_group_id <- String -> PrimType -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"phys_group_id" PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
phys_group_id) Int
0
let iterations :: TExp Int32
iterations =
(TExp Int32
required_groups TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
- TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
phys_group_id)
TExp Int32 -> TExp Int32 -> TExp Int32
forall e. IntegralExp e => e -> e -> e
`divUp` TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelNumGroups KernelConstants
constants)
String
-> TExp Int32 -> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" TExp Int32
iterations ((TExp Int32 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int32 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int32
i -> do
TExp Int32 -> InKernelGen ()
m (TExp Int32 -> InKernelGen ())
-> (TV Int32 -> TExp Int32) -> TV Int32 -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp
(TV Int32 -> InKernelGen ())
-> ImpM KernelsMem KernelEnv KernelOp (TV Int32) -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> TExp Int32 -> ImpM KernelsMem KernelEnv KernelOp (TV Int32)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV
String
"virt_group_id"
(TV Int32 -> TExp Int32
forall t. TV t -> TExp t
tvExp TV Int32
phys_group_id TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
+ TExp Int32
i TExp Int32 -> TExp Int32 -> TExp Int32
forall a. Num a => a -> a -> a
* TExp Int64 -> TExp Int32
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int32 v
sExt32 (KernelConstants -> TExp Int64
kernelNumGroups KernelConstants
constants))
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
virtualiseGroups SegVirt
_ TExp Int32
_ TExp Int32 -> InKernelGen ()
m = do
VName
gid <- KernelConstants -> VName
kernelGroupIdVar (KernelConstants -> VName)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> VName)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
TExp Int32 -> InKernelGen ()
m (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> TExp Int32
Imp.vi32 VName
gid
sKernelThread ::
String ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
VName ->
InKernelGen () ->
CallKernelGen ()
sKernelThread :: String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelThread = Operations KernelsMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernel Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants -> TExp Int32
kernelGlobalThreadId
sKernelGroup ::
String ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
VName ->
InKernelGen () ->
CallKernelGen ()
sKernelGroup :: String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelGroup = Operations KernelsMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernel Operations KernelsMem KernelEnv KernelOp
groupOperations KernelConstants -> TExp Int32
kernelGroupId
sKernelFailureTolerant ::
Bool ->
Operations KernelsMem KernelEnv Imp.KernelOp ->
KernelConstants ->
Name ->
InKernelGen () ->
CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
tol Operations KernelsMem KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
HostEnv AtomicBinOp
atomics Target
_ <- ImpM KernelsMem HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
Code KernelOp
body <- CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal (CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp))
-> CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a b. (a -> b) -> a -> b
$ KernelEnv
-> Operations KernelsMem KernelEnv KernelOp
-> InKernelGen ()
-> CallKernelGen (Code KernelOp)
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants) Operations KernelsMem KernelEnv KernelOp
ops InKernelGen ()
m
[KernelUse]
uses <- Code KernelOp -> [VName] -> CallKernelGen [KernelUse]
forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body [VName]
forall a. Monoid a => a
mempty
Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp) -> HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$
Kernel -> HostOp
Imp.CallKernel
Kernel :: Code KernelOp
-> [KernelUse]
-> [PrimExp ExpLeaf]
-> [PrimExp ExpLeaf]
-> Name
-> Bool
-> Kernel
Imp.Kernel
{ kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body,
kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses,
kernelNumGroups :: [PrimExp ExpLeaf]
Imp.kernelNumGroups = [TExp Int64 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> PrimExp ExpLeaf) -> TExp Int64 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelNumGroups KernelConstants
constants],
kernelGroupSize :: [PrimExp ExpLeaf]
Imp.kernelGroupSize = [TExp Int64 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> PrimExp ExpLeaf) -> TExp Int64 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants],
kernelName :: Name
Imp.kernelName = Name
name,
kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = Bool
tol
}
sKernel ::
Operations KernelsMem KernelEnv Imp.KernelOp ->
(KernelConstants -> Imp.TExp Int32) ->
String ->
Count NumGroups (Imp.TExp Int64) ->
Count GroupSize (Imp.TExp Int64) ->
VName ->
InKernelGen () ->
CallKernelGen ()
sKernel :: Operations KernelsMem KernelEnv KernelOp
-> (KernelConstants -> TExp Int32)
-> String
-> Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> VName
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernel Operations KernelsMem KernelEnv KernelOp
ops KernelConstants -> TExp Int32
flatf String
name Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size VName
v InKernelGen ()
f = do
(KernelConstants
constants, InKernelGen ()
set_constants) <- Count NumGroups (TExp Int64)
-> Count GroupSize (TExp Int64)
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumGroups (TExp Int64)
num_groups Count GroupSize (TExp Int64)
group_size
Name
name' <- String -> ImpM KernelsMem HostEnv HostOp Name
forall lore r op. String -> ImpM lore r op Name
nameForFun (String -> ImpM KernelsMem HostEnv HostOp Name)
-> String -> ImpM KernelsMem HostEnv HostOp Name
forall a b. (a -> b) -> a -> b
$ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag VName
v)
Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
False Operations KernelsMem KernelEnv KernelOp
ops KernelConstants
constants Name
name' (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
VName -> TExp Int32 -> InKernelGen ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ VName
v (TExp Int32 -> InKernelGen ()) -> TExp Int32 -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
flatf KernelConstants
constants
InKernelGen ()
f
copyInGroup :: CopyCompiler KernelsMem KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler KernelsMem KernelEnv KernelOp
copyInGroup PrimType
pt MemLocation
destloc [DimIndex (TExp Int64)]
destslice MemLocation
srcloc [DimIndex (TExp Int64)]
srcslice = do
Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem KernelEnv KernelOp MemEntry
-> ImpM KernelsMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
destloc)
Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM KernelsMem KernelEnv KernelOp MemEntry
-> ImpM KernelsMem KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
srcloc)
case (Space
dest_space, Space
src_space) of
(ScalarSpace [SubExp]
destds PrimType
_, ScalarSpace [SubExp]
srcds PrimType
_) -> do
let destslice' :: [DimIndex (TExp Int64)]
destslice' =
Int -> DimIndex (TExp Int64) -> [DimIndex (TExp Int64)]
forall a. Int -> a -> [a]
replicate ([DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
destslice Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
0)
[DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ Int -> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
destds) [DimIndex (TExp Int64)]
destslice
srcslice' :: [DimIndex (TExp Int64)]
srcslice' =
Int -> DimIndex (TExp Int64) -> [DimIndex (TExp Int64)]
forall a. Int -> a -> [a]
replicate ([DimIndex (TExp Int64)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex (TExp Int64)]
srcslice Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) (TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix TExp Int64
0)
[DimIndex (TExp Int64)]
-> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. [a] -> [a] -> [a]
++ Int -> [DimIndex (TExp Int64)] -> [DimIndex (TExp Int64)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
srcds) [DimIndex (TExp Int64)]
srcslice
CopyCompiler KernelsMem KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
destloc [DimIndex (TExp Int64)]
destslice' MemLocation
srcloc [DimIndex (TExp Int64)]
srcslice'
(Space, Space)
_ -> do
[TExp Int64] -> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace ([DimIndex (TExp Int64)] -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
destslice) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is ->
CopyCompiler KernelsMem KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise
PrimType
pt
MemLocation
destloc
((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix ([TExp Int64] -> [DimIndex (TExp Int64)])
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Int64]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice [TExp Int64]
is)
MemLocation
srcloc
((TExp Int64 -> DimIndex (TExp Int64))
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> DimIndex d
DimFix ([TExp Int64] -> [DimIndex (TExp Int64)])
-> [TExp Int64] -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Int64]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice [TExp Int64]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
threadOperations, groupOperations :: Operations KernelsMem KernelEnv Imp.KernelOp
threadOperations :: Operations KernelsMem KernelEnv KernelOp
threadOperations =
(OpCompiler KernelsMem KernelEnv KernelOp
-> Operations KernelsMem KernelEnv KernelOp
forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler KernelsMem KernelEnv KernelOp
compileThreadOp)
{ opsCopyCompiler :: CopyCompiler KernelsMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler KernelsMem KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise,
opsExpCompiler :: ExpCompiler KernelsMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler KernelsMem KernelEnv KernelOp
compileThreadExp,
opsStmsCompiler :: Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
opsAllocCompilers :: Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
opsAllocCompilers =
[(Space, AllocCompiler KernelsMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String -> Space
Space String
"local", AllocCompiler KernelsMem KernelEnv KernelOp
forall r. AllocCompiler KernelsMem r KernelOp
allocLocal)]
}
groupOperations :: Operations KernelsMem KernelEnv KernelOp
groupOperations =
(OpCompiler KernelsMem KernelEnv KernelOp
-> Operations KernelsMem KernelEnv KernelOp
forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler KernelsMem KernelEnv KernelOp
compileGroupOp)
{ opsCopyCompiler :: CopyCompiler KernelsMem KernelEnv KernelOp
opsCopyCompiler = CopyCompiler KernelsMem KernelEnv KernelOp
copyInGroup,
opsExpCompiler :: ExpCompiler KernelsMem KernelEnv KernelOp
opsExpCompiler = ExpCompiler KernelsMem KernelEnv KernelOp
compileGroupExp,
opsStmsCompiler :: Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms KernelsMem -> InKernelGen () -> InKernelGen ()
forall lore op r.
(Mem lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty,
opsAllocCompilers :: Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
opsAllocCompilers =
[(Space, AllocCompiler KernelsMem KernelEnv KernelOp)]
-> Map Space (AllocCompiler KernelsMem KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(String -> Space
Space String
"local", AllocCompiler KernelsMem KernelEnv KernelOp
forall r. AllocCompiler KernelsMem r KernelOp
allocLocal)]
}
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se = do
TypeBase (ShapeBase SubExp) NoUniqueness
t <- SubExp
-> ImpM
KernelsMem
HostEnv
HostOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall t (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
subExpType SubExp
se
[SubExp]
ds <- Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
dropLast (TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
t) ([SubExp] -> [SubExp])
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp])
-> ImpM
KernelsMem
HostEnv
HostOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
-> ImpM KernelsMem HostEnv HostOp [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM
KernelsMem
HostEnv
HostOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
arr
let dims :: [TExp Int64]
dims = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp ([SubExp] -> [TExp Int64]) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ds [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ TypeBase (ShapeBase SubExp) NoUniqueness -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims TypeBase (ShapeBase SubExp) NoUniqueness
t
(KernelConstants
constants, InKernelGen ()
set_constants) <-
TExp Int64
-> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants ([TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64) -> [TExp Int64] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 [TExp Int64]
dims) String
"replicate"
Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let name :: Name
name =
Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
is' :: [TExp Int64]
is' = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
dims (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Bool
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix VName
arr [TExp Int64]
is' SubExp
se ([TExp Int64] -> InKernelGen ()) -> [TExp Int64] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Int -> [TExp Int64] -> [TExp Int64]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [TExp Int64]
is'
replicateName :: PrimType -> String
replicateName :: PrimType -> String
replicateName PrimType
bt = String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt
replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> ImpM KernelsMem HostEnv HostOp Name
replicateForType PrimType
bt = do
let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
replicateName PrimType
bt
Bool
exists <- Name -> ImpM KernelsMem HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
Bool
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
VName
mem <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
VName
num_elems <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"num_elems"
VName
val <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"val"
let params :: [Param]
params =
[ VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int32,
VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt
]
shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
Name
-> [Param]
-> [Param]
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [] [Param]
params (ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
VName
arr <-
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray String
"arr" PrimType
bt ShapeBase SubExp
shape (MemBind -> ImpM KernelsMem HostEnv HostOp VName)
-> MemBind -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicateKernel VName
arr (SubExp -> ImpM KernelsMem HostEnv HostOp ())
-> SubExp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val
Name -> ImpM KernelsMem HostEnv HostOp Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname
replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName
-> SubExp
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
v = do
ArrayEntry (MemLocation VName
arr_mem [SubExp]
arr_shape IxFun (TExp Int64)
arr_ixfun) PrimType
_ <- VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
TypeBase (ShapeBase SubExp) NoUniqueness
v_t <- SubExp
-> ImpM
KernelsMem
HostEnv
HostOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall t (m :: * -> *).
HasScope t m =>
SubExp -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
subExpType SubExp
v
case TypeBase (ShapeBase SubExp) NoUniqueness
v_t of
Prim PrimType
v_t'
| IxFun (TExp Int64) -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun (TExp Int64)
arr_ixfun -> Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ())))
-> Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
forall a b. (a -> b) -> a -> b
$
ImpM KernelsMem HostEnv HostOp ()
-> Maybe (ImpM KernelsMem HostEnv HostOp ())
forall a. a -> Maybe a
Just (ImpM KernelsMem HostEnv HostOp ()
-> Maybe (ImpM KernelsMem HostEnv HostOp ()))
-> ImpM KernelsMem HostEnv HostOp ()
-> Maybe (ImpM KernelsMem HostEnv HostOp ())
forall a b. (a -> b) -> a -> b
$ do
Name
fname <- PrimType -> ImpM KernelsMem HostEnv HostOp Name
replicateForType PrimType
v_t'
Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
[]
Name
fname
[ VName -> Arg
Imp.MemArg VName
arr_mem,
PrimExp ExpLeaf -> Arg
Imp.ExpArg (PrimExp ExpLeaf -> Arg) -> PrimExp ExpLeaf -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> PrimExp ExpLeaf) -> TExp Int64 -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TExp Int64] -> TExp Int64) -> [TExp Int64] -> TExp Int64
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
arr_shape,
PrimExp ExpLeaf -> Arg
Imp.ExpArg (PrimExp ExpLeaf -> Arg) -> PrimExp ExpLeaf -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> PrimExp ExpLeaf
forall a. ToExp a => PrimType -> a -> PrimExp ExpLeaf
toExp' PrimType
v_t' SubExp
v
]
TypeBase (ShapeBase SubExp) NoUniqueness
_ -> Maybe (ImpM KernelsMem HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ImpM KernelsMem HostEnv HostOp ())
forall a. Maybe a
Nothing
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicate VName
arr SubExp
se = do
Maybe (ImpM KernelsMem HostEnv HostOp ())
is_fill <- VName
-> SubExp
-> CallKernelGen (Maybe (ImpM KernelsMem HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
se
case Maybe (ImpM KernelsMem HostEnv HostOp ())
is_fill of
Just ImpM KernelsMem HostEnv HostOp ()
m -> ImpM KernelsMem HostEnv HostOp ()
m
Maybe (ImpM KernelsMem HostEnv HostOp ())
Nothing -> VName -> SubExp -> ImpM KernelsMem HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se
sIotaKernel ::
VName ->
Imp.TExp Int64 ->
Imp.Exp ->
Imp.Exp ->
IntType ->
CallKernelGen ()
sIotaKernel :: VName
-> TExp Int64
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIotaKernel VName
arr TExp Int64
n PrimExp ExpLeaf
x PrimExp ExpLeaf
s IntType
et = do
MemLocation
destloc <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM KernelsMem HostEnv HostOp ArrayEntry
-> ImpM KernelsMem HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
(KernelConstants
constants, InKernelGen ()
set_constants) <- TExp Int64
-> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants TExp Int64
n String
"iota"
Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let name :: Name
name =
Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
et String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_"
String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
let gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Bool
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(VName
destmem, Space
destspace, Count Elements (TExp Int64)
destidx) <- MemLocation
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destloc [TExp Int64
gtid]
Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> PrimExp ExpLeaf
-> Code KernelOp
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> PrimExp ExpLeaf
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile (PrimExp ExpLeaf -> Code KernelOp)
-> PrimExp ExpLeaf -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp
(IntType -> Overflow -> BinOp
Add IntType
et Overflow
OverflowWrap)
(BinOp -> PrimExp ExpLeaf -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
et Overflow
OverflowWrap) (IntType -> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall v. IntType -> PrimExp v -> PrimExp v
Imp.sExt IntType
et (PrimExp ExpLeaf -> PrimExp ExpLeaf)
-> PrimExp ExpLeaf -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
gtid) PrimExp ExpLeaf
s)
PrimExp ExpLeaf
x
iotaName :: IntType -> String
iotaName :: IntType -> String
iotaName IntType
bt = String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
bt
iotaForType :: IntType -> CallKernelGen Name
iotaForType :: IntType -> ImpM KernelsMem HostEnv HostOp Name
iotaForType IntType
bt = do
let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> IntType -> String
iotaName IntType
bt
Bool
exists <- Name -> ImpM KernelsMem HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
Bool
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
VName
mem <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
VName
n <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"n"
VName
x <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
VName
s <- String -> ImpM KernelsMem HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"s"
let params :: [Param]
params =
[ VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
VName -> PrimType -> Param
Imp.ScalarParam VName
n PrimType
int32,
VName -> PrimType -> Param
Imp.ScalarParam VName
x (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt,
VName -> PrimType -> Param
Imp.ScalarParam VName
s (PrimType -> Param) -> PrimType -> Param
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
]
shape :: ShapeBase SubExp
shape = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
n]
n' :: TExp Int64
n' = VName -> TExp Int64
Imp.vi64 VName
n
x' :: PrimExp ExpLeaf
x' = VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
x (PrimType -> PrimExp ExpLeaf) -> PrimType -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
s' :: PrimExp ExpLeaf
s' = VName -> PrimType -> PrimExp ExpLeaf
Imp.var VName
s (PrimType -> PrimExp ExpLeaf) -> PrimType -> PrimExp ExpLeaf
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
bt
Name
-> [Param]
-> [Param]
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall lore r op.
Name
-> [Param] -> [Param] -> ImpM lore r op () -> ImpM lore r op ()
function Name
fname [] [Param]
params (ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ())
-> ImpM KernelsMem HostEnv HostOp ()
-> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
VName
arr <-
String
-> PrimType
-> ShapeBase SubExp
-> MemBind
-> ImpM KernelsMem HostEnv HostOp VName
forall lore r op.
String
-> PrimType -> ShapeBase SubExp -> MemBind -> ImpM lore r op VName
sArray String
"arr" (IntType -> PrimType
IntType IntType
bt) ShapeBase SubExp
shape (MemBind -> ImpM KernelsMem HostEnv HostOp VName)
-> MemBind -> ImpM KernelsMem HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$
Shape (TPrimExp Int64 VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (TPrimExp Int64 VName) -> IxFun)
-> Shape (TPrimExp Int64 VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 ([SubExp] -> Shape (TPrimExp Int64 VName))
-> [SubExp] -> Shape (TPrimExp Int64 VName)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shape
VName
-> TExp Int64
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIotaKernel VName
arr (TExp Int64 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 TExp Int64
n') PrimExp ExpLeaf
x' PrimExp ExpLeaf
s' IntType
bt
Name -> ImpM KernelsMem HostEnv HostOp Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname
sIota ::
VName ->
Imp.TExp Int64 ->
Imp.Exp ->
Imp.Exp ->
IntType ->
CallKernelGen ()
sIota :: VName
-> TExp Int64
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIota VName
arr TExp Int64
n PrimExp ExpLeaf
x PrimExp ExpLeaf
s IntType
et = do
ArrayEntry (MemLocation VName
arr_mem [SubExp]
_ IxFun (TExp Int64)
arr_ixfun) PrimType
_ <- VName -> ImpM KernelsMem HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
if IxFun (TExp Int64) -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun (TExp Int64)
arr_ixfun
then do
Name
fname <- IntType -> ImpM KernelsMem HostEnv HostOp Name
iotaForType IntType
et
Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM KernelsMem HostEnv HostOp ())
-> Code HostOp -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$
[VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call
[]
Name
fname
[VName -> Arg
Imp.MemArg VName
arr_mem, PrimExp ExpLeaf -> Arg
Imp.ExpArg (PrimExp ExpLeaf -> Arg) -> PrimExp ExpLeaf -> Arg
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> PrimExp ExpLeaf
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
n, PrimExp ExpLeaf -> Arg
Imp.ExpArg PrimExp ExpLeaf
x, PrimExp ExpLeaf -> Arg
Imp.ExpArg PrimExp ExpLeaf
s]
else VName
-> TExp Int64
-> PrimExp ExpLeaf
-> PrimExp ExpLeaf
-> IntType
-> ImpM KernelsMem HostEnv HostOp ()
sIotaKernel VName
arr TExp Int64
n PrimExp ExpLeaf
x PrimExp ExpLeaf
s IntType
et
sCopy :: CopyCompiler KernelsMem HostEnv Imp.HostOp
sCopy :: CopyCompiler KernelsMem HostEnv HostOp
sCopy
PrimType
bt
destloc :: MemLocation
destloc@(MemLocation VName
destmem [SubExp]
_ IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
destslice
srcloc :: MemLocation
srcloc@(MemLocation VName
srcmem [SubExp]
_ IxFun (TExp Int64)
_)
[DimIndex (TExp Int64)]
srcslice =
do
let shape :: [TExp Int64]
shape = [DimIndex (TExp Int64)] -> [TExp Int64]
forall d. Slice d -> [d]
sliceDims [DimIndex (TExp Int64)]
srcslice
kernel_size :: TExp Int64
kernel_size = [TExp Int64] -> TExp Int64
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [TExp Int64]
shape
(KernelConstants
constants, InKernelGen ()
set_constants) <- TExp Int64
-> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants TExp Int64
kernel_size String
"copy"
Maybe Name
fname <- ImpM KernelsMem HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let name :: Name
name =
Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$
String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$
String
"copy_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
Bool
-> Operations KernelsMem KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM KernelsMem HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations KernelsMem KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM KernelsMem HostEnv HostOp ())
-> InKernelGen () -> ImpM KernelsMem HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
let gtid :: TExp Int64
gtid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants
dest_is :: [TExp Int64]
dest_is = [TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
shape TExp Int64
gtid
src_is :: [TExp Int64]
src_is = [TExp Int64]
dest_is
(VName
_, Space
destspace, Count Elements (TExp Int64)
destidx) <-
MemLocation
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
destloc ([TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64)))
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Int64]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
destslice [TExp Int64]
dest_is
(VName
_, Space
srcspace, Count Elements (TExp Int64)
srcidx) <-
MemLocation
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall lore r op.
MemLocation
-> [TExp Int64]
-> ImpM lore r op (VName, Space, Count Elements (TExp Int64))
fullyIndexArray' MemLocation
srcloc ([TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64)))
-> [TExp Int64]
-> ImpM
KernelsMem
KernelEnv
KernelOp
(VName, Space, Count Elements (TExp Int64))
forall a b. (a -> b) -> a -> b
$ [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Int64]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
srcslice [TExp Int64]
src_is
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int64
gtid TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TExp Int64
kernel_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> PrimExp ExpLeaf
-> Code KernelOp
forall a.
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> PrimExp ExpLeaf
-> Code a
Imp.Write VName
destmem Count Elements (TExp Int64)
destidx PrimType
bt Space
destspace Volatility
Imp.Nonvolatile (PrimExp ExpLeaf -> Code KernelOp)
-> PrimExp ExpLeaf -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> PrimExp ExpLeaf
Imp.index VName
srcmem Count Elements (TExp Int64)
srcidx PrimType
bt Space
srcspace Volatility
Imp.Nonvolatile
compileGroupResult ::
SegSpace ->
PatElem KernelsMem ->
KernelResult ->
InKernelGen ()
compileGroupResult :: SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileGroupResult SegSpace
_ PatElemT (LetDec KernelsMem)
pe (TileReturns [(SubExp
w, SubExp
per_group_elems)] VName
what) = do
TExp Int64
n <- SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (TypeBase (ShapeBase SubExp) NoUniqueness -> TExp Int64)
-> ImpM
KernelsMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM
KernelsMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let ltid :: TExp Int64
ltid = TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
offset :: TExp Int64
offset =
SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
per_group_elems
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGroupId KernelConstants
constants)
Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
if SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
per_group_elems TExp Int64 -> TExp Int64 -> Bool
forall a. Eq a => a -> a -> Bool
== KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants
then
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
offset TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64
ltid TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
offset] (VName -> SubExp
Var VName
what) [TExp Int64
ltid]
else String
-> TExp Int64 -> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall t lore r op.
String
-> TExp t -> (TExp t -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (TExp Int64
n TExp Int64 -> TExp Int64 -> TExp Int64
forall e. IntegralExp e => e -> e -> e
`divUp` KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants) ((TExp Int64 -> InKernelGen ()) -> InKernelGen ())
-> (TExp Int64 -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \TExp Int64
i -> do
TExp Int64
j <- String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"j" (TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int64
kernelGroupSize KernelConstants
constants TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int64
i TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
ltid
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (TExp Int64
j TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
offset TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64
j TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
+ TExp Int64
offset] (VName -> SubExp
Var VName
what) [TExp Int64
j]
compileGroupResult SegSpace
space PatElemT (LetDec KernelsMem)
pe (TileReturns [(SubExp, SubExp)]
dims VName
what) = do
let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
out_tile_sizes :: [TExp Int64]
out_tile_sizes = ((SubExp, SubExp) -> TExp Int64)
-> [(SubExp, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> ((SubExp, SubExp) -> SubExp) -> (SubExp, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
group_is :: [TExp Int64]
group_is = (TExp Int64 -> TExp Int64 -> TExp Int64)
-> [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(*) ((VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gids) [TExp Int64]
out_tile_sizes
[TExp Int64]
local_is <- [SubExp] -> InKernelGen [TExp Int64]
localThreadIDs ([SubExp] -> InKernelGen [TExp Int64])
-> [SubExp] -> InKernelGen [TExp Int64]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(SubExp, SubExp)]
dims
[TV Int64]
is_for_thread <-
(TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64))
-> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp [TV Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TV Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TV t)
dPrimV String
"thread_out_index") ([TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp [TV Int64])
-> [TExp Int64] -> ImpM KernelsMem KernelEnv KernelOp [TV Int64]
forall a b. (a -> b) -> a -> b
$
(TExp Int64 -> TExp Int64 -> TExp Int64)
-> [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
(+) [TExp Int64]
group_is [TExp Int64]
local_is
Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> TExp Bool
isActive ([(VName, SubExp)] -> TExp Bool) -> [(VName, SubExp)] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((TV Int64 -> VName) -> [TV Int64] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> VName
forall t. TV t -> VName
tvVar [TV Int64]
is_for_thread) ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) ((TV Int64 -> TExp Int64) -> [TV Int64] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map TV Int64 -> TExp Int64
forall t. TV t -> TExp t
tvExp [TV Int64]
is_for_thread) (VName -> SubExp
Var VName
what) [TExp Int64]
local_is
compileGroupResult SegSpace
space PatElemT (LetDec KernelsMem)
pe (RegTileReturns [(SubExp, SubExp, SubExp)]
dims_n_tiles VName
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
([SubExp]
dims, [SubExp]
group_tiles, [SubExp]
reg_tiles) = [(SubExp, SubExp, SubExp)] -> ([SubExp], [SubExp], [SubExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(SubExp, SubExp, SubExp)]
dims_n_tiles
group_tiles' :: [TExp Int64]
group_tiles' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
group_tiles
reg_tiles' :: [TExp Int64]
reg_tiles' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
reg_tiles
let group_tile_is :: [TExp Int64]
group_tile_is = (VName -> TExp Int64) -> [VName] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TExp Int64
Imp.vi64 [VName]
gids
[TExp Int64]
reg_tile_is <-
(TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64))
-> [TExp Int64] -> InKernelGen [TExp Int64]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> TExp Int64 -> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"reg_tile_i") ([TExp Int64] -> InKernelGen [TExp Int64])
-> [TExp Int64] -> InKernelGen [TExp Int64]
forall a b. (a -> b) -> a -> b
$
[TExp Int64] -> TExp Int64 -> [TExp Int64]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [TExp Int64]
group_tiles' (TExp Int64 -> [TExp Int64]) -> TExp Int64 -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64) -> TExp Int32 -> TExp Int64
forall a b. (a -> b) -> a -> b
$ KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants
let regTileSliceDim :: (TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM lore r op (DimIndex (TExp t))
regTileSliceDim (TExp t
group_tile, TExp t
group_tile_i) (TExp t
reg_tile, TExp t
reg_tile_i) = do
TExp t
tile_dim_start <-
String -> TExp t -> ImpM lore r op (TExp t)
forall t lore r op. String -> TExp t -> ImpM lore r op (TExp t)
dPrimVE String
"tile_dim_start" (TExp t -> ImpM lore r op (TExp t))
-> TExp t -> ImpM lore r op (TExp t)
forall a b. (a -> b) -> a -> b
$
TExp t
reg_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* (TExp t
group_tile TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
* TExp t
group_tile_i TExp t -> TExp t -> TExp t
forall a. Num a => a -> a -> a
+ TExp t
reg_tile_i)
DimIndex (TExp t) -> ImpM lore r op (DimIndex (TExp t))
forall (m :: * -> *) a. Monad m => a -> m a
return (DimIndex (TExp t) -> ImpM lore r op (DimIndex (TExp t)))
-> DimIndex (TExp t) -> ImpM lore r op (DimIndex (TExp t))
forall a b. (a -> b) -> a -> b
$ TExp t -> TExp t -> TExp t -> DimIndex (TExp t)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp t
tile_dim_start TExp t
reg_tile TExp t
1
[DimIndex (TExp Int64)]
reg_tile_slices <-
((TExp Int64, TExp Int64)
-> (TExp Int64, TExp Int64)
-> ImpM KernelsMem KernelEnv KernelOp (DimIndex (TExp Int64)))
-> [(TExp Int64, TExp Int64)]
-> [(TExp Int64, TExp Int64)]
-> ImpM KernelsMem KernelEnv KernelOp [DimIndex (TExp Int64)]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
(TExp Int64, TExp Int64)
-> (TExp Int64, TExp Int64)
-> ImpM KernelsMem KernelEnv KernelOp (DimIndex (TExp Int64))
forall {t} {lore} {r} {op}.
NumExp t =>
(TExp t, TExp t)
-> (TExp t, TExp t) -> ImpM lore r op (DimIndex (TExp t))
regTileSliceDim
([TExp Int64] -> [TExp Int64] -> [(TExp Int64, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TExp Int64]
group_tiles' [TExp Int64]
group_tile_is)
([TExp Int64] -> [TExp Int64] -> [(TExp Int64, TExp Int64)]
forall a b. [a] -> [b] -> [(a, b)]
zip [TExp Int64]
reg_tiles' [TExp Int64]
reg_tile_is)
Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
ShapeBase SubExp
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
ShapeBase SubExp
-> ([TExp Int64] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
reg_tiles) (([TExp Int64] -> InKernelGen ()) -> InKernelGen ())
-> ([TExp Int64] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[TExp Int64]
is_in_reg_tile -> do
let dest_is :: [TExp Int64]
dest_is = [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Int64]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice [DimIndex (TExp Int64)]
reg_tile_slices [TExp Int64]
is_in_reg_tile
src_is :: [TExp Int64]
src_is = [TExp Int64]
reg_tile_is [TExp Int64] -> [TExp Int64] -> [TExp Int64]
forall a. [a] -> [a] -> [a]
++ [TExp Int64]
is_in_reg_tile
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen ((TExp Bool -> TExp Bool -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$ (TExp Int64 -> TExp Int64 -> TExp Bool)
-> [TExp Int64] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TExp Int64 -> TExp Int64 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
(.<.) [TExp Int64]
dest_is ([TExp Int64] -> [TExp Bool]) -> [TExp Int64] -> [TExp Bool]
forall a b. (a -> b) -> a -> b
$ (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64]
dest_is (VName -> SubExp
Var VName
what) [TExp Int64]
src_is
compileGroupResult SegSpace
space PatElemT (LetDec KernelsMem)
pe (Returns ResultManifest
_ SubExp
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
let gids :: [TExp Int64]
gids = ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
if Bool -> Bool
not Bool
in_local_memory
then
Operations KernelsMem KernelEnv KernelOp
-> InKernelGen () -> InKernelGen ()
forall lore r op a.
Operations lore r op -> ImpM lore r op a -> ImpM lore r op a
localOps Operations KernelsMem KernelEnv KernelOp
threadOperations (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> TExp Int32
kernelLocalThreadId KernelConstants
constants TExp Int32 -> TExp Int32 -> TExp Bool
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TExp Int32
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64]
gids SubExp
what []
else
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElemT (LetDec KernelsMem)
_ WriteReturns {} =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: WriteReturns not handled yet."
compileGroupResult SegSpace
_ PatElemT (LetDec KernelsMem)
_ ConcatReturns {} =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: ConcatReturns not handled yet."
compileThreadResult ::
SegSpace ->
PatElem KernelsMem ->
KernelResult ->
InKernelGen ()
compileThreadResult :: SegSpace
-> PatElemT (LetDec KernelsMem) -> KernelResult -> InKernelGen ()
compileThreadResult SegSpace
_ PatElemT (LetDec KernelsMem)
_ RegTileReturns {} =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileThreadResult: RegTileReturns not yet handled."
compileThreadResult SegSpace
space PatElemT (LetDec KernelsMem)
pe (Returns ResultManifest
_ SubExp
what) = do
let is :: [TExp Int64]
is = ((VName, SubExp) -> TExp Int64)
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> TExp Int64
Imp.vi64 (VName -> TExp Int64)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [TExp Int64])
-> [(VName, SubExp)] -> [TExp Int64]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
VName -> [TExp Int64] -> SubExp -> [TExp Int64] -> InKernelGen ()
forall lore r op.
VName
-> [TExp Int64] -> SubExp -> [TExp Int64] -> ImpM lore r op ()
copyDWIMFix (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64]
is SubExp
what []
compileThreadResult SegSpace
_ PatElemT (LetDec KernelsMem)
pe (ConcatReturns SplitOrdering
SplitContiguous SubExp
_ SubExp
per_thread_elems VName
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let offset :: TExp Int64
offset =
SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
per_thread_elems
TExp Int64 -> TExp Int64 -> TExp Int64
forall a. Num a => a -> a -> a
* TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (KernelConstants -> TExp Int32
kernelGlobalThreadId KernelConstants
constants)
TExp Int64
n <- SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (TypeBase (ShapeBase SubExp) NoUniqueness -> TExp Int64)
-> ImpM
KernelsMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM
KernelsMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
offset TExp Int64
n TExp Int64
1] (VName -> SubExp
Var VName
what) []
compileThreadResult SegSpace
_ PatElemT (LetDec KernelsMem)
pe (ConcatReturns (SplitStrided SubExp
stride) SubExp
_ SubExp
_ VName
what) = do
TExp Int64
offset <- TExp Int32 -> TExp Int64
forall t v. IntExp t => TPrimExp t v -> TPrimExp Int64 v
sExt64 (TExp Int32 -> TExp Int64)
-> (KernelEnv -> TExp Int32) -> KernelEnv -> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelConstants -> TExp Int32
kernelGlobalThreadId (KernelConstants -> TExp Int32)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> TExp Int32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> TExp Int64)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
TExp Int64
n <- SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp (SubExp -> TExp Int64)
-> (TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TExp Int64
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> TypeBase (ShapeBase SubExp) NoUniqueness -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 (TypeBase (ShapeBase SubExp) NoUniqueness -> TExp Int64)
-> ImpM
KernelsMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
-> ImpM KernelsMem KernelEnv KernelOp (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName
-> ImpM
KernelsMem
KernelEnv
KernelOp
(TypeBase (ShapeBase SubExp) NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase (ShapeBase SubExp) NoUniqueness)
lookupType VName
what
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [TExp Int64 -> TExp Int64 -> TExp Int64 -> DimIndex (TExp Int64)
forall d. d -> d -> d -> DimIndex d
DimSlice TExp Int64
offset TExp Int64
n (TExp Int64 -> DimIndex (TExp Int64))
-> TExp Int64 -> DimIndex (TExp Int64)
forall a b. (a -> b) -> a -> b
$ SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp SubExp
stride] (VName -> SubExp
Var VName
what) []
compileThreadResult SegSpace
_ PatElemT (LetDec KernelsMem)
pe (WriteReturns (Shape [SubExp]
rws) VName
_arr [(Slice SubExp, SubExp)]
dests) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM KernelsMem KernelEnv KernelOp KernelEnv
-> ImpM KernelsMem KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM KernelsMem KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let rws' :: [TExp Int64]
rws' = (SubExp -> TExp Int64) -> [SubExp] -> [TExp Int64]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp [SubExp]
rws
[(Slice SubExp, SubExp)]
-> ((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Slice SubExp, SubExp)]
dests (((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Slice SubExp, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Slice SubExp
slice, SubExp
e) -> do
let slice' :: [DimIndex (TExp Int64)]
slice' = (DimIndex SubExp -> DimIndex (TExp Int64))
-> Slice SubExp -> [DimIndex (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map ((SubExp -> TExp Int64) -> DimIndex SubExp -> DimIndex (TExp Int64)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TExp Int64
forall a. ToExp a => a -> TExp Int64
toInt64Exp) Slice SubExp
slice
condInBounds :: DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds (DimFix TPrimExp t v
i) TPrimExp t v
rw =
TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
rw
condInBounds (DimSlice TPrimExp t v
i TPrimExp t v
n TPrimExp t v
s) TPrimExp t v
rw =
TPrimExp t v
0 TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<=. TPrimExp t v
i TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. TPrimExp t v
i TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
+ TPrimExp t v
n TPrimExp t v -> TPrimExp t v -> TPrimExp t v
forall a. Num a => a -> a -> a
* TPrimExp t v
s TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. TPrimExp t v
rw
write :: TExp Bool
write =
(TExp Bool -> TExp Bool -> TExp Bool)
-> TExp Bool -> [TExp Bool] -> TExp Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl TExp Bool -> TExp Bool -> TExp Bool
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
(.&&.) (KernelConstants -> TExp Bool
kernelThreadActive KernelConstants
constants) ([TExp Bool] -> TExp Bool) -> [TExp Bool] -> TExp Bool
forall a b. (a -> b) -> a -> b
$
(DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool)
-> [DimIndex (TExp Int64)] -> [TExp Int64] -> [TExp Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith DimIndex (TExp Int64) -> TExp Int64 -> TExp Bool
forall {t} {v}.
(NumExp t, Pretty v) =>
DimIndex (TPrimExp t v) -> TPrimExp t v -> TPrimExp Bool v
condInBounds [DimIndex (TExp Int64)]
slice' [TExp Int64]
rws'
TExp Bool -> InKernelGen () -> InKernelGen ()
forall lore r op.
TExp Bool -> ImpM lore r op () -> ImpM lore r op ()
sWhen TExp Bool
write (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> InKernelGen ()
forall lore r op.
VName
-> [DimIndex (TExp Int64)]
-> SubExp
-> [DimIndex (TExp Int64)]
-> ImpM lore r op ()
copyDWIM (PatElemT LParamMem -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec KernelsMem)
PatElemT LParamMem
pe) [DimIndex (TExp Int64)]
slice' SubExp
e []
compileThreadResult SegSpace
_ PatElemT (LetDec KernelsMem)
_ TileReturns {} =
String -> InKernelGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: TileReturns unhandled."
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
VarEntry KernelsMem
res <- VName -> ImpM KernelsMem KernelEnv KernelOp (VarEntry KernelsMem)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry KernelsMem
res of
ArrayVar Maybe (Exp KernelsMem)
_ ArrayEntry
entry ->
(String -> Space
Space String
"local" Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
==) (Space -> Bool) -> (MemEntry -> Space) -> MemEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace
(MemEntry -> Bool)
-> ImpM KernelsMem KernelEnv KernelOp MemEntry -> InKernelGen Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM KernelsMem KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
entry))
VarEntry KernelsMem
_ -> Bool -> InKernelGen Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
arrayInLocalMemory Constant {} = Bool -> InKernelGen Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False