{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.BlockedKernel
( DistRep,
MkSegLevel,
ThreadRecommendation (..),
segRed,
nonSegRed,
segScan,
segHist,
segMap,
mapKernel,
KernelInput (..),
readKernelInput,
mkSegSpace,
dummyDim,
)
where
import Control.Monad
import Control.Monad.Writer
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.GPU.Op (SegVirt (..))
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)
type DistRep rep =
( Buildable rep,
HasSegOp rep,
BuilderOps rep,
LetDec rep ~ Type,
ExpDec rep ~ (),
BodyDec rep ~ (),
CanBeAliased (Op rep)
)
data ThreadRecommendation = ManyThreads | NoRecommendation SegVirt
type MkSegLevel rep m =
[SubExp] -> String -> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)
mkSegSpace :: MonadFreshNames m => [(VName, SubExp)] -> m SegSpace
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
dims = VName -> [(VName, SubExp)] -> SegSpace
SegSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid" forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, SubExp)]
dims
prepareRedOrScan ::
(MonadBuilder m, DistRep (Rep m)) =>
Certs ->
SubExp ->
Lambda (Rep m) ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (SegSpace, KernelBody (Rep m))
prepareRedOrScan :: forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda (Rep m)
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = do
VName
gtid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
w)]
KernelBody (Rep m)
kbody <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))) forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput forall a b. (a -> b) -> a -> b
$ do
(Param (LParamInfo (Rep m))
p, VName
arr) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
map_lam) [VName]
arrs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p) (forall dec. Typed dec => Param dec -> Type
paramType Param (LParamInfo (Rep m))
p) VName
arr [VName -> SubExp
Var VName
gtid]
Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
res_cs SubExp
se) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
res_cs SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space, KernelBody (Rep m)
kbody)
segRed ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat (LetDec rep) ->
Certs ->
SubExp ->
[SegBinOp rep] ->
Lambda rep ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms rep)
segRed :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat (LetDec rep)
pat Certs
cs SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody rep
kbody) <- forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
kspace [SegBinOp rep]
ops (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody
segScan ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat (LetDec rep) ->
Certs ->
SubExp ->
[SegBinOp rep] ->
Lambda rep ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms rep)
segScan :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel rep
lvl Pat (LetDec rep)
pat Certs
cs SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody rep
kbody) <- forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
kspace [SegBinOp rep]
ops (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody
segMap ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat (LetDec rep) ->
SubExp ->
Lambda rep ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms rep)
segMap :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> SubExp
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segMap SegOpLevel rep
lvl Pat (LetDec rep)
pat SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody rep
kbody) <- forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan forall a. Monoid a => a
mempty SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
kspace (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody
dummyDim ::
MonadBuilder m =>
Pat Type ->
m (Pat Type, [(VName, SubExp)], m ())
dummyDim :: forall (m :: * -> *).
MonadBuilder m =>
Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim Pat Type
pat = do
let addDummyDim :: Type -> Type
addDummyDim Type
t = Type
t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
Pat Type
pat' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
addDummyDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
Pat dec -> m (Pat dec)
renamePat Pat Type
pat
VName
dummy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"dummy"
let ispace :: [(VName, SubExp)]
ispace = [(VName
dummy, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Pat Type
pat',
[(VName, SubExp)]
ispace,
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
pat') (forall dec. Pat dec -> [VName]
patNames Pat Type
pat)) forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) -> do
Type
from_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
from
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
to] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
case Type
from_t of
Acc {} -> SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
from
Type
_ -> VName -> Slice SubExp -> BasicOp
Index VName
from forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
from_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
)
nonSegRed ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat Type ->
SubExp ->
[SegBinOp rep] ->
Lambda rep ->
[VName] ->
m (Stms rep)
nonSegRed :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel rep
lvl Pat Type
pat SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(Pat Type
pat', [(VName, SubExp)]
ispace, Builder rep ()
read_dummy) <- forall (m :: * -> *).
MonadBuilder m =>
Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim Pat Type
pat
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat Type
pat' forall a. Monoid a => a
mempty SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace []
Builder rep ()
read_dummy
segHist ::
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep ->
Pat Type ->
SubExp ->
[(VName, SubExp)] ->
[KernelInput] ->
[HistOp rep] ->
Lambda rep ->
[VName] ->
m (Stms rep)
segHist :: forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel rep
lvl Pat Type
pat SubExp
arr_w [(VName, SubExp)]
ispace [KernelInput]
inps [HistOp rep]
ops Lambda rep
lam [VName]
arrs = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
VName
gtid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
arr_w)]
KernelBody rep
kbody <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ())) forall a b. (a -> b) -> a -> b
$
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k). SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
Type
arr_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
arr
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
HasSegOp rep =>
SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) KernelBody rep
kbody
mapKernelSkeleton ::
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] ->
[KernelInput] ->
m (SegSpace, Stms rep)
mapKernelSkeleton :: forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs = do
Stms rep
read_input_stms <- forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space, Stms rep
read_input_stms)
mapKernel ::
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m ->
[(VName, SubExp)] ->
[KernelInput] ->
[Type] ->
KernelBody rep ->
m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel :: forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel MkSegLevel rep m
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
inputs [Type]
rts (KernelBody () Stms rep
kstms [KernelResult]
krets) = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ do
(SegSpace
space, Stms rep
read_input_stms) <- forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs
let kbody' :: KernelBody rep
kbody' = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () (Stms rep
read_input_stms forall a. Semigroup a => a -> a -> a
<> Stms rep
kstms) [KernelResult]
krets
let r :: ThreadRecommendation
r = if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType [Type]
rts then ThreadRecommendation
ManyThreads else SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegVirt
SegOpLevel rep
lvl <- MkSegLevel rep m
mk_lvl (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segmap" ThreadRecommendation
r
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} lvl (rep :: k).
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
rts KernelBody rep
kbody'
data KernelInput = KernelInput
{ KernelInput -> VName
kernelInputName :: VName,
KernelInput -> Type
kernelInputType :: Type,
KernelInput -> VName
kernelInputArray :: VName,
KernelInput -> [SubExp]
kernelInputIndices :: [SubExp]
}
deriving (Int -> KernelInput -> ShowS
[KernelInput] -> ShowS
KernelInput -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelInput] -> ShowS
$cshowList :: [KernelInput] -> ShowS
show :: KernelInput -> String
$cshow :: KernelInput -> String
showsPrec :: Int -> KernelInput -> ShowS
$cshowsPrec :: Int -> KernelInput -> ShowS
Show)
readKernelInput ::
(DistRep (Rep m), MonadBuilder m) =>
KernelInput ->
m ()
readKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp = do
let pe :: PatElem Type
pe = forall dec. VName -> dec -> PatElem dec
PatElem (KernelInput -> VName
kernelInputName KernelInput
inp) forall a b. (a -> b) -> a -> b
$ KernelInput -> Type
kernelInputType KernelInput
inp
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
case KernelInput -> Type
kernelInputType KernelInput
inp of
Acc {} ->
SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
Type
_ ->
VName -> Slice SubExp -> BasicOp
Index (KernelInput -> VName
kernelInputArray KernelInput
inp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix (KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp)
forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (KernelInput -> Type
kernelInputType KernelInput
inp))