{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.BlockedKernel
( DistRep,
MkSegLevel,
ThreadRecommendation (..),
segRed,
nonSegRed,
segScan,
segHist,
segMap,
mapKernel,
KernelInput (..),
readKernelInput,
mkSegSpace,
dummyDim,
)
where
import Control.Monad
import Control.Monad.Writer
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Aliases (AliasableRep)
import Futhark.IR.GPU.Op (SegVirt (..))
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)
type DistRep rep =
( Buildable rep,
HasSegOp rep,
BuilderOps rep,
LetDec rep ~ Type,
ExpDec rep ~ (),
BodyDec rep ~ (),
AliasableRep rep
)
data ThreadRecommendation = ManyThreads | NoRecommendation SegVirt
type MkSegLevel rep m =
[SubExp] -> String -> ThreadRecommendation -> BuilderT rep m (SegOpLevel rep)
mkSegSpace :: MonadFreshNames m => [(VName, SubExp)] -> m SegSpace
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
dims = VName -> [(VName, SubExp)] -> SegSpace
SegSpace forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid" forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, SubExp)]
dims
prepareRedOrScan ::
(MonadBuilder m, DistRep (Rep m)) =>
Certs ->
SubExp ->
Lambda (Rep m) ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (SegSpace, KernelBody (Rep m))
prepareRedOrScan :: forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda (Rep m)
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = do
VName
gtid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
w)]
KernelBody (Rep m)
kbody <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip (forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ()))) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput forall a b. (a -> b) -> a -> b
$ do
(Param (LParamInfo (Rep m))
p, VName
arr) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda (Rep m)
map_lam) [VName]
arrs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
p) (forall dec. Typed dec => Param dec -> Type
paramType Param (LParamInfo (Rep m))
p) VName
arr [VName -> SubExp
Var VName
gtid]
Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall rep. Lambda rep -> Body rep
lambdaBody Lambda (Rep m)
map_lam)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
res_cs SubExp
se) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
res_cs SubExp
se
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space, KernelBody (Rep m)
kbody)
segRed ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat (LetDec rep) ->
Certs ->
SubExp ->
[SegBinOp rep] ->
Lambda rep ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms rep)
segRed :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat (LetDec rep)
pat Certs
cs SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody rep
kbody) <- forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed SegOpLevel rep
lvl SegSpace
kspace [SegBinOp rep]
ops (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody
segScan ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat (LetDec rep) ->
Certs ->
SubExp ->
[SegBinOp rep] ->
Lambda rep ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms rep)
segScan :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel rep
lvl Pat (LetDec rep)
pat Certs
cs SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody rep
kbody) <- forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan Certs
cs SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan SegOpLevel rep
lvl SegSpace
kspace [SegBinOp rep]
ops (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody
segMap ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat (LetDec rep) ->
SubExp ->
Lambda rep ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms rep)
segMap :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> SubExp
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segMap SegOpLevel rep
lvl Pat (LetDec rep)
pat SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody rep
kbody) <- forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
Certs
-> SubExp
-> Lambda (Rep m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Rep m))
prepareRedOrScan forall a. Monoid a => a
mempty SubExp
w Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
kspace (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
map_lam) KernelBody rep
kbody
dummyDim ::
MonadBuilder m =>
Pat Type ->
m (Pat Type, [(VName, SubExp)], m ())
dummyDim :: forall (m :: * -> *).
MonadBuilder m =>
Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim Pat Type
pat = do
let addDummyDim :: Type -> Type
addDummyDim Type
t = Type
t forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
Pat Type
pat' <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
addDummyDim forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
Pat dec -> m (Pat dec)
renamePat Pat Type
pat
VName
dummy <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"dummy"
let ispace :: [(VName, SubExp)]
ispace = [(VName
dummy, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure
( Pat Type
pat',
[(VName, SubExp)]
ispace,
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
pat') (forall dec. Pat dec -> [VName]
patNames Pat Type
pat)) forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) -> do
Type
from_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
from
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
to] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
case Type
from_t of
Acc {} -> SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
from
Type
_ -> VName -> Slice SubExp -> BasicOp
Index VName
from forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
from_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
)
nonSegRed ::
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep ->
Pat Type ->
SubExp ->
[SegBinOp rep] ->
Lambda rep ->
[VName] ->
m (Stms rep)
nonSegRed :: forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegOpLevel rep
lvl Pat Type
pat SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
(Pat Type
pat', [(VName, SubExp)]
ispace, Builder rep ()
read_dummy) <- forall (m :: * -> *).
MonadBuilder m =>
Pat Type -> m (Pat Type, [(VName, SubExp)], m ())
dummyDim Pat Type
pat
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel rep
lvl Pat Type
pat' forall a. Monoid a => a
mempty SubExp
w [SegBinOp rep]
ops Lambda rep
map_lam [VName]
arrs [(VName, SubExp)]
ispace []
Builder rep ()
read_dummy
segHist ::
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep ->
Pat Type ->
SubExp ->
[(VName, SubExp)] ->
[KernelInput] ->
[HistOp rep] ->
Lambda rep ->
[VName] ->
m (Stms rep)
segHist :: forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel rep
lvl Pat Type
pat SubExp
arr_w [(VName, SubExp)]
ispace [KernelInput]
inps [HistOp rep]
ops Lambda rep
lam [VName]
arrs = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
VName
gtid <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
arr_w)]
KernelBody rep
kbody <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody ())) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam) [VName]
arrs) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
Type
arr_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
arr forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [forall d. d -> DimIndex d
DimFix forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
Result
res <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM Result
res forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall a b. (a -> b) -> a -> b
$ forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. HasSegOp rep => SegOp (SegOpLevel rep) rep -> Op rep
segOp forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist SegOpLevel rep
lvl SegSpace
space [HistOp rep]
ops (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda rep
lam) KernelBody rep
kbody
mapKernelSkeleton ::
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] ->
[KernelInput] ->
m (SegSpace, Stms rep)
mapKernelSkeleton :: forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs = do
Stms rep
read_input_stms <- forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
SegSpace
space <- forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SegSpace
space, Stms rep
read_input_stms)
mapKernel ::
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m ->
[(VName, SubExp)] ->
[KernelInput] ->
[Type] ->
KernelBody rep ->
m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel :: forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel MkSegLevel rep m
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
inputs [Type]
rts (KernelBody () Stms rep
kstms [KernelResult]
krets) = forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
BuilderT rep m a -> m (a, Stms rep)
runBuilderT' forall a b. (a -> b) -> a -> b
$ do
(SegSpace
space, Stms rep
read_input_stms) <- forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms rep)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs
let kbody' :: KernelBody rep
kbody' = forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () (Stms rep
read_input_stms forall a. Semigroup a => a -> a -> a
<> Stms rep
kstms) [KernelResult]
krets
let r :: ThreadRecommendation
r = if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all forall shape u. TypeBase shape u -> Bool
primType [Type]
rts then ThreadRecommendation
ManyThreads else SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegVirt
SegOpLevel rep
lvl <- MkSegLevel rep m
mk_lvl (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segmap" ThreadRecommendation
r
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegOpLevel rep
lvl SegSpace
space [Type]
rts KernelBody rep
kbody'
data KernelInput = KernelInput
{ KernelInput -> VName
kernelInputName :: VName,
KernelInput -> Type
kernelInputType :: Type,
KernelInput -> VName
kernelInputArray :: VName,
KernelInput -> [SubExp]
kernelInputIndices :: [SubExp]
}
deriving (Int -> KernelInput -> ShowS
[KernelInput] -> ShowS
KernelInput -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelInput] -> ShowS
$cshowList :: [KernelInput] -> ShowS
show :: KernelInput -> String
$cshow :: KernelInput -> String
showsPrec :: Int -> KernelInput -> ShowS
$cshowsPrec :: Int -> KernelInput -> ShowS
Show)
readKernelInput ::
(DistRep (Rep m), MonadBuilder m) =>
KernelInput ->
m ()
readKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp = do
let pe :: PatElem Type
pe = forall dec. VName -> dec -> PatElem dec
PatElem (KernelInput -> VName
kernelInputName KernelInput
inp) forall a b. (a -> b) -> a -> b
$ KernelInput -> Type
kernelInputType KernelInput
inp
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind (forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
case KernelInput -> Type
kernelInputType KernelInput
inp of
Acc {} ->
SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
Type
_ ->
VName -> Slice SubExp -> BasicOp
Index (KernelInput -> VName
kernelInputArray KernelInput
inp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix (KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp)
forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (KernelInput -> Type
kernelInputType KernelInput
inp))