{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractMulticore (extractMulticore) where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.MC
import qualified Futhark.IR.MC as MC
import Futhark.IR.SOACS hiding
  ( Body,
    Exp,
    LParam,
    Lambda,
    Pattern,
    Stm,
  )
import qualified Futhark.IR.SOACS as SOACS
import qualified Futhark.IR.SOACS.Simplify as SOACS
import Futhark.Pass
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ToKernels (injectSOACS)
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename (Rename, renameSomething)
import Futhark.Util (takeLast)
import Futhark.Util.Log

newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a)
  deriving
    ( (forall a b. (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b. a -> ExtractM b -> ExtractM a) -> Functor ExtractM
forall a b. a -> ExtractM b -> ExtractM a
forall a b. (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ExtractM b -> ExtractM a
$c<$ :: forall a b. a -> ExtractM b -> ExtractM a
fmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
$cfmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
Functor,
      Functor ExtractM
Functor ExtractM
-> (forall a. a -> ExtractM a)
-> (forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b c.
    (a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM a)
-> Applicative ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
$c<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
liftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
$c<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
pure :: forall a. a -> ExtractM a
$cpure :: forall a. a -> ExtractM a
Applicative,
      Applicative ExtractM
Applicative ExtractM
-> (forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a. a -> ExtractM a)
-> Monad ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ExtractM a
$creturn :: forall a. a -> ExtractM a
>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
$c>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
Monad,
      HasScope MC,
      LocalScope MC,
      Monad ExtractM
Applicative ExtractM
ExtractM VNameSource
Applicative ExtractM
-> Monad ExtractM
-> ExtractM VNameSource
-> (VNameSource -> ExtractM ())
-> MonadFreshNames ExtractM
VNameSource -> ExtractM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> ExtractM ()
$cputNameSource :: VNameSource -> ExtractM ()
getNameSource :: ExtractM VNameSource
$cgetNameSource :: ExtractM VNameSource
MonadFreshNames
    )

-- XXX: throwing away the log here...
instance MonadLogger ExtractM where
  addLog :: Log -> ExtractM ()
addLog Log
_ = () -> ExtractM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i (Param VName
p LParamInfo SOACS
t) VName
arr =
  Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT (TypeBase Shape NoUniqueness)]
-> [PatElemT (TypeBase Shape NoUniqueness)]
-> PatternT (TypeBase Shape NoUniqueness)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> TypeBase Shape NoUniqueness
-> PatElemT (TypeBase Shape NoUniqueness)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
p TypeBase Shape NoUniqueness
LParamInfo SOACS
t]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
    BasicOp -> Exp MC
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp MC) -> BasicOp -> Exp MC
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
LParamInfo SOACS
t)

mapLambdaToBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Body MC)
mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  let indexings :: [Stm MC]
indexings = (Param (TypeBase Shape NoUniqueness) -> VName -> Stm MC)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName] -> [Stm MC]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i) (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam) [VName]
arrs
  Body () Stms MC
stms [SubExp]
res <- [Stm MC] -> ExtractM (Body MC) -> ExtractM (Body MC)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf [Stm MC]
indexings (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
onBody (Body SOACS -> ExtractM (Body MC))
-> Body SOACS -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam
  Body MC -> ExtractM (Body MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body MC -> ExtractM (Body MC)) -> Body MC -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [SubExp] -> Body MC
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () ([Stm MC] -> Stms MC
forall lore. [Stm lore] -> Stms lore
stmsFromList [Stm MC]
indexings Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
stms) [SubExp]
res

mapLambdaToKernelBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (KernelBody MC)
mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  Body () Stms MC
stms [SubExp]
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs
  KernelBody MC -> ExtractM (KernelBody MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelBody MC -> ExtractM (KernelBody MC))
-> KernelBody MC -> ExtractM (KernelBody MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms MC
stms ([KernelResult] -> KernelBody MC)
-> [KernelResult] -> KernelBody MC
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) [SubExp]
res

reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Binder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
comm Lambda MC
lam'' [SubExp]
nes' Shape
shape)

scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp (Scan Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Binder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Shape -> SegBinOp lore
SegBinOp Commutativity
Noncommutative Lambda MC
lam'' [SubExp]
nes' Shape
shape)

histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC)
histToSegBinOp :: HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp (SOACS.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) = do
  ((Lambda SOACS
op', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Binder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Binder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBinder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
  Lambda MC
op'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
op'
  (Stms MC, HistOp MC) -> ExtractM (Stms MC, HistOp MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
stms, SubExp
-> SubExp -> [VName] -> [SubExp] -> Shape -> Lambda MC -> HistOp MC
forall lore.
SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda lore
-> HistOp lore
MC.HistOp SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda MC
op'')

mkSegSpace :: MonadFreshNames m => SubExp -> m (VName, SegSpace)
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w = do
  VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_tid"
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat [(VName
gtid, SubExp
w)]
  (VName, SegSpace) -> m (VName, SegSpace)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
gtid, SegSpace
space)

transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm (WhileLoop VName
cond) = VName -> LoopForm MC
forall lore. VName -> LoopForm lore
WhileLoop VName
cond
transformLoopForm (ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
params) = VName -> IntType -> SubExp -> [(LParam MC, VName)] -> LoopForm MC
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam MC, VName)]
params

transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
op)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp MC
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
op
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
ret (Safety, SrcLoc, [SrcLoc])
info)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType MC]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp MC
forall lore.
Name
-> [(SubExp, Diet)]
-> [RetType lore]
-> (Safety, SrcLoc, [SrcLoc])
-> ExpT lore
Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
[RetType MC]
ret (Safety, SrcLoc, [SrcLoc])
info
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form Body SOACS
body)) = do
  let form' :: LoopForm MC
form' = LoopForm SOACS -> LoopForm MC
transformLoopForm LoopForm SOACS
form
  Body MC
body' <-
    Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope
      ( [Param DeclType] -> Scope MC
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx)
          Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope MC
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val)
          Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> LoopForm MC -> Scope MC
forall lore a. Scoped lore a => a -> Scope lore
scopeOf LoopForm MC
form'
      )
      (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ [(FParam MC, SubExp)]
-> [(FParam MC, SubExp)] -> LoopForm MC -> Body MC -> Exp MC
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
[(FParam MC, SubExp)]
ctx [(FParam SOACS, SubExp)]
[(FParam MC, SubExp)]
val LoopForm MC
form' Body MC
body'
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (If SubExp
cond Body SOACS
tbranch Body SOACS
fbranch IfDec (BranchType SOACS)
ret)) =
  Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Body MC -> Body MC -> IfDec (BranchType MC) -> Exp MC
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond (Body MC -> Body MC -> IfDec ExtType -> Exp MC)
-> ExtractM (Body MC)
-> ExtractM (Body MC -> IfDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
tbranch ExtractM (Body MC -> IfDec ExtType -> Exp MC)
-> ExtractM (Body MC) -> ExtractM (IfDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
fbranch ExtractM (IfDec ExtType -> Exp MC)
-> ExtractM (IfDec ExtType) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec ExtType -> ExtractM (IfDec ExtType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure IfDec ExtType
IfDec (BranchType SOACS)
ret)
transformStm (Let Pattern SOACS
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
op)) =
  (Stm MC -> Stm MC) -> Stms MC -> Stms MC
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> Stm MC -> Stm MC
forall lore. Certificates -> Stm lore -> Stm lore
certify (StmAux () -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pattern SOACS -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pattern SOACS
pat (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Op SOACS
SOAC SOACS
op

transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda (Lambda [LParam SOACS]
params Body SOACS
body [TypeBase Shape NoUniqueness]
ret) =
  [LParam MC]
-> Body MC -> [TypeBase Shape NoUniqueness] -> Lambda MC
forall lore.
[LParam lore]
-> BodyT lore -> [TypeBase Shape NoUniqueness] -> LambdaT lore
Lambda [LParam SOACS]
[LParam MC]
params
    (Body MC -> [TypeBase Shape NoUniqueness] -> Lambda MC)
-> ExtractM (Body MC)
-> ExtractM ([TypeBase Shape NoUniqueness] -> Lambda MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope MC
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
params) (Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body)
    ExtractM ([TypeBase Shape NoUniqueness] -> Lambda MC)
-> ExtractM [TypeBase Shape NoUniqueness] -> ExtractM (Lambda MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase Shape NoUniqueness]
-> ExtractM [TypeBase Shape NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase Shape NoUniqueness]
ret

transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms =
  case Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms SOACS
stms of
    Maybe (Stm SOACS, Stms SOACS)
Nothing -> Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms MC
forall a. Monoid a => a
mempty
    Just (Stm SOACS
stm, Stms SOACS
stms') -> do
      Stms MC
stm_stms <- Stm SOACS -> ExtractM (Stms MC)
transformStm Stm SOACS
stm
      Stms MC -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms MC
stm_stms (ExtractM (Stms MC) -> ExtractM (Stms MC))
-> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ (Stms MC
stm_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<>) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms'

transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody (Body () Stms SOACS
stms [SubExp]
res) =
  BodyDec MC -> Stms MC -> [SubExp] -> Body MC
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms MC -> [SubExp] -> Body MC)
-> ExtractM (Stms MC) -> ExtractM ([SubExp] -> Body MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms ExtractM ([SubExp] -> Body MC)
-> ExtractM [SubExp] -> ExtractM (Body MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> ExtractM [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
res

sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody = Body MC -> ExtractM (Body MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC))
-> (Body SOACS -> Body MC) -> Body SOACS -> ExtractM (Body MC)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Body MC) -> Body MC
forall a. Identity a -> a
runIdentity (Identity (Body MC) -> Body MC)
-> (Body SOACS -> Identity (Body MC)) -> Body SOACS -> Body MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS MC -> Body SOACS -> Identity (Body MC)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser Identity SOACS MC
toMC
  where
    toMC :: Rephraser Identity SOACS MC
toMC = (SOAC MC -> Op MC) -> Rephraser Identity SOACS MC
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC MC -> Op MC
forall lore op. op -> MCOp lore op
OtherOp

transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = do
  Body MC
body' <- Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param DeclType] -> Scope MC
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  FunDef MC -> ExtractM (FunDef MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDef MC -> ExtractM (FunDef MC))
-> FunDef MC -> ExtractM (FunDef MC)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType MC]
-> [FParam MC]
-> Body MC
-> FunDef MC
forall lore.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType lore]
-> [FParam lore]
-> BodyT lore
-> FunDef lore
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType MC]
rettype [FParam SOACS]
[FParam MC]
params Body MC
body'

-- Sets the chunk size to one.
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
nes Lambda SOACS
lam = do
  let (Param (TypeBase Shape NoUniqueness)
chunk_param, [Param (TypeBase Shape NoUniqueness)]
acc_params, [Param (TypeBase Shape NoUniqueness)]
slice_params) =
        Int
-> [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam)

  [Param (TypeBase Shape NoUniqueness)]
inp_params <- [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness)
    -> ExtractM (Param (TypeBase Shape NoUniqueness)))
-> ExtractM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape NoUniqueness)]
slice_params ((Param (TypeBase Shape NoUniqueness)
  -> ExtractM (Param (TypeBase Shape NoUniqueness)))
 -> ExtractM [Param (TypeBase Shape NoUniqueness)])
-> (Param (TypeBase Shape NoUniqueness)
    -> ExtractM (Param (TypeBase Shape NoUniqueness)))
-> ExtractM [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param VName
p TypeBase Shape NoUniqueness
t) ->
    String
-> TypeBase Shape NoUniqueness
-> ExtractM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
p) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType TypeBase Shape NoUniqueness
t)

  Body SOACS
body <- Binder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (Body SOACS) -> ExtractM (Body SOACS))
-> Binder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    Scope SOACS
-> Binder SOACS (Body SOACS) -> Binder SOACS (Body SOACS)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param (TypeBase Shape NoUniqueness)]
inp_params) (Binder SOACS (Body SOACS) -> Binder SOACS (Body SOACS))
-> Binder SOACS (Body SOACS) -> Binder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_param] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      [(Param (TypeBase Shape NoUniqueness), SubExp)]
-> ((Param (TypeBase Shape NoUniqueness), SubExp)
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
acc_params [SubExp]
nes) (((Param (TypeBase Shape NoUniqueness), SubExp)
  -> BinderT SOACS (State VNameSource) ())
 -> BinderT SOACS (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness), SubExp)
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, SubExp
ne) ->
        [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      [(Param (TypeBase Shape NoUniqueness),
  Param (TypeBase Shape NoUniqueness))]
-> ((Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [(Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
slice_params [Param (TypeBase Shape NoUniqueness)]
inp_params) (((Param (TypeBase Shape NoUniqueness),
   Param (TypeBase Shape NoUniqueness))
  -> BinderT SOACS (State VNameSource) ())
 -> BinderT SOACS (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
slice, Param (TypeBase Shape NoUniqueness)
v) ->
        [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
slice] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
v] (Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
v)

      ([SubExp]
red_res, [SubExp]
map_res) <- Int -> [SubExp] -> ([SubExp], [SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([SubExp] -> ([SubExp], [SubExp]))
-> BinderT SOACS (State VNameSource) [SubExp]
-> BinderT SOACS (State VNameSource) ([SubExp], [SubExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Lambda SOACS -> Body SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam)

      [SubExp]
map_res' <- [SubExp]
-> (SubExp -> BinderT SOACS (State VNameSource) SubExp)
-> BinderT SOACS (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExp]
map_res ((SubExp -> BinderT SOACS (State VNameSource) SubExp)
 -> BinderT SOACS (State VNameSource) [SubExp])
-> (SubExp -> BinderT SOACS (State VNameSource) SubExp)
-> BinderT SOACS (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \SubExp
se -> do
        VName
v <- String
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"map_res" (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) VName)
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        TypeBase Shape NoUniqueness
v_t <- VName
-> BinderT SOACS (State VNameSource) (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
        String
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"chunk" (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) SubExp)
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$
            VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
              TypeBase Shape NoUniqueness -> Slice SubExp -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]

      Body SOACS -> Binder SOACS (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> Binder SOACS (Body SOACS))
-> Body SOACS -> Binder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body SOACS
forall lore. Bindable lore => [SubExp] -> Body lore
resultBody ([SubExp] -> Body SOACS) -> [SubExp] -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [SubExp]
red_res [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> [SubExp]
map_res'

  let ([TypeBase Shape NoUniqueness]
red_ts, [TypeBase Shape NoUniqueness]
map_ts) = Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      map_lam :: Lambda SOACS
map_lam =
        Lambda :: forall lore.
[LParam lore]
-> BodyT lore -> [TypeBase Shape NoUniqueness] -> LambdaT lore
Lambda
          { lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
red_ts [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType [TypeBase Shape NoUniqueness]
map_ts,
            lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
inp_params,
            lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
          }

  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Lambda SOACS
map_lam' <- ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
-> Scope SOACS -> ExtractM (Lambda SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Lambda SOACS -> ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
SOACS.simplifyLambda Lambda SOACS
map_lam) Scope SOACS
soacs_scope

  if Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
    then Lambda SOACS -> ExtractM (Lambda SOACS)
forall (m :: * -> *) lore somelore.
(MonadFreshNames m, Bindable lore, BinderOps lore,
 LocalScope somelore m, SameScope somelore lore,
 LetDec lore ~ LetDec SOACS) =>
Lambda SOACS -> m (Lambda lore)
FOT.transformLambda Lambda SOACS
map_lam'
    else Lambda SOACS -> ExtractM (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Lambda SOACS
map_lam'

-- Code generation for each parallel basic block is parameterised over
-- how we handle parallelism in the body (whether it's sequentialised
-- by keeping it as SOACs, or turned into SegOps).

data NeedsRename = DoRename | DoNotRename

renameIfNeeded :: Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded :: forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
DoRename = a -> ExtractM a
forall a (m :: * -> *). (Rename a, MonadFreshNames m) => a -> m a
renameSomething
renameIfNeeded NeedsRename
DoNotRename = a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

transformMap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (SegOp () MC)
transformMap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
    ()
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl lore.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegMap () SegSpace
space (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody

transformRedomap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [Reduce SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformRedomap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
reds_stms, [SegBinOp MC]
reds') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Reduce SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp [Reduce SOACS]
reds
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegRed () SegSpace
space [SegBinOp MC]
reds' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms MC]
reds_stms, SegOp () MC
op')

transformHist ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformHist :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
hists_stms, [HistOp MC]
hists') <- [(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC]))
-> ExtractM [(Stms MC, HistOp MC)]
-> ExtractM ([Stms MC], [HistOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp SOACS -> ExtractM (Stms MC, HistOp MC))
-> [HistOp SOACS] -> ExtractM [(Stms MC, HistOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp [HistOp SOACS]
hists
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [HistOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegHist () SegSpace
space [HistOp MC]
hists' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Stms MC]
hists_stms, SegOp () MC
op')

transformParStream ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Commutativity ->
  Lambda SOACS ->
  [SubExp] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Stms MC, SegOp () MC)
transformParStream :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  (Stms MC
red_stms, SegBinOp MC
red) <- Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes
  SegOp () MC
op <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegRed () SegSpace
space [SegBinOp MC
red] (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  (Stms MC, SegOp () MC) -> ExtractM (Stms MC, SegOp () MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC
red_stms, SegOp () MC
op)

transformSOAC :: Pattern SOACS -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC :: Pattern SOACS -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pattern SOACS
pat Attrs
_ (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form = do
    SegOp () MC
seq_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w Lambda SOACS
lam [VName]
arrs
    if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam
      then do
        SegOp () MC
par_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Lambda SOACS
lam [VName]
arrs
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
      else Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form = do
    ([Stms MC]
seq_reds_stms, SegOp () MC
seq_op) <-
      NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
    if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
      then do
        ([Stms MC]
par_reds_stms, SegOp () MC
par_op) <-
          NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_reds_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_reds_stms)
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
      else
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_reds_stms
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form = do
    (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
    KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
    ([Stms MC]
scans_stms, [SegBinOp MC]
scans') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scan SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Scan SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp [Scan SOACS]
scans
    Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
      [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
scans_stms
        Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm
          ( Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
              Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
                Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
                  ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegScan () SegSpace
space [SegBinOp MC]
scans' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
          )
  | Bool
otherwise = do
    -- This screma is too complicated for us to immediately do
    -- anything, so split it up and try again.
    Scope SOACS
scope <- Scope MC -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
    Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> ExtractM (Stms SOACS) -> ExtractM (Stms MC)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) lore.
MonadFreshNames m =>
BinderT lore m () -> Scope lore -> m (Stms lore)
runBinderT_ (Pattern (Lore (BinderT SOACS ExtractM))
-> SubExp
-> ScremaForm (Lore (BinderT SOACS ExtractM))
-> [VName]
-> BinderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m), Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> ScremaForm (Lore m) -> [VName] -> m ()
dissectScrema Pattern (Lore (BinderT SOACS ExtractM))
Pattern SOACS
pat SubExp
w ScremaForm (Lore (BinderT SOACS ExtractM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformSOAC Pattern SOACS
pat Attrs
_ (Scatter SubExp
w Lambda SOACS
lam [VName]
ivs [(Shape, Int, VName)]
dests) = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w

  Body () Stms MC
kstms [SubExp]
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
lam [VName]
ivs

  let rets :: [TypeBase Shape NoUniqueness]
rets = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
takeLast ([(Shape, Int, VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, Int, VName)]
dests) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall lore. LambdaT lore -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      kres :: [KernelResult]
kres = do
        (Shape
a_w, VName
a, [([SubExp], SubExp)]
is_vs) <-
          [(Shape, Int, VName)]
-> [SubExp] -> [(Shape, VName, [([SubExp], SubExp)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests [SubExp]
res
        KernelResult -> [KernelResult]
forall (m :: * -> *) a. Monad m => a -> m a
return (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Shape
a_w VName
a [((SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix [SubExp]
is, SubExp
v) | ([SubExp]
is, SubExp
v) <- [([SubExp], SubExp)]
is_vs]
      kbody :: KernelBody MC
kbody = BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms MC
kstms [KernelResult]
kres
  Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
    Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$
      Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
        Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
          Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
            ()
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl lore.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody lore
-> SegOp lvl lore
SegMap () SegSpace
space [TypeBase Shape NoUniqueness]
rets KernelBody MC
kbody
transformSOAC Pattern SOACS
pat Attrs
_ (Hist SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs) = do
  ([Stms MC]
seq_hist_stms, SegOp () MC
seq_op) <-
    NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs

  if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
    then do
      ([Stms MC]
par_hist_stms, SegOp () MC
par_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs
      Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_hist_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_hist_stms)
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
    else
      Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_hist_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pattern SOACS
pat Attrs
attrs (Stream SubExp
w [VName]
arrs (Parallel StreamOrd
_ Commutativity
comm Lambda SOACS
red_lam) [SubExp]
red_nes Lambda SOACS
fold_lam)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes = do
    Lambda SOACS
map_lam <- Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
red_nes Lambda SOACS
fold_lam
    (Stms MC
seq_red_stms, SegOp () MC
seq_op) <-
      NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream
        NeedsRename
DoNotRename
        Body SOACS -> ExtractM (Body MC)
sequentialiseBody
        SubExp
w
        Commutativity
comm
        Lambda SOACS
red_lam
        [SubExp]
red_nes
        Lambda SOACS
map_lam
        [VName]
arrs

    if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
      then do
        (Stms MC
par_red_stms, SegOp () MC
par_op) <-
          NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          Stms MC
seq_red_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
par_red_stms
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
      else
        Stms MC -> ExtractM (Stms MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
          Stms MC
seq_red_stms
            Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall lore. Stm lore -> Stms lore
oneStm (Pattern MC -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern SOACS
Pattern MC
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall lore. Op lore -> ExpT lore
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall lore op.
Maybe (SegOp () lore) -> SegOp () lore -> MCOp lore op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pattern SOACS
pat Attrs
_ (Stream SubExp
w [VName]
arrs StreamForm SOACS
_ [SubExp]
nes Lambda SOACS
lam) = do
  -- Just remove the stream and transform the resulting stms.
  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  Stms SOACS
stream_stms <-
    (BinderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS))
-> Scope SOACS
-> BinderT SOACS ExtractM ()
-> ExtractM (Stms SOACS)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) lore.
MonadFreshNames m =>
BinderT lore m () -> Scope lore -> m (Stms lore)
runBinderT_ Scope SOACS
soacs_scope (BinderT SOACS ExtractM () -> ExtractM (Stms SOACS))
-> BinderT SOACS ExtractM () -> ExtractM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
      Pattern (Lore (BinderT SOACS ExtractM))
-> SubExp
-> [SubExp]
-> LambdaT (Lore (BinderT SOACS ExtractM))
-> [VName]
-> BinderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBinder m, Bindable (Lore m)) =>
Pattern (Lore m)
-> SubExp -> [SubExp] -> LambdaT (Lore m) -> [VName] -> m ()
sequentialStreamWholeArray Pattern (Lore (BinderT SOACS ExtractM))
Pattern SOACS
pat SubExp
w [SubExp]
nes LambdaT (Lore (BinderT SOACS ExtractM))
Lambda SOACS
lam [VName]
arrs
  Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stream_stms

transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg (Prog Stms SOACS
consts [FunDef SOACS]
funs) =
  (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC))
-> (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Prog MC)
-> VNameSource -> (Prog MC, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Scope MC) (State VNameSource) (Prog MC)
-> Scope MC -> State VNameSource (Prog MC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope MC) (State VNameSource) (Prog MC)
m Scope MC
forall a. Monoid a => a
mempty)
  where
    ExtractM ReaderT (Scope MC) (State VNameSource) (Prog MC)
m = do
      Stms MC
consts' <- Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
consts
      [FunDef MC]
funs' <- Stms MC -> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms MC
consts' (ExtractM [FunDef MC] -> ExtractM [FunDef MC])
-> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> ExtractM (FunDef MC))
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef [FunDef SOACS]
funs
      Prog MC -> ExtractM (Prog MC)
forall (m :: * -> *) a. Monad m => a -> m a
return (Prog MC -> ExtractM (Prog MC)) -> Prog MC -> ExtractM (Prog MC)
forall a b. (a -> b) -> a -> b
$ Stms MC -> [FunDef MC] -> Prog MC
forall lore. Stms lore -> [FunDef lore] -> Prog lore
Prog Stms MC
consts' [FunDef MC]
funs'

extractMulticore :: Pass SOACS MC
extractMulticore :: Pass SOACS MC
extractMulticore =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
    { passName :: String
passName = String
"extract multicore parallelism",
      passDescription :: String
passDescription = String
"Extract multicore parallelism",
      passFunction :: Prog SOACS -> PassM (Prog MC)
passFunction = Prog SOACS -> PassM (Prog MC)
transformProg
    }