{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where
import Control.Monad.Identity
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.IR.Kernels as Out
import Futhark.IR.Kernels.Kernel hiding (HistOp)
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Util.Log
import Prelude hiding (log)
intraGroupParallelise ::
(MonadFreshNames m, LocalScope Out.Kernels m) =>
KernelNest ->
Lambda ->
m
( Maybe
( (SubExp, SubExp),
SubExp,
Log,
Out.Stms Out.Kernels,
Out.Stms Out.Kernels
)
)
intraGroupParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
intraGroupParallelise KernelNest
knest Lambda SOACS
lam = MaybeT
m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT
m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)))
-> MaybeT
m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> m (Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall a b. (a -> b) -> a -> b
$ do
([(VName, SubExp)]
ispace, [KernelInput]
inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest
(SubExp
num_groups, Stms Kernels
w_stms) <-
m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels))
-> m (SubExp, Stms Kernels) -> MaybeT m (SubExp, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Binder Kernels SubExp -> m (SubExp, Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels SubExp -> m (SubExp, Stms Kernels))
-> Binder Kernels SubExp -> m (SubExp, Stms Kernels)
forall a b. (a -> b) -> a -> b
$
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_num_groups"
(ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)
let body :: BodyT SOACS
body = Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam
VName
group_size <- String -> MaybeT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_group_size"
let intra_lvl :: SegLevel
intra_lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody Kernels
kbody) <-
m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope Kernels) -> [Param Type] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam) (m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall a b. (a -> b) -> a -> b
$
SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
intraGroupParalleliseBody SegLevel
intra_lvl BodyT SOACS
body
Scope Kernels
outside_scope <- m (Scope Kernels) -> MaybeT m (Scope Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let available :: VName -> Bool
available VName
v =
VName
v VName -> Scope Kernels -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope Kernels
outside_scope
Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> Names
forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
String -> MaybeT m ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"
((SubExp
intra_avail_par, SegSpace
kspace, Stms Kernels
read_input_stms), Stms Kernels
prelude_stms) <- m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels))
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
-> MaybeT m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels))
-> Binder Kernels (SubExp, SegSpace, Stms Kernels)
-> m ((SubExp, SegSpace, Stms Kernels), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (Exp (Lore m))) -> SubExp -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Lore m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
[SubExp]
ws_min <-
([SubExp] -> Binder Kernels SubExp)
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"one_intra_par_min" (ExpT Kernels -> Binder Kernels SubExp)
-> ([SubExp] -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [SubExp]
-> Binder Kernels SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
[SubExp]
ws_avail <-
([SubExp] -> Binder Kernels SubExp)
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"one_intra_par_avail" (ExpT Kernels -> Binder Kernels SubExp)
-> ([SubExp] -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> [SubExp]
-> Binder Kernels SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp])
-> [[SubExp]] -> BinderT Kernels (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail
SubExp
intra_avail_par <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"intra_avail_par" (ExpT Kernels -> Binder Kernels SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> Binder Kernels SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> [SubExp]
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
group_size]
(ExpT Kernels -> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
then
BinOp
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
(IntType -> BinOp
SMin IntType
Int64)
(SubExp -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> Binder Kernels SubExp
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"max_group_size" (Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup))
(SubExp
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
intra_avail_par)
else BinOp
-> [SubExp]
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall {m :: * -> *}.
MonadBinder m =>
BinOp -> [SubExp] -> m (Exp (Lore m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min
let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn BodyT SOACS
body
used_inps :: [KernelInput]
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
w_stms
Stms Kernels
read_input_stms <- Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels))
-> Binder Kernels [()]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> Binder Kernels [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
used_inps
SegSpace
space <- [(VName, SubExp)] -> BinderT Kernels (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
(SubExp, SegSpace, Stms Kernels)
-> Binder Kernels (SubExp, SegSpace, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_avail_par, SegSpace
space, Stms Kernels
read_input_stms)
let kbody' :: KernelBody Kernels
kbody' = KernelBody Kernels
kbody {kernelBodyStms :: Stms Kernels
kernelBodyStms = Stms Kernels
read_input_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> KernelBody Kernels -> Stms Kernels
forall lore. KernelBody lore -> Stms lore
kernelBodyStms KernelBody Kernels
kbody}
let nested_pat :: PatternT Type
nested_pat = LoopNesting -> PatternT Type
loopNestingPattern LoopNesting
first_nest
rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
`stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
nested_pat
lvl :: SegLevel
lvl = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegGroup (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count SubExp
num_groups) (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> SubExp -> Count GroupSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
group_size) SegVirt
SegNoVirt
kstm :: Stm Kernels
kstm =
Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern Kernels
nested_pat StmAux ()
StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$
Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody Kernels
kbody'
let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> MaybeT
m ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return
( (SubExp
intra_min_par, SubExp
intra_avail_par),
VName -> SubExp
Var VName
group_size,
Log
log,
Stms Kernels
prelude_stms,
Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
kstm
)
where
first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest
data Acc = Acc
{ Acc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
Acc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
Acc -> Log
accLog :: Log
}
instance Semigroup Acc where
Acc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: Acc -> Acc -> Acc
<> Acc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)
instance Monoid Acc where
mempty :: Acc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty
type IntraGroupM =
BinderT Out.Kernels (RWS () Acc VNameSource)
instance MonadLogger IntraGroupM where
addLog :: Log -> IntraGroupM ()
addLog Log
log = Acc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell Acc
forall a. Monoid a => a
mempty {accLog :: Log
accLog = Log
log}
runIntraGroupM ::
(MonadFreshNames m, HasScope Out.Kernels m) =>
IntraGroupM () ->
m (Acc, Out.Stms Out.Kernels)
runIntraGroupM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
IntraGroupM () -> m (Acc, Stms Kernels)
runIntraGroupM IntraGroupM ()
m = do
Scope Kernels
scope <- Scope Kernels -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope Kernels -> Scope Kernels)
-> m (Scope Kernels) -> m (Scope Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
(VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels))
-> (VNameSource -> ((Acc, Stms Kernels), VNameSource))
-> m (Acc, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let (((), Stms Kernels
kstms), VNameSource
src', Acc
acc) = RWS () Acc VNameSource ((), Stms Kernels)
-> () -> VNameSource -> (((), Stms Kernels), VNameSource, Acc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntraGroupM ()
-> Scope Kernels -> RWS () Acc VNameSource ((), Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT IntraGroupM ()
m Scope Kernels
scope) () VNameSource
src
in ((Acc
acc, Stms Kernels
kstms), VNameSource
src')
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
Acc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
Acc
forall a. Monoid a => a
mempty
{ accMinPar :: Set [SubExp]
accMinPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws,
accAvailPar :: Set [SubExp]
accAvailPar = [SubExp] -> Set [SubExp]
forall a. a -> Set a
S.singleton [SubExp]
ws
}
intraGroupBody :: SegLevel -> Body -> IntraGroupM (Out.Body Out.Kernels)
intraGroupBody :: SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
body = do
Stms Kernels
stms <- IntraGroupM ()
-> BinderT
Kernels
(RWST () Acc VNameSource Identity)
(Stms (Lore IntraGroupM))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (IntraGroupM ()
-> BinderT
Kernels
(RWST () Acc VNameSource Identity)
(Stms (Lore IntraGroupM)))
-> IntraGroupM ()
-> BinderT
Kernels
(RWST () Acc VNameSource Identity)
(Stms (Lore IntraGroupM))
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
Body Kernels -> IntraGroupM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> IntraGroupM (Body Kernels))
-> Body Kernels -> IntraGroupM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [SubExp] -> Body Kernels
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody Stms Kernels
stms ([SubExp] -> Body Kernels) -> [SubExp] -> Body Kernels
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm :: SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl stm :: Stm
stm@(Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
Scope Kernels
scope <- BinderT Kernels (RWST () Acc VNameSource Identity) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let lvl' :: SegLevel
lvl' = Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread (SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl) SegVirt
SegNoVirt
case Exp SOACS
e of
DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loopbody ->
Scope Kernels -> IntraGroupM () -> IntraGroupM ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (LoopForm Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm Kernels
form') (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
Scope Kernels -> IntraGroupM () -> IntraGroupM ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams ([Param DeclType] -> Scope Kernels)
-> [Param DeclType] -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
Body Kernels
loopbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
loopbody
Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
val LoopForm Kernels
form' Body Kernels
loopbody'
where
form' :: LoopForm Kernels
form' = case LoopForm SOACS
form of
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
inps -> VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam Kernels, VName)]
inps
WhileLoop VName
cond -> VName -> LoopForm Kernels
forall lore. VName -> LoopForm lore
WhileLoop VName
cond
If SubExp
cond BodyT SOACS
tbody BodyT SOACS
fbody IfDec (BranchType SOACS)
ifdec -> do
Body Kernels
tbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
tbody
Body Kernels
fbody' <- SegLevel -> BodyT SOACS -> IntraGroupM (Body Kernels)
intraGroupBody SegLevel
lvl BodyT SOACS
fbody
Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body Kernels
-> Body Kernels
-> IfDec (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond Body Kernels
tbody' Body Kernels
fbody' IfDec (BranchType SOACS)
IfDec (BranchType Kernels)
ifdec
Op Op SOACS
soac
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm -> Stm) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm -> Stm
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
(Stms SOACS -> IntraGroupM ())
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Stms SOACS)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS ()
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Pattern (Lore (BinderT SOACS (State VNameSource)))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS (State VNameSource)))
Pattern SOACS
pat Op SOACS
SOAC (Lore (BinderT SOACS (State VNameSource)))
soac)
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form -> do
let loopnest :: LoopNesting
loopnest = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam) [VName]
arrs
env :: DistEnv Kernels IntraGroupM
env =
DistEnv :: forall lore (m :: * -> *).
Nestings
-> Scope lore
-> (Stms SOACS -> DistNestT lore m (Stms lore))
-> (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stm -> Binder lore (Stms lore))
-> (Lambda SOACS -> Binder lore (Lambda lore))
-> MkSegLevel lore m
-> DistEnv lore m
DistEnv
{ distNest :: Nestings
distNest =
Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
distScope :: Scope Kernels
distScope =
PatternT Type -> Scope Kernels
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern PatternT Type
Pattern SOACS
pat
Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope Kernels
scopeForKernels (Lambda SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda SOACS
lam)
Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
scope,
distOnInnerMap :: MapLoop
-> DistAcc Kernels
-> DistNestT Kernels IntraGroupM (DistAcc Kernels)
distOnInnerMap =
MapLoop
-> DistAcc Kernels
-> DistNestT Kernels IntraGroupM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap,
distOnTopLevelStms :: Stms SOACS -> DistNestT Kernels IntraGroupM (Stms Kernels)
distOnTopLevelStms =
BinderT Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> DistNestT Kernels IntraGroupM (Stms Kernels)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (BinderT Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> DistNestT Kernels IntraGroupM (Stms Kernels))
-> (Stms SOACS
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> Stms SOACS
-> DistNestT Kernels IntraGroupM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntraGroupM ()
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (IntraGroupM ()
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> (Stms SOACS -> IntraGroupM ())
-> Stms SOACS
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl,
distSegLevel :: MkSegLevel Kernels IntraGroupM
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
IntraGroupM () -> BinderT Kernels IntraGroupM ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntraGroupM () -> BinderT Kernels IntraGroupM ())
-> IntraGroupM () -> BinderT Kernels IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
SegLevel -> BinderT Kernels IntraGroupM SegLevel
forall (m :: * -> *) a. Monad m => a -> m a
return SegLevel
lvl,
distOnSOACSStms :: Stm -> BinderT Kernels (State VNameSource) (Stms Kernels)
distOnSOACSStms =
Stms Kernels -> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels))
-> (Stm -> Stms Kernels)
-> Stm
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels)
-> (Stm -> Stm Kernels) -> Stm -> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> Stm Kernels
soacsStmToKernels,
distOnSOACSLambda :: Lambda SOACS -> Binder Kernels (Lambda Kernels)
distOnSOACSLambda =
Lambda Kernels -> Binder Kernels (Lambda Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda Kernels -> Binder Kernels (Lambda Kernels))
-> (Lambda SOACS -> Lambda Kernels)
-> Lambda SOACS
-> Binder Kernels (Lambda Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels
}
acc :: DistAcc Kernels
acc =
DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
{ distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
Pattern SOACS
pat, BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT SOACS -> [SubExp]) -> BodyT SOACS -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam),
distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
}
Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms
(Stms Kernels -> IntraGroupM ())
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv Kernels IntraGroupM
-> DistNestT Kernels IntraGroupM (DistAcc Kernels)
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadLogger m, DistLore lore) =>
DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv Kernels IntraGroupM
env (DistAcc Kernels
-> Stms SOACS -> DistNestT Kernels IntraGroupM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc Kernels
acc (BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam))
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form,
Scan Lambda SOACS
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans -> do
let scanfun' :: Lambda Kernels
scanfun' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
scanfun
mapfun' :: Lambda Kernels
mapfun' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
mapfun
Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel Kernels
SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [Commutativity
-> Lambda Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda Kernels
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda Kernels
mapfun' [VName]
arrs [] []
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes <- [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds -> do
let red_lam' :: Lambda Kernels
red_lam' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
red_lam
map_lam' :: Lambda Kernels
map_lam' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
map_lam
Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel Kernels
SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [Commutativity
-> Lambda Kernels -> [SubExp] -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda Kernels
red_lam' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda Kernels
map_lam' [VName]
arrs [] []
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Op (Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
bucket_fun [VName]
arrs) -> do
[HistOp Kernels]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels])
-> (HistOp SOACS
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> BinderT
Kernels (RWST () Acc VNameSource Identity) [HistOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
(Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- Lambda SOACS
-> [SubExp]
-> BinderT
Kernels
(RWST () Acc VNameSource Identity)
(Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
let op'' :: Lambda Kernels
op'' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
op'
HistOp Kernels
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (HistOp Kernels
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels))
-> HistOp Kernels
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (HistOp Kernels)
forall a b. (a -> b) -> a -> b
$ SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda Kernels
-> HistOp Kernels
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
Out.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda Kernels
op''
let bucket_fun' :: Lambda Kernels
bucket_fun' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
bucket_fun
Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> IntraGroupM ())
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel Kernels
SegLevel
lvl' Pattern SOACS
Pattern Kernels
pat SubExp
w [] [] [HistOp Kernels]
ops' Lambda Kernels
bucket_fun' [VName]
arrs
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Op (Stream SubExp
w [VName]
arrs StreamForm SOACS
Sequential [SubExp]
accs Lambda SOACS
lam)
| LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam -> do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS)
-> BinderT Kernels (RWST () Acc VNameSource Identity) (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope
((), Stms SOACS
stream_bnds) <-
BinderT SOACS IntraGroupM ()
-> Scope SOACS
-> BinderT
Kernels (RWST () Acc VNameSource Identity) ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS IntraGroupM))
-> SubExp
-> [SubExp]
-> LambdaT (Lore (BinderT SOACS IntraGroupM))
-> [VName]
-> BinderT SOACS IntraGroupM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> [SubExp] -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS IntraGroupM))
Pattern SOACS
pat SubExp
w [SubExp]
accs LambdaT (Lore (BinderT SOACS IntraGroupM))
Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
replace SubExp
se = SubExp
se
replaceSets :: Acc -> Acc
replaceSets (Acc Set [SubExp]
x Set [SubExp]
y Log
log) =
Set [SubExp] -> Set [SubExp] -> Log -> Acc
Acc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
(Acc -> Acc) -> IntraGroupM () -> IntraGroupM ()
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor Acc -> Acc
replaceSets (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl Stms SOACS
stream_bnds
Op (Scatter SubExp
w Lambda SOACS
lam [VName]
ivs [(Shape, Int, VName)]
dests) -> do
VName
write_i <- String -> BinderT Kernels (RWST () Acc VNameSource Identity) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
SegSpace
space <- [(VName, SubExp)]
-> BinderT Kernels (RWST () Acc VNameSource Identity) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]
let lam' :: Lambda Kernels
lam' = Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels Lambda SOACS
lam
([Shape]
dests_ws, [Int]
_, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
dests
krets :: [KernelResult]
krets = do
(Shape
a_w, VName
a, [([SubExp], SubExp)]
is_vs) <-
[(Shape, Int, VName)]
-> [SubExp] -> [(Shape, VName, [([SubExp], SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests ([SubExp] -> [(Shape, VName, [([SubExp], SubExp)])])
-> [SubExp] -> [(Shape, VName, [([SubExp], SubExp)])]
forall a b. (a -> b) -> a -> b
$ Body Kernels -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (Body Kernels -> [SubExp]) -> Body Kernels -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Shape
a_w VName
a [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is, SubExp
v) | ([SubExp]
is, SubExp
v) <- [([SubExp], SubExp)]
is_vs]
inputs :: [KernelInput]
inputs = do
(Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam') [VName]
ivs
KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
Stms Kernels
kstms <- BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels (RWST () Acc VNameSource Identity) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ())
-> BinderT Kernels (State VNameSource) ()
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
(KernelInput -> BinderT Kernels (State VNameSource) ())
-> [KernelInput] -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ())
-> Stms (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> Stms Kernels) -> Body Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
Certificates -> IntraGroupM () -> IntraGroupM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
let ts :: [Type]
ts = (Shape -> Type -> Type) -> [Shape] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> (Shape -> Int) -> Shape -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
dests_ws ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat
body :: KernelBody Kernels
body = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms [KernelResult]
krets
Pattern (Lore IntraGroupM)
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore IntraGroupM)
Pattern SOACS
pat (Exp (Lore IntraGroupM) -> IntraGroupM ())
-> Exp (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl' SegSpace
space [Type]
ts KernelBody Kernels
body
[SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
Exp SOACS
_ ->
Stm (Lore IntraGroupM) -> IntraGroupM ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm (Stm (Lore IntraGroupM) -> IntraGroupM ())
-> Stm (Lore IntraGroupM) -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stm -> Stm Kernels
soacsStmToKernels Stm
stm
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms :: SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl = (Stm -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SegLevel -> Stm -> IntraGroupM ()
intraGroupStm SegLevel
lvl)
intraGroupParalleliseBody ::
(MonadFreshNames m, HasScope Out.Kernels m) =>
SegLevel ->
Body ->
m ([[SubExp]], [[SubExp]], Log, Out.KernelBody Out.Kernels)
intraGroupParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
SegLevel
-> BodyT SOACS
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
intraGroupParalleliseBody SegLevel
lvl BodyT SOACS
body = do
(Acc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms Kernels
kstms) <-
IntraGroupM () -> m (Acc, Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
IntraGroupM () -> m (Acc, Stms Kernels)
runIntraGroupM (IntraGroupM () -> m (Acc, Stms Kernels))
-> IntraGroupM () -> m (Acc, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel -> Stms SOACS -> IntraGroupM ()
intraGroupStms SegLevel
lvl (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return
( Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
Log
log,
BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult]) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body
)