{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels (extractKernels) where
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.Maybe
import Futhark.IR.GPU
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyStms)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Intragroup
import Futhark.Pass.ExtractKernels.StreamKernel
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Transform.Rename
import Futhark.Util.Log
import Prelude hiding (log)
extractKernels :: Pass SOACS GPU
=
Pass
{ passName :: [Char]
passName = [Char]
"extract kernels",
passDescription :: [Char]
passDescription = [Char]
"Perform kernel extraction",
passFunction :: Prog SOACS -> PassM (Prog GPU)
passFunction = Prog SOACS -> PassM (Prog GPU)
transformProg
}
transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg :: Prog SOACS -> PassM (Prog GPU)
transformProg Prog SOACS
prog = do
Stms GPU
consts' <- forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> Stms rep
progConsts Prog SOACS
prog
[FunDef GPU]
funs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts') forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$
Prog SOACS
prog
{ progConsts :: Stms GPU
progConsts = Stms GPU
consts',
progFuns :: [FunDef GPU]
progFuns = [FunDef GPU]
funs'
}
data State = State
{ State -> VNameSource
stateNameSource :: VNameSource,
State -> Int
stateThresholdCounter :: Int
}
newtype DistribM a = DistribM (RWS (Scope GPU) Log State a)
deriving
( forall a b. a -> DistribM b -> DistribM a
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DistribM b -> DistribM a
$c<$ :: forall a b. a -> DistribM b -> DistribM a
fmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
$cfmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
Functor,
Functor DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. DistribM a -> DistribM b -> DistribM a
$c<* :: forall a b. DistribM a -> DistribM b -> DistribM a
*> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c*> :: forall a b. DistribM a -> DistribM b -> DistribM b
liftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
$c<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
pure :: forall a. a -> DistribM a
$cpure :: forall a. a -> DistribM a
Applicative,
Applicative DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DistribM a
$creturn :: forall a. a -> DistribM a
>> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c>> :: forall a b. DistribM a -> DistribM b -> DistribM b
>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
$c>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
Monad,
HasScope GPU,
LocalScope GPU,
MonadState State,
Monad DistribM
Applicative DistribM
Log -> DistribM ()
forall a. ToLog a => a -> DistribM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> (forall a. ToLog a => a -> m ())
-> (Log -> m ())
-> MonadLogger m
addLog :: Log -> DistribM ()
$caddLog :: Log -> DistribM ()
logMsg :: forall a. ToLog a => a -> DistribM ()
$clogMsg :: forall a. ToLog a => a -> DistribM ()
MonadLogger
)
instance MonadFreshNames DistribM where
getNameSource :: DistribM VNameSource
getNameSource = forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
putNameSource :: VNameSource -> DistribM ()
putNameSource VNameSource
src = forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
runDistribM ::
(MonadLogger m, MonadFreshNames m) =>
DistribM a ->
m a
runDistribM :: forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM RWS (Scope GPU) Log State a
m) = do
(a
x, Log
msgs) <- forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let (a
x, State
s, Log
msgs) = forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope GPU) Log State a
m forall a. Monoid a => a
mempty (VNameSource -> Int -> State
State VNameSource
src Int
0)
in ((a
x, Log
msgs), State -> VNameSource
stateNameSource State
s)
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
msgs
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
transformFunDef ::
(MonadFreshNames m, MonadLogger m) =>
Scope GPU ->
FunDef SOACS ->
m (FunDef GPU)
transformFunDef :: forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope GPU -> FunDef SOACS -> m (FunDef GPU)
transformFunDef Scope GPU
scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM forall a b. (a -> b) -> a -> b
$ do
Body GPU
body' <-
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope GPU
scope forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [FParam SOACS]
params) forall a b. (a -> b) -> a -> b
$
KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody forall a. Monoid a => a
mempty Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body GPU
body'
type GPUStms = Stms GPU
transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody :: KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body = do
Stms GPU
stms <- KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
stms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult Body SOACS
body
transformStms :: KernelPath -> [Stm SOACS] -> DistribM GPUStms
transformStms :: KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
_ [] =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
transformStms KernelPath
path (Stm SOACS
stm : [Stm SOACS]
stms) =
Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm Stm SOACS
stm forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Stms SOACS)
Nothing -> do
Stms GPU
stm' <- KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
path Stm SOACS
stm
forall {k} (rep :: k) a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms GPU
stm' forall a b. (a -> b) -> a -> b
$
(Stms GPU
stm' <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path [Stm SOACS]
stms
Just Stms SOACS
stms' ->
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms' forall a. Semigroup a => a -> a -> a
<> [Stm SOACS]
stms
unbalancedLambda :: Lambda SOACS -> Bool
unbalancedLambda :: Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
orig_lam =
forall {k} {k} {rep :: k} {rep :: k}.
(Op rep ~ SOAC rep) =>
Names -> Body rep -> Bool
unbalancedBody ([VName] -> Names
namesFromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
orig_lam) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
orig_lam
where
subExpBound :: SubExp -> Names -> Bool
subExpBound (Var VName
i) Names
bound = VName
i VName -> Names -> Bool
`nameIn` Names
bound
subExpBound (Constant PrimValue
_) Names
_ = Bool
False
unbalancedBody :: Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
body =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> Exp rep -> Bool
unbalancedStm (Names
bound forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k). Body rep -> Names
boundInBody Body rep
body) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Exp rep
stmExp) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
unbalancedStm :: Names -> Exp rep -> Bool
unbalancedStm Names
bound (Op (Stream SubExp
w [VName]
_ [SubExp]
_ Lambda rep
_)) =
SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
unbalancedStm Names
bound (Op (Screma SubExp
w [VName]
_ ScremaForm rep
_)) =
SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
unbalancedStm Names
_ Op {} =
Bool
False
unbalancedStm Names
_ DoLoop {} = Bool
False
unbalancedStm Names
bound (WithAcc [WithAccInput rep]
_ Lambda rep
lam) =
Names -> Body rep -> Bool
unbalancedBody Names
bound (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
unbalancedStm Names
bound (Match [SubExp]
ses [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_) =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (SubExp -> Names -> Bool
`subExpBound` Names
bound) [SubExp]
ses
Bool -> Bool -> Bool
&& ( forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> Body rep -> Bool
unbalancedBody Names
bound forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases
Bool -> Bool -> Bool
|| Names -> Body rep -> Bool
unbalancedBody Names
bound Body rep
defbody
)
unbalancedStm Names
_ (BasicOp BasicOp
_) =
Bool
False
unbalancedStm Names
_ (Apply Name
fname [(SubExp, Diet)]
_ [RetType rep]
_ (Safety, SrcLoc, [SrcLoc])
_) =
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ Name -> Bool
isBuiltInFunction Name
fname
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op soac :: Op SOACS
soac@(Screma SubExp
_ [VName]
_ ScremaForm SOACS
form)))
| Just ([Reduce SOACS]
_, Lambda SOACS
lam2) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form,
Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
lam2,
Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam2 = do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec SOACS)
pat Op SOACS
soac) Scope SOACS
types
sequentialisedUnbalancedStm Stm SOACS
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
cmpSizeLe ::
String ->
SizeClass ->
[SubExp] ->
DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe :: [Char]
-> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe [Char]
desc SizeClass
size_class [SubExp]
to_what = do
Int
x <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Int
stateThresholdCounter
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateThresholdCounter :: Int
stateThresholdCounter = Int
x forall a. Num a => a -> a -> a
+ Int
1}
let size_key :: Name
size_key = [Char] -> Name
nameFromString forall a b. (a -> b) -> a -> b
$ [Char]
desc forall a. [a] -> [a] -> [a]
++ [Char]
"_" forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Int
x
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
SubExp
to_what' <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"comparatee"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
to_what
SubExp
cmp_res <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
desc forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
size_key SizeClass
size_class SubExp
to_what'
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
cmp_res, Name
size_key)
kernelAlternatives ::
(MonadFreshNames m, HasScope GPU m) =>
Pat Type ->
Body GPU ->
[(SubExp, Body GPU)] ->
m (Stms GPU)
kernelAlternatives :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
default_body [] = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
Result
ses <- forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body GPU
default_body
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [VName]
patNames Pat Type
pat) Result
ses) forall a b. (a -> b) -> a -> b
$ \(VName
name, SubExpRes Certs
cs SubExp
se) ->
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
name] forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
kernelAlternatives Pat Type
pat Body GPU
default_body ((SubExp
cond, Body GPU
alt) : [(SubExp, Body GPU)]
alts) = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
Pat Type
alts_pat <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall dec. [PatElem dec] -> Pat dec
Pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) forall a b. (a -> b) -> a -> b
$ \PatElem Type
pe -> do
VName
name <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
forall (f :: * -> *) a. Applicative f => a -> f a
pure PatElem Type
pe {patElemName :: VName
patElemName = VName
name}
Stms GPU
alt_stms <- forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
alts_pat Body GPU
default_body [(SubExp, Body GPU)]
alts
let alt_body :: Body GPU
alt_body = forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
alt_stms forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
alts_pat
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
cond] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body GPU
alt] Body GPU
alt_body forall a b. (a -> b) -> a -> b
$
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec (forall u.
[TypeBase (ShapeBase SubExp) u] -> [TypeBase (ShapeBase ExtSize) u]
staticShapes (forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
pat)) MatchSort
MatchEquiv
transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda :: KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda KernelPath
path (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ret) =
forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k} (rep :: k) dec.
(LParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) (KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret
transformStm :: KernelPath -> Stm SOACS -> DistribM GPUStms
transformStm :: KernelPath -> Stm SOACS -> DistribM (Stms GPU)
transformStm KernelPath
_ Stm SOACS
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs (forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm SOACS
stm) =
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac))
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux =
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify (forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec SOACS)
aux))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec SOACS)
pat Op SOACS
soac)
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
c [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
rt)) = do
[Case (Body GPU)]
cases' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse forall a b. (a -> b) -> a -> b
$ KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path) [Case (Body SOACS)]
cases
Body GPU
defbody' <- KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
defbody
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
c [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType SOACS)
rt
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc (forall a b. (a -> b) -> [a] -> [b]
map forall {f :: * -> *} {p :: * -> * -> *} {a} {b} {c}.
(Functor f, Bifunctor p) =>
(a, b, f (p (Lambda SOACS) c)) -> (a, b, f (p (Lambda GPU) c))
transformInput [WithAccInput SOACS]
inputs) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Lambda SOACS -> DistribM (Lambda GPU)
transformLambda KernelPath
path Lambda SOACS
lam)
where
transformInput :: (a, b, f (p (Lambda SOACS) c)) -> (a, b, f (p (Lambda GPU) c))
transformInput (a
shape, b
arrs, f (p (Lambda SOACS) c)
op) =
(a
shape, b
arrs, forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Lambda SOACS -> Lambda GPU
soacsLambdaToGPU) f (p (Lambda SOACS) c)
op)
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)) =
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall {k1} {k2} (fromrep :: k1) (torep :: k2).
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf LoopForm SOACS
form) forall a. Semigroup a => a -> a -> a
<> forall {k} (rep :: k) dec.
(FParamInfo rep ~ dec) =>
[Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
params) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k).
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
merge LoopForm GPU
form' forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Body SOACS -> DistribM (Body GPU)
transformBody KernelPath
path Body SOACS
body
where
params :: [Param DeclType]
params = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
merge
form' :: LoopForm GPU
form' = case LoopForm SOACS
form of
WhileLoop VName
cond ->
forall {k} (rep :: k). VName -> LoopForm rep
WhileLoop VName
cond
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
ps ->
forall {k} (rep :: k).
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
ps
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just Lambda SOACS
lam <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form =
KernelPath -> MapLoop -> DistribM (Stms GPU)
onMap KernelPath
path forall a b. (a -> b) -> a -> b
$ Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux SubExp
w Lambda SOACS
lam [VName]
arrs
transformStm KernelPath
path (Let Pat (LetDec SOACS)
res_pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just [Scan SOACS]
scans <- forall {k} (rep :: k). ScremaForm rep -> Maybe [Scan rep]
isScanSOAC ScremaForm SOACS
form,
Scan Lambda SOACS
scan_lam [SubExp]
nes <- forall {k} (rep :: k). Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans,
Just BuilderT SOACS DistribM ()
do_iswim <- forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp -> Lambda SOACS -> [(SubExp, VName)] -> Maybe (m ())
iswim Pat (LetDec SOACS)
res_pat SubExp
w Lambda SOACS
scan_lam forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs BuilderT SOACS DistribM ()
do_iswim) Scope SOACS
types
| Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
[SegBinOp GPU]
scan_ops <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
nes) -> do
(Lambda SOACS
scan_lam', [SubExp]
nes', ShapeBase SubExp
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS
-> [SubExp] -> m (Lambda SOACS, [SubExp], ShapeBase SubExp)
determineReduceOp Lambda SOACS
scan_lam [SubExp]
nes
let scan_lam'' :: Lambda GPU
scan_lam'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scan_lam'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scan_lam'' [SubExp]
nes' ShapeBase SubExp
shape
let map_lam_sequential :: Lambda GPU
map_lam_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
SegLevel
lvl <- forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"segscan" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegLevel
lvl Pat (LetDec SOACS)
res_pat forall a. Monoid a => a
mempty SubExp
w [SegBinOp GPU]
scan_ops Lambda GPU
map_lam_sequential [VName]
arrs [] []
transformStm KernelPath
path (Let Pat (LetDec SOACS)
res_pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just [Reduce Commutativity
comm Lambda SOACS
red_fun [SubExp]
nes] <- forall {k} (rep :: k). ScremaForm rep -> Maybe [Reduce rep]
isReduceSOAC ScremaForm SOACS
form,
let comm' :: Commutativity
comm'
| forall {k} (rep :: k). Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_fun = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm,
Just BuilderT SOACS DistribM ()
do_irwim <- forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pat (LetDec SOACS)
res_pat SubExp
w Commutativity
comm' Lambda SOACS
red_fun forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
nes [VName]
arrs = do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
Stms SOACS
stms <- forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec SOACS)
aux BuilderT SOACS DistribM ()
do_irwim)) Scope SOACS
types
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
let paralleliseOuter :: DistribM (Stms GPU)
paralleliseOuter = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
[SegBinOp GPU]
red_ops <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes) -> do
(Lambda SOACS
red_lam', [SubExp]
nes', ShapeBase SubExp
shape) <- forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS
-> [SubExp] -> m (Lambda SOACS, [SubExp], ShapeBase SubExp)
determineReduceOp Lambda SOACS
red_lam [SubExp]
nes
let comm' :: Commutativity
comm'
| forall {k} (rep :: k). Lambda rep -> Bool
commutativeLambda Lambda SOACS
red_lam' = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
red_lam'' :: Lambda GPU
red_lam'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam'
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda GPU
red_lam'' [SubExp]
nes' ShapeBase SubExp
shape
let map_lam_sequential :: Lambda GPU
map_lam_sequential = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
SegLevel
lvl <- forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"segred" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
nonSegRed SegLevel
lvl Pat (LetDec SOACS)
pat SubExp
w [SegBinOp GPU]
red_ops Lambda GPU
map_lam_sequential [VName]
arrs
outerParallelBody :: DistribM (Body GPU)
outerParallelBody =
forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM (Stms GPU)
paralleliseOuter forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat)))
paralleliseInner :: KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' = do
(Stm SOACS
mapstm, Stm SOACS
redstm) <-
forall {k} (m :: * -> *) (rep :: k).
(MonadFreshNames m, Buildable rep, ExpDec rep ~ (),
Op rep ~ SOAC rep) =>
Pat (LetDec rep)
-> (SubExp, [Reduce rep], Lambda rep, [VName])
-> m (Stm rep, Stm rep)
redomapToMapAndReduce Pat (LetDec SOACS)
pat (SubExp
w, [Reduce SOACS]
reds, Lambda SOACS
map_lam, [VName]
arrs)
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (forall {k} (m :: * -> *) (rep :: k).
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
`runBuilderT_` Scope SOACS
types) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (Stms SOACS)
simplifyStms (forall {k} (rep :: k). [Stm rep] -> Stms rep
stmsFromList [forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
mapstm, forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs Stm SOACS
redstm])
innerParallelBody :: KernelPath -> DistribM (Body GPU)
innerParallelBody KernelPath
path' =
forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
paralleliseInner KernelPath
path' forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> Result
varsRes (forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat)))
if Bool -> Bool
not (Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam)
Bool -> Bool -> Bool
|| Attr
"sequential_inner"
Attr -> Attrs -> Bool
`inAttrs` forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux (ExpDec SOACS)
aux
then DistribM (Stms GPU)
paralleliseOuter
else do
((SubExp
outer_suff, Name
outer_suff_key), Stms GPU
suff_stms) <-
[Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_redomap" [SubExp
w] KernelPath
path forall a. Maybe a
Nothing
Body GPU
outer_stms <- DistribM (Body GPU)
outerParallelBody
Body GPU
inner_stms <- KernelPath -> DistribM (Body GPU)
innerParallelBody ((Name
outer_suff_key, Bool
False) forall a. a -> [a] -> [a]
: KernelPath
path)
(Stms GPU
suff_stms <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat (LetDec SOACS)
pat Body GPU
inner_stms [(SubExp
outer_suff, Body GPU
outer_stms)]
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) = do
Scope SOACS
scope <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k). Certs -> Stm rep -> Stm rep
certify Certs
cs) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (LetDec SOACS)
pat SubExp
w ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformStm KernelPath
path (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Op (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda SOACS
fold_fun))) = do
Scope SOACS
types <- forall {k} (rep :: k) (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (m :: * -> *) (rep :: k) a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (LetDec SOACS)
pat SubExp
w [SubExp]
nes Lambda SOACS
fold_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
_ (Let Pat (LetDec SOACS)
pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(ShapeBase SubExp, Int, VName)]
as))) = forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
VName
write_i <- forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_i"
let ([ShapeBase SubExp]
as_ws, [Int]
_, [VName]
_) = forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(ShapeBase SubExp, Int, VName)]
as
kstms :: Stms GPU
kstms = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
krets :: [KernelResult]
krets = do
(ShapeBase SubExp
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <- forall array a.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, array, [([a], a)])]
groupScatterResults [(ShapeBase SubExp, Int, VName)]
as forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
let res_cs :: Certs
res_cs =
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
forall a. Semigroup a => a -> a -> a
<> forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [(forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall d. d -> DimIndex d
DimFix forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns Certs
res_cs ShapeBase SubExp
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
body :: KernelBody GPU
body = forall {k} (rep :: k).
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
inputs :: [KernelInput]
inputs = do
(Param Type
p, VName
p_a) <- forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (forall dec. Param dec -> VName
paramName Param Type
p) (forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
(SegOp SegLevel GPU
kernel, Stms GPU
stms) <-
forall rep (m :: * -> *).
(DistRep rep, HasScope rep m, MonadFreshNames m) =>
MkSegLevel rep m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody rep
-> m (SegOp (SegOpLevel rep) rep, Stms rep)
mapKernel
forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped
[(VName
write_i, SubExp
w)]
[KernelInput]
inputs
(forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => t a -> Int
length) [ShapeBase SubExp]
as_ws forall a b. (a -> b) -> a -> b
$ forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec SOACS)
pat)
KernelBody GPU
body
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
stms
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec SOACS)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SegOp SegLevel rep -> HostOp rep op
SegOp SegOp SegLevel GPU
kernel
transformStm KernelPath
_ (Let Pat (LetDec SOACS)
orig_pat (StmAux Certs
cs Attrs
_ ExpDec SOACS
_) (Op (Hist SubExp
w [VName]
imgs [HistOp SOACS]
ops Lambda SOACS
bucket_fun))) = do
let bfun' :: Lambda GPU
bfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ do
SegLevel
lvl <- forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped [SubExp
w] [Char]
"seghist" forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, DistRep (Rep m)) =>
(Lambda SOACS -> m (Lambda (Rep m)))
-> SegOpLevel (Rep m)
-> Pat Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certs
-> SubExp
-> [HistOp SOACS]
-> Lambda (Rep m)
-> [VName]
-> m (Stms (Rep m))
histKernel Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
onLambda SegLevel
lvl Pat (LetDec SOACS)
orig_pat [] [] Certs
cs SubExp
w [HistOp SOACS]
ops Lambda GPU
bfun' [VName]
imgs
where
onLambda :: Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
onLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
transformStm KernelPath
_ Stm SOACS
stm =
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(Transformer m, LetDec (Rep m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
sufficientParallelism ::
String ->
[SubExp] ->
KernelPath ->
Maybe Int64 ->
DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism :: [Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
desc [SubExp]
ws KernelPath
path Maybe Int64
def =
[Char]
-> SizeClass -> [SubExp] -> DistribM ((SubExp, Name), Stms GPU)
cmpSizeLe [Char]
desc (KernelPath -> Maybe Int64 -> SizeClass
SizeThreshold KernelPath
path Maybe Int64
def) [SubExp]
ws
worthIntraGroup :: Lambda SOACS -> Bool
worthIntraGroup :: Lambda SOACS -> Bool
worthIntraGroup Lambda SOACS
lam = forall {k} {rep :: k}. (Op rep ~ SOAC rep) => Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam) forall a. Ord a => a -> a -> Bool
> Int
1
where
bodyInterest :: Body rep -> Int
bodyInterest Body rep
body =
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
interest :: Stm rep -> Int
interest Stm rep
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
Int
0 :: Int
| Op (Screma SubExp
w [VName]
_ ScremaForm rep
form) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm,
Just Lambda rep
lam' <- forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
| Op (Scatter SubExp
w [VName]
_ Lambda rep
lam' [(ShapeBase SubExp, Int, VName)]
_) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam'
| DoLoop [(FParam rep, SubExp)]
_ LoopForm rep
_ Body rep
body <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest Body rep
body forall a. Num a => a -> a -> a
* Int
10
| Match [SubExp]
_ [Case (Body rep)]
cases Body rep
defbody MatchDec (BranchType rep)
_ <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
forall a. Ord a => a -> a -> a
max
(Body rep -> Int
bodyInterest Body rep
defbody)
(forall a b. (a -> b) -> [a] -> [b]
map (Body rep -> Int
bodyInterest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body rep)]
cases)
| Op (Screma SubExp
w [VName]
_ (ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam')) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
forall {a}. Num a => SubExp -> a
zeroIfTooSmall SubExp
w forall a. Num a => a -> a -> a
+ Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
| Op (Stream SubExp
_ [VName]
_ [SubExp]
_ Lambda rep
lam') <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam'
| Bool
otherwise =
Int
0
where
attrs :: Attrs
attrs = forall dec. StmAux dec -> Attrs
stmAuxAttrs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
zeroIfTooSmall :: SubExp -> a
zeroIfTooSmall (Constant (IntValue IntValue
x))
| IntValue -> Int64
intToInt64 IntValue
x forall a. Ord a => a -> a -> Bool
< Int64
32 = a
0
zeroIfTooSmall SubExp
_ = a
1
mapLike :: SubExp -> Lambda rep -> Int
mapLike SubExp
w Lambda rep
lam' =
if Bool
sequential_inner
then Int
0
else forall a. Ord a => a -> a -> a
max (forall {a}. Num a => SubExp -> a
zeroIfTooSmall SubExp
w) (Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam'))
worthSequentialising :: Lambda SOACS -> Bool
worthSequentialising :: Lambda SOACS -> Bool
worthSequentialising Lambda SOACS
lam = forall {k} {rep :: k}. (Op rep ~ SOAC rep) => Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam) forall a. Ord a => a -> a -> Bool
> Int
1
where
bodyInterest :: Body rep -> Int
bodyInterest Body rep
body =
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ Stm rep -> Int
interest forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body rep
body
interest :: Stm rep -> Int
interest Stm rep
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
Int
0 :: Int
| Op (Screma SubExp
_ [VName]
_ form :: ScremaForm rep
form@(ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam')) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm,
forall a. Maybe a -> Bool
isJust forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm rep
form =
if Bool
sequential_inner
then Int
0
else Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
| Op Scatter {} <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
Int
0
| DoLoop [(FParam rep, SubExp)]
_ ForLoop {} Body rep
body <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest Body rep
body forall a. Num a => a -> a -> a
* Int
10
| WithAcc [WithAccInput rep]
_ Lambda rep
withacc_lam <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
withacc_lam)
| Op (Screma SubExp
_ [VName]
_ form :: ScremaForm rep
form@(ScremaForm [Scan rep]
_ [Reduce rep]
_ Lambda rep
lam')) <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm =
Int
1
forall a. Num a => a -> a -> a
+ Body rep -> Int
bodyInterest (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam')
forall a. Num a => a -> a -> a
+
case forall {k} (rep :: k).
ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm rep
form of
Just ([Reduce rep], Lambda rep)
_ -> Int
1
Maybe ([Reduce rep], Lambda rep)
Nothing -> Int
0
| Bool
otherwise =
Int
0
where
attrs :: Attrs
attrs = forall dec. StmAux dec -> Attrs
stmAuxAttrs forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> StmAux (ExpDec rep)
stmAux Stm rep
stm
sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
onTopLevelStms ::
KernelPath ->
Stms SOACS ->
DistNestT GPU DistribM GPUStms
onTopLevelStms :: KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path Stms SOACS
stms =
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms GPU)
transformStms KernelPath
path forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms
onMap :: KernelPath -> MapLoop -> DistribM GPUStms
onMap :: KernelPath -> MapLoop -> DistribM (Stms GPU)
onMap KernelPath
path (MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) = do
Scope GPU
types <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat StmAux ()
aux SubExp
w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
env :: KernelPath -> DistEnv GPU DistribM
env KernelPath
path' =
DistEnv
{ distNest :: Nestings
distNest = Nesting -> Nestings
singleNesting (Names -> LoopNesting -> Nesting
Nesting forall a. Monoid a => a
mempty LoopNesting
loopnest),
distScope :: Scope GPU
distScope =
forall {k} (rep :: k) dec.
(LetDec rep ~ dec) =>
Pat dec -> Scope rep
scopeOfPat Pat Type
pat
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
forall a. Semigroup a => a -> a -> a
<> Scope GPU
types,
distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path',
distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path',
distSegLevel :: MkSegLevel GPU DistribM
distSegLevel = forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped,
distOnSOACSStms :: Stm SOACS -> Builder GPU (Stms GPU)
distOnSOACSStms = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Stms rep
oneStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
distOnSOACSLambda :: Lambda SOACS -> BuilderT GPU (State VNameSource) (Lambda GPU)
distOnSOACSLambda = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
}
exploitInnerParallelism :: KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism KernelPath
path' =
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)
let exploitOuterParallelism :: KernelPath -> DistribM (Stms GPU)
exploitOuterParallelism KernelPath
path' = do
let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT (KernelPath -> DistEnv GPU DistribM
env KernelPath
path') forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> DistNestT rep m (DistAcc rep)
distribute forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Stms rep -> DistAcc rep -> DistAcc rep
addStmsToAcc (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lam') DistAcc GPU
acc
KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pat Type
-> Lambda SOACS
-> DistribM (Stms GPU)
onMap' (LoopNesting -> KernelNest
newKernel LoopNesting
loopnest) KernelPath
path KernelPath -> DistribM (Stms GPU)
exploitOuterParallelism KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism Pat Type
pat Lambda SOACS
lam
where
acc :: DistAcc GPU
acc =
DistAcc
{ distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
pat, forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
distStms :: Stms GPU
distStms = forall a. Monoid a => a
mempty
}
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra Attrs
attrs =
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
mayExploitOuter :: Attrs -> Bool
mayExploitOuter :: Attrs -> Bool
mayExploitOuter Attrs
attrs =
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_outer"]
Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"]
Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
mayExploitIntra :: Attrs -> Bool
mayExploitIntra :: Attrs -> Bool
mayExploitIntra Attrs
attrs =
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_intra"]
Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"]
Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
intraMinInnerPar :: Int64
intraMinInnerPar :: Int64
intraMinInnerPar = Int64
32
onMap' ::
KernelNest ->
KernelPath ->
(KernelPath -> DistribM (Stms GPU)) ->
(KernelPath -> DistribM (Stms GPU)) ->
Pat Type ->
Lambda SOACS ->
DistribM (Stms GPU)
onMap' :: KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pat Type
-> Lambda SOACS
-> DistribM (Stms GPU)
onMap' KernelNest
loopnest KernelPath
path KernelPath -> DistribM (Stms GPU)
mk_seq_stms KernelPath -> DistribM (Stms GPU)
mk_par_stms Pat Type
pat Lambda SOACS
lam = do
Scope GPU
types <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
m (Scope rep)
askScope
let only_intra :: Bool
only_intra = Attrs -> Bool
onlyExploitIntra (forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux)
may_intra :: Bool
may_intra = Lambda SOACS -> Bool
worthIntraGroup Lambda SOACS
lam Bool -> Bool -> Bool
&& Attrs -> Bool
mayExploitIntra Attrs
attrs
Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra <-
if Bool
only_intra Bool -> Bool -> Bool
|| Bool
may_intra
then forall a b c. (a -> b -> c) -> b -> a -> c
flip forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope GPU
types forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
loopnest Lambda SOACS
lam
else forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
case Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra of
Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
_ | Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs -> do
Body GPU
seq_body <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_seq_stms KernelPath
path forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
seq_body []
Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
Nothing
| Bool -> Bool
not Bool
only_intra,
Just DistribM (SubExp, Name, Stms GPU, Body GPU)
m <- Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts -> do
(SubExp
outer_suff, Name
outer_suff_key, Stms GPU
outer_suff_stms, Body GPU
seq_body) <- DistribM (SubExp, Name, Stms GPU, Body GPU)
m
Body GPU
par_body <-
forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms ((Name
outer_suff_key, Bool
False) forall a. a -> [a] -> [a]
: KernelPath
path)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(Stms GPU
outer_suff_stms <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
par_body [(SubExp
outer_suff, Body GPU
seq_body)]
| Bool
otherwise -> do
Body GPU
par_body <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms KernelPath
path forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
par_body []
Just intra' :: ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'@((SubExp, SubExp)
_, SubExp
_, Log
log, Stms GPU
intra_prelude, Stms GPU
intra_stms)
| Bool
only_intra -> do
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
Body GPU
group_par_body <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
intra_stms Result
res
(Stms GPU
intra_prelude <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
group_par_body []
| Bool
otherwise -> do
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
case Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts of
Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
Nothing -> do
(Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms GPU
intra_suff_stms) <-
KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar KernelPath
path ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'
Body GPU
par_body <-
forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms ((Name
intra_suff_key, Bool
False) forall a. a -> [a] -> [a]
: KernelPath
path)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(Stms GPU
intra_suff_stms <>)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives Pat Type
pat Body GPU
par_body [(SubExp
intra_ok, Body GPU
group_par_body)]
Just DistribM (SubExp, Name, Stms GPU, Body GPU)
m -> do
(SubExp
outer_suff, Name
outer_suff_key, Stms GPU
outer_suff_stms, Body GPU
seq_body) <- DistribM (SubExp, Name, Stms GPU, Body GPU)
m
(Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms GPU
intra_suff_stms) <-
KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar ((Name
outer_suff_key, Bool
False) forall a. a -> [a] -> [a]
: KernelPath
path) ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
intra'
Body GPU
par_body <-
forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_par_stms
( [ (Name
outer_suff_key, Bool
False),
(Name
intra_suff_key, Bool
False)
]
forall a. [a] -> [a] -> [a]
++ KernelPath
path
)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
((Stms GPU
outer_suff_stms forall a. Semigroup a => a -> a -> a
<> Stms GPU
intra_suff_stms) <>)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Pat Type -> Body GPU -> [(SubExp, Body GPU)] -> m (Stms GPU)
kernelAlternatives
Pat Type
pat
Body GPU
par_body
[(SubExp
outer_suff, Body GPU
seq_body), (SubExp
intra_ok, Body GPU
group_par_body)]
where
nest_ws :: [SubExp]
nest_ws = KernelNest -> [SubExp]
kernelNestWidths KernelNest
loopnest
res :: Result
res = [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat Type
pat
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
innermostKernelNesting KernelNest
loopnest
attrs :: Attrs
attrs = forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
mkSeqAlts :: Maybe (DistribM (SubExp, Name, Stms GPU, Body GPU))
mkSeqAlts
| Lambda SOACS -> Bool
worthSequentialising Lambda SOACS
lam,
Attrs -> Bool
mayExploitOuter Attrs
attrs = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
((SubExp
outer_suff, Name
outer_suff_key), Stms GPU
outer_suff_stms) <- DistribM ((SubExp, Name), Stms GPU)
checkSuffOuterPar
Body GPU
seq_body <-
forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms GPU)
mk_seq_stms ((Name
outer_suff_key, Bool
True) forall a. a -> [a] -> [a]
: KernelPath
path)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
outer_suff, Name
outer_suff_key, Stms GPU
outer_suff_stms, Body GPU
seq_body)
| Bool
otherwise =
forall a. Maybe a
Nothing
checkSuffOuterPar :: DistribM ((SubExp, Name), Stms GPU)
checkSuffOuterPar =
[Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism [Char]
"suff_outer_par" [SubExp]
nest_ws KernelPath
path forall a. Maybe a
Nothing
checkSuffIntraPar :: KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> DistribM (Body GPU, SubExp, Name, Stms GPU)
checkSuffIntraPar
KernelPath
path'
((SubExp
_intra_min_par, SubExp
intra_avail_par), SubExp
group_size, Log
_, Stms GPU
intra_prelude, Stms GPU
intra_stms) = do
((SubExp
intra_ok, Name
intra_suff_key), Stms GPU
intra_suff_stms) <- do
((SubExp
intra_suff, Name
suff_key), Stms GPU
check_suff_stms) <-
[Char]
-> [SubExp]
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms GPU)
sufficientParallelism
[Char]
"suff_intra_par"
[SubExp
intra_avail_par]
KernelPath
path'
(forall a. a -> Maybe a
Just Int64
intraMinInnerPar)
forall {k1} {k2} (m :: * -> *) (somerep :: k1) (rep :: k2) a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
intra_prelude
SubExp
max_group_size <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"max_group_size" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k) op. SizeOp -> HostOp rep op
SizeOp forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
SizeGroup
SubExp
fits <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"fits" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
group_size SubExp
max_group_size
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms GPU
check_suff_stms
SubExp
intra_ok <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"intra_suff_and_fits" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
fits SubExp
intra_suff
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
intra_ok, Name
suff_key)
Body GPU
group_par_body <- forall {k} (rep :: k) (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
Buildable rep =>
Stms rep -> Result -> Body rep
mkBody Stms GPU
intra_stms Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms GPU
intra_suff_stms)
removeUnusedMapResults ::
Pat Type ->
[SubExpRes] ->
Lambda rep ->
Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults :: forall {k} (rep :: k).
Pat Type
-> Result -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults (Pat [PatElem Type]
pes) Result
res Lambda rep
lam = do
let ([PatElem Type]
pes', Result
body_res) =
forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a. (a -> Bool) -> [a] -> [a]
filter (PatElem Type -> Bool
used forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Result
bodyResult (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
[Int]
perm <- forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName) [PatElem Type]
pes' forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Int]
perm, forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
pes', Lambda rep
lam {lambdaBody :: Body rep
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda rep
lam) {bodyResult :: Result
bodyResult = Result
body_res}})
where
used :: PatElem Type -> Bool
used PatElem Type
pe = forall dec. PatElem dec -> VName
patElemName PatElem Type
pe VName -> Names -> Bool
`nameIn` forall a. FreeIn a => a -> Names
freeIn Result
res
onInnerMap ::
KernelPath ->
MapLoop ->
DistAcc GPU ->
DistNestT GPU DistribM (DistAcc GPU)
onInnerMap :: KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path maploop :: MapLoop
maploop@(MapLoop Pat Type
pat StmAux ()
aux SubExp
w Lambda SOACS
lam [VName]
arrs) DistAcc GPU
acc
| Lambda SOACS -> Bool
unbalancedLambda Lambda SOACS
lam,
Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam =
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep) =>
Stm SOACS -> DistAcc rep -> DistNestT rep m (DistAcc rep)
addStmToAcc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistAcc GPU
acc
| Bool
otherwise =
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep
-> Stm SOACS
-> DistNestT
rep m (Maybe (PostStms rep, Result, KernelNest, DistAcc rep))
distributeSingleStm DistAcc GPU
acc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms GPU
post_kernels, Result
res, KernelNest
nest, DistAcc GPU
acc')
| Just ([Int]
perm, Pat Type
pat', Lambda SOACS
lam') <- forall {k} (rep :: k).
Pat Type
-> Result -> Lambda rep -> Maybe ([Int], Pat Type, Lambda rep)
removeUnusedMapResults Pat Type
pat Result
res Lambda SOACS
lam -> do
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
PostStms rep -> DistNestT rep m ()
addPostStms PostStms GPU
post_kernels
[Int]
-> KernelNest
-> DistAcc GPU
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc GPU)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc' Pat Type
pat' Lambda SOACS
lam'
Maybe (PostStms GPU, Result, KernelNest, DistAcc GPU)
_ -> forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop DistAcc GPU
acc
where
discardTargets :: DistAcc rep -> DistAcc rep
discardTargets DistAcc rep
acc' =
DistAcc rep
acc' {distTargets :: Targets
distTargets = Target -> Targets
singleTarget (forall a. Monoid a => a
mempty, forall a. Monoid a => a
mempty)}
multiVersion :: [Int]
-> KernelNest
-> DistAcc GPU
-> Pat Type
-> Lambda SOACS
-> DistNestT GPU DistribM (DistAcc GPU)
multiVersion [Int]
perm KernelNest
nest DistAcc GPU
acc' Pat Type
pat' Lambda SOACS
lam' = do
DistEnv GPU DistribM
dist_env <- forall r (m :: * -> *). MonadReader r m => m r
ask
let extra_scope :: Scope GPU
extra_scope = forall rep. DistRep rep => Targets -> Scope rep
targetsScope forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). DistAcc rep -> Targets
distTargets DistAcc GPU
acc'
Stms GPU
stms <- forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope forall a b. (a -> b) -> a -> b
$ do
let maploop' :: MapLoop
maploop' = Pat Type
-> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> MapLoop
MapLoop Pat Type
pat' StmAux ()
aux SubExp
w Lambda SOACS
lam' [VName]
arrs
exploitInnerParallelism :: KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism KernelPath
path' = do
let dist_env' :: DistEnv GPU DistribM
dist_env' =
DistEnv GPU DistribM
dist_env
{ distOnTopLevelStms :: Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
distOnTopLevelStms = KernelPath -> Stms SOACS -> DistNestT GPU DistribM (Stms GPU)
onTopLevelStms KernelPath
path',
distOnInnerMap :: MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
distOnInnerMap = KernelPath
-> MapLoop -> DistAcc GPU -> DistNestT GPU DistribM (DistAcc GPU)
onInnerMap KernelPath
path'
}
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU DistribM
dist_env' forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) rep a.
(Monad m, DistRep rep) =>
KernelNest -> DistNestT rep m a -> DistNestT rep m a
inNesting KernelNest
nest forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope forall a b. (a -> b) -> a -> b
$
forall {k} {rep :: k}. DistAcc rep -> DistAcc rep
discardTargets
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap MapLoop
maploop' DistAcc GPU
acc {distStms :: Stms GPU
distStms = forall a. Monoid a => a
mempty}
let lam_res' :: Result
lam_res' =
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam'
lam'' :: Lambda SOACS
lam'' = Lambda SOACS
lam' {lambdaBody :: Body SOACS
lambdaBody = (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam') {bodyResult :: Result
bodyResult = Result
lam_res'}}
map_nesting :: LoopNesting
map_nesting = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
pat' StmAux ()
aux SubExp
w forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam') [VName]
arrs
nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (Pat Type
pat', Result
lam_res') LoopNesting
map_nesting KernelNest
nest
(Stm GPU
sequentialised_kernel, Stms GPU
nestw_stms) <- forall {k} (rep :: k) (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope Scope GPU
extra_scope forall a b. (a -> b) -> a -> b
$ do
let sequentialised_lam :: Lambda GPU
sequentialised_lam = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam''
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, LocalScope rep m) =>
MkSegLevel rep m -> KernelNest -> Body rep -> m (Stm rep, Stms rep)
constructKernel forall (m :: * -> *). MonadFreshNames m => MkSegLevel GPU m
segThreadCapped KernelNest
nest' forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
sequentialised_lam
let outer_pat :: Pat Type
outer_pat = LoopNesting -> Pat Type
loopNestingPat forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst KernelNest
nest
(Stms GPU
nestw_stms <>)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms GPU))
-> (KernelPath -> DistribM (Stms GPU))
-> Pat Type
-> Lambda SOACS
-> DistribM (Stms GPU)
onMap'
KernelNest
nest'
KernelPath
path
(forall a b. a -> b -> a
const forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Stm rep -> Stms rep
oneStm Stm GPU
sequentialised_kernel)
KernelPath -> DistribM (Stms GPU)
exploitInnerParallelism
Pat Type
outer_pat
Lambda SOACS
lam''
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Stms rep -> DistNestT rep m ()
postStm Stms GPU
stms
forall (f :: * -> *) a. Applicative f => a -> f a
pure DistAcc GPU
acc'