{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.CodeGen.ImpGen.Kernels.Base
( KernelConstants (..)
, keyWithEntryPoint
, CallKernelGen
, InKernelGen
, HostEnv (..)
, KernelEnv (..)
, computeThreadChunkSize
, groupReduce
, groupScan
, isActive
, sKernelThread
, sKernelGroup
, sReplicate
, sIota
, sCopy
, compileThreadResult
, compileGroupResult
, virtualiseGroups
, groupLoop
, kernelLoop
, groupCoverSpace
, atomicUpdateLocking
, AtomicBinOp
, Locking(..)
, AtomicUpdate(..)
, DoAtomicUpdate
)
where
import Control.Monad.Except
import Data.Maybe
import qualified Data.Map.Strict as M
import Data.List (elemIndex, find, nub, zip4)
import Prelude hiding (quot, rem)
import Futhark.Error
import Futhark.MonadFreshNames
import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpGen
import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
import Futhark.Util (chunks, maybeNth, mapAccumLM, takeLast, dropLast)
newtype HostEnv = HostEnv
{ HostEnv -> AtomicBinOp
hostAtomics :: AtomicBinOp }
data KernelEnv = KernelEnv
{ KernelEnv -> AtomicBinOp
kernelAtomics :: AtomicBinOp
, KernelEnv -> KernelConstants
kernelConstants :: KernelConstants
}
type CallKernelGen = ImpM ExplicitMemory HostEnv Imp.HostOp
type InKernelGen = ImpM ExplicitMemory KernelEnv Imp.KernelOp
data KernelConstants = KernelConstants
{ KernelConstants -> Exp
kernelGlobalThreadId :: Imp.Exp
, KernelConstants -> Exp
kernelLocalThreadId :: Imp.Exp
, KernelConstants -> Exp
kernelGroupId :: Imp.Exp
, KernelConstants -> VName
kernelGlobalThreadIdVar :: VName
, KernelConstants -> VName
kernelLocalThreadIdVar :: VName
, KernelConstants -> VName
kernelGroupIdVar :: VName
, KernelConstants -> Exp
kernelNumGroups :: Imp.Exp
, KernelConstants -> Exp
kernelGroupSize :: Imp.Exp
, KernelConstants -> Exp
kernelNumThreads :: Imp.Exp
, KernelConstants -> Exp
kernelWaveSize :: Imp.Exp
, KernelConstants -> Exp
kernelThreadActive :: Imp.Exp
}
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint :: Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key =
String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String -> (Name -> String) -> Maybe Name -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ((String -> String -> String
forall a. [a] -> [a] -> [a]
++String
".") (String -> String) -> (Name -> String) -> Name -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> String
nameToString) Maybe Name
fname String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
nameToString Name
key
allocLocal :: AllocCompiler ExplicitMemory r Imp.KernelOp
allocLocal :: AllocCompiler ExplicitMemory r KernelOp
allocLocal VName
mem Count Bytes Exp
size =
KernelOp -> ImpM ExplicitMemory r KernelOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> ImpM ExplicitMemory r KernelOp ())
-> KernelOp -> ImpM ExplicitMemory r KernelOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Count Bytes Exp -> KernelOp
Imp.LocalAlloc VName
mem Count Bytes Exp
size
kernelAlloc :: Pattern ExplicitMemory
-> SubExp -> Space
-> InKernelGen ()
kernelAlloc :: Pattern ExplicitMemory -> SubExp -> Space -> InKernelGen ()
kernelAlloc (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
_]) SubExp
_ ScalarSpace{} =
() -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
kernelAlloc (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
mem]) SubExp
size (Space String
"local") = do
Exp
size' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
size
AllocCompiler ExplicitMemory KernelEnv KernelOp
forall r. AllocCompiler ExplicitMemory r KernelOp
allocLocal (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
mem) (Count Bytes Exp -> InKernelGen ())
-> Count Bytes Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> Count Bytes Exp
Imp.bytes Exp
size'
kernelAlloc (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
mem]) SubExp
_ Space
_ =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Cannot allocate memory block " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatElemT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
mem String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" in kernel."
kernelAlloc Pattern ExplicitMemory
dest SubExp
_ Space
_ =
String -> InKernelGen ()
forall a. HasCallStack => String -> a
error (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for in-kernel allocation: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Show a => a -> String
show Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
dest
splitSpace :: (ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern ExplicitMemory -> SplitOrdering -> w -> i -> elems_per_thread
-> ImpM lore r op ()
splitSpace :: Pattern ExplicitMemory
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace (Pattern [] [PatElemT (LetAttr ExplicitMemory)
size]) SplitOrdering
o w
w i
i elems_per_thread
elems_per_thread = do
Count Elements Exp
num_elements <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> w -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp w
w
Exp
i' <- i -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp i
i
Count Elements Exp
elems_per_thread' <- Exp -> Count Elements Exp
Imp.elements (Exp -> Count Elements Exp)
-> ImpM lore r op Exp -> ImpM lore r op (Count Elements Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> elems_per_thread -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp elems_per_thread
elems_per_thread
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
forall lore r op.
SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize SplitOrdering
o Exp
i' Count Elements Exp
elems_per_thread' Count Elements Exp
num_elements (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
size)
splitSpace Pattern ExplicitMemory
pat SplitOrdering
_ w
_ i
_ elems_per_thread
_ =
String -> ImpM lore r op ()
forall a. HasCallStack => String -> a
error (String -> ImpM lore r op ()) -> String -> ImpM lore r op ()
forall a b. (a -> b) -> a -> b
$ String
"Invalid target for splitSpace: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
compileThreadExp :: ExpCompiler ExplicitMemory KernelEnv Imp.KernelOp
compileThreadExp :: ExpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (ArrayLit [SubExp]
es Type
_)) =
[(Int32, SubExp)]
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int32] -> [SubExp] -> [(Int32, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int32
0..] [SubExp]
es) (((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int32
i,SubExp
e) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Int32 -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
i::Int32)] SubExp
e []
compileThreadExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e =
ExpCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e
kernelLoop :: Imp.Exp -> Imp.Exp -> Imp.Exp
-> (Imp.Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop :: Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop Exp
tid Exp
num_threads Exp
n Exp -> InKernelGen ()
f =
if Exp
n Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp
num_threads then
Exp -> InKernelGen ()
f Exp
tid
else do
let elems_for_this :: Exp
elems_for_this = (Exp
n Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
tid) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
num_threads
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
elems_for_this ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> Exp -> InKernelGen ()
f (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
num_threads Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
tid
groupLoop :: Imp.Exp
-> (Imp.Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop :: Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop Exp
n Exp -> InKernelGen ()
f = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Exp -> Exp -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
kernelLoop (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants) (KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) Exp
n Exp -> InKernelGen ()
f
groupCoverSpace :: [Imp.Exp]
-> ([Imp.Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace :: [Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds [Exp] -> InKernelGen ()
f =
Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
ds) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> InKernelGen ()
f ([Exp] -> InKernelGen ())
-> (Exp -> [Exp]) -> Exp -> InKernelGen ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
ds
groupCopy :: VName -> [Imp.Exp] -> SubExp -> [Imp.Exp] -> InKernelGen ()
groupCopy :: VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy VName
to [Exp]
to_is SubExp
from [Exp]
from_is = do
[Exp]
ds <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> (Type -> [SubExp])
-> Type
-> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
from
[Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
to ([Exp]
to_is[Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is) SubExp
from ([Exp]
from_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)
compileGroupExp :: ExpCompiler ExplicitMemory KernelEnv Imp.KernelOp
compileGroupExp :: ExpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (ArrayLit [SubExp]
es Type
_)) =
[(Int32, SubExp)]
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Int32] -> [SubExp] -> [(Int32, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int32
0..] [SubExp]
es) (((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ())
-> ((Int32, SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Int32
i,SubExp
e) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Int32 -> Exp
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32
i::Int32)] SubExp
e []
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Copy VName
arr)) = do
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [] (VName -> SubExp
Var VName
arr) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Manifest [Int]
_ VName
arr)) = do
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [] (VName -> SubExp
Var VName
arr) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Replicate Shape
ds SubExp
se)) = do
[Exp]
ds' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp])
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
ds
[Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [Exp]
ds' (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Exp]
is SubExp
se (Int -> [Exp] -> [Exp]
forall a. Int -> [a] -> [a]
drop (Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
ds) [Exp]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp (Pattern [PatElemT (LetAttr ExplicitMemory)]
_ [PatElemT (LetAttr ExplicitMemory)
dest]) (BasicOp (Iota SubExp
n SubExp
e SubExp
s IntType
_)) = do
Exp
n' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
n
Exp
e' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
e
Exp
s' <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
s
Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
groupLoop Exp
n' ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i' -> do
VName
x <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"x" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ Exp
e' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i' Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s'
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
dest) [Exp
i'] (VName -> SubExp
Var VName
x) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e =
ExpCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op.
ExplicitMemorish lore =>
Pattern lore -> Exp lore -> ImpM lore r op ()
defCompileExp Pattern ExplicitMemory
dest ExpT ExplicitMemory
e
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel :: SegLevel -> InKernelGen ()
sanityCheckLevel SegThread{} = () -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
sanityCheckLevel SegGroup{} =
String -> InKernelGen ()
forall a. HasCallStack => String -> a
error String
"compileGroupOp: unexpected group-level SegOp."
compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace :: SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space = do
SegLevel -> InKernelGen ()
sanityCheckLevel SegLevel
lvl
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
[Exp]
dims' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
Exp
ltid <- KernelConstants -> Exp
kernelLocalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
(VName -> Exp -> InKernelGen ())
-> [VName] -> [Exp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ [VName]
ltids ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims' Exp
ltid
VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) Exp
ltid
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp ExplicitMemory]
-> InKernelGen [[Imp.Exp] -> InKernelGen ()]
prepareIntraGroupSegHist :: Count GroupSize SubExp
-> [HistOp ExplicitMemory] -> InKernelGen [[Exp] -> InKernelGen ()]
prepareIntraGroupSegHist Count GroupSize SubExp
group_size =
((Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()])
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [[Exp] -> InKernelGen ()])
-> InKernelGen [[Exp] -> InKernelGen ()]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Locking, [[Exp] -> InKernelGen ()])
-> [[Exp] -> InKernelGen ()]
forall a b. (a, b) -> b
snd (ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [[Exp] -> InKernelGen ()])
-> InKernelGen [[Exp] -> InKernelGen ()])
-> ([HistOp ExplicitMemory]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [[Exp] -> InKernelGen ()]))
-> [HistOp ExplicitMemory]
-> InKernelGen [[Exp] -> InKernelGen ()]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe Locking
-> HistOp ExplicitMemory
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ()))
-> Maybe Locking
-> [HistOp ExplicitMemory]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [[Exp] -> InKernelGen ()])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM Maybe Locking
-> HistOp ExplicitMemory
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ())
onOp Maybe Locking
forall a. Maybe a
Nothing
where
onOp :: Maybe Locking
-> HistOp ExplicitMemory
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ())
onOp Maybe Locking
l HistOp ExplicitMemory
op = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
AtomicBinOp
atomicBinOp <- KernelEnv -> AtomicBinOp
kernelAtomics (KernelEnv -> AtomicBinOp)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp AtomicBinOp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let local_subhistos :: [VName]
local_subhistos = HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest HistOp ExplicitMemory
op
case (Maybe Locking
l, AtomicBinOp
-> Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp (Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv)
-> Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ HistOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. HistOp lore -> Lambda lore
histOp HistOp ExplicitMemory
op) of
(Maybe Locking
_, AtomicPrim DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate ExplicitMemory KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
(Maybe Locking
_, AtomicCAS DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, DoAtomicUpdate ExplicitMemory KernelEnv
f (String -> Space
Space String
"local") [VName]
local_subhistos)
(Just Locking
l', AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f) -> (Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Locking
l, Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
(Maybe Locking
Nothing, AtomicLocking Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f) -> do
VName
locks <- String -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"locks"
Exp
num_locks <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size
let dims :: [Exp]
dims = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) ([SubExp] -> [Exp]) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> a -> b
$
Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (HistOp ExplicitMemory -> Shape
forall lore. HistOp lore -> Shape
histShape HistOp ExplicitMemory
op) [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++
[HistOp ExplicitMemory -> SubExp
forall lore. HistOp lore -> SubExp
histWidth HistOp ExplicitMemory
op]
l' :: Locking
l' = VName -> Exp -> Exp -> Exp -> ([Exp] -> [Exp]) -> Locking
Locking VName
locks Exp
0 Exp
1 Exp
0 (Exp -> [Exp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> [Exp]) -> ([Exp] -> Exp) -> [Exp] -> [Exp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
num_locks) (Exp -> Exp) -> ([Exp] -> Exp) -> [Exp] -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Exp] -> [Exp] -> Exp
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [Exp]
dims)
locks_t :: Type
locks_t = PrimType -> Shape -> NoUniqueness -> Type
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
int32 ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount Count GroupSize SubExp
group_size]) NoUniqueness
NoUniqueness
VName
locks_mem <- String
-> Count Bytes Exp
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> Count Bytes Exp -> Space -> ImpM lore r op VName
sAlloc String
"locks_mem" (Type -> Count Bytes Exp
typeSize Type
locks_t) (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
VName -> PrimType -> Shape -> MemBind -> InKernelGen ()
forall lore r op.
VName -> PrimType -> Shape -> MemBind -> ImpM lore r op ()
dArray VName
locks PrimType
int32 (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
locks_t) (MemBind -> InKernelGen ()) -> MemBind -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> IxFun -> MemBind
ArrayIn VName
locks_mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
locks_t
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"All locks start out unlocked" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Exp] -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
groupCoverSpace [KernelConstants -> Exp
kernelGroupSize KernelConstants
constants] (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
locks [Exp]
is (IntType -> Integer -> SubExp
intConst IntType
Int32 Integer
0) []
(Maybe Locking, [Exp] -> InKernelGen ())
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(Maybe Locking, [Exp] -> InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Locking -> Maybe Locking
forall a. a -> Maybe a
Just Locking
l', Locking -> DoAtomicUpdate ExplicitMemory KernelEnv
f Locking
l' (String -> Space
Space String
"local") [VName]
local_subhistos)
compileGroupOp :: OpCompiler ExplicitMemory KernelEnv Imp.KernelOp
compileGroupOp :: OpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupOp Pattern ExplicitMemory
pat (Alloc size space) =
Pattern ExplicitMemory -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern ExplicitMemory
pat SubExp
size Space
space
compileGroupOp Pattern ExplicitMemory
pat (Inner (SizeOp (SplitSpace o w i elems_per_thread))) =
Pattern ExplicitMemory
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern ExplicitMemory
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern ExplicitMemory
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegMap lvl space _ body))) = do
InKernelGen () -> InKernelGen ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space) (PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) ([KernelResult] -> InKernelGen ())
-> [KernelResult] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegScan lvl space scan_op _ _ body))) = do
SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
[Exp]
dims' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT (MemInfo SubExp NoUniqueness MemBind) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat) ([KernelResult] -> [(VName, KernelResult)])
-> [KernelResult] -> [(VName, KernelResult)]
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest
((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
ltids)
(KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') Lambda ExplicitMemory
scan_op ([VName] -> InKernelGen ()) -> [VName] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
PatternT (MemInfo SubExp NoUniqueness MemBind) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegRed lvl space ops _ body))) = do
SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
let ([VName]
ltids, [SubExp]
dims) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, SubExp)] -> ([VName], [SubExp]))
-> [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) =
Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
[PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
ops) ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
[PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
[Exp]
dims' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
dims
let mkTempArr :: Type -> ImpM ExplicitMemory KernelEnv KernelOp VName
mkTempArr Type
t =
String
-> PrimType
-> Shape
-> Space
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op.
String -> PrimType -> Shape -> Space -> ImpM lore r op VName
sAllocArray String
"red_arr" (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims Shape -> Shape -> Shape
forall a. Semigroup a => a -> a -> a
<> Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Space -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Space -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$ String -> Space
Space String
"local"
[VName]
tmp_arrs <- (Type -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> [Type] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ImpM ExplicitMemory KernelEnv KernelOp VName
mkTempArr ([Type] -> ImpM ExplicitMemory KernelEnv KernelOp [VName])
-> [Type] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (SegRedOp ExplicitMemory -> [Type])
-> [SegRedOp ExplicitMemory] -> [Type]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType (Lambda ExplicitMemory -> [Type])
-> (SegRedOp ExplicitMemory -> Lambda ExplicitMemory)
-> SegRedOp ExplicitMemory
-> [Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda) [SegRedOp ExplicitMemory]
ops
let tmps_for_ops :: [[VName]]
tmps_for_ops = [Int] -> [VName] -> [[VName]]
forall a. [Int] -> [a] -> [[a]]
chunks ((SegRedOp ExplicitMemory -> Int)
-> [SegRedOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (SegRedOp ExplicitMemory -> [SubExp])
-> SegRedOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegRedOp ExplicitMemory -> [SubExp]
forall lore. SegRedOp lore -> [SubExp]
segRedNeutral) [SegRedOp ExplicitMemory]
ops) [VName]
tmp_arrs
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
body) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) =
Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SegRedOp ExplicitMemory] -> Int
forall lore. [SegRedOp lore] -> Int
segRedResults [SegRedOp ExplicitMemory]
ops) ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
body
[(VName, KernelResult)]
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [KernelResult] -> [(VName, KernelResult)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
tmp_arrs [KernelResult]
red_res) (((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ())
-> ((VName, KernelResult) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
dest, KernelResult
res) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
dest ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) [VName]
ltids) (KernelResult -> SubExp
kernelResultSubExp KernelResult
res) []
(PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes [KernelResult]
map_res
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
case [Exp]
dims' of
[Exp
dim'] -> do
[(SegRedOp ExplicitMemory, [VName])]
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]] -> [(SegRedOp ExplicitMemory, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp ExplicitMemory]
ops [[VName]]
tmps_for_ops) (((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOp ExplicitMemory
op, [VName]
tmps) ->
Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduce Exp
dim' (SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp ExplicitMemory
op) [VName]
tmps
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
[(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes [VName]
tmp_arrs) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [] (VName -> SubExp
Var VName
arr) [Exp
0]
[Exp]
_ -> do
let segment_size :: Exp
segment_size = [Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'
crossesSegment :: Exp -> Exp -> Exp
crossesSegment Exp
from Exp
to = (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
from) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.>. (Exp
to Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`rem` Exp
segment_size)
[(SegRedOp ExplicitMemory, [VName])]
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SegRedOp ExplicitMemory]
-> [[VName]] -> [(SegRedOp ExplicitMemory, [VName])]
forall a b. [a] -> [b] -> [(a, b)]
zip [SegRedOp ExplicitMemory]
ops [[VName]]
tmps_for_ops) (((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ())
-> ((SegRedOp ExplicitMemory, [VName]) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(SegRedOp ExplicitMemory
op, [VName]
tmps) ->
Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just Exp -> Exp -> Exp
crossesSegment) ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims') ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims')
(SegRedOp ExplicitMemory -> Lambda ExplicitMemory
forall lore. SegRedOp lore -> Lambda lore
segRedLambda SegRedOp ExplicitMemory
op) [VName]
tmps
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
let segment_is :: [Exp]
segment_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 ([VName] -> [Exp]) -> [VName] -> [Exp]
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids
[(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
red_pes [VName]
tmp_arrs) (((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((PatElemT (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe, VName
arr) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
segment_is (VName -> SubExp
Var VName
arr) ([Exp]
segment_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [[Exp] -> Exp
forall a. [a] -> a
last [Exp]
dims'Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1])
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
compileGroupOp Pattern ExplicitMemory
pat (Inner (SegOp (SegHist lvl space ops _ kbody))) = do
SegLevel -> SegSpace -> InKernelGen ()
compileGroupSpace SegLevel
lvl SegSpace
space
let ltids :: [VName]
ltids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
let num_red_res :: Int
num_red_res = [HistOp ExplicitMemory] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp ExplicitMemory]
ops Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((HistOp ExplicitMemory -> Int) -> [HistOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int)
-> (HistOp ExplicitMemory -> [SubExp])
-> HistOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [SubExp]
forall lore. HistOp lore -> [SubExp]
histNeutral) [HistOp ExplicitMemory]
ops)
([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
_red_pes, [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes) =
Int
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
[PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
[PatElemT (MemInfo SubExp NoUniqueness MemBind)]))
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> ([PatElemT (MemInfo SubExp NoUniqueness MemBind)],
[PatElemT (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ PatternT (MemInfo SubExp NoUniqueness MemBind)
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
forall attr. PatternT attr -> [PatElemT attr]
patternElements Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
[[Exp] -> InKernelGen ()]
ops' <- Count GroupSize SubExp
-> [HistOp ExplicitMemory] -> InKernelGen [[Exp] -> InKernelGen ()]
prepareIntraGroupSegHist (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) [HistOp ExplicitMemory]
ops
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore r op.
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
compileStms Names
forall a. Monoid a => a
mempty (KernelBody ExplicitMemory -> Stms ExplicitMemory
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody ExplicitMemory
kbody) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
let ([KernelResult]
red_res, [KernelResult]
map_res) = Int -> [KernelResult] -> ([KernelResult], [KernelResult])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_red_res ([KernelResult] -> ([KernelResult], [KernelResult]))
-> [KernelResult] -> ([KernelResult], [KernelResult])
forall a b. (a -> b) -> a -> b
$ KernelBody ExplicitMemory -> [KernelResult]
forall lore. KernelBody lore -> [KernelResult]
kernelBodyResult KernelBody ExplicitMemory
kbody
([SubExp]
red_is, [SubExp]
red_vs) = Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([HistOp ExplicitMemory] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp ExplicitMemory]
ops) ([SubExp] -> ([SubExp], [SubExp]))
-> [SubExp] -> ([SubExp], [SubExp])
forall a b. (a -> b) -> a -> b
$ (KernelResult -> SubExp) -> [KernelResult] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map KernelResult -> SubExp
kernelResultSubExp [KernelResult]
red_res
(PatElemT (MemInfo SubExp NoUniqueness MemBind)
-> KernelResult -> InKernelGen ())
-> [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
-> [KernelResult]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space) [PatElemT (MemInfo SubExp NoUniqueness MemBind)]
map_pes [KernelResult]
map_res
let vs_per_op :: [[SubExp]]
vs_per_op = [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks ((HistOp ExplicitMemory -> Int) -> [HistOp ExplicitMemory] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int)
-> (HistOp ExplicitMemory -> [VName])
-> HistOp ExplicitMemory
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HistOp ExplicitMemory -> [VName]
forall lore. HistOp lore -> [VName]
histDest) [HistOp ExplicitMemory]
ops) [SubExp]
red_vs
[(SubExp, [SubExp], [Exp] -> InKernelGen (),
HistOp ExplicitMemory)]
-> ((SubExp, [SubExp], [Exp] -> InKernelGen (),
HistOp ExplicitMemory)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExp]
-> [[SubExp]]
-> [[Exp] -> InKernelGen ()]
-> [HistOp ExplicitMemory]
-> [(SubExp, [SubExp], [Exp] -> InKernelGen (),
HistOp ExplicitMemory)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [SubExp]
red_is [[SubExp]]
vs_per_op [[Exp] -> InKernelGen ()]
ops' [HistOp ExplicitMemory]
ops) (((SubExp, [SubExp], [Exp] -> InKernelGen (),
HistOp ExplicitMemory)
-> InKernelGen ())
-> InKernelGen ())
-> ((SubExp, [SubExp], [Exp] -> InKernelGen (),
HistOp ExplicitMemory)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
\(SubExp
bin, [SubExp]
op_vs, [Exp] -> InKernelGen ()
do_op, HistOp SubExp
dest_w SubExp
_ [VName]
_ [SubExp]
_ Shape
shape Lambda ExplicitMemory
lam) -> do
let bin' :: Exp
bin' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
bin
dest_w' :: Exp
dest_w' = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
dest_w
bin_in_bounds :: Exp
bin_in_bounds = Exp
0 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
bin' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
bin' Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
dest_w'
bin_is :: [Exp]
bin_is = (VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> PrimType -> Exp
`Imp.var` PrimType
int32) ([VName] -> [VName]
forall a. [a] -> [a]
init [VName]
ltids) [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp
bin']
vs_params :: [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. Int -> [a] -> [a]
takeLast ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
op_vs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)])
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform atomic updates" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
bin_in_bounds (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam
Shape -> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
Shape -> ([Exp] -> ImpM lore r op ()) -> ImpM lore r op ()
sLoopNest Shape
shape (([Exp] -> InKernelGen ()) -> InKernelGen ())
-> ([Exp] -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \[Exp]
is -> do
[(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [SubExp]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
vs_params [SubExp]
op_vs) (((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), SubExp)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, SubExp
v) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] SubExp
v [Exp]
is
[Exp] -> InKernelGen ()
do_op ([Exp]
bin_is [Exp] -> [Exp] -> [Exp]
forall a. [a] -> [a] -> [a]
++ [Exp]
is)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.ErrorSync Fence
Imp.FenceLocal
compileGroupOp Pattern ExplicitMemory
pat Op ExplicitMemory
_ =
String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileGroupOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
compileThreadOp :: OpCompiler ExplicitMemory KernelEnv Imp.KernelOp
compileThreadOp :: OpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadOp Pattern ExplicitMemory
pat (Alloc size space) =
Pattern ExplicitMemory -> SubExp -> Space -> InKernelGen ()
kernelAlloc Pattern ExplicitMemory
pat SubExp
size Space
space
compileThreadOp Pattern ExplicitMemory
pat (Inner (SizeOp (SplitSpace o w i elems_per_thread))) =
Pattern ExplicitMemory
-> SplitOrdering -> SubExp -> SubExp -> SubExp -> InKernelGen ()
forall w i elems_per_thread lore r op.
(ToExp w, ToExp i, ToExp elems_per_thread) =>
Pattern ExplicitMemory
-> SplitOrdering -> w -> i -> elems_per_thread -> ImpM lore r op ()
splitSpace Pattern ExplicitMemory
pat SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread
compileThreadOp Pattern ExplicitMemory
pat Op ExplicitMemory
_ =
String -> InKernelGen ()
forall a. String -> a
compilerBugS (String -> InKernelGen ()) -> String -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ String
"compileThreadOp: cannot compile rhs of binding " String -> String -> String
forall a. [a] -> [a] -> [a]
++ PatternT (MemInfo SubExp NoUniqueness MemBind) -> String
forall a. Pretty a => a -> String
pretty Pattern ExplicitMemory
PatternT (MemInfo SubExp NoUniqueness MemBind)
pat
data Locking =
Locking { Locking -> VName
lockingArray :: VName
, Locking -> Exp
lockingIsUnlocked :: Imp.Exp
, Locking -> Exp
lockingToLock :: Imp.Exp
, Locking -> Exp
lockingToUnlock :: Imp.Exp
, Locking -> [Exp] -> [Exp]
lockingMapping :: [Imp.Exp] -> [Imp.Exp]
}
type DoAtomicUpdate lore r =
Space -> [VName] -> [Imp.Exp] -> ImpM lore r Imp.KernelOp ()
data AtomicUpdate lore r
= AtomicPrim (DoAtomicUpdate lore r)
| AtomicCAS (DoAtomicUpdate lore r)
| AtomicLocking (Locking -> DoAtomicUpdate lore r)
type AtomicBinOp =
BinOp ->
Maybe (VName -> VName -> Count Imp.Elements Imp.Exp -> Imp.Exp -> Imp.AtomicOp)
atomicUpdateLocking :: AtomicBinOp -> Lambda ExplicitMemory
-> AtomicUpdate ExplicitMemory KernelEnv
atomicUpdateLocking :: AtomicBinOp
-> Lambda ExplicitMemory -> AtomicUpdate ExplicitMemory KernelEnv
atomicUpdateLocking AtomicBinOp
atomicBinOp Lambda ExplicitMemory
lam
| Just [(BinOp, PrimType, VName, VName)]
ops_and_ts <- Lambda ExplicitMemory -> Maybe [(BinOp, PrimType, VName, VName)]
forall lore.
Attributes lore =>
Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda ExplicitMemory
lam,
((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(BinOp
_, PrimType
t, VName
_, VName
_) -> PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32) [(BinOp, PrimType, VName, VName)]
ops_and_ts =
[(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops_and_ts (DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv)
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName]
arrs [Exp]
bucket ->
[(VName, (BinOp, PrimType, VName, VName))]
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName]
-> [(BinOp, PrimType, VName, VName)]
-> [(VName, (BinOp, PrimType, VName, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
arrs [(BinOp, PrimType, VName, VName)]
ops_and_ts) (((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ())
-> ((VName, (BinOp, PrimType, VName, VName)) -> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(VName
a, (BinOp
op, PrimType
t, VName
x, VName
y)) -> do
VName
old <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
t
(VName
arr', Space
_a_space, Count Elements Exp
bucket_offset) <- VName
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
a [Exp]
bucket
case Space
-> VName
-> VName
-> Count Elements Exp
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements Exp
bucket_offset BinOp
op of
Just Exp -> KernelOp
f -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Exp -> KernelOp
f (Exp -> KernelOp) -> Exp -> KernelOp
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> Exp
Imp.var VName
y PrimType
t
Maybe (Exp -> KernelOp)
Nothing -> Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
a VName
old [Exp]
bucket VName
x (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
x VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp BinOp
op (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t) (VName -> PrimType -> Exp
Imp.var VName
y PrimType
t)
where opHasAtomicSupport :: Space
-> VName
-> VName
-> Count Elements Exp
-> BinOp
-> Maybe (Exp -> KernelOp)
opHasAtomicSupport Space
space VName
old VName
arr' Count Elements Exp
bucket' BinOp
bop = do
let atomic :: (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp
atomic VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
f = Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> (Exp -> AtomicOp) -> Exp -> KernelOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> VName -> Count Elements Exp -> Exp -> AtomicOp
f VName
old VName
arr' Count Elements Exp
bucket'
(VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp
atomic ((VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Exp -> KernelOp)
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Maybe (Exp -> KernelOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> AtomicBinOp
atomicBinOp BinOp
bop
primOrCas :: [(BinOp, PrimType, VName, VName)]
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
primOrCas [(BinOp, PrimType, VName, VName)]
ops
| ((BinOp, PrimType, VName, VName) -> Bool)
-> [(BinOp, PrimType, VName, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (BinOp, PrimType, VName, VName) -> Bool
isPrim [(BinOp, PrimType, VName, VName)]
ops = DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicPrim
| Bool
otherwise = DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS
isPrim :: (BinOp, PrimType, VName, VName) -> Bool
isPrim (BinOp
op, PrimType
_, VName
_, VName
_) = Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool
forall a. Maybe a -> Bool
isJust (Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool)
-> Maybe (VName -> VName -> Count Elements Exp -> Exp -> AtomicOp)
-> Bool
forall a b. (a -> b) -> a -> b
$ AtomicBinOp
atomicBinOp BinOp
op
atomicUpdateLocking AtomicBinOp
_ Lambda ExplicitMemory
op
| [Prim PrimType
t] <- Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
op,
[LParam ExplicitMemory
xp, LParam ExplicitMemory
_] <- Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
op,
PrimType -> Int
primBitSize PrimType
t Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
32 = DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r. DoAtomicUpdate lore r -> AtomicUpdate lore r
AtomicCAS (DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv)
-> DoAtomicUpdate ExplicitMemory KernelEnv
-> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ \Space
space [VName
arr] [Exp]
bucket -> do
VName
old <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
t
Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [Exp]
bucket (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName LParam ExplicitMemory
Param (MemInfo SubExp NoUniqueness MemBind)
xp) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [LParam ExplicitMemory
Param (MemInfo SubExp NoUniqueness MemBind)
xp] (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
op
atomicUpdateLocking AtomicBinOp
_ Lambda ExplicitMemory
op = (Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
-> AtomicUpdate ExplicitMemory KernelEnv
forall lore r.
(Locking -> DoAtomicUpdate lore r) -> AtomicUpdate lore r
AtomicLocking ((Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
-> AtomicUpdate ExplicitMemory KernelEnv)
-> (Locking -> DoAtomicUpdate ExplicitMemory KernelEnv)
-> AtomicUpdate ExplicitMemory KernelEnv
forall a b. (a -> b) -> a -> b
$ \Locking
locking Space
space [VName]
arrs [Exp]
bucket -> do
VName
old <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old" PrimType
int32
VName
continue <- String -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"continue"
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrimVol_ VName
continue PrimType
Bool
VName
continue VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
forall v. PrimExp v
true
(VName
locks', Space
_locks_space, Count Elements Exp
locks_offset) <-
VName
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray (Locking -> VName
lockingArray Locking
locking) ([Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp))
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall a b. (a -> b) -> a -> b
$ Locking -> [Exp] -> [Exp]
lockingMapping Locking
locking [Exp]
bucket
let try_acquire_lock :: InKernelGen ()
try_acquire_lock =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old VName
locks' Count Elements Exp
locks_offset
(Locking -> Exp
lockingIsUnlocked Locking
locking) (Locking -> Exp
lockingToLock Locking
locking)
lock_acquired :: Exp
lock_acquired = VName -> PrimType -> Exp
Imp.var VName
old PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Locking -> Exp
lockingIsUnlocked Locking
locking
release_lock :: InKernelGen ()
release_lock =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old VName
locks' Count Elements Exp
locks_offset
(Locking -> Exp
lockingToLock Locking
locking) (Locking -> Exp
lockingToUnlock Locking
locking)
break_loop :: InKernelGen ()
break_loop = VName
continue VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
forall v. PrimExp v
false
let ([Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
_arr_params) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
op
bind_acc_params :: InKernelGen ()
bind_acc_params =
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"bind lhs" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
acc_p, VName
arr) ->
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
acc_p) [] (VName -> SubExp
Var VName
arr) [Exp]
bucket
let op_body :: InKernelGen ()
op_body = String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"execute operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
op
do_hist :: InKernelGen ()
do_hist =
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"update global result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(VName -> SubExp -> InKernelGen ())
-> [VName] -> [SubExp] -> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ ([Exp] -> VName -> SubExp -> InKernelGen ()
forall lore r op. [Exp] -> VName -> SubExp -> ImpM lore r op ()
writeArray [Exp]
bucket) [VName]
arrs ([SubExp] -> InKernelGen ()) -> [SubExp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind) -> SubExp)
-> [Param (MemInfo SubExp NoUniqueness MemBind)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param (MemInfo SubExp NoUniqueness MemBind) -> VName)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName) [Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params
fence :: InKernelGen ()
fence = case Space
space of Space String
"local" -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceLocal
Space
_ -> KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.MemFence Fence
Imp.FenceGlobal
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
continue PrimType
Bool) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
try_acquire_lock
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
lock_acquired (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
[LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
acc_params
InKernelGen ()
bind_acc_params
InKernelGen ()
op_body
InKernelGen ()
do_hist
InKernelGen ()
fence
InKernelGen ()
release_lock
InKernelGen ()
break_loop
InKernelGen ()
fence
where writeArray :: [Exp] -> VName -> SubExp -> ImpM lore r op ()
writeArray [Exp]
bucket VName
arr SubExp
val = VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp]
bucket SubExp
val []
atomicUpdateCAS :: Space -> PrimType
-> VName -> VName
-> [Imp.Exp] -> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS :: Space
-> PrimType
-> VName
-> VName
-> [Exp]
-> VName
-> InKernelGen ()
-> InKernelGen ()
atomicUpdateCAS Space
space PrimType
t VName
arr VName
old [Exp]
bucket VName
x InKernelGen ()
do_op = do
VName
assumed <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"assumed" PrimType
t
VName
run_loop <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"run_loop" Exp
1
InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
old [] (VName -> SubExp
Var VName
arr) [Exp]
bucket
(VName
arr', Space
_a_space, Count Elements Exp
bucket_offset) <- VName
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall lore r op.
VName -> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray VName
arr [Exp]
bucket
let (Exp -> Exp
toBits, Exp -> Exp
fromBits) =
case PrimType
t of FloatType FloatType
Float32 -> (\Exp
v -> String -> [Exp] -> PrimType -> Exp
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"to_bits32" [Exp
v] PrimType
int32,
\Exp
v -> String -> [Exp] -> PrimType -> Exp
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
Imp.FunExp String
"from_bits32" [Exp
v] PrimType
t)
PrimType
_ -> (Exp -> Exp
forall a. a -> a
id, Exp -> Exp
forall a. a -> a
id)
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
run_loop PrimType
int32) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
VName
assumed VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
old PrimType
t
VName
x VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t
InKernelGen ()
do_op
VName
old_bits <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"old_bits" PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Space -> AtomicOp -> KernelOp
Imp.Atomic Space
space (AtomicOp -> KernelOp) -> AtomicOp -> KernelOp
forall a b. (a -> b) -> a -> b
$
PrimType
-> VName -> VName -> Count Elements Exp -> Exp -> Exp -> AtomicOp
Imp.AtomicCmpXchg PrimType
int32 VName
old_bits VName
arr' Count Elements Exp
bucket_offset
(Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t)) (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
x PrimType
t))
VName
old VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp -> Exp
fromBits (VName -> PrimType -> Exp
Imp.var VName
old_bits PrimType
int32)
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp -> Exp
toBits (VName -> PrimType -> Exp
Imp.var VName
assumed PrimType
t) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. VName -> PrimType -> Exp
Imp.var VName
old_bits PrimType
int32)
(VName
run_loop VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0)
splitOp :: Attributes lore => Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp :: Lambda lore -> Maybe [(BinOp, PrimType, VName, VName)]
splitOp Lambda lore
lam = (SubExp -> Maybe (BinOp, PrimType, VName, VName))
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm ([SubExp] -> Maybe [(BinOp, PrimType, VName, VName)])
-> [SubExp] -> Maybe [(BinOp, PrimType, VName, VName)]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT lore -> [SubExp]) -> BodyT lore -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
where n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam
splitStm :: SubExp -> Maybe (BinOp, PrimType, VName, VName)
splitStm (Var VName
res) = do
Let (Pattern [] [PatElemT (LetAttr lore)
pe]) StmAux (ExpAttr lore)
_ (BasicOp (BinOp BinOp
op (Var VName
x) (Var VName
y))) <-
(Stm lore -> Bool) -> [Stm lore] -> Maybe (Stm lore)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (([VName
res][VName] -> [VName] -> Bool
forall a. Eq a => a -> a -> Bool
==) ([VName] -> Bool) -> (Stm lore -> [VName]) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT (LetAttr lore) -> [VName]
forall attr. PatternT attr -> [VName]
patternNames (PatternT (LetAttr lore) -> [VName])
-> (Stm lore -> PatternT (LetAttr lore)) -> Stm lore -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> PatternT (LetAttr lore)
forall lore. Stm lore -> Pattern lore
stmPattern) ([Stm lore] -> Maybe (Stm lore)) -> [Stm lore] -> Maybe (Stm lore)
forall a b. (a -> b) -> a -> b
$
Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore]) -> Stms lore -> [Stm lore]
forall a b. (a -> b) -> a -> b
$ BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT lore -> Stms lore) -> BodyT lore -> Stms lore
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam
Int
i <- VName -> SubExp
Var VName
res SubExp -> [SubExp] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` BodyT lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
Param (LParamAttr lore)
xp <- Int -> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
i ([Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore)))
-> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
Param (LParamAttr lore)
yp <- Int -> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
i) ([Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore)))
-> [Param (LParamAttr lore)] -> Maybe (Param (LParamAttr lore))
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamAttr lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
xp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
yp VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
y
Prim PrimType
t <- Type -> Maybe Type
forall a. a -> Maybe a
Just (Type -> Maybe Type) -> Type -> Maybe Type
forall a b. (a -> b) -> a -> b
$ PatElemT (LetAttr lore) -> Type
forall attr. Typed attr => PatElemT attr -> Type
patElemType PatElemT (LetAttr lore)
pe
(BinOp, PrimType, VName, VName)
-> Maybe (BinOp, PrimType, VName, VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (BinOp
op, PrimType
t, Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
xp, Param (LParamAttr lore) -> VName
forall attr. Param attr -> VName
paramName Param (LParamAttr lore)
yp)
splitStm SubExp
_ = Maybe (BinOp, PrimType, VName, VName)
forall a. Maybe a
Nothing
computeKernelUses :: FreeIn a =>
a -> [VName]
-> CallKernelGen [Imp.KernelUse]
computeKernelUses :: a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses a
kernel_body [VName]
bound_in_kernel = do
let actually_free :: Names
actually_free = a -> Names
forall a. FreeIn a => a -> Names
freeIn a
kernel_body Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
bound_in_kernel
[KernelUse] -> [KernelUse]
forall a. Eq a => [a] -> [a]
nub ([KernelUse] -> [KernelUse])
-> CallKernelGen [KernelUse] -> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Names -> CallKernelGen [KernelUse]
readsFromSet Names
actually_free
readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
readsFromSet :: Names -> CallKernelGen [KernelUse]
readsFromSet Names
free =
([Maybe KernelUse] -> [KernelUse])
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Maybe KernelUse] -> [KernelUse]
forall a. [Maybe a] -> [a]
catMaybes (ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse])
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
-> CallKernelGen [KernelUse]
forall a b. (a -> b) -> a -> b
$
[VName]
-> (VName -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Names -> [VName]
namesToList Names
free) ((VName -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse])
-> (VName -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp [Maybe KernelUse]
forall a b. (a -> b) -> a -> b
$ \VName
var -> do
Type
t <- VName -> ImpM ExplicitMemory HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
var
VTable ExplicitMemory
vtable <- ImpM ExplicitMemory HostEnv HostOp (VTable ExplicitMemory)
forall lore r op. ImpM lore r op (VTable lore)
getVTable
case Type
t of
Array {} -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
Mem (Space String
"local") -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
Mem {} -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelUse
Imp.MemoryUse VName
var
Prim PrimType
bt ->
VTable ExplicitMemory
-> Exp -> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelConstExp)
forall lore r op.
VTable ExplicitMemory
-> Exp -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable ExplicitMemory
vtable (VName -> PrimType -> Exp
Imp.var VName
var PrimType
bt) ImpM ExplicitMemory HostEnv HostOp (Maybe KernelConstExp)
-> (Maybe KernelConstExp
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just KernelConstExp
ce -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> KernelConstExp -> KernelUse
Imp.ConstUse VName
var KernelConstExp
ce
Maybe KernelConstExp
Nothing | PrimType
bt PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
Cert -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe KernelUse
forall a. Maybe a
Nothing
| Bool
otherwise -> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse))
-> Maybe KernelUse
-> ImpM ExplicitMemory HostEnv HostOp (Maybe KernelUse)
forall a b. (a -> b) -> a -> b
$ KernelUse -> Maybe KernelUse
forall a. a -> Maybe a
Just (KernelUse -> Maybe KernelUse) -> KernelUse -> Maybe KernelUse
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> KernelUse
Imp.ScalarUse VName
var PrimType
bt
isConstExp :: VTable ExplicitMemory -> Imp.Exp
-> ImpM lore r op (Maybe Imp.KernelConstExp)
isConstExp :: VTable ExplicitMemory
-> Exp -> ImpM lore r op (Maybe KernelConstExp)
isConstExp VTable ExplicitMemory
vtable Exp
size = do
Maybe Name
fname <- ImpM lore r op (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let onLeaf :: ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf (Imp.ScalarVar VName
name) PrimType
_ = VName -> Maybe KernelConstExp
lookupConstExp VName
name
onLeaf (Imp.SizeOf PrimType
pt) PrimType
_ = KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ PrimType -> KernelConstExp
forall a. Num a => PrimType -> a
primByteSize PrimType
pt
onLeaf Imp.Index{} PrimType
_ = Maybe KernelConstExp
forall a. Maybe a
Nothing
lookupConstExp :: VName -> Maybe KernelConstExp
lookupConstExp VName
name =
ExpT ExplicitMemory -> Maybe KernelConstExp
constExp (ExpT ExplicitMemory -> Maybe KernelConstExp)
-> Maybe (ExpT ExplicitMemory) -> Maybe KernelConstExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VarEntry ExplicitMemory -> Maybe (ExpT ExplicitMemory)
forall lore. VarEntry lore -> Maybe (Exp lore)
hasExp (VarEntry ExplicitMemory -> Maybe (ExpT ExplicitMemory))
-> Maybe (VarEntry ExplicitMemory) -> Maybe (ExpT ExplicitMemory)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VTable ExplicitMemory -> Maybe (VarEntry ExplicitMemory)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name VTable ExplicitMemory
vtable
constExp :: ExpT ExplicitMemory -> Maybe KernelConstExp
constExp (Op (Inner (SizeOp (GetSize key _)))) =
KernelConstExp -> Maybe KernelConstExp
forall a. a -> Maybe a
Just (KernelConstExp -> Maybe KernelConstExp)
-> KernelConstExp -> Maybe KernelConstExp
forall a b. (a -> b) -> a -> b
$ KernelConst -> PrimType -> KernelConstExp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> KernelConst
Imp.SizeConst (Name -> KernelConst) -> Name -> KernelConst
forall a b. (a -> b) -> a -> b
$ Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname Name
key) PrimType
int32
constExp ExpT ExplicitMemory
e = (VName -> Maybe KernelConstExp)
-> ExpT ExplicitMemory -> Maybe KernelConstExp
forall (m :: * -> *) lore v.
(MonadFail m, Annotations lore) =>
(VName -> m (PrimExp v)) -> Exp lore -> m (PrimExp v)
primExpFromExp VName -> Maybe KernelConstExp
lookupConstExp ExpT ExplicitMemory
e
Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp))
-> Maybe KernelConstExp -> ImpM lore r op (Maybe KernelConstExp)
forall a b. (a -> b) -> a -> b
$ (ExpLeaf -> PrimType -> Maybe KernelConstExp)
-> Exp -> Maybe KernelConstExp
forall (m :: * -> *) a b.
Monad m =>
(a -> PrimType -> m (PrimExp b)) -> PrimExp a -> m (PrimExp b)
replaceInPrimExpM ExpLeaf -> PrimType -> Maybe KernelConstExp
onLeaf Exp
size
where hasExp :: VarEntry lore -> Maybe (Exp lore)
hasExp (ArrayVar Maybe (Exp lore)
e ArrayEntry
_) = Maybe (Exp lore)
e
hasExp (ScalarVar Maybe (Exp lore)
e ScalarEntry
_) = Maybe (Exp lore)
e
hasExp (MemVar Maybe (Exp lore)
e MemEntry
_) = Maybe (Exp lore)
e
computeThreadChunkSize :: SplitOrdering
-> Imp.Exp
-> Imp.Count Imp.Elements Imp.Exp
-> Imp.Count Imp.Elements Imp.Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize :: SplitOrdering
-> Exp
-> Count Elements Exp
-> Count Elements Exp
-> VName
-> ImpM lore r op ()
computeThreadChunkSize (SplitStrided SubExp
stride) Exp
thread_index Count Elements Exp
elements_per_thread Count Elements Exp
num_elements VName
chunk_var = do
Exp
stride' <- SubExp -> ImpM lore r op Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp SubExp
stride
VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<--
BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
Imp.BinOpExp (IntType -> BinOp
SMin IntType
Int32)
(Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread)
((Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
thread_index) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` Exp
stride')
computeThreadChunkSize SplitOrdering
SplitContiguous Exp
thread_index Count Elements Exp
elements_per_thread Count Elements Exp
num_elements VName
chunk_var = do
VName
starting_point <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"starting_point" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
Exp
thread_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread
VName
remaining_elements <- String -> Exp -> ImpM lore r op VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"remaining_elements" (Exp -> ImpM lore r op VName) -> Exp -> ImpM lore r op VName
forall a b. (a -> b) -> a -> b
$
Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- VName -> PrimType -> Exp
Imp.var VName
starting_point PrimType
int32
let no_remaining_elements :: Exp
no_remaining_elements = VName -> PrimType -> Exp
Imp.var VName
remaining_elements PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
0
beyond_bounds :: Exp
beyond_bounds = Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> PrimType -> Exp
Imp.var VName
starting_point PrimType
int32
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf (Exp
no_remaining_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. Exp
beyond_bounds)
(VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0)
(Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
forall lore r op.
Exp -> ImpM lore r op () -> ImpM lore r op () -> ImpM lore r op ()
sIf Exp
is_last_thread
(VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
last_thread_elements)
(VName
chunk_var VName -> Exp -> ImpM lore r op ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread))
where last_thread_elements :: Count Elements Exp
last_thread_elements =
Count Elements Exp
num_elements Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall a. Num a => a -> a -> a
- Exp -> Count Elements Exp
Imp.elements Exp
thread_index Count Elements Exp -> Count Elements Exp -> Count Elements Exp
forall a. Num a => a -> a -> a
* Count Elements Exp
elements_per_thread
is_last_thread :: Exp
is_last_thread =
Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
num_elements Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.
(Exp
thread_index Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
1) Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Count Elements Exp -> Exp
forall u e. Count u e -> e
Imp.unCount Count Elements Exp
elements_per_thread
kernelInitialisationSimple :: Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple :: Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple (Count Exp
num_groups) (Count Exp
group_size) = do
VName
global_tid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"global_tid"
VName
local_tid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"local_tid"
VName
group_id <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_tid"
VName
wave_size <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"wave_size"
VName
inner_group_size <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"group_size"
let constants :: KernelConstants
constants =
Exp
-> Exp
-> Exp
-> VName
-> VName
-> VName
-> Exp
-> Exp
-> Exp
-> Exp
-> Exp
-> KernelConstants
KernelConstants
(VName -> PrimType -> Exp
Imp.var VName
global_tid PrimType
int32)
(VName -> PrimType -> Exp
Imp.var VName
local_tid PrimType
int32)
(VName -> PrimType -> Exp
Imp.var VName
group_id PrimType
int32)
VName
global_tid VName
local_tid VName
group_id
Exp
num_groups Exp
group_size (Exp
group_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
num_groups)
(VName -> PrimType -> Exp
Imp.var VName
wave_size PrimType
int32)
Exp
forall v. PrimExp v
true
let set_constants :: InKernelGen ()
set_constants = do
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
global_tid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
local_tid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
inner_group_size PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
wave_size PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
global_tid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
local_tid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalSize VName
inner_group_size Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> KernelOp
Imp.GetLockstepWidth VName
wave_size)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
(KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelConstants
constants, InKernelGen ()
set_constants)
isActive :: [(VName, SubExp)] -> Imp.Exp
isActive :: [(VName, SubExp)] -> Exp
isActive [(VName, SubExp)]
limit = case [Exp]
actives of
[] -> PrimValue -> Exp
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> Exp) -> PrimValue -> Exp
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True
Exp
x:[Exp]
xs -> (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) Exp
x [Exp]
xs
where ([VName]
is, [SubExp]
ws) = [(VName, SubExp)] -> ([VName], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
limit
actives :: [Exp]
actives = (VName -> Exp -> Exp) -> [VName] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Exp -> Exp
active [VName]
is ([Exp] -> [Exp]) -> [Exp] -> [Exp]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
Bool) [SubExp]
ws
active :: VName -> Exp -> Exp
active VName
i = (VName -> PrimType -> Exp
Imp.var VName
i PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<.)
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal =
Space -> CallKernelGen a -> CallKernelGen a
forall lore r op a. Space -> ImpM lore r op a -> ImpM lore r op a
localDefaultSpace (String -> Space
Imp.Space String
"global") (CallKernelGen a -> CallKernelGen a)
-> (CallKernelGen a -> CallKernelGen a)
-> CallKernelGen a
-> CallKernelGen a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VTable ExplicitMemory -> VTable ExplicitMemory)
-> CallKernelGen a -> CallKernelGen a
forall lore r op a.
(VTable lore -> VTable lore)
-> ImpM lore r op a -> ImpM lore r op a
localVTable ((VarEntry ExplicitMemory -> VarEntry ExplicitMemory)
-> VTable ExplicitMemory -> VTable ExplicitMemory
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry ExplicitMemory -> VarEntry ExplicitMemory
forall lore. VarEntry lore -> VarEntry lore
globalMemory)
where globalMemory :: VarEntry lore -> VarEntry lore
globalMemory (MemVar Maybe (Exp lore)
_ MemEntry
entry)
| MemEntry -> Space
entryMemSpace MemEntry
entry Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
/= String -> Space
Space String
"local" =
Maybe (Exp lore) -> MemEntry -> VarEntry lore
forall lore. Maybe (Exp lore) -> MemEntry -> VarEntry lore
MemVar Maybe (Exp lore)
forall a. Maybe a
Nothing MemEntry
entry { entryMemSpace :: Space
entryMemSpace = String -> Space
Imp.Space String
"global" }
globalMemory VarEntry lore
entry =
VarEntry lore
entry
groupReduce :: Imp.Exp
-> Lambda ExplicitMemory
-> [VName]
-> InKernelGen ()
groupReduce :: Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduce Exp
w Lambda ExplicitMemory
lam [VName]
arrs = do
VName
offset <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"offset" PrimType
int32
VName -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduceWithOffset VName
offset Exp
w Lambda ExplicitMemory
lam [VName]
arrs
groupReduceWithOffset :: VName
-> Imp.Exp
-> Lambda ExplicitMemory
-> [VName]
-> InKernelGen ()
groupReduceWithOffset :: VName -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupReduceWithOffset VName
offset Exp
w Lambda ExplicitMemory
lam [VName]
arrs = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let local_tid :: Exp
local_tid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
global_tid :: Exp
global_tid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
barrier :: InKernelGen ()
barrier
| (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
lam = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
| Bool
otherwise = KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
readReduceArgument :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readReduceArgument Param (MemInfo SubExp NoUniqueness MemBind)
param VName
arr
| Prim PrimType
_ <- Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
param = do
let i :: Exp
i = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
offset
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
param) [] (VName -> SubExp
Var VName
arr) [Exp
i]
| Bool
otherwise = do
let i :: Exp
i = Exp
global_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> Exp
Imp.vi32 VName
offset
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
param) [] (VName -> SubExp
Var VName
arr) [Exp
i]
writeReduceOpResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeReduceOpResult Param (MemInfo SubExp NoUniqueness MemBind)
param VName
arr
| Prim PrimType
_ <- Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
param =
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp
local_tid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
param) []
| Bool
otherwise =
() -> InKernelGen ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
let ([Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_arr_params) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam
VName
skip_waves <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"skip_waves" PrimType
int32
[LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams ([LParam ExplicitMemory] -> InKernelGen ())
-> [LParam ExplicitMemory] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam
VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
0
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"participating threads read initial accumulator" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
local_tid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readReduceArgument [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params [VName]
arrs
let do_reduce :: InKernelGen ()
do_reduce = do String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"read array element" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readReduceArgument [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_arr_params [VName]
arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"apply reduction operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment String
"write result of operation" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeReduceOpResult [Param (MemInfo SubExp NoUniqueness MemBind)]
reduce_acc_params [VName]
arrs
in_wave_reduce :: InKernelGen ()
in_wave_reduce = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile InKernelGen ()
do_reduce
wave_size :: Exp
wave_size = KernelConstants -> Exp
kernelWaveSize KernelConstants
constants
group_size :: Exp
group_size = KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
wave_id :: Exp
wave_id = Exp
local_tid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
wave_size
in_wave_id :: Exp
in_wave_id = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
wave_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
wave_size
num_waves :: Exp
num_waves = (Exp
group_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
wave_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
wave_size
arg_in_bounds :: Exp
arg_in_bounds = Exp
local_tid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w
doing_in_wave_reductions :: Exp
doing_in_wave_reductions =
VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
wave_size
apply_in_in_wave_iteration :: Exp
apply_in_in_wave_iteration =
(Exp
in_wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&. (Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
in_wave_reductions :: InKernelGen ()
in_wave_reductions = do
VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
doing_in_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
arg_in_bounds Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
apply_in_in_wave_iteration)
InKernelGen ()
in_wave_reduce
VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
offset PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2
doing_cross_wave_reductions :: Exp
doing_cross_wave_reductions =
VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
num_waves
is_first_thread_in_wave :: Exp
is_first_thread_in_wave =
Exp
in_wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
wave_not_skipped :: Exp
wave_not_skipped =
(Exp
wave_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&. (Exp
2 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1)) Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
apply_in_cross_wave_iteration :: Exp
apply_in_cross_wave_iteration =
Exp
arg_in_bounds Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
is_first_thread_in_wave Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
wave_not_skipped
cross_wave_reductions :: InKernelGen ()
cross_wave_reductions = do
VName
skip_waves VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile Exp
doing_cross_wave_reductions (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
barrier
VName
offset VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
wave_size
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
apply_in_cross_wave_iteration
InKernelGen ()
do_reduce
VName
skip_waves VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_waves PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2
InKernelGen ()
in_wave_reductions
InKernelGen ()
cross_wave_reductions
groupScan :: Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
-> Imp.Exp
-> Imp.Exp
-> Lambda ExplicitMemory
-> [VName]
-> InKernelGen ()
groupScan :: Maybe (Exp -> Exp -> Exp)
-> Exp -> Exp -> Lambda ExplicitMemory -> [VName] -> InKernelGen ()
groupScan Maybe (Exp -> Exp -> Exp)
seg_flag Exp
arrs_full_size Exp
w Lambda ExplicitMemory
lam [VName]
arrs = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Lambda ExplicitMemory
renamed_lam <- Lambda ExplicitMemory
-> ImpM ExplicitMemory KernelEnv KernelOp (Lambda ExplicitMemory)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda Lambda ExplicitMemory
lam
let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) = Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
arrs) ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)]))
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam
[LParam ExplicitMemory] -> InKernelGen ()
forall lore r op.
ExplicitMemorish lore =>
[LParam lore] -> ImpM lore r op ()
dLParams (Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
lam[Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
forall a. [a] -> [a] -> [a]
++Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
renamed_lam)
let block_size :: PrimExp v
block_size = PrimValue -> PrimExp v
forall v. PrimValue -> PrimExp v
Imp.ValueExp (PrimValue -> PrimExp v) -> PrimValue -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ Int32 -> IntValue
Int32Value Int32
32
simd_width :: Exp
simd_width = KernelConstants -> Exp
kernelWaveSize KernelConstants
constants
block_id :: Exp
block_id = Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
forall v. PrimExp v
block_size
in_block_id :: Exp
in_block_id = Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
forall v. PrimExp v
block_size
doInBlockScan :: Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda ExplicitMemory -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
seg_flag' Exp
active =
KernelConstants
-> Maybe (Exp -> Exp -> Exp)
-> Exp
-> Exp
-> Exp
-> Exp
-> [VName]
-> InKernelGen ()
-> Lambda ExplicitMemory
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (Exp -> Exp -> Exp)
seg_flag' Exp
arrs_full_size
Exp
simd_width Exp
forall v. PrimExp v
block_size Exp
active [VName]
arrs InKernelGen ()
barrier
ltid_in_bounds :: Exp
ltid_in_bounds = Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
w
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
lam
barrier :: InKernelGen ()
barrier | Bool
array_scan =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
| Bool
otherwise =
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceLocal
group_offset :: Exp
group_offset = KernelConstants -> Exp
kernelGroupId KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
writeBlockResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeBlockResult Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
| Bool
otherwise =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
block_id] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
readPrevBlockResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readPrevBlockResult Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
| Bool
otherwise =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1]
Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda ExplicitMemory -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
seg_flag Exp
ltid_in_bounds Lambda ExplicitMemory
lam
InKernelGen ()
barrier
let is_first_block :: Exp
is_first_block = Exp
block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"save correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
InKernelGen ()
barrier
let last_in_block :: Exp
last_in_block = Exp
in_block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
1
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"last thread of block 'i' writes its result to offset 'i'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
last_in_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
writeBlockResult [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs
InKernelGen ()
barrier
let first_block_seg_flag :: Maybe (Exp -> Exp -> Exp)
first_block_seg_flag = do
Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag
(Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a. a -> Maybe a
Just ((Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp))
-> (Exp -> Exp -> Exp) -> Maybe (Exp -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \Exp
from Exp
to ->
Exp -> Exp -> Exp
flag_true (Exp
fromExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1) (Exp
toExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
+Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1)
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
comment
String
"scan the first block, after which offset 'i' contains carry-in for block 'i+1'" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Maybe (Exp -> Exp -> Exp)
-> Exp -> Lambda ExplicitMemory -> InKernelGen ()
doInBlockScan Maybe (Exp -> Exp -> Exp)
first_block_seg_flag (Exp
is_first_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
ltid_in_bounds) Lambda ExplicitMemory
renamed_lam
InKernelGen ()
barrier
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"move correct values for first block back a block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM
VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]
(VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
forall v. PrimExp v
block_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]
InKernelGen ()
barrier
let read_carry_in :: InKernelGen ()
read_carry_in = do
[(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x,Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x)) []
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readPrevBlockResult [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs
y_to_x :: InKernelGen ()
y_to_x = [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x,Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
op_to_x :: InKernelGen ()
op_to_x
| Maybe (Exp -> Exp -> Exp)
Nothing <- Maybe (Exp -> Exp -> Exp)
seg_flag =
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
| Just Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag = do
Exp
inactive <-
String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"inactive" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ Exp -> Exp -> Exp
flag_true (Exp
block_idExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
forall v. PrimExp v
block_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-Exp
1) Exp
ltid
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
inactive InKernelGen ()
y_to_x
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
lam
write_final_result :: InKernelGen ()
write_final_result =
[(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
p, VName
arr) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) []
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"carry-in for every block except the first" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless (Exp
is_first_block Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.||. UnOp -> Exp -> Exp
forall v. UnOp -> PrimExp v -> PrimExp v
Imp.UnOpExp UnOp
Not Exp
ltid_in_bounds) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" InKernelGen ()
read_carry_in
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write final result" InKernelGen ()
write_final_result
InKernelGen ()
barrier
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"restore correct values for first block" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
is_first_block (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$[(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind), VName)]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind), VName)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params [VName]
arrs) (((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind), VName)
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x, Param (MemInfo SubExp NoUniqueness MemBind)
y, VName
arr) ->
if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
y)
then VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) []
else VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
arrs_full_size Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
group_offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid]
InKernelGen ()
barrier
inBlockScan :: KernelConstants
-> Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
-> Imp.Exp
-> Imp.Exp
-> Imp.Exp
-> Imp.Exp
-> [VName]
-> InKernelGen ()
-> Lambda ExplicitMemory
-> InKernelGen ()
inBlockScan :: KernelConstants
-> Maybe (Exp -> Exp -> Exp)
-> Exp
-> Exp
-> Exp
-> Exp
-> [VName]
-> InKernelGen ()
-> Lambda ExplicitMemory
-> InKernelGen ()
inBlockScan KernelConstants
constants Maybe (Exp -> Exp -> Exp)
seg_flag Exp
arrs_full_size Exp
lockstep_width Exp
block_size Exp
active [VName]
arrs InKernelGen ()
barrier Lambda ExplicitMemory
scan_lam = InKernelGen () -> InKernelGen ()
forall lore r op a. ImpM lore r op a -> ImpM lore r op a
everythingVolatile (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
VName
skip_threads <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"skip_threads" PrimType
int32
let in_block_thread_active :: Exp
in_block_thread_active =
VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. Exp
in_block_id
actual_params :: [LParam ExplicitMemory]
actual_params = Lambda ExplicitMemory -> [LParam ExplicitMemory]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda ExplicitMemory
scan_lam
([Param (MemInfo SubExp NoUniqueness MemBind)]
x_params, [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) =
Int
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> ([Param (MemInfo SubExp NoUniqueness MemBind)],
[Param (MemInfo SubExp NoUniqueness MemBind)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Param (MemInfo SubExp NoUniqueness MemBind)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
actual_params Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) [LParam ExplicitMemory]
[Param (MemInfo SubExp NoUniqueness MemBind)]
actual_params
y_to_x :: InKernelGen ()
y_to_x =
[(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [(Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params) (((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ())
-> ((Param (MemInfo SubExp NoUniqueness MemBind),
Param (MemInfo SubExp NoUniqueness MemBind))
-> InKernelGen ())
-> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \(Param (MemInfo SubExp NoUniqueness MemBind)
x,Param (MemInfo SubExp NoUniqueness MemBind)
y) ->
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x)) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) [] (VName -> SubExp
Var (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y)) []
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
active (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readInitial [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params [VName]
arrs
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_id Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) InKernelGen ()
y_to_x
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
let op_to_x :: InKernelGen ()
op_to_x
| Maybe (Exp -> Exp -> Exp)
Nothing <- Maybe (Exp -> Exp -> Exp)
seg_flag =
[Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_lam
| Just Exp -> Exp -> Exp
flag_true <- Maybe (Exp -> Exp -> Exp)
seg_flag = do
Exp
inactive <- String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall lore r op. String -> Exp -> ImpM lore r op Exp
dPrimVE String
"inactive" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$
Exp -> Exp -> Exp
flag_true (Exp
ltidExp -> Exp -> Exp
forall a. Num a => a -> a -> a
-VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32) Exp
ltid
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
inactive InKernelGen ()
y_to_x
Bool -> InKernelGen () -> InKernelGen ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
array_scan InKernelGen ()
barrier
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sUnless Exp
inactive (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ [Param (MemInfo SubExp NoUniqueness MemBind)]
-> Body ExplicitMemory -> InKernelGen ()
forall attr lore r op.
[Param attr] -> Body lore -> ImpM lore r op ()
compileBody' [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params (Body ExplicitMemory -> InKernelGen ())
-> Body ExplicitMemory -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> Body ExplicitMemory
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda ExplicitMemory
scan_lam
maybeBarrier :: InKernelGen ()
maybeBarrier = Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
lockstep_width Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32)
InKernelGen ()
barrier
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"in-block scan (hopefully no barriers needed)" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
VName
skip_threads VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- Exp
1
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhile (VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
block_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_thread_active Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"read operands" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
(Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> InKernelGen ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (Exp
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
readParam (VName -> Exp
Imp.vi32 VName
skip_threads)) [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [VName]
arrs
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"perform operation" InKernelGen ()
op_to_x
InKernelGen ()
maybeBarrier
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
in_block_thread_active Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. Exp
active) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
String -> InKernelGen () -> InKernelGen ()
forall lore r op. String -> ImpM lore r op () -> ImpM lore r op ()
sComment String
"write result" (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
[InKernelGen ()] -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, Monad m) =>
t (m a) -> m ()
sequence_ ([InKernelGen ()] -> InKernelGen ())
-> [InKernelGen ()] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ (Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ())
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [Param (MemInfo SubExp NoUniqueness MemBind)]
-> [VName]
-> [InKernelGen ()]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3 Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
writeResult [Param (MemInfo SubExp NoUniqueness MemBind)]
x_params [Param (MemInfo SubExp NoUniqueness MemBind)]
y_params [VName]
arrs
InKernelGen ()
maybeBarrier
VName
skip_threads VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var VName
skip_threads PrimType
int32 Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
2
where block_id :: Exp
block_id = Exp
ltid Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quot` Exp
block_size
in_block_id :: Exp
in_block_id = Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
block_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
block_size
ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
array_scan :: Bool
array_scan = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda ExplicitMemory -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda ExplicitMemory
scan_lam
readInitial :: Param (MemInfo SubExp NoUniqueness MemBind)
-> VName -> InKernelGen ()
readInitial Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid]
| Bool
otherwise =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
gtid]
readParam :: Exp
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
readParam Exp
behind Param (MemInfo SubExp NoUniqueness MemBind)
p VName
arr
| Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
p =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
behind]
| Bool
otherwise =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
p) [] (VName -> SubExp
Var VName
arr) [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ Exp
gtid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- Exp
behind Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
arrs_full_size]
writeResult :: Param (MemInfo SubExp NoUniqueness MemBind)
-> Param (MemInfo SubExp NoUniqueness MemBind)
-> VName
-> InKernelGen ()
writeResult Param (MemInfo SubExp NoUniqueness MemBind)
x Param (MemInfo SubExp NoUniqueness MemBind)
y VName
arr
| Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> Type -> Bool
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> Type
forall attr. Typed attr => Param attr -> Type
paramType Param (MemInfo SubExp NoUniqueness MemBind)
x = do
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM VName
arr [Exp -> DimIndex Exp
forall d. d -> DimIndex d
DimFix Exp
ltid] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
| Bool
otherwise =
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
y) [] (VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. Param attr -> VName
paramName Param (MemInfo SubExp NoUniqueness MemBind)
x) []
computeMapKernelGroups :: Imp.Exp -> CallKernelGen (Imp.Exp, Imp.Exp)
computeMapKernelGroups :: Exp -> CallKernelGen (Exp, Exp)
computeMapKernelGroups Exp
kernel_size = do
VName
group_size <- String -> PrimType -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"group_size" PrimType
int32
Maybe Name
fname <- ImpM ExplicitMemory HostEnv HostOp (Maybe Name)
forall lore r op. ImpM lore r op (Maybe Name)
askFunction
let group_size_var :: Exp
group_size_var = VName -> PrimType -> Exp
Imp.var VName
group_size PrimType
int32
group_size_key :: Name
group_size_key = Maybe Name -> Name -> Name
keyWithEntryPoint Maybe Name
fname (Name -> Name) -> Name -> Name
forall a b. (a -> b) -> a -> b
$ String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ VName -> String
forall a. Pretty a => a -> String
pretty VName
group_size
HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. op -> ImpM lore r op ()
sOp (HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> Name -> SizeClass -> HostOp
Imp.GetSize VName
group_size Name
group_size_key SizeClass
Imp.SizeGroup
VName
num_groups <- String -> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"num_groups" (Exp -> ImpM ExplicitMemory HostEnv HostOp VName)
-> Exp -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ Exp
kernel_size Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
Int32) Exp
group_size_var
(Exp, Exp) -> CallKernelGen (Exp, Exp)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName -> PrimType -> Exp
Imp.var VName
num_groups PrimType
int32, VName -> PrimType -> Exp
Imp.var VName
group_size PrimType
int32)
simpleKernelConstants :: Imp.Exp -> String
-> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants :: Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
kernel_size String
desc = do
VName
thread_gtid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM ExplicitMemory HostEnv HostOp VName)
-> String -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gtid"
VName
thread_ltid <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM ExplicitMemory HostEnv HostOp VName)
-> String -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_ltid"
VName
group_id <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ImpM ExplicitMemory HostEnv HostOp VName)
-> String -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gid"
(Exp
num_groups, Exp
group_size) <- Exp -> CallKernelGen (Exp, Exp)
computeMapKernelGroups Exp
kernel_size
let set_constants :: InKernelGen ()
set_constants = do
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_gtid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
thread_ltid PrimType
int32
VName -> PrimType -> InKernelGen ()
forall lore r op. VName -> PrimType -> ImpM lore r op ()
dPrim_ VName
group_id PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGlobalId VName
thread_gtid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetLocalId VName
thread_ltid Int
0)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (VName -> Int -> KernelOp
Imp.GetGroupId VName
group_id Int
0)
(KernelConstants, InKernelGen ())
-> CallKernelGen (KernelConstants, InKernelGen ())
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp
-> Exp
-> Exp
-> VName
-> VName
-> VName
-> Exp
-> Exp
-> Exp
-> Exp
-> Exp
-> KernelConstants
KernelConstants
(VName -> PrimType -> Exp
Imp.var VName
thread_gtid PrimType
int32) (VName -> PrimType -> Exp
Imp.var VName
thread_ltid PrimType
int32) (VName -> PrimType -> Exp
Imp.var VName
group_id PrimType
int32)
VName
thread_gtid VName
thread_ltid VName
group_id
Exp
num_groups Exp
group_size (Exp
group_sizeExp -> Exp -> Exp
forall a. Num a => a -> a -> a
*Exp
num_groups) Exp
0
(VName -> PrimType -> Exp
Imp.var VName
thread_gtid PrimType
int32 Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
kernel_size),
InKernelGen ()
set_constants)
virtualiseGroups :: SegVirt
-> Imp.Exp
-> (VName -> InKernelGen ())
-> InKernelGen ()
virtualiseGroups :: SegVirt -> Exp -> (VName -> InKernelGen ()) -> InKernelGen ()
virtualiseGroups SegVirt
SegNoVirt Exp
_ VName -> InKernelGen ()
m = do
VName
gid <- KernelConstants -> VName
kernelGroupIdVar (KernelConstants -> VName)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> VName)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
VName -> InKernelGen ()
m VName
gid
virtualiseGroups SegVirt
SegVirt Exp
required_groups VName -> InKernelGen ()
m = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
VName
phys_group_id <- String -> PrimType -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> PrimType -> ImpM lore r op VName
dPrim String
"phys_group_id" PrimType
int32
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> Int -> KernelOp
Imp.GetGroupId VName
phys_group_id Int
0
let iterations :: Exp
iterations = (Exp
required_groups Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
- VName -> Exp
Imp.vi32 VName
phys_group_id) Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp`
KernelConstants -> Exp
kernelNumGroups KernelConstants
constants
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" Exp
iterations ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
VName -> InKernelGen ()
m (VName -> InKernelGen ())
-> ImpM ExplicitMemory KernelEnv KernelOp VName -> InKernelGen ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"virt_group_id" (VName -> Exp
Imp.vi32 VName
phys_group_id Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelNumGroups KernelConstants
constants)
KernelOp -> InKernelGen ()
forall op lore r. op -> ImpM lore r op ()
sOp (KernelOp -> InKernelGen ()) -> KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Fence -> KernelOp
Imp.Barrier Fence
Imp.FenceGlobal
sKernelThread :: String
-> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelThread :: String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelThread = Operations ExplicitMemory KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernel Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants -> Exp
kernelGlobalThreadId
sKernelGroup :: String
-> Count NumGroups Imp.Exp -> Count GroupSize Imp.Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernelGroup :: String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelGroup = Operations ExplicitMemory KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernel Operations ExplicitMemory KernelEnv KernelOp
groupOperations KernelConstants -> Exp
kernelGroupId
sKernelFailureTolerant :: Bool
-> Operations ExplicitMemory KernelEnv Imp.KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> CallKernelGen ()
sKernelFailureTolerant :: Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
tol Operations ExplicitMemory KernelEnv KernelOp
ops KernelConstants
constants Name
name InKernelGen ()
m = do
HostEnv AtomicBinOp
atomics <- ImpM ExplicitMemory HostEnv HostOp HostEnv
forall lore r op. ImpM lore r op r
askEnv
Code KernelOp
body <- CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a. CallKernelGen a -> CallKernelGen a
makeAllMemoryGlobal (CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp))
-> CallKernelGen (Code KernelOp) -> CallKernelGen (Code KernelOp)
forall a b. (a -> b) -> a -> b
$ KernelEnv
-> Operations ExplicitMemory KernelEnv KernelOp
-> InKernelGen ()
-> CallKernelGen (Code KernelOp)
forall r' lore op' a r op.
r'
-> Operations lore r' op'
-> ImpM lore r' op' a
-> ImpM lore r op (Code op')
subImpM_ (AtomicBinOp -> KernelConstants -> KernelEnv
KernelEnv AtomicBinOp
atomics KernelConstants
constants) Operations ExplicitMemory KernelEnv KernelOp
ops InKernelGen ()
m
[KernelUse]
uses <- Code KernelOp -> [VName] -> CallKernelGen [KernelUse]
forall a. FreeIn a => a -> [VName] -> CallKernelGen [KernelUse]
computeKernelUses Code KernelOp
body [VName]
forall a. Monoid a => a
mempty
Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ HostOp -> Code HostOp
forall a. a -> Code a
Imp.Op (HostOp -> Code HostOp) -> HostOp -> Code HostOp
forall a b. (a -> b) -> a -> b
$ Kernel -> HostOp
Imp.CallKernel Kernel :: Code KernelOp
-> [KernelUse] -> [Exp] -> [Exp] -> Name -> Bool -> Kernel
Imp.Kernel
{ kernelBody :: Code KernelOp
Imp.kernelBody = Code KernelOp
body
, kernelUses :: [KernelUse]
Imp.kernelUses = [KernelUse]
uses
, kernelNumGroups :: [Exp]
Imp.kernelNumGroups = [KernelConstants -> Exp
kernelNumGroups KernelConstants
constants]
, kernelGroupSize :: [Exp]
Imp.kernelGroupSize = [KernelConstants -> Exp
kernelGroupSize KernelConstants
constants]
, kernelName :: Name
Imp.kernelName = Name
name
, kernelFailureTolerant :: Bool
Imp.kernelFailureTolerant = Bool
tol
}
sKernel :: Operations ExplicitMemory KernelEnv Imp.KernelOp
-> (KernelConstants -> Imp.Exp)
-> String
-> Count NumGroups Imp.Exp
-> Count GroupSize Imp.Exp
-> VName
-> InKernelGen ()
-> CallKernelGen ()
sKernel :: Operations ExplicitMemory KernelEnv KernelOp
-> (KernelConstants -> Exp)
-> String
-> Count NumGroups Exp
-> Count GroupSize Exp
-> VName
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernel Operations ExplicitMemory KernelEnv KernelOp
ops KernelConstants -> Exp
flatf String
name Count NumGroups Exp
num_groups Count GroupSize Exp
group_size VName
v InKernelGen ()
f = do
(KernelConstants
constants, InKernelGen ()
set_constants) <- Count NumGroups Exp
-> Count GroupSize Exp
-> CallKernelGen (KernelConstants, InKernelGen ())
kernelInitialisationSimple Count NumGroups Exp
num_groups Count GroupSize Exp
group_size
let name' :: Name
name' = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag VName
v)
Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
False Operations ExplicitMemory KernelEnv KernelOp
ops KernelConstants
constants Name
name' (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
dPrimV_ VName
v (Exp -> InKernelGen ()) -> Exp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
flatf KernelConstants
constants
InKernelGen ()
f
copyInGroup :: CopyCompiler ExplicitMemory KernelEnv Imp.KernelOp
copyInGroup :: CopyCompiler ExplicitMemory KernelEnv KernelOp
copyInGroup PrimType
pt MemLocation
destloc MemLocation
srcloc = do
Space
dest_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
-> ImpM ExplicitMemory KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
destloc)
Space
src_space <- MemEntry -> Space
entryMemSpace (MemEntry -> Space)
-> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
-> ImpM ExplicitMemory KernelEnv KernelOp Space
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName MemLocation
srcloc)
if Space -> Bool
isScalarMem Space
dest_space Bool -> Bool -> Bool
&& Space -> Bool
isScalarMem Space
src_space
then MemLocation -> VName
memLocationName MemLocation
destloc VName -> Exp -> InKernelGen ()
forall lore r op. VName -> Exp -> ImpM lore r op ()
<-- VName -> PrimType -> Exp
Imp.var (MemLocation -> VName
memLocationName MemLocation
srcloc) PrimType
pt
else CopyCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise PrimType
pt MemLocation
destloc MemLocation
srcloc
where isScalarMem :: Space -> Bool
isScalarMem ScalarSpace{} = Bool
True
isScalarMem Space
_ = Bool
False
threadOperations, groupOperations :: Operations ExplicitMemory KernelEnv Imp.KernelOp
threadOperations :: Operations ExplicitMemory KernelEnv KernelOp
threadOperations =
(OpCompiler ExplicitMemory KernelEnv KernelOp
-> Operations ExplicitMemory KernelEnv KernelOp
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadOp)
{ opsCopyCompiler :: CopyCompiler ExplicitMemory KernelEnv KernelOp
opsCopyCompiler = CopyCompiler ExplicitMemory KernelEnv KernelOp
forall lore r op. CopyCompiler lore r op
copyElementWise
, opsExpCompiler :: ExpCompiler ExplicitMemory KernelEnv KernelOp
opsExpCompiler = ExpCompiler ExplicitMemory KernelEnv KernelOp
compileThreadExp
, opsStmsCompiler :: Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty
, opsAllocCompilers :: Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
opsAllocCompilers =
[(Space, AllocCompiler ExplicitMemory KernelEnv KernelOp)]
-> Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String -> Space
Space String
"local", AllocCompiler ExplicitMemory KernelEnv KernelOp
forall r. AllocCompiler ExplicitMemory r KernelOp
allocLocal) ]
}
groupOperations :: Operations ExplicitMemory KernelEnv KernelOp
groupOperations =
(OpCompiler ExplicitMemory KernelEnv KernelOp
-> Operations ExplicitMemory KernelEnv KernelOp
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupOp)
{ opsCopyCompiler :: CopyCompiler ExplicitMemory KernelEnv KernelOp
opsCopyCompiler = CopyCompiler ExplicitMemory KernelEnv KernelOp
copyInGroup
, opsExpCompiler :: ExpCompiler ExplicitMemory KernelEnv KernelOp
opsExpCompiler = ExpCompiler ExplicitMemory KernelEnv KernelOp
compileGroupExp
, opsStmsCompiler :: Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
opsStmsCompiler = \Names
_ -> Names -> Stms ExplicitMemory -> InKernelGen () -> InKernelGen ()
forall lore op r.
(ExplicitMemorish lore, FreeIn op) =>
Names -> Stms lore -> ImpM lore r op () -> ImpM lore r op ()
defCompileStms Names
forall a. Monoid a => a
mempty
, opsAllocCompilers :: Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
opsAllocCompilers =
[(Space, AllocCompiler ExplicitMemory KernelEnv KernelOp)]
-> Map Space (AllocCompiler ExplicitMemory KernelEnv KernelOp)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (String -> Space
Space String
"local", AllocCompiler ExplicitMemory KernelEnv KernelOp
forall r. AllocCompiler ExplicitMemory r KernelOp
allocLocal) ]
}
sReplicateKernel :: VName -> SubExp -> CallKernelGen ()
sReplicateKernel :: VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se = do
Type
t <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
se
[SubExp]
ds <- Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
dropLast (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) ([SubExp] -> [SubExp]) -> (Type -> [SubExp]) -> Type -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp])
-> ImpM ExplicitMemory HostEnv HostOp Type
-> ImpM ExplicitMemory HostEnv HostOp [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[Exp]
dims <- (SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp)
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory HostEnv HostOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp ([SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp])
-> [SubExp] -> ImpM ExplicitMemory HostEnv HostOp [Exp]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ds [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t
(KernelConstants
constants, InKernelGen ()
set_constants) <-
Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants ([Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
dims) String
"replicate"
let is' :: [Exp]
is' = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
dims (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix VName
arr [Exp]
is' SubExp
se ([Exp] -> InKernelGen ()) -> [Exp] -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Int -> [Exp] -> [Exp]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [Exp]
is'
replicateFunction :: PrimType -> CallKernelGen Imp.Function
replicateFunction :: PrimType -> CallKernelGen Function
replicateFunction PrimType
bt = do
VName
mem <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"mem"
VName
num_elems <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"num_elems"
VName
val <- String -> ImpM ExplicitMemory HostEnv HostOp VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"val"
let params :: [Param]
params = [VName -> Space -> Param
Imp.MemParam VName
mem (String -> Space
Space String
"device"),
VName -> PrimType -> Param
Imp.ScalarParam VName
num_elems PrimType
int32,
VName -> PrimType -> Param
Imp.ScalarParam VName
val PrimType
bt]
shape :: Shape
shape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [VName -> SubExp
Var VName
num_elems]
[Param]
-> [Param]
-> ImpM ExplicitMemory HostEnv HostOp ()
-> CallKernelGen Function
forall lore r op.
[Param]
-> [Param] -> ImpM lore r op () -> ImpM lore r op (Function op)
function [] [Param]
params (ImpM ExplicitMemory HostEnv HostOp () -> CallKernelGen Function)
-> ImpM ExplicitMemory HostEnv HostOp () -> CallKernelGen Function
forall a b. (a -> b) -> a -> b
$ do
VName
arr <- String
-> PrimType
-> Shape
-> MemBind
-> ImpM ExplicitMemory HostEnv HostOp VName
forall lore r op.
String -> PrimType -> Shape -> MemBind -> ImpM lore r op VName
sArray String
"arr" PrimType
bt Shape
shape (MemBind -> ImpM ExplicitMemory HostEnv HostOp VName)
-> MemBind -> ImpM ExplicitMemory HostEnv HostOp VName
forall a b. (a -> b) -> a -> b
$ VName -> IxFun -> MemBind
ArrayIn VName
mem (IxFun -> MemBind) -> IxFun -> MemBind
forall a b. (a -> b) -> a -> b
$ Shape (PrimExp VName) -> IxFun
forall num. IntegralExp num => Shape num -> IxFun num
IxFun.iota (Shape (PrimExp VName) -> IxFun) -> Shape (PrimExp VName) -> IxFun
forall a b. (a -> b) -> a -> b
$
(SubExp -> PrimExp VName) -> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) ([SubExp] -> Shape (PrimExp VName))
-> [SubExp] -> Shape (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicateKernel VName
arr (SubExp -> ImpM ExplicitMemory HostEnv HostOp ())
-> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
val
replicateName :: PrimType -> String
replicateName :: PrimType -> String
replicateName PrimType
bt = String
"replicate_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ PrimType -> String
forall a. Pretty a => a -> String
pretty PrimType
bt
replicateForType :: PrimType -> CallKernelGen Name
replicateForType :: PrimType -> CallKernelGen Name
replicateForType PrimType
bt = do
let fname :: Name
fname = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"builtin#" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> PrimType -> String
replicateName PrimType
bt
Bool
exists <- Name -> ImpM ExplicitMemory HostEnv HostOp Bool
forall lore r op. Name -> ImpM lore r op Bool
hasFunction Name
fname
Bool
-> ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
exists (ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ())
-> ImpM ExplicitMemory HostEnv HostOp ()
-> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ Name -> Function -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Name -> Function op -> ImpM lore r op ()
emitFunction Name
fname (Function -> ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen Function -> ImpM ExplicitMemory HostEnv HostOp ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PrimType -> CallKernelGen Function
replicateFunction PrimType
bt
Name -> CallKernelGen Name
forall (m :: * -> *) a. Monad m => a -> m a
return Name
fname
replicateIsFill :: VName -> SubExp -> CallKernelGen (Maybe (CallKernelGen ()))
replicateIsFill :: VName
-> SubExp
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
v = do
ArrayEntry (MemLocation VName
arr_mem [SubExp]
arr_shape IxFun Exp
arr_ixfun) PrimType
_ <- VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
Type
v_t <- SubExp -> ImpM ExplicitMemory HostEnv HostOp Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v
case Type
v_t of
Prim PrimType
v_t'
| IxFun Exp -> Bool
forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
IxFun.isLinear IxFun Exp
arr_ixfun -> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ())))
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
forall a b. (a -> b) -> a -> b
$ ImpM ExplicitMemory HostEnv HostOp ()
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
forall a. a -> Maybe a
Just (ImpM ExplicitMemory HostEnv HostOp ()
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
-> ImpM ExplicitMemory HostEnv HostOp ()
-> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
forall a b. (a -> b) -> a -> b
$ do
Name
fname <- PrimType -> CallKernelGen Name
replicateForType PrimType
v_t'
Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ())
-> Code HostOp -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Name -> [Arg] -> Code HostOp
forall a. [VName] -> Name -> [Arg] -> Code a
Imp.Call [] Name
fname
[VName -> Arg
Imp.MemArg VName
arr_mem,
Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
arr_shape,
Exp -> Arg
Imp.ExpArg (Exp -> Arg) -> Exp -> Arg
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
v_t' SubExp
v]
Type
_ -> Maybe (ImpM ExplicitMemory HostEnv HostOp ())
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (ImpM ExplicitMemory HostEnv HostOp ())
forall a. Maybe a
Nothing
sReplicate :: VName -> SubExp -> CallKernelGen ()
sReplicate :: VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicate VName
arr SubExp
se = do
Maybe (ImpM ExplicitMemory HostEnv HostOp ())
is_fill <- VName
-> SubExp
-> CallKernelGen (Maybe (ImpM ExplicitMemory HostEnv HostOp ()))
replicateIsFill VName
arr SubExp
se
case Maybe (ImpM ExplicitMemory HostEnv HostOp ())
is_fill of
Just ImpM ExplicitMemory HostEnv HostOp ()
m -> ImpM ExplicitMemory HostEnv HostOp ()
m
Maybe (ImpM ExplicitMemory HostEnv HostOp ())
Nothing -> VName -> SubExp -> ImpM ExplicitMemory HostEnv HostOp ()
sReplicateKernel VName
arr SubExp
se
sIota :: VName -> Imp.Exp -> Imp.Exp -> Imp.Exp -> IntType
-> CallKernelGen ()
sIota :: VName
-> Exp
-> Exp
-> Exp
-> IntType
-> ImpM ExplicitMemory HostEnv HostOp ()
sIota VName
arr Exp
n Exp
x Exp
s IntType
et = do
MemLocation
destloc <- ArrayEntry -> MemLocation
entryArrayLocation (ArrayEntry -> MemLocation)
-> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
-> ImpM ExplicitMemory HostEnv HostOp MemLocation
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory HostEnv HostOp ArrayEntry
forall lore r op. VName -> ImpM lore r op ArrayEntry
lookupArray VName
arr
(KernelConstants
constants, InKernelGen ()
set_constants) <- Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
n String
"iota"
let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"iota_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
let gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ do
(VName
destmem, Space
destspace, Count Elements Exp
destidx) <- MemLocation
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destloc [Exp
gtid]
Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code KernelOp
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx (IntType -> PrimType
IntType IntType
et) Space
destspace Volatility
Imp.Nonvolatile (Exp -> Code KernelOp) -> Exp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
ConvOp -> Exp -> Exp
forall v. ConvOp -> PrimExp v -> PrimExp v
Imp.ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
Int32 IntType
et) Exp
gtid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
s Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
x
sCopy :: PrimType
-> MemLocation
-> MemLocation
-> CallKernelGen ()
sCopy :: PrimType
-> MemLocation
-> MemLocation
-> ImpM ExplicitMemory HostEnv HostOp ()
sCopy PrimType
bt
destloc :: MemLocation
destloc@(MemLocation VName
destmem [SubExp]
_ IxFun Exp
_)
srcloc :: MemLocation
srcloc@(MemLocation VName
srcmem [SubExp]
srcshape IxFun Exp
_)
= do
let shape :: [Exp]
shape = (SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
srcshape
kernel_size :: Exp
kernel_size = [Exp] -> Exp
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [Exp]
shape
(KernelConstants
constants, InKernelGen ()
set_constants) <- Exp -> String -> CallKernelGen (KernelConstants, InKernelGen ())
simpleKernelConstants Exp
kernel_size String
"copy"
let name :: Name
name = String -> Name
nameFromString (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"copy_" String -> String -> String
forall a. [a] -> [a] -> [a]
++
Int -> String
forall a. Show a => a -> String
show (VName -> Int
baseTag (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ KernelConstants -> VName
kernelGlobalThreadIdVar KernelConstants
constants)
Bool
-> Operations ExplicitMemory KernelEnv KernelOp
-> KernelConstants
-> Name
-> InKernelGen ()
-> ImpM ExplicitMemory HostEnv HostOp ()
sKernelFailureTolerant Bool
True Operations ExplicitMemory KernelEnv KernelOp
threadOperations KernelConstants
constants Name
name (InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ())
-> InKernelGen () -> ImpM ExplicitMemory HostEnv HostOp ()
forall a b. (a -> b) -> a -> b
$ do
InKernelGen ()
set_constants
let gtid :: Exp
gtid = KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
dest_is :: [Exp]
dest_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
shape Exp
gtid
src_is :: [Exp]
src_is = [Exp]
dest_is
(VName
_, Space
destspace, Count Elements Exp
destidx) <- MemLocation
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
destloc [Exp]
dest_is
(VName
_, Space
srcspace, Count Elements Exp
srcidx) <- MemLocation
-> [Exp]
-> ImpM
ExplicitMemory
KernelEnv
KernelOp
(VName, Space, Count Elements Exp)
forall lore r op.
MemLocation
-> [Exp] -> ImpM lore r op (VName, Space, Count Elements Exp)
fullyIndexArray' MemLocation
srcloc [Exp]
src_is
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
gtid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
kernel_size) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ Code KernelOp -> InKernelGen ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code KernelOp -> InKernelGen ())
-> Code KernelOp -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code KernelOp
forall a.
VName
-> Count Elements Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> Code a
Imp.Write VName
destmem Count Elements Exp
destidx PrimType
bt Space
destspace Volatility
Imp.Nonvolatile (Exp -> Code KernelOp) -> Exp -> Code KernelOp
forall a b. (a -> b) -> a -> b
$
VName
-> Count Elements Exp -> PrimType -> Space -> Volatility -> Exp
Imp.index VName
srcmem Count Elements Exp
srcidx PrimType
bt Space
srcspace Volatility
Imp.Nonvolatile
compileGroupResult :: SegSpace
-> PatElem ExplicitMemory -> KernelResult
-> InKernelGen ()
compileGroupResult :: SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileGroupResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (TileReturns [(SubExp
w,SubExp
per_group_elems)] VName
what) = do
Exp
n <- SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> (Type -> SubExp)
-> Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
what
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let ltid :: Exp
ltid = KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
offset :: Exp
offset = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
per_group_elems Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGroupId KernelConstants
constants
if PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
per_group_elems Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== KernelConstants -> Exp
kernelGroupSize KernelConstants
constants
then Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
offset Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
w) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp
ltid Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
offset] (VName -> SubExp
Var VName
what) [Exp
ltid]
else
String -> Exp -> (Exp -> InKernelGen ()) -> InKernelGen ()
forall lore r op.
String -> Exp -> (Exp -> ImpM lore r op ()) -> ImpM lore r op ()
sFor String
"i" (Exp
n Exp -> Exp -> Exp
forall num. IntegralExp num => num -> num -> num
`quotRoundingUp` KernelConstants -> Exp
kernelGroupSize KernelConstants
constants) ((Exp -> InKernelGen ()) -> InKernelGen ())
-> (Exp -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \Exp
i -> do
Exp
j <- (VName -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp VName
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> Exp
Imp.vi32 (ImpM ExplicitMemory KernelEnv KernelOp VName
-> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp VName
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a b. (a -> b) -> a -> b
$ String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"j" (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall a b. (a -> b) -> a -> b
$
KernelConstants -> Exp
kernelGroupSize KernelConstants
constants Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* Exp
i Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
ltid
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (Exp
j Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. Exp
n) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp
j Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
+ Exp
offset] (VName -> SubExp
Var VName
what) [Exp
j]
compileGroupResult SegSpace
space PatElemT (LetAttr ExplicitMemory)
pe (TileReturns [(SubExp, SubExp)]
dims VName
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let gids :: [VName]
gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
out_tile_sizes :: [Exp]
out_tile_sizes = ((SubExp, SubExp) -> Exp) -> [(SubExp, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp)
-> ((SubExp, SubExp) -> SubExp) -> (SubExp, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(SubExp, SubExp)]
dims
local_is :: [Exp]
local_is = [Exp] -> Exp -> [Exp]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [Exp]
out_tile_sizes (Exp -> [Exp]) -> Exp -> [Exp]
forall a b. (a -> b) -> a -> b
$ KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants
group_is :: [Exp]
group_is = (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(*) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
gids) [Exp]
out_tile_sizes
[VName]
is_for_thread <- (Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName)
-> [Exp] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp -> ImpM ExplicitMemory KernelEnv KernelOp VName
forall lore r op. String -> Exp -> ImpM lore r op VName
dPrimV String
"thread_out_index") ([Exp] -> ImpM ExplicitMemory KernelEnv KernelOp [VName])
-> [Exp] -> ImpM ExplicitMemory KernelEnv KernelOp [VName]
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
(+) [Exp]
group_is [Exp]
local_is
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen ([(VName, SubExp)] -> Exp
isActive ([(VName, SubExp)] -> Exp) -> [(VName, SubExp)] -> Exp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is_for_thread ([SubExp] -> [(VName, SubExp)]) -> [SubExp] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ ((SubExp, SubExp) -> SubExp) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, SubExp) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, SubExp)]
dims) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((VName -> Exp) -> [VName] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Exp
Imp.vi32 [VName]
is_for_thread) (VName -> SubExp
Var VName
what) [Exp]
local_is
compileGroupResult SegSpace
space PatElemT (LetAttr ExplicitMemory)
pe (Returns ResultManifest
_ SubExp
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Bool
in_local_memory <- SubExp -> InKernelGen Bool
arrayInLocalMemory SubExp
what
let gids :: [Exp]
gids = ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
if Bool -> Bool
not Bool
in_local_memory then
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen (KernelConstants -> Exp
kernelLocalThreadId KernelConstants
constants Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
.==. Exp
0) (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
gids SubExp
what []
else
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
groupCopy (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
gids SubExp
what []
compileGroupResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
_ WriteReturns{} =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: WriteReturns not handled yet."
compileGroupResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
_ ConcatReturns{} =
String -> InKernelGen ()
forall a. String -> a
compilerLimitationS String
"compileGroupResult: ConcatReturns not handled yet."
compileThreadResult :: SegSpace
-> PatElem ExplicitMemory -> KernelResult
-> InKernelGen ()
compileThreadResult :: SegSpace
-> PatElemT (LetAttr ExplicitMemory)
-> KernelResult
-> InKernelGen ()
compileThreadResult SegSpace
space PatElemT (LetAttr ExplicitMemory)
pe (Returns ResultManifest
_ SubExp
what) = do
let is :: [Exp]
is = ((VName, SubExp) -> Exp) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> Exp
Imp.vi32 (VName -> Exp)
-> ((VName, SubExp) -> VName) -> (VName, SubExp) -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst) ([(VName, SubExp)] -> [Exp]) -> [(VName, SubExp)] -> [Exp]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp]
is SubExp
what []
compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (ConcatReturns SplitOrdering
SplitContiguous SubExp
_ SubExp
per_thread_elems VName
what) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
let offset :: Exp
offset = PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
per_thread_elems Exp -> Exp -> Exp
forall a. Num a => a -> a -> a
* KernelConstants -> Exp
kernelGlobalThreadId KernelConstants
constants
Exp
n <- PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (Type -> SubExp) -> Type -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
what
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp -> Exp -> Exp -> DimIndex Exp
forall d. d -> d -> d -> DimIndex d
DimSlice Exp
offset Exp
n Exp
1] (VName -> SubExp
Var VName
what) []
compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (ConcatReturns (SplitStrided SubExp
stride) SubExp
_ SubExp
_ VName
what) = do
Exp
offset <- KernelConstants -> Exp
kernelGlobalThreadId (KernelConstants -> Exp)
-> (KernelEnv -> KernelConstants) -> KernelEnv -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
Exp
n <- PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 (SubExp -> Exp) -> (Type -> SubExp) -> Type -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 (Type -> Exp)
-> ImpM ExplicitMemory KernelEnv KernelOp Type
-> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ImpM ExplicitMemory KernelEnv KernelOp Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
what
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> InKernelGen ()
forall lore r op.
VName
-> [DimIndex Exp] -> SubExp -> [DimIndex Exp] -> ImpM lore r op ()
copyDWIM (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) [Exp -> Exp -> Exp -> DimIndex Exp
forall d. d -> d -> d -> DimIndex d
DimSlice Exp
offset Exp
n (Exp -> DimIndex Exp) -> Exp -> DimIndex Exp
forall a b. (a -> b) -> a -> b
$ PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32 SubExp
stride] (VName -> SubExp
Var VName
what) []
compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
pe (WriteReturns [SubExp]
rws VName
_arr [([SubExp], SubExp)]
dests) = do
KernelConstants
constants <- KernelEnv -> KernelConstants
kernelConstants (KernelEnv -> KernelConstants)
-> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
-> ImpM ExplicitMemory KernelEnv KernelOp KernelConstants
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ImpM ExplicitMemory KernelEnv KernelOp KernelEnv
forall lore r op. ImpM lore r op r
askEnv
[Exp]
rws' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
rws
[([SubExp], SubExp)]
-> (([SubExp], SubExp) -> InKernelGen ()) -> InKernelGen ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [([SubExp], SubExp)]
dests ((([SubExp], SubExp) -> InKernelGen ()) -> InKernelGen ())
-> (([SubExp], SubExp) -> InKernelGen ()) -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ \([SubExp]
is, SubExp
e) -> do
[Exp]
is' <- (SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp)
-> [SubExp] -> ImpM ExplicitMemory KernelEnv KernelOp [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> ImpM ExplicitMemory KernelEnv KernelOp Exp
forall a lore r op. ToExp a => a -> ImpM lore r op Exp
toExp [SubExp]
is
let condInBounds :: PrimExp v -> PrimExp v -> PrimExp v
condInBounds PrimExp v
i PrimExp v
rw = PrimExp v
0 PrimExp v -> PrimExp v -> PrimExp v
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<=. PrimExp v
i PrimExp v -> PrimExp v -> PrimExp v
forall v. PrimExp v -> PrimExp v -> PrimExp v
.&&. PrimExp v
i PrimExp v -> PrimExp v -> PrimExp v
forall v. PrimExp v -> PrimExp v -> PrimExp v
.<. PrimExp v
rw
write :: Exp
write = (Exp -> Exp -> Exp) -> Exp -> [Exp] -> Exp
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Exp -> Exp -> Exp
forall v. PrimExp v -> PrimExp v -> PrimExp v
(.&&.) (KernelConstants -> Exp
kernelThreadActive KernelConstants
constants) ([Exp] -> Exp) -> [Exp] -> Exp
forall a b. (a -> b) -> a -> b
$
(Exp -> Exp -> Exp) -> [Exp] -> [Exp] -> [Exp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Exp -> Exp -> Exp
forall v. Pretty v => PrimExp v -> PrimExp v -> PrimExp v
condInBounds [Exp]
is' [Exp]
rws'
Exp -> InKernelGen () -> InKernelGen ()
forall lore r op. Exp -> ImpM lore r op () -> ImpM lore r op ()
sWhen Exp
write (InKernelGen () -> InKernelGen ())
-> InKernelGen () -> InKernelGen ()
forall a b. (a -> b) -> a -> b
$ VName -> [Exp] -> SubExp -> [Exp] -> InKernelGen ()
forall lore r op.
VName -> [Exp] -> SubExp -> [Exp] -> ImpM lore r op ()
copyDWIMFix (PatElemT (MemInfo SubExp NoUniqueness MemBind) -> VName
forall attr. PatElemT attr -> VName
patElemName PatElemT (LetAttr ExplicitMemory)
PatElemT (MemInfo SubExp NoUniqueness MemBind)
pe) ((SubExp -> Exp) -> [SubExp] -> [Exp]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> Exp
forall a. ToExp a => PrimType -> a -> Exp
toExp' PrimType
int32) [SubExp]
is) SubExp
e []
compileThreadResult SegSpace
_ PatElemT (LetAttr ExplicitMemory)
_ TileReturns{} =
String -> InKernelGen ()
forall a. String -> a
compilerBugS String
"compileThreadResult: TileReturns unhandled."
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var VName
name) = do
VarEntry ExplicitMemory
res <- VName
-> ImpM ExplicitMemory KernelEnv KernelOp (VarEntry ExplicitMemory)
forall lore r op. VName -> ImpM lore r op (VarEntry lore)
lookupVar VName
name
case VarEntry ExplicitMemory
res of
ArrayVar Maybe (ExpT ExplicitMemory)
_ ArrayEntry
entry ->
(String -> Space
Space String
"local"Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
==) (Space -> Bool) -> (MemEntry -> Space) -> MemEntry -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemEntry -> Space
entryMemSpace (MemEntry -> Bool)
-> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
-> InKernelGen Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
VName -> ImpM ExplicitMemory KernelEnv KernelOp MemEntry
forall lore r op. VName -> ImpM lore r op MemEntry
lookupMemory (MemLocation -> VName
memLocationName (ArrayEntry -> MemLocation
entryArrayLocation ArrayEntry
entry))
VarEntry ExplicitMemory
_ -> Bool -> InKernelGen Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
arrayInLocalMemory Constant{} = Bool -> InKernelGen Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False