{-# LANGUAGE TypeFamilies #-}

-- | Extraction of parallelism from a SOACs program.  This generates
-- parallel constructs aimed at CPU execution, which in particular may
-- involve ad-hoc irregular nested parallelism.
module Futhark.Pass.ExtractMulticore (extractMulticore) where

import Control.Monad
import Control.Monad.Identity
import Control.Monad.Reader
import Control.Monad.State
import Data.Bitraversable
import Futhark.IR
import Futhark.IR.MC
import Futhark.IR.MC qualified as MC
import Futhark.IR.SOACS hiding
  ( Body,
    Exp,
    LParam,
    Lambda,
    Pat,
    Stm,
  )
import Futhark.IR.SOACS qualified as SOACS
import Futhark.Pass
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.ToGPU (injectSOACS)
import Futhark.Tools
import Futhark.Transform.Rename (Rename, renameSomething)
import Futhark.Util (takeLast)
import Futhark.Util.Log

newtype ExtractM a = ExtractM (ReaderT (Scope MC) (State VNameSource) a)
  deriving
    ( (forall a b. (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b. a -> ExtractM b -> ExtractM a) -> Functor ExtractM
forall a b. a -> ExtractM b -> ExtractM a
forall a b. (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
fmap :: forall a b. (a -> b) -> ExtractM a -> ExtractM b
$c<$ :: forall a b. a -> ExtractM b -> ExtractM a
<$ :: forall a b. a -> ExtractM b -> ExtractM a
Functor,
      Functor ExtractM
Functor ExtractM
-> (forall a. a -> ExtractM a)
-> (forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b)
-> (forall a b c.
    (a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM a)
-> Applicative ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> ExtractM a
pure :: forall a. a -> ExtractM a
$c<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
<*> :: forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
$cliftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
liftA2 :: forall a b c.
(a -> b -> c) -> ExtractM a -> ExtractM b -> ExtractM c
$c*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
*> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$c<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
<* :: forall a b. ExtractM a -> ExtractM b -> ExtractM a
Applicative,
      Applicative ExtractM
Applicative ExtractM
-> (forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b)
-> (forall a b. ExtractM a -> ExtractM b -> ExtractM b)
-> (forall a. a -> ExtractM a)
-> Monad ExtractM
forall a. a -> ExtractM a
forall a b. ExtractM a -> ExtractM b -> ExtractM b
forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
>>= :: forall a b. ExtractM a -> (a -> ExtractM b) -> ExtractM b
$c>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
>> :: forall a b. ExtractM a -> ExtractM b -> ExtractM b
$creturn :: forall a. a -> ExtractM a
return :: forall a. a -> ExtractM a
Monad,
      HasScope MC,
      LocalScope MC,
      Monad ExtractM
ExtractM VNameSource
Monad ExtractM
-> ExtractM VNameSource
-> (VNameSource -> ExtractM ())
-> MonadFreshNames ExtractM
VNameSource -> ExtractM ()
forall (m :: * -> *).
Monad m
-> m VNameSource -> (VNameSource -> m ()) -> MonadFreshNames m
$cgetNameSource :: ExtractM VNameSource
getNameSource :: ExtractM VNameSource
$cputNameSource :: VNameSource -> ExtractM ()
putNameSource :: VNameSource -> ExtractM ()
MonadFreshNames
    )

-- XXX: throwing away the log here...
instance MonadLogger ExtractM where
  addLog :: Log -> ExtractM ()
addLog Log
_ = () -> ExtractM ()
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray :: VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i (Param Attrs
_ VName
p LParamInfo SOACS
t) VName
arr =
  Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
p Type
LParamInfo SOACS
t]) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> (BasicOp -> Exp MC) -> BasicOp -> Stm MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp MC
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Stm MC) -> BasicOp -> Stm MC
forall a b. (a -> b) -> a -> b
$
    case LParamInfo SOACS
t of
      Acc {} -> SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      LParamInfo SOACS
_ -> VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i) DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
LParamInfo SOACS
t)

mapLambdaToBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (Body MC)
mapLambdaToBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  let indexings :: [Stm MC]
indexings = (Param Type -> VName -> Stm MC)
-> [Param Type] -> [VName] -> [Stm MC]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (VName -> LParam SOACS -> VName -> Stm MC
indexArray VName
i) (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
  Body () Stms MC
stms Result
res <- [Stm MC] -> ExtractM (Body MC) -> ExtractM (Body MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf [Stm MC]
indexings (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
onBody (Body SOACS -> ExtractM (Body MC))
-> Body SOACS -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
  Body MC -> ExtractM (Body MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC)) -> Body MC -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> Result -> Body MC
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () ([Stm MC] -> Stms MC
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm MC]
indexings Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stms MC
stms) Result
res

mapLambdaToKernelBody ::
  (Body SOACS -> ExtractM (Body MC)) ->
  VName ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (KernelBody MC)
mapLambdaToKernelBody :: (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs = do
  Body () Stms MC
stms Result
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
onBody VName
i Lambda SOACS
lam [VName]
arrs
  let ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se
  KernelBody MC -> ExtractM (KernelBody MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelBody MC -> ExtractM (KernelBody MC))
-> KernelBody MC -> ExtractM (KernelBody MC)
forall a b. (a -> b) -> a -> b
$ BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
stms ([KernelResult] -> KernelBody MC)
-> [KernelResult] -> KernelBody MC
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret Result
res

reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp :: Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp (Reduce Commutativity
comm Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', ShapeBase SubExp
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
-> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
 -> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
-> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp]
-> Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS
-> [SubExp] -> m (Lambda SOACS, [SubExp], ShapeBase SubExp)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  let comm' :: Commutativity
comm'
        | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
commutativeLambda Lambda SOACS
lam' = Commutativity
Commutative
        | Bool
otherwise = Commutativity
comm
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Commutativity
-> Lambda MC -> [SubExp] -> ShapeBase SubExp -> SegBinOp MC
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
comm' Lambda MC
lam'' [SubExp]
nes' ShapeBase SubExp
shape)

scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp :: Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp (Scan Lambda SOACS
lam [SubExp]
nes) = do
  ((Lambda SOACS
lam', [SubExp]
nes', ShapeBase SubExp
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
-> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
 -> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
-> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp]
-> Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS
-> [SubExp] -> m (Lambda SOACS, [SubExp], ShapeBase SubExp)
determineReduceOp Lambda SOACS
lam [SubExp]
nes
  Lambda MC
lam'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam'
  (Stms MC, SegBinOp MC) -> ExtractM (Stms MC, SegBinOp MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, Commutativity
-> Lambda MC -> [SubExp] -> ShapeBase SubExp -> SegBinOp MC
forall rep.
Commutativity
-> Lambda rep -> [SubExp] -> ShapeBase SubExp -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda MC
lam'' [SubExp]
nes' ShapeBase SubExp
shape)

histToSegBinOp :: SOACS.HistOp SOACS -> ExtractM (Stms MC, MC.HistOp MC)
histToSegBinOp :: HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp (SOACS.HistOp ShapeBase SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) = do
  ((Lambda SOACS
op', [SubExp]
nes', ShapeBase SubExp
shape), Stms MC
stms) <- Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
-> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
 -> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC))
-> Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
-> ExtractM ((Lambda SOACS, [SubExp], ShapeBase SubExp), Stms MC)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
-> [SubExp]
-> Builder MC (Lambda SOACS, [SubExp], ShapeBase SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS
-> [SubExp] -> m (Lambda SOACS, [SubExp], ShapeBase SubExp)
determineReduceOp Lambda SOACS
op [SubExp]
nes
  Lambda MC
op'' <- Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
op'
  (Stms MC, HistOp MC) -> ExtractM (Stms MC, HistOp MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC
stms, ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda MC
-> HistOp MC
forall rep.
ShapeBase SubExp
-> SubExp
-> [VName]
-> [SubExp]
-> ShapeBase SubExp
-> Lambda rep
-> HistOp rep
MC.HistOp ShapeBase SubExp
num_bins SubExp
rf [VName]
dests [SubExp]
nes' ShapeBase SubExp
shape Lambda MC
op'')

mkSegSpace :: (MonadFreshNames m) => SubExp -> m (VName, SegSpace)
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w = do
  VName
flat <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"flat_tid"
  VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
  let space :: SegSpace
space = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
flat [(VName
gtid, SubExp
w)]
  (VName, SegSpace) -> m (VName, SegSpace)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
gtid, SegSpace
space)

transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm :: Stm SOACS -> ExtractM (Stms MC)
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
op)) =
  Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp MC
forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Apply Name
f [(SubExp, Diet)]
args [(RetType SOACS, RetAls)]
ret (Safety, SrcLoc, [SrcLoc])
info)) =
  Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Name
-> [(SubExp, Diet)]
-> [(RetType MC, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp MC
forall rep.
Name
-> [(SubExp, Diet)]
-> [(RetType rep, RetAls)]
-> (Safety, SrcLoc, [SrcLoc])
-> Exp rep
Apply Name
f [(SubExp, Diet)]
args [(RetType SOACS, RetAls)]
[(RetType MC, RetAls)]
ret (Safety, SrcLoc, [SrcLoc])
info
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Loop [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
body)) = do
  Body MC
body' <-
    Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall a. Scope MC -> ExtractM a -> ExtractM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (FParamInfo MC)] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (FParamInfo MC), SubExp) -> Param (FParamInfo MC))
-> [(Param (FParamInfo MC), SubExp)] -> [Param (FParamInfo MC)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo MC), SubExp) -> Param (FParamInfo MC)
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
[(Param (FParamInfo MC), SubExp)]
merge) Scope MC -> Scope MC -> Scope MC
forall a. Semigroup a => a -> a -> a
<> LoopForm -> Scope MC
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form) (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$
      Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$ Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo MC), SubExp)] -> LoopForm -> Body MC -> Exp MC
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
[(Param (FParamInfo MC), SubExp)]
merge LoopForm
form Body MC
body'
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ret)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp]
-> [Case (Body MC)]
-> Body MC
-> MatchDec (BranchType MC)
-> Exp MC
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
ses ([Case (Body MC)] -> Body MC -> MatchDec ExtType -> Exp MC)
-> ExtractM [Case (Body MC)]
-> ExtractM (Body MC -> MatchDec ExtType -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Case (Body SOACS) -> ExtractM (Case (Body MC)))
-> [Case (Body SOACS)] -> ExtractM [Case (Body MC)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Case (Body SOACS) -> ExtractM (Case (Body MC))
transformCase [Case (Body SOACS)]
cases ExtractM (Body MC -> MatchDec ExtType -> Exp MC)
-> ExtractM (Body MC) -> ExtractM (MatchDec ExtType -> Exp MC)
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
defbody ExtractM (MatchDec ExtType -> Exp MC)
-> ExtractM (MatchDec ExtType) -> ExtractM (Exp MC)
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> MatchDec ExtType -> ExtractM (MatchDec ExtType)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure MatchDec ExtType
MatchDec (BranchType SOACS)
ret)
  where
    transformCase :: Case (Body SOACS) -> ExtractM (Case (Body MC))
transformCase (Case [Maybe PrimValue]
vs Body SOACS
body) = [Maybe PrimValue] -> Body MC -> Case (Body MC)
forall body. [Maybe PrimValue] -> body -> Case body
Case [Maybe PrimValue]
vs (Body MC -> Case (Body MC))
-> ExtractM (Body MC) -> ExtractM (Case (Body MC))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) =
  Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> (Exp MC -> Stm MC) -> Exp MC -> Stms MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
Pat (LetDec MC)
pat StmAux (ExpDec SOACS)
StmAux (ExpDec MC)
aux
    (Exp MC -> Stms MC) -> ExtractM (Exp MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([WithAccInput MC] -> Lambda MC -> Exp MC
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc ([WithAccInput MC] -> Lambda MC -> Exp MC)
-> ExtractM [WithAccInput MC] -> ExtractM (Lambda MC -> Exp MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (WithAccInput SOACS -> ExtractM (WithAccInput MC))
-> [WithAccInput SOACS] -> ExtractM [WithAccInput MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput SOACS -> ExtractM (WithAccInput MC)
forall {t :: * -> *} {t :: * -> * -> *} {t} {t} {d}.
(Traversable t, Bitraversable t) =>
(t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput [WithAccInput SOACS]
inputs ExtractM (Lambda MC -> Exp MC)
-> ExtractM (Lambda MC) -> ExtractM (Exp MC)
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Lambda SOACS -> ExtractM (Lambda MC)
transformLambda Lambda SOACS
lam)
  where
    transformInput :: (t, t, t (t (Lambda SOACS) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
transformInput (t
shape, t
arrs, t (t (Lambda SOACS) d)
op) =
      (t
shape,t
arrs,) (t (t (Lambda MC) d) -> (t, t, t (t (Lambda MC) d)))
-> ExtractM (t (t (Lambda MC) d))
-> ExtractM (t, t, t (t (Lambda MC) d))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (t (Lambda SOACS) d -> ExtractM (t (Lambda MC) d))
-> t (t (Lambda SOACS) d) -> ExtractM (t (t (Lambda MC) d))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> t a -> f (t b)
traverse ((Lambda SOACS -> ExtractM (Lambda MC))
-> (d -> ExtractM d)
-> t (Lambda SOACS) d
-> ExtractM (t (Lambda MC) d)
forall (f :: * -> *) a c b d.
Applicative f =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
forall (t :: * -> * -> *) (f :: * -> *) a c b d.
(Bitraversable t, Applicative f) =>
(a -> f c) -> (b -> f d) -> t a b -> f (t c d)
bitraverse Lambda SOACS -> ExtractM (Lambda MC)
transformLambda d -> ExtractM d
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure) t (t (Lambda SOACS) d)
op
transformStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op Op SOACS
op)) =
  (Stm MC -> Stm MC) -> Stms MC -> Stms MC
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm MC -> Stm MC
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat Type
Pat (LetDec SOACS)
pat (StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux) Op SOACS
SOAC SOACS
op

transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda :: Lambda SOACS -> ExtractM (Lambda MC)
transformLambda (Lambda [LParam SOACS]
params [Type]
ret Body SOACS
body) =
  [LParam MC] -> [Type] -> Body MC -> Lambda MC
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [LParam SOACS]
[LParam MC]
params [Type]
ret
    (Body MC -> Lambda MC)
-> ExtractM (Body MC) -> ExtractM (Lambda MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall a. Scope MC -> ExtractM a -> ExtractM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam MC] -> Scope MC
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
[LParam MC]
params) (Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body)

transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms :: Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms =
  case Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
stms of
    Maybe (Stm SOACS, Stms SOACS)
Nothing -> Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms MC
forall a. Monoid a => a
mempty
    Just (Stm SOACS
stm, Stms SOACS
stms') -> do
      Stms MC
stm_stms <- Stm SOACS -> ExtractM (Stms MC)
transformStm Stm SOACS
stm
      Stms MC -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
stm_stms (ExtractM (Stms MC) -> ExtractM (Stms MC))
-> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ (Stms MC
stm_stms <>) (Stms MC -> Stms MC) -> ExtractM (Stms MC) -> ExtractM (Stms MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms'

transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody :: Body SOACS -> ExtractM (Body MC)
transformBody (Body () Stms SOACS
stms Result
res) =
  BodyDec MC -> Stms MC -> Result -> Body MC
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () (Stms MC -> Result -> Body MC)
-> ExtractM (Stms MC) -> ExtractM (Result -> Body MC)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stms ExtractM (Result -> Body MC)
-> ExtractM Result -> ExtractM (Body MC)
forall a b. ExtractM (a -> b) -> ExtractM a -> ExtractM b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Result -> ExtractM Result
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Result
res

sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody :: Body SOACS -> ExtractM (Body MC)
sequentialiseBody = Body MC -> ExtractM (Body MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body MC -> ExtractM (Body MC))
-> (Body SOACS -> Body MC) -> Body SOACS -> ExtractM (Body MC)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identity (Body MC) -> Body MC
forall a. Identity a -> a
runIdentity (Identity (Body MC) -> Body MC)
-> (Body SOACS -> Identity (Body MC)) -> Body SOACS -> Body MC
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS MC -> Body SOACS -> Identity (Body MC)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Body from -> m (Body to)
rephraseBody Rephraser Identity SOACS MC
toMC
  where
    toMC :: Rephraser Identity SOACS MC
toMC = (SOAC MC -> Op MC) -> Rephraser Identity SOACS MC
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC MC -> Op MC
SOAC MC -> MCOp SOAC MC
forall (op :: * -> *) rep. op rep -> MCOp op rep
OtherOp

transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef :: FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef (FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [(RetType SOACS, RetAls)]
rettype [FParam SOACS]
params Body SOACS
body) = do
  Body MC
body' <- Scope MC -> ExtractM (Body MC) -> ExtractM (Body MC)
forall a. Scope MC -> ExtractM a -> ExtractM a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (FParamInfo MC)] -> Scope MC
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [FParam SOACS]
[Param (FParamInfo MC)]
params) (ExtractM (Body MC) -> ExtractM (Body MC))
-> ExtractM (Body MC) -> ExtractM (Body MC)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> ExtractM (Body MC)
transformBody Body SOACS
body
  FunDef MC -> ExtractM (FunDef MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef MC -> ExtractM (FunDef MC))
-> FunDef MC -> ExtractM (FunDef MC)
forall a b. (a -> b) -> a -> b
$ Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType MC, RetAls)]
-> [Param (FParamInfo MC)]
-> Body MC
-> FunDef MC
forall rep.
Maybe EntryPoint
-> Attrs
-> Name
-> [(RetType rep, RetAls)]
-> [FParam rep]
-> Body rep
-> FunDef rep
FunDef Maybe EntryPoint
entry Attrs
attrs Name
name [(RetType SOACS, RetAls)]
[(RetType MC, RetAls)]
rettype [FParam SOACS]
[Param (FParamInfo MC)]
params Body MC
body'

-- Code generation for each parallel basic block is parameterised over
-- how we handle parallelism in the body (whether it's sequentialised
-- by keeping it as SOACs, or turned into SegOps).

data NeedsRename = DoRename | DoNotRename

renameIfNeeded :: (Rename a) => NeedsRename -> a -> ExtractM a
renameIfNeeded :: forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
DoRename = a -> ExtractM a
forall a (m :: * -> *). (Rename a, MonadFreshNames m) => a -> m a
renameSomething
renameIfNeeded NeedsRename
DoNotRename = a -> ExtractM a
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure

transformMap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  ExtractM (SegOp () MC)
transformMap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
    () -> SegSpace -> [Type] -> KernelBody MC -> SegOp () MC
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap () SegSpace
space (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody

transformRedomap ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [Reduce SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformRedomap :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
reds_stms, [SegBinOp MC]
reds') <- (Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Reduce SOACS] -> ExtractM ([Stms MC], [SegBinOp MC])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Reduce SOACS -> ExtractM (Stms MC, SegBinOp MC)
reduceToSegBinOp [Reduce SOACS]
reds
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [SegBinOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegRed () SegSpace
space [SegBinOp MC]
reds' (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
reds_stms, SegOp () MC
op')

transformHist ::
  NeedsRename ->
  (Body SOACS -> ExtractM (Body MC)) ->
  SubExp ->
  [SOACS.HistOp SOACS] ->
  Lambda SOACS ->
  [VName] ->
  ExtractM ([Stms MC], SegOp () MC)
transformHist :: NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
rename Body SOACS -> ExtractM (Body MC)
onBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
  KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
onBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
  ([Stms MC]
hists_stms, [HistOp MC]
hists') <- (HistOp SOACS -> ExtractM (Stms MC, HistOp MC))
-> [HistOp SOACS] -> ExtractM ([Stms MC], [HistOp MC])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM HistOp SOACS -> ExtractM (Stms MC, HistOp MC)
histToSegBinOp [HistOp SOACS]
hists
  SegOp () MC
op' <-
    NeedsRename -> SegOp () MC -> ExtractM (SegOp () MC)
forall a. Rename a => NeedsRename -> a -> ExtractM a
renameIfNeeded NeedsRename
rename (SegOp () MC -> ExtractM (SegOp () MC))
-> SegOp () MC -> ExtractM (SegOp () MC)
forall a b. (a -> b) -> a -> b
$
      ()
-> SegSpace
-> [HistOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [HistOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegHist () SegSpace
space [HistOp MC]
hists' (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
  ([Stms MC], SegOp () MC) -> ExtractM ([Stms MC], SegOp () MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stms MC]
hists_stms, SegOp () MC
op')

transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC :: Pat Type -> Attrs -> SOAC SOACS -> ExtractM (Stms MC)
transformSOAC Pat Type
_ Attrs
_ JVP {} =
  String -> ExtractM (Stms MC)
forall a. HasCallStack => String -> a
error String
"transformSOAC: unhandled JVP"
transformSOAC Pat Type
_ Attrs
_ VJP {} =
  String -> ExtractM (Stms MC)
forall a. HasCallStack => String -> a
error String
"transformSOAC: unhandled VJP"
transformSOAC Pat Type
pat Attrs
_ (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
  | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form = do
      SegOp () MC
seq_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w Lambda SOACS
lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
lam
        then do
          SegOp () MC
par_op <- NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> Lambda SOACS
-> [VName]
-> ExtractM (SegOp () MC)
transformMap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w Lambda SOACS
lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form = do
      ([Stms MC]
seq_reds_stms, SegOp () MC
seq_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
      if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
        then do
          ([Stms MC]
par_reds_stms, SegOp () MC
par_op) <-
            NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [Reduce SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformRedomap NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [Reduce SOACS]
reds Lambda SOACS
map_lam [VName]
arrs
          Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_reds_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_reds_stms)
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
        else
          Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
            [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_reds_stms
              Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
  | Just ([Scan SOACS]
scans, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form = do
      (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w
      KernelBody MC
kbody <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (KernelBody MC)
mapLambdaToKernelBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
map_lam [VName]
arrs
      ([Stms MC]
scans_stms, [SegBinOp MC]
scans') <- (Scan SOACS -> ExtractM (Stms MC, SegBinOp MC))
-> [Scan SOACS] -> ExtractM ([Stms MC], [SegBinOp MC])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM Scan SOACS -> ExtractM (Stms MC, SegBinOp MC)
scanToSegBinOp [Scan SOACS]
scans
      Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
scans_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm
            ( Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
                Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
                  Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp SOAC MC) -> SegOp () MC -> MCOp SOAC MC
forall a b. (a -> b) -> a -> b
$
                    ()
-> SegSpace
-> [SegBinOp MC]
-> [Type]
-> KernelBody MC
-> SegOp () MC
forall lvl rep.
lvl
-> SegSpace
-> [SegBinOp rep]
-> [Type]
-> KernelBody rep
-> SegOp lvl rep
SegScan () SegSpace
space [SegBinOp MC]
scans' (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam) KernelBody MC
kbody
            )
  | Bool
otherwise = do
      -- This screma is too complicated for us to immediately do
      -- anything, so split it up and try again.
      Scope SOACS
scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
      Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> ExtractM (Stms SOACS) -> ExtractM (Stms MC)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
-> SubExp
-> ScremaForm (Rep (BuilderT SOACS ExtractM))
-> [VName]
-> BuilderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat Type
Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
pat SubExp
w ScremaForm (Rep (BuilderT SOACS ExtractM))
ScremaForm SOACS
form [VName]
arrs) Scope SOACS
scope
transformSOAC Pat Type
pat Attrs
_ (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(ShapeBase SubExp, Int, VName)]
dests) = do
  (VName
gtid, SegSpace
space) <- SubExp -> ExtractM (VName, SegSpace)
forall (m :: * -> *).
MonadFreshNames m =>
SubExp -> m (VName, SegSpace)
mkSegSpace SubExp
w

  Body () Stms MC
kstms Result
res <- (Body SOACS -> ExtractM (Body MC))
-> VName -> Lambda SOACS -> [VName] -> ExtractM (Body MC)
mapLambdaToBody Body SOACS -> ExtractM (Body MC)
transformBody VName
gtid Lambda SOACS
lam [VName]
ivs

  let rets :: [Type]
rets = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
takeLast ([(ShapeBase SubExp, Int, VName)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, Int, VName)]
dests) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
      kres :: [KernelResult]
kres = do
        (ShapeBase SubExp
a_w, VName
a, [(Result, SubExpRes)]
is_vs) <- [(ShapeBase SubExp, Int, VName)]
-> Result -> [(ShapeBase SubExp, VName, [(Result, SubExpRes)])]
forall array a.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, array, [([a], a)])]
groupScatterResults [(ShapeBase SubExp, Int, VName)]
dests Result
res
        let cs :: Certs
cs =
              ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
            is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
        KernelResult -> [KernelResult]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Certs
-> ShapeBase SubExp
-> VName
-> [(Slice SubExp, SubExp)]
-> KernelResult
WriteReturns Certs
cs ShapeBase SubExp
a_w VName
a [(Slice SubExp, SubExp)]
is_vs'
      kbody :: KernelBody MC
kbody = BodyDec MC -> Stms MC -> [KernelResult] -> KernelBody MC
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms MC
kstms [KernelResult]
kres
  Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
    Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Stm MC -> Stms MC) -> Stm MC -> Stms MC
forall a b. (a -> b) -> a -> b
$
      Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$
        Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$
          Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing (SegOp () MC -> MCOp SOAC MC) -> SegOp () MC -> MCOp SOAC MC
forall a b. (a -> b) -> a -> b
$
            () -> SegSpace -> [Type] -> KernelBody MC -> SegOp () MC
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap () SegSpace
space [Type]
rets KernelBody MC
kbody
transformSOAC Pat Type
pat Attrs
_ (Hist SubExp
w [VName]
arrs [HistOp SOACS]
hists Lambda SOACS
map_lam) = do
  ([Stms MC]
seq_hist_stms, SegOp () MC
seq_op) <-
    NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoNotRename Body SOACS -> ExtractM (Body MC)
sequentialiseBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs

  if Lambda SOACS -> Bool
lambdaContainsParallelism Lambda SOACS
map_lam
    then do
      ([Stms MC]
par_hist_stms, SegOp () MC
par_op) <-
        NeedsRename
-> (Body SOACS -> ExtractM (Body MC))
-> SubExp
-> [HistOp SOACS]
-> Lambda SOACS
-> [VName]
-> ExtractM ([Stms MC], SegOp () MC)
transformHist NeedsRename
DoRename Body SOACS -> ExtractM (Body MC)
transformBody SubExp
w [HistOp SOACS]
hists Lambda SOACS
map_lam [VName]
arrs
      Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat ([Stms MC]
seq_hist_stms [Stms MC] -> [Stms MC] -> [Stms MC]
forall a. Semigroup a => a -> a -> a
<> [Stms MC]
par_hist_stms)
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp (SegOp () MC -> Maybe (SegOp () MC)
forall a. a -> Maybe a
Just SegOp () MC
par_op) SegOp () MC
seq_op)
    else
      Stms MC -> ExtractM (Stms MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms MC -> ExtractM (Stms MC)) -> Stms MC -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$
        [Stms MC] -> Stms MC
forall a. Monoid a => [a] -> a
mconcat [Stms MC]
seq_hist_stms
          Stms MC -> Stms MC -> Stms MC
forall a. Semigroup a => a -> a -> a
<> Stm MC -> Stms MC
forall rep. Stm rep -> Stms rep
oneStm (Pat (LetDec MC) -> StmAux (ExpDec MC) -> Exp MC -> Stm MC
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec MC)
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp MC -> Stm MC) -> Exp MC -> Stm MC
forall a b. (a -> b) -> a -> b
$ Op MC -> Exp MC
forall rep. Op rep -> Exp rep
Op (Op MC -> Exp MC) -> Op MC -> Exp MC
forall a b. (a -> b) -> a -> b
$ Maybe (SegOp () MC) -> SegOp () MC -> MCOp SOAC MC
forall (op :: * -> *) rep.
Maybe (SegOp () rep) -> SegOp () rep -> MCOp op rep
ParOp Maybe (SegOp () MC)
forall a. Maybe a
Nothing SegOp () MC
seq_op)
transformSOAC Pat Type
pat Attrs
_ (Stream SubExp
w [VName]
arrs [SubExp]
nes Lambda SOACS
lam) = do
  -- Just remove the stream and transform the resulting stms.
  Scope SOACS
soacs_scope <- Scope MC -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope MC -> Scope SOACS)
-> ExtractM (Scope MC) -> ExtractM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtractM (Scope MC)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  Stms SOACS
stream_stms <-
    (BuilderT SOACS ExtractM ()
 -> Scope SOACS -> ExtractM (Stms SOACS))
-> Scope SOACS
-> BuilderT SOACS ExtractM ()
-> ExtractM (Stms SOACS)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BuilderT SOACS ExtractM () -> Scope SOACS -> ExtractM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ Scope SOACS
soacs_scope (BuilderT SOACS ExtractM () -> ExtractM (Stms SOACS))
-> BuilderT SOACS ExtractM () -> ExtractM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
      Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
-> SubExp
-> [SubExp]
-> Lambda (Rep (BuilderT SOACS ExtractM))
-> [VName]
-> BuilderT SOACS ExtractM ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat Type
Pat (LetDec (Rep (BuilderT SOACS ExtractM)))
pat SubExp
w [SubExp]
nes Lambda (Rep (BuilderT SOACS ExtractM))
Lambda SOACS
lam [VName]
arrs
  Stms SOACS -> ExtractM (Stms MC)
transformStms Stms SOACS
stream_stms

transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg :: Prog SOACS -> PassM (Prog MC)
transformProg Prog SOACS
prog =
  (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC))
-> (VNameSource -> (Prog MC, VNameSource)) -> PassM (Prog MC)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Prog MC)
-> VNameSource -> (Prog MC, VNameSource)
forall s a. State s a -> s -> (a, s)
runState (ReaderT (Scope MC) (State VNameSource) (Prog MC)
-> Scope MC -> State VNameSource (Prog MC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope MC) (State VNameSource) (Prog MC)
m Scope MC
forall a. Monoid a => a
mempty)
  where
    ExtractM ReaderT (Scope MC) (State VNameSource) (Prog MC)
m = do
      Stms MC
consts' <- Stms SOACS -> ExtractM (Stms MC)
transformStms (Stms SOACS -> ExtractM (Stms MC))
-> Stms SOACS -> ExtractM (Stms MC)
forall a b. (a -> b) -> a -> b
$ Prog SOACS -> Stms SOACS
forall rep. Prog rep -> Stms rep
progConsts Prog SOACS
prog
      [FunDef MC]
funs' <- Stms MC -> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms MC
consts' (ExtractM [FunDef MC] -> ExtractM [FunDef MC])
-> ExtractM [FunDef MC] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ (FunDef SOACS -> ExtractM (FunDef MC))
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM FunDef SOACS -> ExtractM (FunDef MC)
transformFunDef ([FunDef SOACS] -> ExtractM [FunDef MC])
-> [FunDef SOACS] -> ExtractM [FunDef MC]
forall a b. (a -> b) -> a -> b
$ Prog SOACS -> [FunDef SOACS]
forall rep. Prog rep -> [FunDef rep]
progFuns Prog SOACS
prog
      Prog MC -> ExtractM (Prog MC)
forall a. a -> ExtractM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Prog MC -> ExtractM (Prog MC)) -> Prog MC -> ExtractM (Prog MC)
forall a b. (a -> b) -> a -> b
$
        Prog SOACS
prog
          { progConsts :: Stms MC
progConsts = Stms MC
consts',
            progFuns :: [FunDef MC]
progFuns = [FunDef MC]
funs'
          }

-- | Transform a program using SOACs to a program in the 'MC'
-- representation, using some amount of flattening.
extractMulticore :: Pass SOACS MC
extractMulticore :: Pass SOACS MC
extractMulticore =
  Pass
    { passName :: String
passName = String
"extract multicore parallelism",
      passDescription :: String
passDescription = String
"Extract multicore parallelism",
      passFunction :: Prog SOACS -> PassM (Prog MC)
passFunction = Prog SOACS -> PassM (Prog MC)
transformProg
    }