{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels (extractKernels) where
import Control.Monad.Identity
import Control.Monad.RWS.Strict
import Control.Monad.Reader
import Data.Bifunctor (first)
import Data.Maybe
import qualified Futhark.IR.Kernels as Out
import Futhark.IR.Kernels.Kernel
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyStms)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ISRWIM
import Futhark.Pass.ExtractKernels.Intragroup
import Futhark.Pass.ExtractKernels.StreamKernel
import Futhark.Pass.ExtractKernels.ToKernels
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename
import Futhark.Util.Log
import Prelude hiding (log)
extractKernels :: Pass SOACS Out.Kernels
=
Pass :: forall fromlore tolore.
[Char]
-> [Char]
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
{ passName :: [Char]
passName = [Char]
"extract kernels",
passDescription :: [Char]
passDescription = [Char]
"Perform kernel extraction",
passFunction :: Prog SOACS -> PassM (Prog Kernels)
passFunction = Prog SOACS -> PassM (Prog Kernels)
transformProg
}
transformProg :: Prog SOACS -> PassM (Prog Out.Kernels)
transformProg :: Prog SOACS -> PassM (Prog Kernels)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) = do
Stms Kernels
consts' <- DistribM (Stms Kernels) -> PassM (Stms Kernels)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (Stms Kernels) -> PassM (Stms Kernels))
-> DistribM (Stms Kernels) -> PassM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
forall a. Monoid a => a
mempty ([Stm SOACS] -> DistribM (Stms Kernels))
-> [Stm SOACS] -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
consts
[FunDef Kernels]
funs' <- (FunDef SOACS -> PassM (FunDef Kernels))
-> [FunDef SOACS] -> PassM [FunDef Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Scope Kernels -> FunDef SOACS -> PassM (FunDef Kernels)
forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope Kernels -> FunDef SOACS -> m (FunDef Kernels)
transformFunDef (Scope Kernels -> FunDef SOACS -> PassM (FunDef Kernels))
-> Scope Kernels -> FunDef SOACS -> PassM (FunDef Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
consts') [FunDef SOACS]
funs
Prog Kernels -> PassM (Prog Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog Kernels -> PassM (Prog Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> [FunDef Kernels] -> Prog Kernels
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms Kernels
consts' [FunDef Kernels]
funs'
data State = State
{ State -> VNameSource
stateNameSource :: VNameSource,
State -> Int
stateThresholdCounter :: Int
}
newtype DistribM a = DistribM (RWS (Scope Out.Kernels) Log State a)
deriving
( (forall a b. (a -> b) -> DistribM a -> DistribM b)
-> (forall a b. a -> DistribM b -> DistribM a) -> Functor DistribM
forall a b. a -> DistribM b -> DistribM a
forall a b. (a -> b) -> DistribM a -> DistribM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> DistribM b -> DistribM a
$c<$ :: forall a b. a -> DistribM b -> DistribM a
fmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
$cfmap :: forall a b. (a -> b) -> DistribM a -> DistribM b
Functor,
Functor DistribM
Functor DistribM
-> (forall a. a -> DistribM a)
-> (forall a b. DistribM (a -> b) -> DistribM a -> DistribM b)
-> (forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM a)
-> Applicative DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. DistribM a -> DistribM b -> DistribM a
$c<* :: forall a b. DistribM a -> DistribM b -> DistribM a
*> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c*> :: forall a b. DistribM a -> DistribM b -> DistribM b
liftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> DistribM a -> DistribM b -> DistribM c
<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
$c<*> :: forall a b. DistribM (a -> b) -> DistribM a -> DistribM b
pure :: forall a. a -> DistribM a
$cpure :: forall a. a -> DistribM a
Applicative,
Applicative DistribM
Applicative DistribM
-> (forall a b. DistribM a -> (a -> DistribM b) -> DistribM b)
-> (forall a b. DistribM a -> DistribM b -> DistribM b)
-> (forall a. a -> DistribM a)
-> Monad DistribM
forall a. a -> DistribM a
forall a b. DistribM a -> DistribM b -> DistribM b
forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> DistribM a
$creturn :: forall a. a -> DistribM a
>> :: forall a b. DistribM a -> DistribM b -> DistribM b
$c>> :: forall a b. DistribM a -> DistribM b -> DistribM b
>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
$c>>= :: forall a b. DistribM a -> (a -> DistribM b) -> DistribM b
Monad,
HasScope Out.Kernels,
LocalScope Out.Kernels,
MonadState State,
Monad DistribM
Applicative DistribM
Applicative DistribM
-> Monad DistribM
-> (forall a. ToLog a => a -> DistribM ())
-> (Log -> DistribM ())
-> MonadLogger DistribM
Log -> DistribM ()
forall a. ToLog a => a -> DistribM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> (forall a. ToLog a => a -> m ())
-> (Log -> m ())
-> MonadLogger m
addLog :: Log -> DistribM ()
$caddLog :: Log -> DistribM ()
logMsg :: forall a. ToLog a => a -> DistribM ()
$clogMsg :: forall a. ToLog a => a -> DistribM ()
MonadLogger
)
instance MonadFreshNames DistribM where
getNameSource :: DistribM VNameSource
getNameSource = (State -> VNameSource) -> DistribM VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> VNameSource
stateNameSource
putNameSource :: VNameSource -> DistribM ()
putNameSource VNameSource
src = (State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateNameSource :: VNameSource
stateNameSource = VNameSource
src}
runDistribM ::
(MonadLogger m, MonadFreshNames m) =>
DistribM a ->
m a
runDistribM :: forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM RWS (Scope Kernels) Log State a
m) = do
(a
x, Log
msgs) <- (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((a, Log), VNameSource)) -> m (a, Log))
-> (VNameSource -> ((a, Log), VNameSource)) -> m (a, Log)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
let (a
x, State
s, Log
msgs) = RWS (Scope Kernels) Log State a
-> Scope Kernels -> State -> (a, State, Log)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (Scope Kernels) Log State a
m Scope Kernels
forall a. Monoid a => a
mempty (VNameSource -> Int -> State
State VNameSource
src Int
0)
in ((a
x, Log
msgs), State -> VNameSource
stateNameSource State
s)
Log -> m ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
msgs
a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
transformFunDef ::
(MonadFreshNames m, MonadLogger m) =>
Scope Out.Kernels ->
FunDef SOACS ->
m (Out.FunDef Out.Kernels)
transformFunDef :: forall (m :: * -> *).
(MonadFreshNames m, MonadLogger m) =>
Scope Kernels -> FunDef SOACS -> m (FunDef Kernels)
transformFunDef Scope Kernels
scope (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params BodyT SOACS
body) = DistribM (FunDef Kernels) -> m (FunDef Kernels)
forall (m :: * -> *) a.
(MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM (FunDef Kernels) -> m (FunDef Kernels))
-> DistribM (FunDef Kernels) -> m (FunDef Kernels)
forall a b. (a -> b) -> a -> b
$ do
Body Kernels
body' <-
Scope Kernels -> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (DistribM (Body Kernels) -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$
KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
forall a. Monoid a => a
mempty BodyT SOACS
body
FunDef Kernels -> DistribM (FunDef Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef Kernels -> DistribM (FunDef Kernels))
-> FunDef Kernels -> DistribM (FunDef Kernels)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType Kernels]
-> [FParam Kernels]
-> Body Kernels
-> FunDef Kernels
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType Kernels]
rettype [FParam SOACS]
[FParam Kernels]
params Body Kernels
body'
type KernelsStms = Stms Out.Kernels
transformBody :: KernelPath -> Body -> DistribM (Out.Body Out.Kernels)
transformBody :: KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
body = do
Stms Kernels
bnds <- KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> [Stm SOACS] -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS]) -> Stms SOACS -> [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body
Body Kernels -> DistribM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> DistribM (Body Kernels))
-> Body Kernels -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
bnds (Result -> Body Kernels) -> Result -> Body Kernels
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult BodyT SOACS
body
transformStms :: KernelPath -> [Stm] -> DistribM KernelsStms
transformStms :: KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
_ [] =
Stms Kernels -> DistribM (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms Kernels
forall a. Monoid a => a
mempty
transformStms KernelPath
path (Stm SOACS
bnd : [Stm SOACS]
bnds) =
Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm Stm SOACS
bnd DistribM (Maybe (Stms SOACS))
-> (Maybe (Stms SOACS) -> DistribM (Stms Kernels))
-> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Maybe (Stms SOACS)
Nothing -> do
Stms Kernels
bnd' <- KernelPath -> Stm SOACS -> DistribM (Stms Kernels)
transformStm KernelPath
path Stm SOACS
bnd
Stms Kernels -> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms Kernels
bnd' (DistribM (Stms Kernels) -> DistribM (Stms Kernels))
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
(Stms Kernels
bnd' Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>) (Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path [Stm SOACS]
bnds
Just Stms SOACS
bnds' ->
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> [Stm SOACS] -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
bnds' [Stm SOACS] -> [Stm SOACS] -> [Stm SOACS]
forall a. Semigroup a => a -> a -> a
<> [Stm SOACS]
bnds
unbalancedLambda :: Lambda -> Bool
unbalancedLambda :: Lambda -> Bool
unbalancedLambda Lambda
orig_lam =
Names -> BodyT SOACS -> Bool
forall {lore} {lore}.
(Op lore ~ SOAC lore) =>
Names -> BodyT lore -> Bool
unbalancedBody ([VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
orig_lam) (BodyT SOACS -> Bool) -> BodyT SOACS -> Bool
forall a b. (a -> b) -> a -> b
$
Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
orig_lam
where
subExpBound :: SubExp -> Names -> Bool
subExpBound (Var VName
i) Names
bound = VName
i VName -> Names -> Bool
`nameIn` Names
bound
subExpBound (Constant PrimValue
_) Names
_ = Bool
False
unbalancedBody :: Names -> BodyT lore -> Bool
unbalancedBody Names
bound BodyT lore
body =
(Stm lore -> Bool) -> Seq (Stm lore) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Names -> ExpT lore -> Bool
unbalancedStm (Names
bound Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> BodyT lore -> Names
forall lore. Body lore -> Names
boundInBody BodyT lore
body) (ExpT lore -> Bool) -> (Stm lore -> ExpT lore) -> Stm lore -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp) (Seq (Stm lore) -> Bool) -> Seq (Stm lore) -> Bool
forall a b. (a -> b) -> a -> b
$
BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body
unbalancedStm :: Names -> ExpT lore -> Bool
unbalancedStm Names
bound (Op (Stream SubExp
w [VName]
_ StreamForm lore
_ Result
_ Lambda lore
_)) =
SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
unbalancedStm Names
bound (Op (Screma SubExp
w [VName]
_ ScremaForm lore
_)) =
SubExp
w SubExp -> Names -> Bool
`subExpBound` Names
bound
unbalancedStm Names
_ Op {} =
Bool
False
unbalancedStm Names
_ DoLoop {} = Bool
False
unbalancedStm Names
bound (WithAcc [(Shape, [VName], Maybe (Lambda lore, Result))]
_ Lambda lore
lam) =
Names -> BodyT lore -> Bool
unbalancedBody Names
bound (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
unbalancedStm Names
bound (If SubExp
cond BodyT lore
tbranch BodyT lore
fbranch IfDec (BranchType lore)
_) =
SubExp
cond SubExp -> Names -> Bool
`subExpBound` Names
bound
Bool -> Bool -> Bool
&& (Names -> BodyT lore -> Bool
unbalancedBody Names
bound BodyT lore
tbranch Bool -> Bool -> Bool
|| Names -> BodyT lore -> Bool
unbalancedBody Names
bound BodyT lore
fbranch)
unbalancedStm Names
_ (BasicOp BasicOp
_) =
Bool
False
unbalancedStm Names
_ (Apply Name
fname [(SubExp, Diet)]
_ [RetType lore]
_ (Safety, SrcLoc, [SrcLoc])
_) =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Name -> Bool
isBuiltInFunction Name
fname
sequentialisedUnbalancedStm :: Stm -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm :: Stm SOACS -> DistribM (Maybe (Stms SOACS))
sequentialisedUnbalancedStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
_ (Op soac :: Op SOACS
soac@(Screma SubExp
_ [VName]
_ ScremaForm SOACS
form)))
| Just ([Reduce SOACS]
_, Lambda
lam2) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form,
Lambda -> Bool
unbalancedLambda Lambda
lam2,
Lambda -> Bool
lambdaContainsParallelism Lambda
lam2 = do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
Stms SOACS -> Maybe (Stms SOACS)
forall a. a -> Maybe a
Just (Stms SOACS -> Maybe (Stms SOACS))
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Maybe (Stms SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> Maybe (Stms SOACS))
-> DistribM ((), Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SOAC (Lore (BinderT SOACS DistribM))
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat Op SOACS
SOAC (Lore (BinderT SOACS DistribM))
soac) Scope SOACS
types
sequentialisedUnbalancedStm Stm SOACS
_ =
Maybe (Stms SOACS) -> DistribM (Maybe (Stms SOACS))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Stms SOACS)
forall a. Maybe a
Nothing
cmpSizeLe ::
String ->
Out.SizeClass ->
[SubExp] ->
DistribM ((SubExp, Name), Out.Stms Out.Kernels)
cmpSizeLe :: [Char]
-> SizeClass -> Result -> DistribM ((SubExp, Name), Stms Kernels)
cmpSizeLe [Char]
desc SizeClass
size_class Result
to_what = do
Int
x <- (State -> Int) -> DistribM Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets State -> Int
stateThresholdCounter
(State -> State) -> DistribM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((State -> State) -> DistribM ())
-> (State -> State) -> DistribM ()
forall a b. (a -> b) -> a -> b
$ \State
s -> State
s {stateThresholdCounter :: Int
stateThresholdCounter = Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}
let size_key :: Name
size_key = [Char] -> Name
nameFromString ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
desc [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show Int
x
Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), Stms Kernels))
-> Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
SubExp
to_what' <-
[Char]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"comparatee"
(ExpT Kernels -> BinderT Kernels (State VNameSource) SubExp)
-> BinderT Kernels (State VNameSource) (ExpT Kernels)
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> Result
-> BinderT
Kernels
(State VNameSource)
(Exp (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> SubExp -> Result -> m (Exp (Lore m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) Result
to_what
SubExp
cmp_res <- [Char]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
desc (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
size_key SizeClass
size_class SubExp
to_what'
(SubExp, Name) -> Binder Kernels (SubExp, Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
cmp_res, Name
size_key)
kernelAlternatives ::
(MonadFreshNames m, HasScope Out.Kernels m) =>
Out.Pattern Out.Kernels ->
Out.Body Out.Kernels ->
[(SubExp, Out.Body Out.Kernels)] ->
m (Out.Stms Out.Kernels)
kernelAlternatives :: forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern Kernels
pat Body Kernels
default_body [] = Binder Kernels () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> m (Stms Kernels))
-> Binder Kernels () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
Result
ses <- Body (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) Result
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m Result
bodyBind Body (Lore (BinderT Kernels (State VNameSource)))
Body Kernels
default_body
[(VName, SubExp)]
-> ((VName, SubExp) -> Binder Kernels ()) -> Binder Kernels ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat) Result
ses) (((VName, SubExp) -> Binder Kernels ()) -> Binder Kernels ())
-> ((VName, SubExp) -> Binder Kernels ()) -> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ \(VName
name, SubExp
se) ->
[VName]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
name] (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
kernelAlternatives Pattern Kernels
pat Body Kernels
default_body ((SubExp
cond, Body Kernels
alt) : [(SubExp, Body Kernels)]
alts) = Binder Kernels () -> m (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> m (Stms Kernels))
-> Binder Kernels () -> m (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
PatternT Type
alts_pat <- ([PatElemT Type] -> PatternT Type)
-> BinderT Kernels (State VNameSource) [PatElemT Type]
-> BinderT Kernels (State VNameSource) (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern []) (BinderT Kernels (State VNameSource) [PatElemT Type]
-> BinderT Kernels (State VNameSource) (PatternT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type]
-> BinderT Kernels (State VNameSource) (PatternT Type)
forall a b. (a -> b) -> a -> b
$
[PatElemT Type]
-> (PatElemT Type
-> BinderT Kernels (State VNameSource) (PatElemT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat) ((PatElemT Type
-> BinderT Kernels (State VNameSource) (PatElemT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type])
-> (PatElemT Type
-> BinderT Kernels (State VNameSource) (PatElemT Type))
-> BinderT Kernels (State VNameSource) [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ \PatElemT Type
pe -> do
VName
name <- [Char] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName ([Char] -> BinderT Kernels (State VNameSource) VName)
-> [Char] -> BinderT Kernels (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
baseString (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
pe
PatElemT Type
-> BinderT Kernels (State VNameSource) (PatElemT Type)
forall (m :: * -> *) a. Monad m => a -> m a
return PatElemT Type
pe {patElemName :: VName
patElemName = VName
name}
Stms Kernels
alt_stms <- Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives PatternT Type
Pattern Kernels
alts_pat Body Kernels
default_body [(SubExp, Body Kernels)]
alts
let alt_body :: Body Kernels
alt_body = Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
alt_stms (Result -> Body Kernels) -> Result -> Body Kernels
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternValueNames PatternT Type
alts_pat
Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern Kernels
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$
SubExp
-> Body Kernels
-> Body Kernels
-> IfDec (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond Body Kernels
alt Body Kernels
alt_body (IfDec (BranchType Kernels) -> ExpT Kernels)
-> IfDec (BranchType Kernels) -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
[ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec ([Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase (ShapeBase ExtSize) u]
staticShapes (PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern Kernels
pat)) IfSort
IfEquiv
transformLambda :: KernelPath -> Lambda -> DistribM (Out.Lambda Out.Kernels)
transformLambda :: KernelPath -> Lambda -> DistribM (Lambda Kernels)
transformLambda KernelPath
path (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
ret) =
[LParam Kernels] -> Body Kernels -> [Type] -> Lambda Kernels
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
[LParam Kernels]
params
(Body Kernels -> [Type] -> Lambda Kernels)
-> DistribM (Body Kernels) -> DistribM ([Type] -> Lambda Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope Kernels -> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
[LParam SOACS]
params) (KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
body)
DistribM ([Type] -> Lambda Kernels)
-> DistribM [Type] -> DistribM (Lambda Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> DistribM [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Type]
ret
transformStm :: KernelPath -> Stm -> DistribM KernelsStms
transformStm :: KernelPath -> Stm SOACS -> DistribM (Stms Kernels)
transformStm KernelPath
_ Stm SOACS
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (Stm SOACS -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm SOACS
stm) =
Binder Kernels () -> DistribM (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM (Stms Kernels))
-> Binder Kernels () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder Kernels ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
stm
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
soac))
| Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
(Stms SOACS -> DistribM (Stms Kernels))
-> DistribM (Stms SOACS) -> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Binder SOACS () -> DistribM (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Pattern (Lore (BinderT SOACS (State VNameSource)))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (BinderT SOACS (State VNameSource)))
Pattern SOACS
pat Op SOACS
SOAC (Lore (BinderT SOACS (State VNameSource)))
soac)
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (If SubExp
c BodyT SOACS
tb BodyT SOACS
fb IfDec (BranchType SOACS)
rt)) = do
Body Kernels
tb' <- KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
tb
Body Kernels
fb' <- KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
fb
Stms Kernels -> DistribM (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistribM (Stms Kernels))
-> Stms Kernels -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels) -> Stm Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern Kernels
pat StmAux (ExpDec SOACS)
StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body Kernels
-> Body Kernels
-> IfDec (BranchType Kernels)
-> ExpT Kernels
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
c Body Kernels
tb' Body Kernels
fb' IfDec (BranchType SOACS)
IfDec (BranchType Kernels)
rt
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (WithAcc [(Shape, [VName], Maybe (Lambda, Result))]
inputs Lambda
lam)) =
Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels)
-> (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern Kernels
pat StmAux (ExpDec SOACS)
StmAux (ExpDec Kernels)
aux
(ExpT Kernels -> Stms Kernels)
-> DistribM (ExpT Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(Shape, [VName], Maybe (Lambda Kernels, Result))]
-> Lambda Kernels -> ExpT Kernels
forall lore.
[(Shape, [VName], Maybe (Lambda lore, Result))]
-> Lambda lore -> ExpT lore
WithAcc (((Shape, [VName], Maybe (Lambda, Result))
-> (Shape, [VName], Maybe (Lambda Kernels, Result)))
-> [(Shape, [VName], Maybe (Lambda, Result))]
-> [(Shape, [VName], Maybe (Lambda Kernels, Result))]
forall a b. (a -> b) -> [a] -> [b]
map (Shape, [VName], Maybe (Lambda, Result))
-> (Shape, [VName], Maybe (Lambda Kernels, Result))
forall {f :: * -> *} {p :: * -> * -> *} {a} {b} {c}.
(Functor f, Bifunctor p) =>
(a, b, f (p Lambda c)) -> (a, b, f (p (Lambda Kernels) c))
transformInput [(Shape, [VName], Maybe (Lambda, Result))]
inputs) (Lambda Kernels -> ExpT Kernels)
-> DistribM (Lambda Kernels) -> DistribM (ExpT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> Lambda -> DistribM (Lambda Kernels)
transformLambda KernelPath
path Lambda
lam)
where
transformInput :: (a, b, f (p Lambda c)) -> (a, b, f (p (Lambda Kernels) c))
transformInput (a
shape, b
arrs, f (p Lambda c)
op) =
(a
shape, b
arrs, (p Lambda c -> p (Lambda Kernels) c)
-> f (p Lambda c) -> f (p (Lambda Kernels) c)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Lambda -> Lambda Kernels) -> p Lambda c -> p (Lambda Kernels) c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first Lambda -> Lambda Kernels
soacsLambdaToKernels) f (p Lambda c)
op)
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
body)) =
Scope Kernels -> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
( Scope SOACS -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (LoopForm SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm SOACS
form)
Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope Kernels
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
mergeparams
)
(DistribM (Stms Kernels) -> DistribM (Stms Kernels))
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels)
-> (Body Kernels -> Stm Kernels) -> Body Kernels -> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern Kernels
pat StmAux (ExpDec SOACS)
StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels)
-> (Body Kernels -> ExpT Kernels) -> Body Kernels -> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(FParam Kernels, SubExp)]
-> [(FParam Kernels, SubExp)]
-> LoopForm Kernels
-> Body Kernels
-> ExpT Kernels
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam Kernels, SubExp)]
val LoopForm Kernels
form' (Body Kernels -> Stms Kernels)
-> DistribM (Body Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> BodyT SOACS -> DistribM (Body Kernels)
transformBody KernelPath
path BodyT SOACS
body
where
mergeparams :: [Param DeclType]
mergeparams = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val
form' :: LoopForm Kernels
form' = case LoopForm SOACS
form of
WhileLoop VName
cond ->
VName -> LoopForm Kernels
forall lore. VName -> LoopForm lore
WhileLoop VName
cond
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
ps ->
VName
-> IntType
-> SubExp
-> [(LParam Kernels, VName)]
-> LoopForm Kernels
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam Kernels, VName)]
ps
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just Lambda
lam <- ScremaForm SOACS -> Maybe Lambda
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form =
KernelPath -> MapLoop -> DistribM (Stms Kernels)
onMap KernelPath
path (MapLoop -> DistribM (Stms Kernels))
-> MapLoop -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Pattern SOACS
-> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern SOACS
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w Lambda
lam [VName]
arrs
transformStm KernelPath
path (Let Pattern SOACS
res_pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just [Scan SOACS]
scans <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
isScanSOAC ScremaForm SOACS
form,
Scan Lambda
scan_lam Result
nes <- [Scan SOACS] -> Scan SOACS
forall lore. Bindable lore => [Scan lore] -> Scan lore
singleScan [Scan SOACS]
scans,
Just BinderT SOACS DistribM ()
do_iswim <- Pattern SOACS
-> SubExp
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern SOACS
-> SubExp -> Lambda -> [(SubExp, VName)] -> Maybe (m ())
iswim Pattern SOACS
res_pat SubExp
w Lambda
scan_lam ([(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd (((), Stms SOACS) -> DistribM (Stms Kernels))
-> DistribM ((), Stms SOACS) -> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs BinderT SOACS DistribM ()
do_iswim) Scope SOACS
types
| Just ([Scan SOACS]
scans, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form = Binder Kernels () -> DistribM (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM (Stms Kernels))
-> Binder Kernels () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
[SegBinOp Kernels]
scan_ops <- [Scan SOACS]
-> (Scan SOACS
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> BinderT Kernels (State VNameSource) [SegBinOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> BinderT Kernels (State VNameSource) [SegBinOp Kernels])
-> (Scan SOACS
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> BinderT Kernels (State VNameSource) [SegBinOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda
scan_lam Result
nes) -> do
(Lambda
scan_lam', Result
nes', Shape
shape) <- Lambda
-> Result
-> BinderT Kernels (State VNameSource) (Lambda, Result, Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
scan_lam Result
nes
let scan_lam'' :: Lambda Kernels
scan_lam'' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
scan_lam'
SegBinOp Kernels
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegBinOp Kernels
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> SegBinOp Kernels
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels)
forall a b. (a -> b) -> a -> b
$ Commutativity
-> Lambda Kernels -> Result -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda Kernels
scan_lam'' Result
nes' Shape
shape
let map_lam_sequential :: Lambda Kernels
map_lam_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
map_lam
SegLevel
lvl <- MkSegLevel Kernels (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped [SubExp
w] [Char]
"segscan" (ThreadRecommendation
-> BinderT Kernels (State VNameSource) (SegOpLevel Kernels))
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms Kernels -> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> Binder Kernels ())
-> (Stms Kernels -> Stms Kernels)
-> Stms Kernels
-> Binder Kernels ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm Kernels -> Stm Kernels) -> Stms Kernels -> Stms Kernels
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm Kernels -> Stm Kernels
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs)
(Stms Kernels -> Binder Kernels ())
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> Binder Kernels ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel Kernels
SegLevel
lvl Pattern SOACS
Pattern Kernels
res_pat SubExp
w [SegBinOp Kernels]
scan_ops Lambda Kernels
map_lam_sequential [VName]
arrs [] []
transformStm KernelPath
path (Let Pattern SOACS
res_pat StmAux (ExpDec SOACS)
aux (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just [Reduce Commutativity
comm Lambda
red_fun Result
nes] <- ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form,
let comm' :: Commutativity
comm'
| Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_fun = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm,
Just BinderT SOACS DistribM ()
do_irwim <- Pattern SOACS
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (BinderT SOACS DistribM ())
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
Pattern SOACS
-> SubExp
-> Commutativity
-> Lambda
-> [(SubExp, VName)]
-> Maybe (m ())
irwim Pattern SOACS
res_pat SubExp
w Commutativity
comm' Lambda
red_fun ([(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ()))
-> [(SubExp, VName)] -> Maybe (BinderT SOACS DistribM ())
forall a b. (a -> b) -> a -> b
$ Result -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip Result
nes [VName]
arrs = do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
(SymbolTable (Wise SOACS)
_, Stms SOACS
bnds) <- ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> (SymbolTable (Wise SOACS), Stms SOACS)
forall a b. (a, b) -> a
fst (((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> (SymbolTable (Wise SOACS), Stms SOACS))
-> DistribM ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
-> DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
-> Scope SOACS
-> DistribM ((SymbolTable (Wise SOACS), Stms SOACS), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Stms SOACS
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms (Stms SOACS
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS))
-> BinderT SOACS DistribM (Stms SOACS)
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> BinderT SOACS DistribM (Stms (Lore (BinderT SOACS DistribM)))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (StmAux () -> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux BinderT SOACS DistribM ()
do_irwim)) Scope SOACS
types
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> [Stm SOACS] -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
bnds
transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)))
| Just ([Reduce SOACS]
reds, Lambda
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form = do
let paralleliseOuter :: DistribM (Stms Kernels)
paralleliseOuter = Binder Kernels () -> DistribM (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM (Stms Kernels))
-> Binder Kernels () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
[SegBinOp Kernels]
red_ops <- [Reduce SOACS]
-> (Reduce SOACS
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> BinderT Kernels (State VNameSource) [SegBinOp Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> BinderT Kernels (State VNameSource) [SegBinOp Kernels])
-> (Reduce SOACS
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> BinderT Kernels (State VNameSource) [SegBinOp Kernels]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda
red_lam Result
nes) -> do
(Lambda
red_lam', Result
nes', Shape
shape) <- Lambda
-> Result
-> BinderT Kernels (State VNameSource) (Lambda, Result, Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda -> Result -> m (Lambda, Result, Shape)
determineReduceOp Lambda
red_lam Result
nes
let comm' :: Commutativity
comm'
| Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_lam' = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
red_lam'' :: Lambda Kernels
red_lam'' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
red_lam'
SegBinOp Kernels
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegBinOp Kernels
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels))
-> SegBinOp Kernels
-> BinderT Kernels (State VNameSource) (SegBinOp Kernels)
forall a b. (a -> b) -> a -> b
$ Commutativity
-> Lambda Kernels -> Result -> Shape -> SegBinOp Kernels
forall lore.
Commutativity -> Lambda lore -> Result -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm' Lambda Kernels
red_lam'' Result
nes' Shape
shape
let map_lam_sequential :: Lambda Kernels
map_lam_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
map_lam
SegLevel
lvl <- MkSegLevel Kernels (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped [SubExp
w] [Char]
"segred" (ThreadRecommendation
-> BinderT Kernels (State VNameSource) (SegOpLevel Kernels))
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms Kernels -> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> Binder Kernels ())
-> (Stms Kernels -> Stms Kernels)
-> Stms Kernels
-> Binder Kernels ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm Kernels -> Stm Kernels) -> Stms Kernels -> Stms Kernels
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm Kernels -> Stm Kernels
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs)
(Stms Kernels -> Binder Kernels ())
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> Binder Kernels ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel Kernels
-> Pattern Kernels
-> SubExp
-> [SegBinOp Kernels]
-> Lambda Kernels
-> [VName]
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
nonSegRed SegOpLevel Kernels
SegLevel
lvl Pattern SOACS
Pattern Kernels
pat SubExp
w [SegBinOp Kernels]
red_ops Lambda Kernels
map_lam_sequential [VName]
arrs
outerParallelBody :: DistribM (Body Kernels)
outerParallelBody =
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody
(Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> DistribM (Stms Kernels)
paralleliseOuter DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))
paralleliseInner :: KernelPath -> DistribM (Stms Kernels)
paralleliseInner KernelPath
path' = do
(Stm SOACS
mapstm, Stm SOACS
redstm) <-
Pattern SOACS
-> (SubExp, Commutativity, Lambda, Lambda, Result, [VName])
-> DistribM (Stm SOACS, Stm SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, ExpDec lore ~ (),
Op lore ~ SOAC lore) =>
Pattern lore
-> (SubExp, Commutativity, LambdaT lore, LambdaT lore, Result,
[VName])
-> m (Stm lore, Stm lore)
redomapToMapAndReduce Pattern SOACS
pat (SubExp
w, Commutativity
comm', Lambda
red_lam, Lambda
map_lam, Result
nes, [VName]
arrs)
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path' ([Stm SOACS] -> DistribM (Stms Kernels))
-> (Stms SOACS -> [Stm SOACS])
-> Stms SOACS
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> DistribM (Stms Kernels))
-> (BinderT SOACS DistribM () -> DistribM (Stms SOACS))
-> BinderT SOACS DistribM ()
-> DistribM (Stms Kernels)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< (BinderT SOACS DistribM () -> Scope SOACS -> DistribM (Stms SOACS)
forall (m :: * -> *) lore.
MonadFreshNames m =>
BinderT lore m () -> Scope lore -> m (Stms lore)
`runBinderT_` Scope SOACS
types) (BinderT SOACS DistribM () -> DistribM (Stms Kernels))
-> BinderT SOACS DistribM () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
(SymbolTable (Wise SOACS)
_, Stms SOACS
stms) <-
Stms SOACS
-> BinderT SOACS DistribM (SymbolTable (Wise SOACS), Stms SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Stms SOACS -> m (SymbolTable (Wise SOACS), Stms SOACS)
simplifyStms ([Stm SOACS] -> Stms SOACS
forall lore. [Stm lore] -> Stms lore
stmsFromList [Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm SOACS
mapstm, Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs Stm SOACS
redstm])
Stms (Lore (BinderT SOACS DistribM)) -> BinderT SOACS DistribM ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT SOACS DistribM))
Stms SOACS
stms
where
comm' :: Commutativity
comm'
| Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_lam = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
(Reduce Commutativity
comm Lambda
red_lam Result
nes) = [Reduce SOACS] -> Reduce SOACS
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce SOACS]
reds
innerParallelBody :: KernelPath -> DistribM (Body Kernels)
innerParallelBody KernelPath
path' =
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody
(Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
paralleliseInner KernelPath
path' DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))
if Bool -> Bool
not (Lambda -> Bool
lambdaContainsParallelism Lambda
map_lam)
Bool -> Bool -> Bool
|| Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux
then DistribM (Stms Kernels)
paralleliseOuter
else do
((SubExp
outer_suff, Name
outer_suff_key), Stms Kernels
suff_stms) <-
[Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms Kernels)
sufficientParallelism [Char]
"suff_outer_redomap" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing
Body Kernels
outer_stms <- DistribM (Body Kernels)
outerParallelBody
Body Kernels
inner_stms <- KernelPath -> DistribM (Body Kernels)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
(Stms Kernels
suff_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>) (Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
inner_stms [(SubExp
outer_suff, Body Kernels
outer_stms)]
transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Stream SubExp
w [VName]
arrs Parallel {} [] Lambda
map_fun)))
| Bool -> Bool
not (Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) = do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM (Stms Kernels))
-> DistribM ((), Stms SOACS) -> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Certificates
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (BinderT SOACS DistribM () -> BinderT SOACS DistribM ())
-> BinderT SOACS DistribM () -> BinderT SOACS DistribM ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w [] LambdaT (Lore (BinderT SOACS DistribM))
Lambda
map_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
path (Let Pattern SOACS
pat aux :: StmAux (ExpDec SOACS)
aux@(StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Stream SubExp
w [VName]
arrs (Parallel StreamOrd
o Commutativity
comm Lambda
red_fun) Result
nes Lambda
fold_fun)))
| Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux =
KernelPath -> DistribM (Stms Kernels)
paralleliseOuter KernelPath
path
| Bool
otherwise = do
((SubExp
outer_suff, Name
outer_suff_key), Stms Kernels
suff_stms) <-
[Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms Kernels)
sufficientParallelism [Char]
"suff_outer_stream" [SubExp
w] KernelPath
path Maybe Int64
forall a. Maybe a
Nothing
Body Kernels
outer_stms <- KernelPath -> DistribM (Body Kernels)
outerParallelBody ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
Body Kernels
inner_stms <- KernelPath -> DistribM (Body Kernels)
innerParallelBody ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path)
(Stms Kernels
suff_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>)
(Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
inner_stms [(SubExp
outer_suff, Body Kernels
outer_stms)]
where
paralleliseOuter :: KernelPath -> DistribM (Stms Kernels)
paralleliseOuter KernelPath
path'
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda
red_fun = do
let fold_fun' :: Lambda Kernels
fold_fun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
fold_fun
let ([PatElemT Type]
red_pat_elems, [PatElemT Type]
concat_pat_elems) =
Int -> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a. Int -> [a] -> ([a], [a])
splitAt (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
nes) ([PatElemT Type] -> ([PatElemT Type], [PatElemT Type]))
-> [PatElemT Type] -> ([PatElemT Type], [PatElemT Type])
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern SOACS
pat
red_pat :: PatternT Type
red_pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type]
red_pat_elems
((SubExp
num_threads, [VName]
red_results), Stms Kernels
stms) <-
MkSegLevel Kernels DistribM
-> [[Char]]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Result
-> [VName]
-> DistribM ((SubExp, [VName]), Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
MkSegLevel Kernels m
-> [[Char]]
-> [PatElem Kernels]
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Result
-> [VName]
-> m ((SubExp, [VName]), Stms Kernels)
streamMap
MkSegLevel Kernels DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped
((PatElemT Type -> [Char]) -> [PatElemT Type] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [Char]
baseString (VName -> [Char])
-> (PatElemT Type -> VName) -> PatElemT Type -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName) [PatElemT Type]
red_pat_elems)
[PatElemT Type]
[PatElem Kernels]
concat_pat_elems
SubExp
w
Commutativity
Noncommutative
Lambda Kernels
fold_fun'
Result
nes
[VName]
arrs
ScremaForm SOACS
reduce_soac <- [Reduce SOACS] -> DistribM (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Reduce lore] -> m (ScremaForm lore)
reduceSOAC [Commutativity -> Lambda -> Result -> Reduce SOACS
forall lore. Commutativity -> Lambda lore -> Result -> Reduce lore
Reduce Commutativity
comm' Lambda
red_fun Result
nes]
(Stms Kernels
stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>)
(Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms Kernels -> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf
Stms Kernels
stms
( KernelPath -> Stm SOACS -> DistribM (Stms Kernels)
transformStm KernelPath
path' (Stm SOACS -> DistribM (Stms Kernels))
-> Stm SOACS -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Pattern SOACS -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern SOACS
red_pat StmAux ()
StmAux (ExpDec SOACS)
aux {stmAuxAttrs :: Attrs
stmAuxAttrs = Attrs
forall a. Monoid a => a
mempty} (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
num_threads [VName]
red_results ScremaForm SOACS
reduce_soac)
)
| Bool
otherwise = do
let red_fun_sequential :: Lambda Kernels
red_fun_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
red_fun
fold_fun_sequential :: Lambda Kernels
fold_fun_sequential = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
fold_fun
(Stm Kernels -> Stm Kernels) -> Stms Kernels -> Stms Kernels
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm Kernels -> Stm Kernels
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs)
(Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MkSegLevel Kernels DistribM
-> Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> Result
-> [VName]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
MkSegLevel Kernels m
-> Pattern Kernels
-> SubExp
-> Commutativity
-> Lambda Kernels
-> Lambda Kernels
-> Result
-> [VName]
-> m (Stms Kernels)
streamRed
MkSegLevel Kernels DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped
Pattern SOACS
Pattern Kernels
pat
SubExp
w
Commutativity
comm'
Lambda Kernels
red_fun_sequential
Lambda Kernels
fold_fun_sequential
Result
nes
[VName]
arrs
outerParallelBody :: KernelPath -> DistribM (Body Kernels)
outerParallelBody KernelPath
path' =
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody
(Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
paralleliseOuter KernelPath
path' DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))
paralleliseInner :: KernelPath -> DistribM (Stms Kernels)
paralleliseInner KernelPath
path' = do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path' ([Stm SOACS] -> DistribM (Stms Kernels))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> [Stm SOACS] -> [Stm SOACS]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) ([Stm SOACS] -> [Stm SOACS])
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM (Stms Kernels))
-> DistribM ((), Stms SOACS) -> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w Result
nes LambdaT (Lore (BinderT SOACS DistribM))
Lambda
fold_fun [VName]
arrs) Scope SOACS
types
innerParallelBody :: KernelPath -> DistribM (Body Kernels)
innerParallelBody KernelPath
path' =
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody
(Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
paralleliseInner KernelPath
path' DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat)))
comm' :: Commutativity
comm'
| Lambda -> Bool
forall lore. Lambda lore -> Bool
commutativeLambda Lambda
red_fun, StreamOrd
o StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamOrd
InOrder = Commutativity
Commutative
| Bool
otherwise = Commutativity
comm
transformStm KernelPath
path (Let Pattern SOACS
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form))) = do
Scope SOACS
scope <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> [Stm SOACS] -> [Stm SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (Certificates -> Stm SOACS -> Stm SOACS
forall lore. Certificates -> Stm lore -> Stm lore
certify Certificates
cs) ([Stm SOACS] -> [Stm SOACS])
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM (Stms Kernels))
-> DistribM ((), Stms SOACS) -> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w ScremaForm (Lore (BinderT SOACS DistribM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformStm KernelPath
path (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
_ (Op (Stream SubExp
w [VName]
arrs StreamForm SOACS
Sequential Result
nes Lambda
fold_fun))) = do
Scope SOACS
types <- (Scope Kernels -> Scope SOACS) -> DistribM (Scope SOACS)
forall lore (m :: * -> *) a.
HasScope lore m =>
(Scope lore -> a) -> m a
asksScope Scope Kernels -> Scope SOACS
scopeForSOACs
KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> (((), Stms SOACS) -> [Stm SOACS])
-> ((), Stms SOACS)
-> DistribM (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms SOACS -> [Stm SOACS])
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> [Stm SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
(((), Stms SOACS) -> DistribM (Stms Kernels))
-> DistribM ((), Stms SOACS) -> DistribM (Stms Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS DistribM ()
-> Scope SOACS -> DistribM ((), Stms SOACS)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT (Pattern (Lore (BinderT SOACS DistribM))
-> SubExp
-> Result
-> LambdaT (Lore (BinderT SOACS DistribM))
-> [VName]
-> BinderT SOACS DistribM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> Result -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS DistribM))
Pattern SOACS
pat SubExp
w Result
nes LambdaT (Lore (BinderT SOACS DistribM))
Lambda
fold_fun [VName]
arrs) Scope SOACS
types
transformStm KernelPath
_ (Let Pattern SOACS
pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Scatter SubExp
w Lambda
lam [VName]
ivs [(Shape, Int, VName)]
as))) = Binder Kernels () -> DistribM (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM (Stms Kernels))
-> Binder Kernels () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
let lam' :: Lambda Kernels
lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam
VName
write_i <- [Char] -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"write_i"
let ([Shape]
as_ws, [Int]
_, [VName]
_) = [(Shape, Int, VName)] -> ([Shape], [Int], [VName])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 [(Shape, Int, VName)]
as
kstms :: Stms Kernels
kstms = Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> Stms Kernels) -> Body Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
krets :: [KernelResult]
krets = do
(Shape
a_w, VName
a, [(Result, SubExp)]
is_vs) <-
[(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
as (Result -> [(Shape, VName, [(Result, SubExp)])])
-> Result -> [(Shape, VName, [(Result, SubExp)])]
forall a b. (a -> b) -> a -> b
$ Body Kernels -> Result
forall lore. BodyT lore -> Result
bodyResult (Body Kernels -> Result) -> Body Kernels -> Result
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam'
KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Shape
a_w VName
a [((SubExp -> DimIndex SubExp) -> Result -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix Result
is, SubExp
v) | (Result
is, SubExp
v) <- [(Result, SubExp)]
is_vs]
body :: KernelBody Kernels
body = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
kstms [KernelResult]
krets
inputs :: [KernelInput]
inputs = do
(Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
lam') [VName]
ivs
KernelInput -> [KernelInput]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> Result -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]
(SegOp SegLevel Kernels
kernel, Stms Kernels
stms) <-
MkSegLevel Kernels (BinderT Kernels (State VNameSource))
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody Kernels
-> BinderT
Kernels
(State VNameSource)
(SegOp (SegOpLevel Kernels) Kernels, Stms Kernels)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel
MkSegLevel Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped
[(VName
write_i, SubExp
w)]
[KernelInput]
inputs
((Shape -> Type -> Type) -> [Shape] -> [Type] -> [Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (Int -> Type -> Type
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (Int -> Type -> Type) -> (Shape -> Int) -> Shape -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length) [Shape]
as_ws ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
Pattern SOACS
pat)
KernelBody Kernels
body
Certificates -> Binder Kernels () -> Binder Kernels ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder Kernels () -> Binder Kernels ())
-> Binder Kernels () -> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
stms
Pattern (Lore (BinderT Kernels (State VNameSource)))
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern (Lore (BinderT Kernels (State VNameSource)))
Pattern SOACS
pat (Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ())
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel Kernels
kernel
transformStm KernelPath
_ (Let Pattern SOACS
orig_pat (StmAux Certificates
cs Attrs
_ ExpDec SOACS
_) (Op (Hist SubExp
w [HistOp SOACS]
ops Lambda
bucket_fun [VName]
imgs))) = do
let bfun' :: Lambda Kernels
bfun' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
bucket_fun
Binder Kernels () -> DistribM (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM (Stms Kernels))
-> Binder Kernels () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
SegLevel
lvl <- MkSegLevel Kernels (State VNameSource)
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped [SubExp
w] [Char]
"seghist" (ThreadRecommendation
-> BinderT Kernels (State VNameSource) (SegOpLevel Kernels))
-> ThreadRecommendation
-> BinderT Kernels (State VNameSource) (SegOpLevel Kernels)
forall a b. (a -> b) -> a -> b
$ SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegNoVirt
Stms Kernels -> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms Kernels -> Binder Kernels ())
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> Binder Kernels ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Lambda
-> BinderT
Kernels
(State VNameSource)
(Lambda (Lore (BinderT Kernels (State VNameSource)))))
-> SegOpLevel (Lore (BinderT Kernels (State VNameSource)))
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore (BinderT Kernels (State VNameSource)))
-> [VName]
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
(Lambda -> m (Lambda (Lore m)))
-> SegOpLevel (Lore m)
-> PatternT Type
-> [(VName, SubExp)]
-> [KernelInput]
-> Certificates
-> SubExp
-> [HistOp SOACS]
-> Lambda (Lore m)
-> [VName]
-> m (Stms (Lore m))
histKernel Lambda
-> BinderT
Kernels
(State VNameSource)
(Lambda (Lore (BinderT Kernels (State VNameSource))))
Lambda -> BinderT Kernels (State VNameSource) (Lambda Kernels)
onLambda SegOpLevel (Lore (BinderT Kernels (State VNameSource)))
SegLevel
lvl PatternT Type
Pattern SOACS
orig_pat [] [] Certificates
cs SubExp
w [HistOp SOACS]
ops Lambda (Lore (BinderT Kernels (State VNameSource)))
Lambda Kernels
bfun' [VName]
imgs
where
onLambda :: Lambda -> BinderT Kernels (State VNameSource) (Lambda Kernels)
onLambda = Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels))
-> (Lambda -> Lambda Kernels)
-> Lambda
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Lambda Kernels
soacsLambdaToKernels
transformStm KernelPath
_ Stm SOACS
bnd =
Binder Kernels () -> DistribM (Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder Kernels () -> DistribM (Stms Kernels))
-> Binder Kernels () -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Binder Kernels ()
forall (m :: * -> *).
(Transformer m, LetDec (Lore m) ~ LetDec SOACS) =>
Stm SOACS -> m ()
FOT.transformStmRecursively Stm SOACS
bnd
sufficientParallelism ::
String ->
[SubExp] ->
KernelPath ->
Maybe Int64 ->
DistribM ((SubExp, Name), Out.Stms Out.Kernels)
sufficientParallelism :: [Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms Kernels)
sufficientParallelism [Char]
desc Result
ws KernelPath
path Maybe Int64
def =
[Char]
-> SizeClass -> Result -> DistribM ((SubExp, Name), Stms Kernels)
cmpSizeLe [Char]
desc (KernelPath -> Maybe Int64 -> SizeClass
Out.SizeThreshold KernelPath
path Maybe Int64
def) Result
ws
worthIntraGroup :: Lambda -> Bool
worthIntraGroup :: Lambda -> Bool
worthIntraGroup Lambda
lam = BodyT SOACS -> Int
forall {lore}. (Op lore ~ SOAC lore) => BodyT lore -> Int
bodyInterest (Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
where
bodyInterest :: BodyT lore -> Int
bodyInterest BodyT lore
body =
Seq Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm lore -> Int
interest (Stm lore -> Int) -> Seq (Stm lore) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body
interest :: Stm lore -> Int
interest Stm lore
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
Int
0 :: Int
| Op (Screma SubExp
w [VName]
_ ScremaForm lore
form) <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
Just Lambda lore
lam' <- ScremaForm lore -> Maybe (Lambda lore)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm lore
form =
SubExp -> Lambda lore -> Int
mapLike SubExp
w Lambda lore
lam'
| Op (Scatter SubExp
w Lambda lore
lam' [VName]
_ [(Shape, Int, VName)]
_) <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
SubExp -> Lambda lore -> Int
mapLike SubExp
w Lambda lore
lam'
| DoLoop [(FParam lore, SubExp)]
_ [(FParam lore, SubExp)]
_ LoopForm lore
_ BodyT lore
body <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
BodyT lore -> Int
bodyInterest BodyT lore
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
| If SubExp
_ BodyT lore
tbody BodyT lore
fbody IfDec (BranchType lore)
_ <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (BodyT lore -> Int
bodyInterest BodyT lore
tbody) (BodyT lore -> Int
bodyInterest BodyT lore
fbody)
| Op (Screma SubExp
w [VName]
_ (ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
lam')) <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
SubExp -> Int
forall {p}. Num p => SubExp -> p
zeroIfTooSmall SubExp
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ BodyT lore -> Int
bodyInterest (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam')
| Op (Stream SubExp
_ [VName]
_ StreamForm lore
Sequential Result
_ Lambda lore
lam') <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
BodyT lore -> Int
bodyInterest (BodyT lore -> Int) -> BodyT lore -> Int
forall a b. (a -> b) -> a -> b
$ Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam'
| Bool
otherwise =
Int
0
where
attrs :: Attrs
attrs = StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec lore) -> Attrs) -> StmAux (ExpDec lore) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm lore -> StmAux (ExpDec lore)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm lore
stm
sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
zeroIfTooSmall :: SubExp -> p
zeroIfTooSmall (Constant (IntValue IntValue
x))
| IntValue -> Int64
intToInt64 IntValue
x Int64 -> Int64 -> Bool
forall a. Ord a => a -> a -> Bool
< Int64
32 = p
0
zeroIfTooSmall SubExp
_ = p
1
mapLike :: SubExp -> Lambda lore -> Int
mapLike SubExp
w Lambda lore
lam' =
if Bool
sequential_inner
then Int
0
else Int -> Int -> Int
forall a. Ord a => a -> a -> a
max (SubExp -> Int
forall {p}. Num p => SubExp -> p
zeroIfTooSmall SubExp
w) (BodyT lore -> Int
bodyInterest (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam'))
worthSequentialising :: Lambda -> Bool
worthSequentialising :: Lambda -> Bool
worthSequentialising Lambda
lam = BodyT SOACS -> Int
forall {lore}. (Op lore ~ SOAC lore) => BodyT lore -> Int
bodyInterest (Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
where
bodyInterest :: BodyT lore -> Int
bodyInterest BodyT lore
body =
Seq Int -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Seq Int -> Int) -> Seq Int -> Int
forall a b. (a -> b) -> a -> b
$ Stm lore -> Int
interest (Stm lore -> Int) -> Seq (Stm lore) -> Seq Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT lore -> Seq (Stm lore)
forall lore. BodyT lore -> Stms lore
bodyStms BodyT lore
body
interest :: Stm lore -> Int
interest Stm lore
stm
| Attr
"sequential" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs =
Int
0 :: Int
| Op (Screma SubExp
_ [VName]
_ form :: ScremaForm lore
form@(ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
lam')) <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm,
Maybe (Lambda lore) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda lore) -> Bool) -> Maybe (Lambda lore) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm lore -> Maybe (Lambda lore)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm lore
form =
if Bool
sequential_inner
then Int
0
else BodyT lore -> Int
bodyInterest (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam')
| Op Scatter {} <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
Int
0
| DoLoop [(FParam lore, SubExp)]
_ [(FParam lore, SubExp)]
_ ForLoop {} BodyT lore
body <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
BodyT lore -> Int
bodyInterest BodyT lore
body Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
10
| Op (Screma SubExp
_ [VName]
_ form :: ScremaForm lore
form@(ScremaForm [Scan lore]
_ [Reduce lore]
_ Lambda lore
lam')) <- Stm lore -> ExpT lore
forall lore. Stm lore -> Exp lore
stmExp Stm lore
stm =
Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ BodyT lore -> Int
bodyInterest (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam')
Int -> Int -> Int
forall a. Num a => a -> a -> a
+
case ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm lore
form of
Just ([Reduce lore], Lambda lore)
_ -> Int
1
Maybe ([Reduce lore], Lambda lore)
Nothing -> Int
0
| Bool
otherwise =
Int
0
where
attrs :: Attrs
attrs = StmAux (ExpDec lore) -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs (StmAux (ExpDec lore) -> Attrs) -> StmAux (ExpDec lore) -> Attrs
forall a b. (a -> b) -> a -> b
$ Stm lore -> StmAux (ExpDec lore)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm lore
stm
sequential_inner :: Bool
sequential_inner = Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
onTopLevelStms ::
KernelPath ->
Stms SOACS ->
DistNestT Out.Kernels DistribM KernelsStms
onTopLevelStms :: KernelPath
-> Stms SOACS -> DistNestT Kernels DistribM (Stms Kernels)
onTopLevelStms KernelPath
path Stms SOACS
stms =
DistribM (Stms Kernels)
-> DistNestT Kernels DistribM (Stms Kernels)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (DistribM (Stms Kernels)
-> DistNestT Kernels DistribM (Stms Kernels))
-> DistribM (Stms Kernels)
-> DistNestT Kernels DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ KernelPath -> [Stm SOACS] -> DistribM (Stms Kernels)
transformStms KernelPath
path ([Stm SOACS] -> DistribM (Stms Kernels))
-> [Stm SOACS] -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Stm SOACS]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms
onMap :: KernelPath -> MapLoop -> DistribM KernelsStms
onMap :: KernelPath -> MapLoop -> DistribM (Stms Kernels)
onMap KernelPath
path (MapLoop Pattern SOACS
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) = do
Scope Kernels
types <- DistribM (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let loopnest :: LoopNesting
loopnest = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
env :: KernelPath -> DistEnv Kernels DistribM
env KernelPath
path' =
DistEnv :: forall lore (m :: * -> *).
Nestings
-> Scope lore
-> (Stms SOACS -> DistNestT lore m (Stms lore))
-> (MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore))
-> (Stm SOACS -> Binder lore (Stms lore))
-> (Lambda -> Binder lore (Lambda lore))
-> MkSegLevel lore m
-> DistEnv lore m
DistEnv
{ distNest :: Nestings
distNest = Nesting -> Nestings
singleNesting (Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest),
distScope :: Scope Kernels
distScope =
PatternT Type -> Scope Kernels
forall lore dec. (LetDec lore ~ dec) => PatternT dec -> Scope lore
scopeOfPattern PatternT Type
Pattern SOACS
pat
Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope Kernels
scopeForKernels (Lambda -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Lambda
lam)
Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Scope Kernels
types,
distOnInnerMap :: MapLoop
-> DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
distOnInnerMap = KernelPath
-> MapLoop
-> DistAcc Kernels
-> DistNestT Kernels DistribM (DistAcc Kernels)
onInnerMap KernelPath
path',
distOnTopLevelStms :: Stms SOACS -> DistNestT Kernels DistribM (Stms Kernels)
distOnTopLevelStms = KernelPath
-> Stms SOACS -> DistNestT Kernels DistribM (Stms Kernels)
onTopLevelStms KernelPath
path',
distSegLevel :: MkSegLevel Kernels DistribM
distSegLevel = MkSegLevel Kernels DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped,
distOnSOACSStms :: Stm SOACS -> BinderT Kernels (State VNameSource) (Stms Kernels)
distOnSOACSStms = Stms Kernels -> BinderT Kernels (State VNameSource) (Stms Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels))
-> (Stm SOACS -> Stms Kernels)
-> Stm SOACS
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm (Stm Kernels -> Stms Kernels)
-> (Stm SOACS -> Stm Kernels) -> Stm SOACS -> Stms Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm Kernels
soacsStmToKernels,
distOnSOACSLambda :: Lambda -> BinderT Kernels (State VNameSource) (Lambda Kernels)
distOnSOACSLambda = Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda Kernels
-> BinderT Kernels (State VNameSource) (Lambda Kernels))
-> (Lambda -> Lambda Kernels)
-> Lambda
-> BinderT Kernels (State VNameSource) (Lambda Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda -> Lambda Kernels
soacsLambdaToKernels
}
exploitInnerParallelism :: KernelPath -> DistribM (Stms Kernels)
exploitInnerParallelism KernelPath
path' =
DistEnv Kernels DistribM
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels)
forall (m :: * -> *) lore.
(MonadLogger m, DistLore lore) =>
DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT (KernelPath -> DistEnv Kernels DistribM
env KernelPath
path') (DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels))
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
DistAcc Kernels
-> Stms SOACS -> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> Stms SOACS -> DistNestT lore m (DistAcc lore)
distributeMapBodyStms DistAcc Kernels
acc (BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam)
let exploitOuterParallelism :: KernelPath -> DistribM (Stms Kernels)
exploitOuterParallelism KernelPath
path' = do
let lam' :: Lambda Kernels
lam' = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam
DistEnv Kernels DistribM
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels)
forall (m :: * -> *) lore.
(MonadLogger m, DistLore lore) =>
DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT (KernelPath -> DistEnv Kernels DistribM
env KernelPath
path') (DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels))
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore -> DistNestT lore m (DistAcc lore)
distribute (DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels))
-> DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
forall a b. (a -> b) -> a -> b
$
Stms Kernels -> DistAcc Kernels -> DistAcc Kernels
forall lore. Stms lore -> DistAcc lore -> DistAcc lore
addStmsToAcc (Body Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Body Kernels -> Stms Kernels) -> Body Kernels -> Stms Kernels
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
lam') DistAcc Kernels
acc
KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms Kernels))
-> (KernelPath -> DistribM (Stms Kernels))
-> Pattern SOACS
-> Lambda
-> DistribM (Stms Kernels)
onMap' (LoopNesting -> KernelNest
newKernel LoopNesting
loopnest) KernelPath
path KernelPath -> DistribM (Stms Kernels)
exploitOuterParallelism KernelPath -> DistribM (Stms Kernels)
exploitInnerParallelism Pattern SOACS
pat Lambda
lam
where
acc :: DistAcc Kernels
acc =
DistAcc :: forall lore. Targets -> Stms lore -> DistAcc lore
DistAcc
{ distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
Pattern SOACS
pat, BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam),
distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty
}
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra :: Attrs -> Bool
onlyExploitIntra Attrs
attrs =
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
mayExploitOuter :: Attrs -> Bool
mayExploitOuter :: Attrs -> Bool
mayExploitOuter Attrs
attrs =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_outer"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
mayExploitIntra :: Attrs -> Bool
mayExploitIntra :: Attrs -> Bool
mayExploitIntra Attrs
attrs =
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"no_intra"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
Bool -> Bool -> Bool
|| Name -> [Attr] -> Attr
AttrComp Name
"incremental_flattening" [Attr
"only_inner"] Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
intraMinInnerPar :: Int64
intraMinInnerPar :: Int64
intraMinInnerPar = Int64
32
onMap' ::
KernelNest ->
KernelPath ->
(KernelPath -> DistribM (Out.Stms Out.Kernels)) ->
(KernelPath -> DistribM (Out.Stms Out.Kernels)) ->
Pattern ->
Lambda ->
DistribM (Out.Stms Out.Kernels)
onMap' :: KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms Kernels))
-> (KernelPath -> DistribM (Stms Kernels))
-> Pattern SOACS
-> Lambda
-> DistribM (Stms Kernels)
onMap' KernelNest
loopnest KernelPath
path KernelPath -> DistribM (Stms Kernels)
mk_seq_stms KernelPath -> DistribM (Stms Kernels)
mk_par_stms Pattern SOACS
pat Lambda
lam = do
Scope Kernels
types <- DistribM (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
intra <-
if Attrs -> Bool
onlyExploitIntra (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux)
Bool -> Bool -> Bool
|| (Lambda -> Bool
worthIntraGroup Lambda
lam Bool -> Bool -> Bool
&& Attrs -> Bool
mayExploitIntra Attrs
attrs)
then (ReaderT
(Scope Kernels)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
-> Scope Kernels
-> DistribM
(Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)))
-> Scope Kernels
-> ReaderT
(Scope Kernels)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT
(Scope Kernels)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
-> Scope Kernels
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT Scope Kernels
types (ReaderT
(Scope Kernels)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
-> DistribM
(Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)))
-> ReaderT
(Scope Kernels)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall a b. (a -> b) -> a -> b
$ KernelNest
-> Lambda
-> ReaderT
(Scope Kernels)
DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall (m :: * -> *).
(MonadFreshNames m, LocalScope Kernels m) =>
KernelNest
-> Lambda
-> m (Maybe
((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
intraGroupParallelise KernelNest
loopnest Lambda
lam
else Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> DistribM
(Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
forall a. Maybe a
Nothing
case Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
intra of
Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
_ | Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs -> do
Body Kernels
seq_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
mk_seq_stms KernelPath
path DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
seq_body []
Maybe ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
Nothing
| Just DistribM (SubExp, Name, Stms Kernels, Body Kernels)
m <- Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
mkSeqAlts -> do
(SubExp
outer_suff, Name
outer_suff_key, Stms Kernels
outer_suff_stms, Body Kernels
seq_body) <- DistribM (SubExp, Name, Stms Kernels, Body Kernels)
m
Body Kernels
par_body <-
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody
(Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
mk_par_stms ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(Stms Kernels
outer_suff_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>) (Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
par_body [(SubExp
outer_suff, Body Kernels
seq_body)]
| Bool
otherwise -> do
Body Kernels
par_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody (Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
mk_par_stms KernelPath
path DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
par_body []
Just intra' :: ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
intra'@((SubExp, SubExp)
_, SubExp
_, Log
log, Stms Kernels
intra_prelude, Stms Kernels
intra_stms)
| Attrs -> Bool
onlyExploitIntra Attrs
attrs -> do
Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
Body Kernels
group_par_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> Body Kernels -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
intra_stms Result
res
(Stms Kernels
intra_prelude Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>) (Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
group_par_body []
| Bool
otherwise -> do
Log -> DistribM ()
forall (m :: * -> *). MonadLogger m => Log -> m ()
addLog Log
log
case Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
mkSeqAlts of
Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
Nothing -> do
(Body Kernels
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms Kernels
intra_suff_stms) <-
KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> DistribM (Body Kernels, SubExp, Name, Stms Kernels)
checkSuffIntraPar KernelPath
path ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
intra'
Body Kernels
par_body <-
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody
(Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
mk_par_stms ((Name
intra_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(Stms Kernels
intra_suff_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>)
(Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives Pattern SOACS
Pattern Kernels
pat Body Kernels
par_body [(SubExp
intra_ok, Body Kernels
group_par_body)]
Just DistribM (SubExp, Name, Stms Kernels, Body Kernels)
m -> do
(SubExp
outer_suff, Name
outer_suff_key, Stms Kernels
outer_suff_stms, Body Kernels
seq_body) <- DistribM (SubExp, Name, Stms Kernels, Body Kernels)
m
(Body Kernels
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms Kernels
intra_suff_stms) <-
KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> DistribM (Body Kernels, SubExp, Name, Stms Kernels)
checkSuffIntraPar ((Name
outer_suff_key, Bool
False) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
intra'
Body Kernels
par_body <-
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody
(Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
mk_par_stms
( [ (Name
outer_suff_key, Bool
False),
(Name
intra_suff_key, Bool
False)
]
KernelPath -> KernelPath -> KernelPath
forall a. [a] -> [a] -> [a]
++ KernelPath
path
)
DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
((Stms Kernels
outer_suff_stms Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels
intra_suff_stms) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>)
(Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern Kernels
-> Body Kernels
-> [(SubExp, Body Kernels)]
-> DistribM (Stms Kernels)
forall (m :: * -> *).
(MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> Body Kernels -> [(SubExp, Body Kernels)] -> m (Stms Kernels)
kernelAlternatives
Pattern SOACS
Pattern Kernels
pat
Body Kernels
par_body
[(SubExp
outer_suff, Body Kernels
seq_body), (SubExp
intra_ok, Body Kernels
group_par_body)]
where
nest_ws :: Result
nest_ws = KernelNest -> Result
kernelNestWidths KernelNest
loopnest
res :: Result
res = (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern SOACS
pat
aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux (LoopNesting -> StmAux ()) -> LoopNesting -> StmAux ()
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
innermostKernelNesting KernelNest
loopnest
attrs :: Attrs
attrs = StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
mkSeqAlts :: Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
mkSeqAlts
| Lambda -> Bool
worthSequentialising Lambda
lam,
Attrs -> Bool
mayExploitOuter Attrs
attrs = DistribM (SubExp, Name, Stms Kernels, Body Kernels)
-> Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
forall a. a -> Maybe a
Just (DistribM (SubExp, Name, Stms Kernels, Body Kernels)
-> Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels)))
-> DistribM (SubExp, Name, Stms Kernels, Body Kernels)
-> Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
forall a b. (a -> b) -> a -> b
$ do
((SubExp
outer_suff, Name
outer_suff_key), Stms Kernels
outer_suff_stms) <- DistribM ((SubExp, Name), Stms Kernels)
checkSuffOuterPar
Body Kernels
seq_body <-
Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> DistribM (Body Kernels) -> DistribM (Body Kernels)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody
(Stms Kernels -> Result -> Body Kernels)
-> DistribM (Stms Kernels) -> DistribM (Result -> Body Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelPath -> DistribM (Stms Kernels)
mk_seq_stms ((Name
outer_suff_key, Bool
True) (Name, Bool) -> KernelPath -> KernelPath
forall a. a -> [a] -> [a]
: KernelPath
path) DistribM (Result -> Body Kernels)
-> DistribM Result -> DistribM (Body Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> DistribM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res
(SubExp, Name, Stms Kernels, Body Kernels)
-> DistribM (SubExp, Name, Stms Kernels, Body Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
outer_suff, Name
outer_suff_key, Stms Kernels
outer_suff_stms, Body Kernels
seq_body)
| Bool
otherwise =
Maybe (DistribM (SubExp, Name, Stms Kernels, Body Kernels))
forall a. Maybe a
Nothing
checkSuffOuterPar :: DistribM ((SubExp, Name), Stms Kernels)
checkSuffOuterPar =
[Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms Kernels)
sufficientParallelism [Char]
"suff_outer_par" Result
nest_ws KernelPath
path Maybe Int64
forall a. Maybe a
Nothing
checkSuffIntraPar :: KernelPath
-> ((SubExp, SubExp), SubExp, Log, Stms Kernels, Stms Kernels)
-> DistribM (Body Kernels, SubExp, Name, Stms Kernels)
checkSuffIntraPar
KernelPath
path'
((SubExp
_intra_min_par, SubExp
intra_avail_par), SubExp
group_size, Log
_, Stms Kernels
intra_prelude, Stms Kernels
intra_stms) = do
((SubExp
intra_ok, Name
intra_suff_key), Stms Kernels
intra_suff_stms) <- do
((SubExp
intra_suff, Name
suff_key), Stms Kernels
check_suff_stms) <-
[Char]
-> Result
-> KernelPath
-> Maybe Int64
-> DistribM ((SubExp, Name), Stms Kernels)
sufficientParallelism
[Char]
"suff_intra_par"
[SubExp
intra_avail_par]
KernelPath
path'
(Int64 -> Maybe Int64
forall a. a -> Maybe a
Just Int64
intraMinInnerPar)
Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), Stms Kernels))
-> Binder Kernels (SubExp, Name)
-> DistribM ((SubExp, Name), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
intra_prelude
SubExp
max_group_size <-
[Char]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"max_group_size" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp Kernels (SOAC Kernels)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp Kernels (SOAC Kernels))
-> SizeOp -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
Out.GetSizeMax SizeClass
Out.SizeGroup
SubExp
fits <-
[Char]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"fits" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
group_size SubExp
max_group_size
Stms (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms Stms (Lore (BinderT Kernels (State VNameSource)))
Stms Kernels
check_suff_stms
SubExp
intra_ok <- [Char]
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
[Char] -> Exp (Lore m) -> m SubExp
letSubExp [Char]
"intra_suff_and_fits" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
fits SubExp
intra_suff
(SubExp, Name) -> Binder Kernels (SubExp, Name)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
intra_ok, Name
suff_key)
Body Kernels
group_par_body <- Body Kernels -> DistribM (Body Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body Kernels -> DistribM (Body Kernels))
-> Body Kernels -> DistribM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> Result -> Body Kernels
forall lore. Bindable lore => Stms lore -> Result -> Body lore
mkBody Stms Kernels
intra_stms Result
res
(Body Kernels, SubExp, Name, Stms Kernels)
-> DistribM (Body Kernels, SubExp, Name, Stms Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body Kernels
group_par_body, SubExp
intra_ok, Name
intra_suff_key, Stms Kernels
intra_suff_stms)
onInnerMap ::
KernelPath ->
MapLoop ->
DistAcc Out.Kernels ->
DistNestT Out.Kernels DistribM (DistAcc Out.Kernels)
onInnerMap :: KernelPath
-> MapLoop
-> DistAcc Kernels
-> DistNestT Kernels DistribM (DistAcc Kernels)
onInnerMap KernelPath
path maploop :: MapLoop
maploop@(MapLoop Pattern SOACS
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs) DistAcc Kernels
acc
| Lambda -> Bool
unbalancedLambda Lambda
lam,
Lambda -> Bool
lambdaContainsParallelism Lambda
lam =
Stm SOACS
-> DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore) =>
Stm SOACS -> DistAcc lore -> DistNestT lore m (DistAcc lore)
addStmToAcc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistAcc Kernels
acc
| Bool
otherwise =
DistAcc Kernels
-> Stm SOACS
-> DistNestT
Kernels
DistribM
(Maybe (PostStms Kernels, Result, KernelNest, DistAcc Kernels))
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
DistAcc lore
-> Stm SOACS
-> DistNestT
lore m (Maybe (PostStms lore, Result, KernelNest, DistAcc lore))
distributeSingleStm DistAcc Kernels
acc (MapLoop -> Stm SOACS
mapLoopStm MapLoop
maploop) DistNestT
Kernels
DistribM
(Maybe (PostStms Kernels, Result, KernelNest, DistAcc Kernels))
-> (Maybe (PostStms Kernels, Result, KernelNest, DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels))
-> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
Just (PostStms Kernels
post_kernels, Result
res, KernelNest
nest, DistAcc Kernels
acc')
| Just ([Int]
perm, [PatElemT Type]
_pat_unused) <- PatternT Type -> Result -> Maybe ([Int], [PatElemT Type])
permutationAndMissing PatternT Type
Pattern SOACS
pat Result
res -> do
PostStms Kernels -> DistNestT Kernels DistribM ()
forall (m :: * -> *) lore.
Monad m =>
PostStms lore -> DistNestT lore m ()
addPostStms PostStms Kernels
post_kernels
[Int]
-> KernelNest
-> DistAcc Kernels
-> DistNestT Kernels DistribM (DistAcc Kernels)
multiVersion [Int]
perm KernelNest
nest DistAcc Kernels
acc'
Maybe (PostStms Kernels, Result, KernelNest, DistAcc Kernels)
_ -> MapLoop
-> DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap MapLoop
maploop DistAcc Kernels
acc
where
discardTargets :: DistAcc lore -> DistAcc lore
discardTargets DistAcc lore
acc' =
DistAcc lore
acc' {distTargets :: Targets
distTargets = Target -> Targets
singleTarget (PatternT Type
forall a. Monoid a => a
mempty, Result
forall a. Monoid a => a
mempty)}
multiVersion :: [Int]
-> KernelNest
-> DistAcc Kernels
-> DistNestT Kernels DistribM (DistAcc Kernels)
multiVersion [Int]
perm KernelNest
nest DistAcc Kernels
acc' = do
DistEnv Kernels DistribM
dist_env <- DistNestT Kernels DistribM (DistEnv Kernels DistribM)
forall r (m :: * -> *). MonadReader r m => m r
ask
let extra_scope :: Scope Kernels
extra_scope = Targets -> Scope Kernels
forall lore. DistLore lore => Targets -> Scope lore
targetsScope (Targets -> Scope Kernels) -> Targets -> Scope Kernels
forall a b. (a -> b) -> a -> b
$ DistAcc Kernels -> Targets
forall lore. DistAcc lore -> Targets
distTargets DistAcc Kernels
acc'
Stms Kernels
stms <- DistribM (Stms Kernels)
-> DistNestT Kernels DistribM (Stms Kernels)
forall lore (m :: * -> *) a.
(LocalScope lore m, DistLore lore) =>
m a -> DistNestT lore m a
liftInner (DistribM (Stms Kernels)
-> DistNestT Kernels DistribM (Stms Kernels))
-> DistribM (Stms Kernels)
-> DistNestT Kernels DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels -> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
extra_scope (DistribM (Stms Kernels) -> DistribM (Stms Kernels))
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
let maploop' :: MapLoop
maploop' = Pattern SOACS
-> StmAux () -> SubExp -> Lambda -> [VName] -> MapLoop
MapLoop Pattern SOACS
pat StmAux ()
aux SubExp
w Lambda
lam [VName]
arrs
exploitInnerParallelism :: KernelPath -> DistribM (Stms Kernels)
exploitInnerParallelism KernelPath
path' = do
let dist_env' :: DistEnv Kernels DistribM
dist_env' =
DistEnv Kernels DistribM
dist_env
{ distOnTopLevelStms :: Stms SOACS -> DistNestT Kernels DistribM (Stms Kernels)
distOnTopLevelStms = KernelPath
-> Stms SOACS -> DistNestT Kernels DistribM (Stms Kernels)
onTopLevelStms KernelPath
path',
distOnInnerMap :: MapLoop
-> DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
distOnInnerMap = KernelPath
-> MapLoop
-> DistAcc Kernels
-> DistNestT Kernels DistribM (DistAcc Kernels)
onInnerMap KernelPath
path'
}
DistEnv Kernels DistribM
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels)
forall (m :: * -> *) lore.
(MonadLogger m, DistLore lore) =>
DistEnv lore m -> DistNestT lore m (DistAcc lore) -> m (Stms lore)
runDistNestT DistEnv Kernels DistribM
dist_env' (DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels))
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$
KernelNest
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) lore a.
(Monad m, DistLore lore) =>
KernelNest -> DistNestT lore m a -> DistNestT lore m a
inNesting KernelNest
nest (DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels))
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels)
forall a b. (a -> b) -> a -> b
$
Scope Kernels
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
extra_scope (DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels))
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels)
forall a b. (a -> b) -> a -> b
$
DistAcc Kernels -> DistAcc Kernels
forall {lore}. DistAcc lore -> DistAcc lore
discardTargets (DistAcc Kernels -> DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels)
-> DistNestT Kernels DistribM (DistAcc Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MapLoop
-> DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, LocalScope lore m, DistLore lore) =>
MapLoop -> DistAcc lore -> DistNestT lore m (DistAcc lore)
distributeMap MapLoop
maploop' DistAcc Kernels
acc {distStms :: Stms Kernels
distStms = Stms Kernels
forall a. Monoid a => a
mempty}
let lam_res' :: Result
lam_res' =
[Int] -> Result -> Result
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$
BodyT SOACS -> Result
forall lore. BodyT lore -> Result
bodyResult (BodyT SOACS -> Result) -> BodyT SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam
lam' :: Lambda
lam' = Lambda
lam {lambdaBody :: BodyT SOACS
lambdaBody = (Lambda -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda
lam) {bodyResult :: Result
bodyResult = Result
lam_res'}}
map_nesting :: LoopNesting
map_nesting = PatternT Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting PatternT Type
Pattern SOACS
pat StmAux ()
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda
lam) [VName]
arrs
nest' :: KernelNest
nest' = Target -> LoopNesting -> KernelNest -> KernelNest
pushInnerKernelNesting (PatternT Type
Pattern SOACS
pat, Result
lam_res') LoopNesting
map_nesting KernelNest
nest
(Stm Kernels
sequentialised_kernel, Stms Kernels
nestw_bnds) <- Scope Kernels
-> DistribM (Stm Kernels, Stms Kernels)
-> DistribM (Stm Kernels, Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
extra_scope (DistribM (Stm Kernels, Stms Kernels)
-> DistribM (Stm Kernels, Stms Kernels))
-> DistribM (Stm Kernels, Stms Kernels)
-> DistribM (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
let sequentialised_lam :: Lambda Kernels
sequentialised_lam = Lambda -> Lambda Kernels
soacsLambdaToKernels Lambda
lam'
MkSegLevel Kernels DistribM
-> KernelNest
-> Body Kernels
-> DistribM (Stm Kernels, Stms Kernels)
forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, LocalScope lore m) =>
MkSegLevel lore m
-> KernelNest -> Body lore -> m (Stm lore, Stms lore)
constructKernel MkSegLevel Kernels DistribM
forall (m :: * -> *). MonadFreshNames m => MkSegLevel Kernels m
segThreadCapped KernelNest
nest' (Body Kernels -> DistribM (Stm Kernels, Stms Kernels))
-> Body Kernels -> DistribM (Stm Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
sequentialised_lam
let outer_pat :: PatternT Type
outer_pat = LoopNesting -> PatternT Type
loopNestingPattern (LoopNesting -> PatternT Type) -> LoopNesting -> PatternT Type
forall a b. (a -> b) -> a -> b
$ KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
nest
(Stms Kernels
nestw_bnds Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<>)
(Stms Kernels -> Stms Kernels)
-> DistribM (Stms Kernels) -> DistribM (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KernelNest
-> KernelPath
-> (KernelPath -> DistribM (Stms Kernels))
-> (KernelPath -> DistribM (Stms Kernels))
-> Pattern SOACS
-> Lambda
-> DistribM (Stms Kernels)
onMap'
KernelNest
nest'
KernelPath
path
(DistribM (Stms Kernels) -> KernelPath -> DistribM (Stms Kernels)
forall a b. a -> b -> a
const (DistribM (Stms Kernels) -> KernelPath -> DistribM (Stms Kernels))
-> DistribM (Stms Kernels) -> KernelPath -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stms Kernels -> DistribM (Stms Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms Kernels -> DistribM (Stms Kernels))
-> Stms Kernels -> DistribM (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Stms Kernels
forall lore. Stm lore -> Stms lore
oneStm Stm Kernels
sequentialised_kernel)
KernelPath -> DistribM (Stms Kernels)
exploitInnerParallelism
PatternT Type
Pattern SOACS
outer_pat
Lambda
lam'
Stms Kernels -> DistNestT Kernels DistribM ()
forall (m :: * -> *) lore.
Monad m =>
Stms lore -> DistNestT lore m ()
postStm Stms Kernels
stms
DistAcc Kernels -> DistNestT Kernels DistribM (DistAcc Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return DistAcc Kernels
acc'