{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

-- | Extraction of parallelism from a SOACs program.  This generates
-- parallel constructs aimed at CPU execution, which in particular may
-- involve ad-hoc irregular nested parallelism.
module Futhark.Pass.ExtractMulticore (extractMulticore) where

import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bitraversable
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.MC
import qualified Futhark.IR.MC as MC
import Futhark.IR.SOACS hiding
  ( Body,
    Exp,
    LParam,
    Lambda,
    Pat,
    Stm,
  )
import qualified Futhark.IR.SOACS as SOACS
import qualified Futhark.IR.SOACS.Simplify as SOACS
import Futhark.Pass
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ToGPU (injectSOACS)
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT
import Futhark.Transform.Rename (Rename, renameSomething)
import Futhark.Util (takeLast)
import Futhark.Util.Log

newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a)
  deriving
    ( (forall a b. (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b. a -> ExtractM b -> ExtractM a) -> Functor ExtractM
forall a b. a -> ExtractM b -> ExtractM a
forall a b. (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> ExtractM b -> ExtractM a
$c<$ :: forall a b. a -> ExtractM b -> ExtractM a
fmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
$cfmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
Functor,
      Functor ExtractM
Functor ExtractM
-> (forall a. a -> ExtractM a)
-> (forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b c.
    (a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM a)
-> Applicative ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
$c<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
liftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
$c<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
pure :: forall a. a -> ExtractM a
$cpure :: forall a. a -> ExtractM a
Applicative,
      Applicative ExtractM
Applicative ExtractM
-> (forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a. a -> ExtractM a)
-> Monad ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> ExtractM a
$creturn :: forall a. a -> ExtractM a
>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
$c>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
Monad,
      HasScope MC,
      LocalScope MC,
      Monad ExtractM
ExtractM VNameSource
Monad ExtractM
-> ExtractM VNameSource
-> (VNameSource -> ExtractM ())
-> MonadFreshNames ExtractM
VNameSource -> ExtractM ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
putNameSource :: VNameSource -> ExtractM ()
$cputNameSource :: VNameSource -> ExtractM ()
getNameSource :: ExtractM VNameSource
$cgetNameSource :: ExtractM VNameSource
MonadFreshNames
    )

-- XXX: throwing away the log here...
instance MonadLogger ExtractM where
  addLog :: Log -> ExtractM ()
addLog Log
_ = () -> ExtractM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i (Param Attrs
_ VName
p LParamInfo SOACS
t) VName
arr =
  Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem (TypeBase Shape NoUniqueness)]
-> Pat (TypeBase Shape NoUniqueness)
forall dec. [PatElem dec] -> Pat dec
Pat [VName
-> TypeBase Shape NoUniqueness
-> PatElem (TypeBase Shape NoUniqueness)
forall dec. VName -> dec -> PatElem dec
PatElem VName
p TypeBase Shape NoUniqueness
LParamInfo SOACS
t]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> (BasicOp -> Exp MC) -> BasicOp -> Stm MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp MC
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm MC) -> BasicOp -> Stm MC
forall a b. (a -> b) -> a -> b
$
    case LParamInfo SOACS
t of
      Acc {} -> SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      LParamInfo SOACS
_ -> VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (TypeBase Shape NoUniqueness -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims TypeBase Shape NoUniqueness
LParamInfo SOACS
t)

mapLambdaToBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Body MC)
mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  let indexings :: [Stm MC]
indexings = (Param (TypeBase Shape NoUniqueness) -> VName -> Stm MC)
-> [Param (TypeBase Shape NoUniqueness)] -> [VName] -> [Stm MC]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i) (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
  Body () Stms MC
stms [SubExpRes]
res <- [Stm MC] -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf [Stm MC]
indexings (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
onBody (Body SOACS -> ExtractM (Body MC))
-> Body SOACS -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
  Body MC -> ExtractM (Body MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC)) -> Body MC -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [SubExpRes] -> Body MC
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () ([Stm MC] -> Stms MC
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm MC]
indexings Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
stms) [SubExpRes]
res

mapLambdaToKernelBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (KernelBody MC)
mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  Body () Stms MC
stms [SubExpRes]
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs
  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
  KernelBody MC -> ExtractM (KernelBody MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody MC -> ExtractM (KernelBody MC))
-> KernelBody MC -> ExtractM (KernelBody MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
stms ([KernelResult] -> KernelBody MC)
-> [KernelResult] -> KernelBody MC
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> [SubExpRes] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret [SubExpRes]
res

reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Builder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  let comm' :: Commutativity
comm'
        | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam' = Commutativity
Commutative
        | Bool
otherwise = Commutativity
comm
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda MC
lam'' [SubExp]
nes' Shape
shape)

scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp (Scan Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Builder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Commutativity -> Lambda MC -> [SubExp] -> Shape -> SegBinOp MC
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda MC
lam'' [SubExp]
nes' Shape
shape)

histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC)
histToSegBinOp :: HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp (SOACS.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) = do
  ((Lambda SOACS
op', [SubExp]
nes', Shape
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], Shape)
 -> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], Shape)
-> ExtractM ((Lambda SOACS, [SubExp], Shape), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp] -> Builder MC (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
  Lambda MC
op'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
op'
  (Stms MC, HistOp MC) -> ExtractM (Stms MC, HistOp MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Shape
-> SubExp -> [VName] -> [SubExp] -> Shape -> Lambda MC -> HistOp MC
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
MC.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda MC
op'')

mkSegSpace :: MonadFreshNames m => SubExp -> m (VName, SegSpace)
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w = do
  VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_tid"
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat [(VName
gtid, SubExp
w)]
  (VName, SegSpace) -> m (VName, SegSpace)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
gtid, SegSpace
space)

transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm :: LoopForm SOACS -> LoopForm MC
transformLoopForm (WhileLoop VName
cond) = VName -> LoopForm MC
forall rep. VName -> LoopForm rep
WhileLoop VName
cond
transformLoopForm (ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
params) = VName -> IntType -> SubExp -> [(LParam MC, VName)] -> LoopForm MC
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
[(LParam MC, VName)]
params

transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
op)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp MC
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
ret (Safety, SrcLoc, [SrcLoc])
info)) =
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [RetType MC]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp MC
forall rep.
Name
-> [(SubExp, Diet)]
-> [RetType rep]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
[RetType MC]
ret (Safety, SrcLoc, [SrcLoc])
info
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (DoLoop [(FParam SOACS, SubExp)]
merge LoopForm SOACS
form Body SOACS
body)) = do
  let form' :: LoopForm MC
form' = LoopForm SOACS -> LoopForm MC
transformLoopForm LoopForm SOACS
form
  Body MC
body' <-
    Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge) Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> LoopForm MC -> Scope MC
forall rep a. Scoped rep a => a -> Scope rep
scopeOf LoopForm MC
form') (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$
      Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ [(FParam MC, SubExp)] -> LoopForm MC -> Body MC -> Exp MC
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
[(FParam MC, SubExp)]
merge LoopForm MC
form' Body MC
body'
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp]
-> [Case (Body MC)]
-> Body MC
-> MatchDec (BranchType MC)
-> Exp MC
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses ([Case (Body MC)] -> Body MC -> MatchDec ExtType -> Exp MC)
-> ExtractM [Case (Body MC)]
-> ExtractM (Body MC -> MatchDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body SOACS) -> ExtractM (Case (Body MC)))
-> [Case (Body SOACS)] -> ExtractM [Case (Body MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Case (Body SOACS) -> ExtractM (Case (Body MC))
transformCase [Case (Body SOACS)]
cases ExtractM (Body MC -> MatchDec ExtType -> Exp MC)
-> ExtractM (Body MC) -> ExtractM (MatchDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
defbody ExtractM (MatchDec ExtType -> Exp MC)
-> ExtractM (MatchDec ExtType) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MatchDec ExtType -> ExtractM (MatchDec ExtType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure MatchDec ExtType
MatchDec (BranchType SOACS)
ret)
  where
    transformCase :: Case (Body SOACS) -> ExtractM (Case (Body MC))
transformCase (Case [Maybe PrimValue]
vs Body SOACS
body) = [Maybe PrimValue] -> Body MC -> Case (Body MC)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body MC -> Case (Body MC))
-> ExtractM (Body MC) -> ExtractM (Case (Body MC))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([WithAccInput MC] -> Lambda MC -> Exp MC
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ([WithAccInput MC] -> Lambda MC -> Exp MC)
-> ExtractM [WithAccInput MC] -> ExtractM (Lambda MC -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput SOACS -> ExtractM (WithAccInput MC))
-> [WithAccInput SOACS] -> ExtractM [WithAccInput MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM WithAccInput SOACS -> ExtractM (WithAccInput MC)
forall {t :: * -> *} {t :: * -> * -> *} {t} {t} {d}.
(Traversable t, Bitraversable t) =>
(t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput [WithAccInput SOACS]
inputs ExtractM (Lambda MC -> Exp MC)
-> ExtractM (Lambda MC) -> ExtractM (Exp MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam)
  where
    transformInput :: (t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput (t
shape, t
arrs, t (t (Lambda SOACS) d)
op) =
      (t
shape,t
arrs,) (t (t (Lambda MC) d) -> (t, t, t (t (Lambda MC) d)))
-> ExtractM (t (t (Lambda MC) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t (Lambda SOACS) d -> ExtractM (t (Lambda MC) d))
-> t (t (Lambda SOACS) d) -> ExtractM (t (t (Lambda MC) d))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((Lambda SOACS -> ExtractM (Lambda MC))
-> (d -> ExtractM d)
-> t (Lambda SOACS) d
-> ExtractM (t (Lambda MC) d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Lambda SOACS -> ExtractM (Lambda MC)
transformLambda d -> ExtractM d
forall (f :: * -> *) a. Applicative f => a -> f a
pure) t (t (Lambda SOACS) d)
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
op)) =
  (Stm MC -> Stm MC) -> Stms MC -> Stms MC
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm MC -> Stm MC
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat (TypeBase Shape NoUniqueness)
-> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat (TypeBase Shape NoUniqueness)
Pat (LetDec SOACS)
pat (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Op SOACS
SOAC SOACS
op

transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda (Lambda [LParam SOACS]
params Body SOACS
body [TypeBase Shape NoUniqueness]
ret) =
  [LParam MC]
-> Body MC -> [TypeBase Shape NoUniqueness] -> Lambda MC
forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda [LParam SOACS]
[LParam MC]
params
    (Body MC -> [TypeBase Shape NoUniqueness] -> Lambda MC)
-> ExtractM (Body MC)
-> ExtractM ([TypeBase Shape NoUniqueness] -> Lambda MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope MC
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
params) (Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body)
    ExtractM ([TypeBase Shape NoUniqueness] -> Lambda MC)
-> ExtractM [TypeBase Shape NoUniqueness] -> ExtractM (Lambda MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [TypeBase Shape NoUniqueness]
-> ExtractM [TypeBase Shape NoUniqueness]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [TypeBase Shape NoUniqueness]
ret

transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms =
  case Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms of
    Maybe (Stm SOACS, Stms SOACS)
Nothing -> Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms MC
forall a. Monoid a => a
mempty
    Just (Stm SOACS
stm, Stms SOACS
stms') -> do
      Stms MC
stm_stms <- Stm SOACS -> ExtractM (Stms MC)
transformStm Stm SOACS
stm
      Stms MC -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
stm_stms (ExtractM (Stms MC) -> ExtractM (Stms MC))
-> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ (Stms MC
stm_stms Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<>) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms'

transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody (Body () Stms SOACS
stms [SubExpRes]
res) =
  BodyDec MC -> Stms MC -> [SubExpRes] -> Body MC
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () (Stms MC -> [SubExpRes] -> Body MC)
-> ExtractM (Stms MC) -> ExtractM ([SubExpRes] -> Body MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms ExtractM ([SubExpRes] -> Body MC)
-> ExtractM [SubExpRes] -> ExtractM (Body MC)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExpRes] -> ExtractM [SubExpRes]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExpRes]
res

sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody = Body MC -> ExtractM (Body MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC))
-> (Body SOACS -> Body MC) -> Body SOACS -> ExtractM (Body MC)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Body MC) -> Body MC
forall a. Identity a -> a
runIdentity (Identity (Body MC) -> Body MC)
-> (Body SOACS -> Identity (Body MC)) -> Body SOACS -> Body MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS MC -> Body SOACS -> Identity (Body MC)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser Identity SOACS MC
toMC
  where
    toMC :: Rephraser Identity SOACS MC
toMC = (SOAC MC -> Op MC) -> Rephraser Identity SOACS MC
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC MC -> Op MC
forall rep op. op -> MCOp rep op
OtherOp

transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
rettype [FParam SOACS]
params Body SOACS
body) = do
  Body MC
body' <- Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
[FParam SOACS]
params) (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  FunDef MC -> ExtractM (FunDef MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef MC -> ExtractM (FunDef MC))
-> FunDef MC -> ExtractM (FunDef MC)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [RetType MC]
-> [FParam MC]
-> Body MC
-> FunDef MC
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [RetType rep]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [RetType SOACS]
[RetType MC]
rettype [FParam SOACS]
[FParam MC]
params Body MC
body'

-- Sets the chunk size to one.
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda :: Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
nes Lambda SOACS
lam = do
  let (Param (TypeBase Shape NoUniqueness)
chunk_param, [Param (TypeBase Shape NoUniqueness)]
acc_params, [Param (TypeBase Shape NoUniqueness)]
slice_params) =
        Int
-> [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness),
    [Param (TypeBase Shape NoUniqueness)],
    [Param (TypeBase Shape NoUniqueness)])
forall dec.
Int -> [Param dec] -> (Param dec, [Param dec], [Param dec])
partitionChunkedFoldParameters ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam)

  [Param (TypeBase Shape NoUniqueness)]
inp_params <- [Param (TypeBase Shape NoUniqueness)]
-> (Param (TypeBase Shape NoUniqueness)
    -> ExtractM (Param (TypeBase Shape NoUniqueness)))
-> ExtractM [Param (TypeBase Shape NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param (TypeBase Shape NoUniqueness)]
slice_params ((Param (TypeBase Shape NoUniqueness)
  -> ExtractM (Param (TypeBase Shape NoUniqueness)))
 -> ExtractM [Param (TypeBase Shape NoUniqueness)])
-> (Param (TypeBase Shape NoUniqueness)
    -> ExtractM (Param (TypeBase Shape NoUniqueness)))
-> ExtractM [Param (TypeBase Shape NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \(Param Attrs
_ VName
p TypeBase Shape NoUniqueness
t) ->
    String
-> TypeBase Shape NoUniqueness
-> ExtractM (Param (TypeBase Shape NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
p) (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType TypeBase Shape NoUniqueness
t)

  Body SOACS
body <- Builder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
 SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (Body SOACS) -> ExtractM (Body SOACS))
-> Builder SOACS (Body SOACS) -> ExtractM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    Scope SOACS
-> Builder SOACS (Body SOACS) -> Builder SOACS (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (TypeBase Shape NoUniqueness)] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param (TypeBase Shape NoUniqueness)]
inp_params) (Builder SOACS (Body SOACS) -> Builder SOACS (Body SOACS))
-> Builder SOACS (Body SOACS) -> Builder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
      [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
chunk_param] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
        BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
          SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
            IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1

      [(Param (TypeBase Shape NoUniqueness), SubExp)]
-> ((Param (TypeBase Shape NoUniqueness), SubExp)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [SubExp] -> [(Param (TypeBase Shape NoUniqueness), SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
acc_params [SubExp]
nes) (((Param (TypeBase Shape NoUniqueness), SubExp)
  -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness), SubExp)
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
p, SubExp
ne) ->
        [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
p] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne

      [(Param (TypeBase Shape NoUniqueness),
  Param (TypeBase Shape NoUniqueness))]
-> ((Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param (TypeBase Shape NoUniqueness)]
-> [Param (TypeBase Shape NoUniqueness)]
-> [(Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param (TypeBase Shape NoUniqueness)]
slice_params [Param (TypeBase Shape NoUniqueness)]
inp_params) (((Param (TypeBase Shape NoUniqueness),
   Param (TypeBase Shape NoUniqueness))
  -> BuilderT SOACS (State VNameSource) ())
 -> BuilderT SOACS (State VNameSource) ())
-> ((Param (TypeBase Shape NoUniqueness),
     Param (TypeBase Shape NoUniqueness))
    -> BuilderT SOACS (State VNameSource) ())
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param (TypeBase Shape NoUniqueness)
slice, Param (TypeBase Shape NoUniqueness)
v) ->
        [VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
slice] (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
            [SubExp] -> TypeBase Shape NoUniqueness -> BasicOp
ArrayLit [VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Param (TypeBase Shape NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase Shape NoUniqueness)
v] (Param (TypeBase Shape NoUniqueness) -> TypeBase Shape NoUniqueness
forall dec. Typed dec => Param dec -> TypeBase Shape NoUniqueness
paramType Param (TypeBase Shape NoUniqueness)
v)

      ([SubExpRes]
red_res, [SubExpRes]
map_res) <- Int -> [SubExpRes] -> ([SubExpRes], [SubExpRes])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([SubExpRes] -> ([SubExpRes], [SubExpRes]))
-> BuilderT SOACS (State VNameSource) [SubExpRes]
-> BuilderT SOACS (State VNameSource) ([SubExpRes], [SubExpRes])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) [SubExpRes]
forall (m :: * -> *).
MonadBuilder m =>
Body (Rep m) -> m [SubExpRes]
bodyBind (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)

      [SubExp]
map_res' <- [SubExpRes]
-> (SubExpRes -> BuilderT SOACS (State VNameSource) SubExp)
-> BuilderT SOACS (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [SubExpRes]
map_res ((SubExpRes -> BuilderT SOACS (State VNameSource) SubExp)
 -> BuilderT SOACS (State VNameSource) [SubExp])
-> (SubExpRes -> BuilderT SOACS (State VNameSource) SubExp)
-> BuilderT SOACS (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$ \(SubExpRes Certs
cs SubExp
se) -> do
        VName
v <- String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"map_res" (Exp (Rep (BuilderT SOACS (State VNameSource)))
 -> BuilderT SOACS (State VNameSource) VName)
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
        TypeBase Shape NoUniqueness
v_t <- VName
-> BuilderT SOACS (State VNameSource) (TypeBase Shape NoUniqueness)
forall rep (m :: * -> *).
HasScope rep m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType VName
v
        Certs
-> BuilderT SOACS (State VNameSource) SubExp
-> BuilderT SOACS (State VNameSource) SubExp
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (BuilderT SOACS (State VNameSource) SubExp
 -> BuilderT SOACS (State VNameSource) SubExp)
-> (BasicOp -> BuilderT SOACS (State VNameSource) SubExp)
-> BasicOp
-> BuilderT SOACS (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT SOACS (State VNameSource) SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"chunk" (Exp SOACS -> BuilderT SOACS (State VNameSource) SubExp)
-> (BasicOp -> Exp SOACS)
-> BasicOp
-> BuilderT SOACS (State VNameSource) SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> BuilderT SOACS (State VNameSource) SubExp)
-> BasicOp -> BuilderT SOACS (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
          VName -> Slice SubExp -> BasicOp
Index VName
v (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
            TypeBase Shape NoUniqueness -> [DimIndex SubExp] -> Slice SubExp
fullSlice TypeBase Shape NoUniqueness
v_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]

      Body SOACS -> Builder SOACS (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> Builder SOACS (Body SOACS))
-> Body SOACS -> Builder SOACS (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [SubExpRes] -> Body SOACS
forall rep. Buildable rep => Stms rep -> [SubExpRes] -> Body rep
mkBody Stms SOACS
forall a. Monoid a => a
mempty ([SubExpRes] -> Body SOACS) -> [SubExpRes] -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
red_res [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> [SubExpRes]
subExpsRes [SubExp]
map_res'

  let ([TypeBase Shape NoUniqueness]
red_ts, [TypeBase Shape NoUniqueness]
map_ts) = Int
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) ([TypeBase Shape NoUniqueness]
 -> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness]))
-> [TypeBase Shape NoUniqueness]
-> ([TypeBase Shape NoUniqueness], [TypeBase Shape NoUniqueness])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      map_lam :: Lambda SOACS
map_lam =
        Lambda :: forall rep.
[LParam rep]
-> Body rep -> [TypeBase Shape NoUniqueness] -> Lambda rep
Lambda
          { lambdaReturnType :: [TypeBase Shape NoUniqueness]
lambdaReturnType = [TypeBase Shape NoUniqueness]
red_ts [TypeBase Shape NoUniqueness]
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. [a] -> [a] -> [a]
++ (TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness)
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape NoUniqueness -> TypeBase Shape NoUniqueness
forall u. TypeBase Shape u -> TypeBase Shape u
rowType [TypeBase Shape NoUniqueness]
map_ts,
            lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase Shape NoUniqueness)]
[LParam SOACS]
inp_params,
            lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
          }

  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Lambda SOACS
map_lam' <- ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
-> Scope SOACS -> ExtractM (Lambda SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Lambda SOACS -> ReaderT (Scope SOACS) ExtractM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
SOACS.simplifyLambda Lambda SOACS
map_lam) Scope SOACS
soacs_scope

  if Attr
"sequential_inner" Attr -> Attrs -> Bool
`inAttrs` Attrs
attrs
    then Lambda SOACS -> ExtractM (Lambda SOACS)
forall (m :: * -> *) rep somerep.
(MonadFreshNames m, Buildable rep, BuilderOps rep,
 LocalScope somerep m, SameScope somerep rep,
 LetDec rep ~ LetDec SOACS, CanBeAliased (Op rep)) =>
Lambda SOACS -> m (Lambda rep)
FOT.transformLambda Lambda SOACS
map_lam'
    else Lambda SOACS -> ExtractM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Lambda SOACS
map_lam'

-- Code generation for each parallel basic block is parameterised over
-- how we handle parallelism in the body (whether it's sequentialised
-- by keeping it as SOACs, or turned into SegOps).

data NeedsRename = DoRename | DoNotRename

renameIfNeeded :: Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded :: forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
DoRename = a -> ExtractM a
forall a (m :: * -> *). (Rename a, MonadFreshNames m) => a -> m a
renameSomething
renameIfNeeded NeedsRename
DoNotRename = a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

transformMap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (SegOp () MC)
transformMap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
    ()
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap () SegSpace
space (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody

transformRedomap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [Reduce SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformRedomap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
reds_stms, [SegBinOp MC]
reds') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Reduce SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp [Reduce SOACS]
reds
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC]
reds' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
reds_stms, SegOp () MC
op')

transformHist ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformHist :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
hists_stms, [HistOp MC]
hists') <- [(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, HistOp MC)] -> ([Stms MC], [HistOp MC]))
-> ExtractM [(Stms MC, HistOp MC)]
-> ExtractM ([Stms MC], [HistOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (HistOp SOACS -> ExtractM (Stms MC, HistOp MC))
-> [HistOp SOACS] -> ExtractM [(Stms MC, HistOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp [HistOp SOACS]
hists
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [HistOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegHist () SegSpace
space [HistOp MC]
hists' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
hists_stms, SegOp () MC
op')

transformParStream ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Commutativity ->
  Lambda SOACS ->
  [SubExp] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Stms MC, SegOp () MC)
transformParStream :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  (Stms MC
red_stms, SegBinOp MC
red) <- Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes
  SegOp () MC
op <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC
red] (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  (Stms MC, SegOp () MC) -> ExtractM (Stms MC, SegOp () MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
red_stms, SegOp () MC
op)

transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC :: Pat (TypeBase Shape NoUniqueness)
-> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat (TypeBase Shape NoUniqueness)
_ Attrs
_ JVP {} =
  String -> ExtractM (Stms MC)
forall a. HasCallStack => String -> a
error String
"transformSOAC: unhandled JVP"
transformSOAC Pat (TypeBase Shape NoUniqueness)
_ Attrs
_ VJP {} =
  String -> ExtractM (Stms MC)
forall a. HasCallStack => String -> a
error String
"transformSOAC: unhandled VJP"
transformSOAC Pat (TypeBase Shape NoUniqueness)
pat Attrs
_ (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
      SegOp () MC
seq_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w Lambda SOACS
lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam
        then do
          SegOp () MC
par_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Lambda SOACS
lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      ([Stms MC]
seq_reds_stms, SegOp () MC
seq_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
        then do
          ([Stms MC]
par_reds_stms, SegOp () MC
par_op) <-
            NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_reds_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_reds_stms)
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_reds_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
      (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
      KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
      ([Stms MC]
scans_stms, [SegBinOp MC]
scans') <- [(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Stms MC, SegBinOp MC)] -> ([Stms MC], [SegBinOp MC]))
-> ExtractM [(Stms MC, SegBinOp MC)]
-> ExtractM ([Stms MC], [SegBinOp MC])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Scan SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Scan SOACS] -> ExtractM [(Stms MC, SegBinOp MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp [Scan SOACS]
scans
      Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
scans_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm
            ( Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
                Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
                  Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
                    ()
-> SegSpace
-> [SegBinOp MC]
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegScan () SegSpace
space [SegBinOp MC]
scans' (Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
            )
  | Bool
otherwise = do
      -- This screma is too complicated for us to immediately do
      -- anything, so split it up and try again.
      Scope SOACS
scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> ExtractM (Stms SOACS) -> ExtractM (Stms MC)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
-> SubExp
-> ScremaForm (Rep (BuilderT SOACS ExtractM))
-> [VName]
-> BuilderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat (TypeBase Shape NoUniqueness)
Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
pat SubExp
w ScremaForm (Rep (BuilderT SOACS ExtractM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformSOAC Pat (TypeBase Shape NoUniqueness)
pat Attrs
_ (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
dests) = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w

  Body () Stms MC
kstms [SubExpRes]
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
lam [VName]
ivs

  let rets :: [TypeBase Shape NoUniqueness]
rets = Int
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a. Int -> [a] -> [a]
takeLast ([(Shape, Int, VName)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(Shape, Int, VName)]
dests) ([TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness])
-> [TypeBase Shape NoUniqueness] -> [TypeBase Shape NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase Shape NoUniqueness]
forall rep. Lambda rep -> [TypeBase Shape NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      kres :: [KernelResult]
kres = do
        (Shape
a_w, VName
a, [([SubExpRes], SubExpRes)]
is_vs) <- [(Shape, Int, VName)]
-> [SubExpRes] -> [(Shape, VName, [([SubExpRes], SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests [SubExpRes]
res
        let cs :: Certs
cs =
              (([SubExpRes], SubExpRes) -> Certs)
-> [([SubExpRes], SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> [SubExpRes] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts ([SubExpRes] -> Certs)
-> (([SubExpRes], SubExpRes) -> [SubExpRes])
-> ([SubExpRes], SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExpRes], SubExpRes) -> [SubExpRes]
forall a b. (a, b) -> a
fst) [([SubExpRes], SubExpRes)]
is_vs
                Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> (([SubExpRes], SubExpRes) -> Certs)
-> [([SubExpRes], SubExpRes)] -> Certs
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> (([SubExpRes], SubExpRes) -> SubExpRes)
-> ([SubExpRes], SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExpRes], SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [([SubExpRes], SubExpRes)]
is_vs
            is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> [SubExpRes] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) [SubExpRes]
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | ([SubExpRes]
is, SubExpRes
v) <- [([SubExpRes], SubExpRes)]
is_vs]
        KernelResult -> [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Certs -> Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs Shape
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
      kbody :: KernelBody MC
kbody = BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
kstms [KernelResult]
kres
  Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
    Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$
      Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
        Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
          Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp MC (SOAC MC))
-> SegOp () MC -> MCOp MC (SOAC MC)
forall a b. (a -> b) -> a -> b
$
            ()
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [TypeBase Shape NoUniqueness]
-> KernelBody rep
-> SegOp lvl rep
SegMap () SegSpace
space [TypeBase Shape NoUniqueness]
rets KernelBody MC
kbody
transformSOAC Pat (TypeBase Shape NoUniqueness)
pat Attrs
_ (Hist SubExp
w [VName]
arrs [HistOp SOACS]
hists Lambda SOACS
map_lam) = do
  ([Stms MC]
seq_hist_stms, SegOp () MC
seq_op) <-
    NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs

  if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
    then do
      ([Stms MC]
par_hist_stms, SegOp () MC
par_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs
      Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_hist_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_hist_stms)
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
    else
      Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_hist_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pat (TypeBase Shape NoUniqueness)
pat Attrs
attrs (Stream SubExp
w [VName]
arrs (Parallel StreamOrd
_ Commutativity
comm Lambda SOACS
red_lam) [SubExp]
red_nes Lambda SOACS
fold_lam)
  | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
red_nes = do
      Lambda SOACS
map_lam <- Attrs -> [SubExp] -> Lambda SOACS -> ExtractM (Lambda SOACS)
unstreamLambda Attrs
attrs [SubExp]
red_nes Lambda SOACS
fold_lam
      (Stms MC
seq_red_stms, SegOp () MC
seq_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream
          NeedsRename
DoNotRename
          Body SOACS -> ExtractM (Body MC)
sequentialiseBody
          SubExp
w
          Commutativity
comm
          Lambda SOACS
red_lam
          [SubExp]
red_nes
          Lambda SOACS
map_lam
          [VName]
arrs

      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
        then do
          (Stms MC
par_red_stms, SegOp () MC
par_op) <-
            NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Commutativity
-> Lambda SOACS
-> [SubExp]
-> Lambda SOACS
-> [VName]
-> ExtractM (Stms MC, SegOp () MC)
transformParStream NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes Lambda SOACS
map_lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            Stms MC
seq_red_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
par_red_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else
          Stms MC -> ExtractM (Stms MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            Stms MC
seq_red_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (TypeBase Shape NoUniqueness)
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp MC (SOAC MC)
forall rep op. Maybe (SegOp () rep) -> SegOp () rep -> MCOp rep op
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pat (TypeBase Shape NoUniqueness)
pat Attrs
_ (Stream SubExp
w [VName]
arrs StreamForm SOACS
_ [SubExp]
nes Lambda SOACS
lam) = do
  -- Just remove the stream and transform the resulting stms.
  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms SOACS
stream_stms <-
    (BuilderT SOACS ExtractM ()
 -> Scope SOACS -> ExtractM (Stms SOACS))
-> Scope SOACS
-> BuilderT SOACS ExtractM ()
-> ExtractM (Stms SOACS)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ Scope SOACS
soacs_scope (BuilderT SOACS ExtractM () -> ExtractM (Stms SOACS))
-> BuilderT SOACS ExtractM () -> ExtractM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
      Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
-> SubExp
-> [SubExp]
-> Lambda (Rep (BuilderT SOACS ExtractM))
-> [VName]
-> BuilderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat (TypeBase Shape NoUniqueness)
Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
pat SubExp
w [SubExp]
nes Lambda (Rep (BuilderT SOACS ExtractM))
Lambda SOACS
lam [VName]
arrs
  Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stream_stms

transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg Prog SOACS
prog =
  (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC))
-> (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Prog MC)
-> VNameSource -> (Prog MC, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Scope MC) (State VNameSource) (Prog MC)
-> Scope MC -> State VNameSource (Prog MC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope MC) (State VNameSource) (Prog MC)
m Scope MC
forall a. Monoid a => a
mempty)
  where
    ExtractM ReaderT (Scope MC) (State VNameSource) (Prog MC)
m = do
      Stms MC
consts' <- Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> Stms SOACS -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Prog SOACS -> Stms SOACS
forall rep. Prog rep -> Stms rep
progConsts Prog SOACS
prog
      [FunDef MC]
funs' <- Stms MC -> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
consts' (ExtractM [FunDef MC] -> ExtractM [FunDef MC])
-> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> ExtractM (FunDef MC))
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef ([FunDef SOACS] -> ExtractM [FunDef MC])
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog
      Prog MC -> ExtractM (Prog MC)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog MC -> ExtractM (Prog MC)) -> Prog MC -> ExtractM (Prog MC)
forall a b. (a -> b) -> a -> b
$
        Prog SOACS
prog
          { progConsts :: Stms MC
progConsts = Stms MC
consts',
            progFuns :: [FunDef MC]
progFuns = [FunDef MC]
funs'
          }

-- | Transform a program using SOACs to a program in the 'MC'
-- representation, using some amount of flattening.
extractMulticore :: Pass SOACS MC
extractMulticore :: Pass SOACS MC
extractMulticore =
  Pass :: forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    { passName :: String
passName = String
"extract multicore parallelism",
      passDescription :: String
passDescription = String
"Extract multicore parallelism",
      passFunction :: Prog SOACS -> PassM (Prog MC)
passFunction = Prog SOACS -> PassM (Prog MC)
transformProg
    }