{-# LANGUAGE TypeFamilies #-}

-- | Extract limited nested parallelism for execution inside
-- individual kernel threadblocks.
module Futhark.Pass.ExtractKernels.Intragroup (intraGroupParallelise) where

import Control.Monad
import Control.Monad.RWS
import Control.Monad.Trans.Maybe
import Data.Map.Strict qualified as M
import Data.Set qualified as S
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.GPU hiding (HistOp)
import Futhark.IR.GPU.Op qualified as GPU
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.BlockedKernel
import Futhark.Pass.ExtractKernels.DistributeNests
import Futhark.Pass.ExtractKernels.Distribution
import Futhark.Pass.ExtractKernels.ToGPU
import Futhark.Tools
import Futhark.Transform.FirstOrderTransform qualified as FOT
import Futhark.Util.Log
import Prelude hiding (log)

-- | Convert the statements inside a map nest to kernel statements,
-- attempting to parallelise any remaining (top-level) parallel
-- statements.  Anything that is not a map, scan or reduction will
-- simply be sequentialised.  This includes sequential loops that
-- contain maps, scans or reduction.  In the future, we could probably
-- do something more clever.  Make sure that the amount of parallelism
-- to be exploited does not exceed the group size.  Further, as a hack
-- we also consider the size of all intermediate arrays as
-- "parallelism to be exploited" to avoid exploding local memory.
--
-- We distinguish between "minimum group size" and "maximum
-- exploitable parallelism".
intraGroupParallelise ::
  (MonadFreshNames m, LocalScope GPU m) =>
  KernelNest ->
  Lambda SOACS ->
  m
    ( Maybe
        ( (SubExp, SubExp),
          SubExp,
          Log,
          Stms GPU,
          Stms GPU
        )
    )
intraGroupParallelise :: forall (m :: * -> *).
(MonadFreshNames m, LocalScope GPU m) =>
KernelNest
-> Lambda SOACS
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
intraGroupParallelise KernelNest
knest Lambda SOACS
lam = MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
 -> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)))
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> m (Maybe ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU))
forall a b. (a -> b) -> a -> b
$ do
  ([(VName, SubExp)]
ispace, [KernelInput]
inps) <- m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([(VName, SubExp)], [KernelInput])
 -> MaybeT m ([(VName, SubExp)], [KernelInput]))
-> m ([(VName, SubExp)], [KernelInput])
-> MaybeT m ([(VName, SubExp)], [KernelInput])
forall a b. (a -> b) -> a -> b
$ KernelNest -> m ([(VName, SubExp)], [KernelInput])
forall (m :: * -> *).
MonadFreshNames m =>
KernelNest -> m ([(VName, SubExp)], [KernelInput])
flatKernel KernelNest
knest

  (SubExp
num_tblocks, Stms GPU
w_stms) <-
    m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU)
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU))
-> m (SubExp, Stms GPU) -> MaybeT m (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$
      Builder GPU SubExp -> m (SubExp, Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU SubExp -> m (SubExp, Stms GPU))
-> Builder GPU SubExp -> m (SubExp, Stms GPU)
forall a b. (a -> b) -> a -> b
$
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_num_tblocks"
          (Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> SubExp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace)

  let body :: Body SOACS
body = Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam

  VName
tblock_size <- String -> MaybeT m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"computed_tblock_size"
  ([[SubExp]]
wss_min, [[SubExp]]
wss_avail, Log
log, KernelBody GPU
kbody) <-
    m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
 -> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
    -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope GPU
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a. Scope GPU -> m a -> m a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param (LParamInfo GPU)] -> Scope GPU
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param (LParamInfo GPU)] -> Scope GPU)
-> [Param (LParamInfo GPU)] -> Scope GPU
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) (m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
 -> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU))
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> MaybeT m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a b. (a -> b) -> a -> b
$
      Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody Body SOACS
body

  Scope GPU
outside_scope <- m (Scope GPU) -> MaybeT m (Scope GPU)
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  -- outside_scope may also contain the inputs, even though those are
  -- not actually available outside the kernel.
  let available :: VName -> Bool
available VName
v =
        VName
v VName -> Scope GPU -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope GPU
outside_scope
          Bool -> Bool -> Bool
&& VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` (KernelInput -> VName) -> [KernelInput] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map KernelInput -> VName
kernelInputName [KernelInput]
inps
  Bool -> MaybeT m () -> MaybeT m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all VName -> Bool
available ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> Names
forall a. FreeIn a => a -> Names
freeIn ([[SubExp]]
wss_min [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a] -> [a]
++ [[SubExp]]
wss_avail)) (MaybeT m () -> MaybeT m ()) -> MaybeT m () -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$
    String -> MaybeT m ()
forall a. String -> MaybeT m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Irregular parallelism"

  ((SubExp
intra_avail_par, SegSpace
kspace, Stms GPU
read_input_stms), Stms GPU
prelude_stms) <- m ((SubExp, SegSpace, Stms GPU), Stms GPU)
-> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall (m :: * -> *) a. Monad m => m a -> MaybeT m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m ((SubExp, SegSpace, Stms GPU), Stms GPU)
 -> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU))
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
-> MaybeT m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$
    Builder GPU (SubExp, SegSpace, Stms GPU)
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (a, Stms rep)
runBuilder (Builder GPU (SubExp, SegSpace, Stms GPU)
 -> m ((SubExp, SegSpace, Stms GPU), Stms GPU))
-> Builder GPU (SubExp, SegSpace, Stms GPU)
-> m ((SubExp, SegSpace, Stms GPU), Stms GPU)
forall a b. (a -> b) -> a -> b
$ do
      let foldBinOp' :: BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' BinOp
_ [] = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m))) -> SubExp -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
          foldBinOp' BinOp
bop (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
x [SubExp]
xs
      [SubExp]
ws_min <-
        ([SubExp] -> Builder GPU SubExp)
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_min" (Exp GPU -> Builder GPU SubExp)
-> ([SubExp] -> BuilderT GPU (State VNameSource) (Exp GPU))
-> [SubExp]
-> Builder GPU SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp])
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_min
      [SubExp]
ws_avail <-
        ([SubExp] -> Builder GPU SubExp)
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"one_intra_par_avail" (Exp GPU -> Builder GPU SubExp)
-> ([SubExp] -> BuilderT GPU (State VNameSource) (Exp GPU))
-> [SubExp]
-> Builder GPU SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)) ([[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp])
-> [[SubExp]] -> BuilderT GPU (State VNameSource) [SubExp]
forall a b. (a -> b) -> a -> b
$
          ([SubExp] -> Bool) -> [[SubExp]] -> [[SubExp]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([SubExp] -> Bool) -> [SubExp] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [[SubExp]]
wss_avail

      -- The amount of parallelism available *in the worst case* is
      -- equal to the smallest parallel loop, or *at least* 1.
      SubExp
intra_avail_par <-
        String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"intra_avail_par" (Exp GPU -> Builder GPU SubExp)
-> BuilderT GPU (State VNameSource) (Exp GPU) -> Builder GPU SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMin IntType
Int64) [SubExp]
ws_avail

      -- The group size is either the maximum of the minimum parallelism
      -- exploited, or the desired parallelism (bounded by the max group
      -- size) in case there is no minimum.
      [VName]
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
tblock_size]
        (Exp GPU -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) (Exp GPU)
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< if [SubExp] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [SubExp]
ws_min
          then
            BinOp
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
              (IntType -> BinOp
SMin IntType
Int64)
              (SubExp
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
SubExp -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> BuilderT GPU (State VNameSource) (Exp GPU))
-> Builder GPU SubExp -> BuilderT GPU (State VNameSource) (Exp GPU)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String
-> Exp (Rep (BuilderT GPU (State VNameSource)))
-> Builder GPU SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"max_tblock_size" (Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (State VNameSource)))
 -> Exp (Rep (BuilderT GPU (State VNameSource))))
-> Op (Rep (BuilderT GPU (State VNameSource)))
-> Exp (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp SOAC GPU
forall (op :: * -> *) rep. SizeOp -> HostOp op rep
SizeOp (SizeOp -> HostOp SOAC GPU) -> SizeOp -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
SizeThreadBlock))
              (SubExp
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
intra_avail_par)
          else BinOp
-> [SubExp]
-> BuilderT
     GPU
     (State VNameSource)
     (Exp (Rep (BuilderT GPU (State VNameSource))))
forall {m :: * -> *}.
MonadBuilder m =>
BinOp -> [SubExp] -> m (Exp (Rep m))
foldBinOp' (IntType -> BinOp
SMax IntType
Int64) [SubExp]
ws_min

      let inputIsUsed :: KernelInput -> Bool
inputIsUsed KernelInput
input = KernelInput -> VName
kernelInputName KernelInput
input VName -> Names -> Bool
`nameIn` Body SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Body SOACS
body
          used_inps :: [KernelInput]
used_inps = (KernelInput -> Bool) -> [KernelInput] -> [KernelInput]
forall a. (a -> Bool) -> [a] -> [a]
filter KernelInput -> Bool
inputIsUsed [KernelInput]
inps

      Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep (BuilderT GPU (State VNameSource)))
Stms GPU
w_stms
      Stms GPU
read_input_stms <- Builder GPU [()] -> BuilderT GPU (State VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Builder GPU [()] -> BuilderT GPU (State VNameSource) (Stms GPU))
-> Builder GPU [()] -> BuilderT GPU (State VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BuilderT GPU (State VNameSource) ())
-> [KernelInput] -> Builder GPU [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM KernelInput -> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput [KernelInput]
used_inps
      SegSpace
space <- VName -> [(VName, SubExp)] -> SegSpace
SegSpace (VName -> [(VName, SubExp)] -> SegSpace)
-> BuilderT GPU (State VNameSource) VName
-> BuilderT GPU (State VNameSource) ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> BuilderT GPU (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tblock_id" BuilderT GPU (State VNameSource) ([(VName, SubExp)] -> SegSpace)
-> BuilderT GPU (State VNameSource) [(VName, SubExp)]
-> BuilderT GPU (State VNameSource) SegSpace
forall a b.
BuilderT GPU (State VNameSource) (a -> b)
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(VName, SubExp)]
-> BuilderT GPU (State VNameSource) [(VName, SubExp)]
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, SubExp)]
ispace
      (SubExp, SegSpace, Stms GPU)
-> Builder GPU (SubExp, SegSpace, Stms GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
intra_avail_par, SegSpace
space, Stms GPU
read_input_stms)

  let kbody' :: KernelBody GPU
kbody' = KernelBody GPU
kbody {kernelBodyStms = read_input_stms <> kernelBodyStms kbody}

  let nested_pat :: Pat Type
nested_pat = LoopNesting -> Pat Type
loopNestingPat LoopNesting
first_nest
      rts :: [Type]
rts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map ([(VName, SubExp)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, SubExp)]
ispace `stripArray`) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
nested_pat
      grid :: KernelGrid
grid = Count NumBlocks SubExp -> Count BlockSize SubExp -> KernelGrid
KernelGrid (SubExp -> Count NumBlocks SubExp
forall {k} (u :: k) e. e -> Count u e
Count SubExp
num_tblocks) (SubExp -> Count BlockSize SubExp
forall {k} (u :: k) e. e -> Count u e
Count (SubExp -> Count BlockSize SubExp)
-> SubExp -> Count BlockSize SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
tblock_size)
      lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegBlock SegVirt
SegNoVirt (KernelGrid -> Maybe KernelGrid
forall a. a -> Maybe a
Just KernelGrid
grid)
      kstm :: Stm GPU
kstm =
        Pat (LetDec GPU) -> StmAux (ExpDec GPU) -> Exp GPU -> Stm GPU
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec GPU)
nested_pat StmAux ()
StmAux (ExpDec GPU)
aux (Exp GPU -> Stm GPU) -> Exp GPU -> Stm GPU
forall a b. (a -> b) -> a -> b
$ Op GPU -> Exp GPU
forall rep. Op rep -> Exp rep
Op (Op GPU -> Exp GPU) -> Op GPU -> Exp GPU
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
kspace [Type]
rts KernelBody GPU
kbody'

  let intra_min_par :: SubExp
intra_min_par = SubExp
intra_avail_par
  ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
-> MaybeT m ((SubExp, SubExp), SubExp, Log, Stms GPU, Stms GPU)
forall a. a -> MaybeT m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( (SubExp
intra_min_par, SubExp
intra_avail_par),
      VName -> SubExp
Var VName
tblock_size,
      Log
log,
      Stms GPU
prelude_stms,
      Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm Stm GPU
kstm
    )
  where
    first_nest :: LoopNesting
first_nest = KernelNest -> LoopNesting
forall a b. (a, b) -> a
fst KernelNest
knest
    aux :: StmAux ()
aux = LoopNesting -> StmAux ()
loopNestingAux LoopNesting
first_nest

readGroupKernelInput ::
  (DistRep (Rep m), MonadBuilder m) =>
  KernelInput ->
  m ()
readGroupKernelInput :: forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readGroupKernelInput KernelInput
inp
  | Array {} <- KernelInput -> Type
kernelInputType KernelInput
inp = do
      VName
v <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputName KernelInput
inp
      KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp {kernelInputName = v}
      [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [KernelInput -> VName
kernelInputName KernelInput
inp] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
  | Bool
otherwise =
      KernelInput -> m ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp

data IntraAcc = IntraAcc
  { IntraAcc -> Set [SubExp]
accMinPar :: S.Set [SubExp],
    IntraAcc -> Set [SubExp]
accAvailPar :: S.Set [SubExp],
    IntraAcc -> Log
accLog :: Log
  }

instance Semigroup IntraAcc where
  IntraAcc Set [SubExp]
min_x Set [SubExp]
avail_x Log
log_x <> :: IntraAcc -> IntraAcc -> IntraAcc
<> IntraAcc Set [SubExp]
min_y Set [SubExp]
avail_y Log
log_y =
    Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (Set [SubExp]
min_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
min_y) (Set [SubExp]
avail_x Set [SubExp] -> Set [SubExp] -> Set [SubExp]
forall a. Semigroup a => a -> a -> a
<> Set [SubExp]
avail_y) (Log
log_x Log -> Log -> Log
forall a. Semigroup a => a -> a -> a
<> Log
log_y)

instance Monoid IntraAcc where
  mempty :: IntraAcc
mempty = Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc Set [SubExp]
forall a. Monoid a => a
mempty Set [SubExp]
forall a. Monoid a => a
mempty Log
forall a. Monoid a => a
mempty

type IntraGroupM =
  BuilderT GPU (RWS () IntraAcc VNameSource)

instance MonadLogger IntraGroupM where
  addLog :: Log -> IntraGroupM ()
addLog Log
log = IntraAcc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell IntraAcc
forall a. Monoid a => a
mempty {accLog = log}

runIntraGroupM ::
  (MonadFreshNames m, HasScope GPU m) =>
  IntraGroupM () ->
  m (IntraAcc, Stms GPU)
runIntraGroupM :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM IntraGroupM ()
m = do
  Scope GPU
scope <- Scope GPU -> Scope GPU
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope (Scope GPU -> Scope GPU) -> m (Scope GPU) -> m (Scope GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  (VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
-> m (IntraAcc, Stms GPU)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
 -> m (IntraAcc, Stms GPU))
-> (VNameSource -> ((IntraAcc, Stms GPU), VNameSource))
-> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
    let (((), Stms GPU
kstms), VNameSource
src', IntraAcc
acc) = RWS () IntraAcc VNameSource ((), Stms GPU)
-> () -> VNameSource -> (((), Stms GPU), VNameSource, IntraAcc)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS (IntraGroupM ()
-> Scope GPU -> RWS () IntraAcc VNameSource ((), Stms GPU)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT IntraGroupM ()
m Scope GPU
scope) () VNameSource
src
     in ((IntraAcc
acc, Stms GPU
kstms), VNameSource
src')

parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin :: [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
ws =
  IntraAcc -> IntraGroupM ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell
    IntraAcc
forall a. Monoid a => a
mempty
      { accMinPar = S.singleton ws,
        accAvailPar = S.singleton ws
      }

intraGroupBody :: Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody :: Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody Body SOACS
body = do
  Stms GPU
stms <- IntraGroupM ()
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntraGroupM ()
 -> BuilderT
      GPU
      (RWS () IntraAcc VNameSource)
      (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> IntraGroupM ()
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntraGroupM ()
intraGroupStms (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  Body GPU -> IntraGroupM (Body GPU)
forall a. a -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body GPU -> IntraGroupM (Body GPU))
-> Body GPU -> IntraGroupM (Body GPU)
forall a b. (a -> b) -> a -> b
$ Stms GPU -> Result -> Body GPU
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms GPU
stms (Result -> Body GPU) -> Result -> Body GPU
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body

intraGroupLambda :: Lambda SOACS -> IntraGroupM (Lambda GPU)
intraGroupLambda :: Lambda SOACS -> IntraGroupM (Lambda GPU)
intraGroupLambda Lambda SOACS
lam =
  [LParam (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))]
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) (BuilderT GPU (RWS () IntraAcc VNameSource) Result
 -> BuilderT
      GPU
      (RWS () IntraAcc VNameSource)
      (Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Lambda (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
forall a b. (a -> b) -> a -> b
$
    Body (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
Body GPU -> BuilderT GPU (RWS () IntraAcc VNameSource) Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body GPU -> BuilderT GPU (RWS () IntraAcc VNameSource) Result)
-> IntraGroupM (Body GPU)
-> BuilderT GPU (RWS () IntraAcc VNameSource) Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody (Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam)

intraGroupWithAccInput :: WithAccInput SOACS -> IntraGroupM (WithAccInput GPU)
intraGroupWithAccInput :: WithAccInput SOACS -> IntraGroupM (WithAccInput GPU)
intraGroupWithAccInput (Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
Nothing) =
  WithAccInput GPU -> IntraGroupM (WithAccInput GPU)
forall a. a -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
shape, [VName]
arrs, Maybe (Lambda GPU, [SubExp])
forall a. Maybe a
Nothing)
intraGroupWithAccInput (Shape
shape, [VName]
arrs, Just (Lambda SOACS
lam, [SubExp]
nes)) = do
  Lambda GPU
lam' <- Lambda SOACS -> IntraGroupM (Lambda GPU)
intraGroupLambda Lambda SOACS
lam
  WithAccInput GPU -> IntraGroupM (WithAccInput GPU)
forall a. a -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Shape
shape, [VName]
arrs, (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU, [SubExp])
forall a. a -> Maybe a
Just (Lambda GPU
lam', [SubExp]
nes))

intraGroupStm :: Stm SOACS -> IntraGroupM ()
intraGroupStm :: Stm SOACS -> IntraGroupM ()
intraGroupStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = do
  Scope GPU
scope <- BuilderT GPU (RWS () IntraAcc VNameSource) (Scope GPU)
forall rep (m :: * -> *). HasScope rep m => m (Scope rep)
askScope
  let lvl :: SegLevel
lvl = SegVirt -> Maybe KernelGrid -> SegLevel
SegThread SegVirt
SegNoVirt Maybe KernelGrid
forall a. Maybe a
Nothing

  case Exp SOACS
e of
    Loop [(FParam SOACS, SubExp)]
merge LoopForm
form Body SOACS
loopbody ->
      Scope GPU -> IntraGroupM () -> IntraGroupM ()
forall a.
Scope GPU
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (LoopForm -> Scope GPU
forall rep. LoopForm -> Scope rep
scopeOfLoopForm LoopForm
form Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> [Param (FParamInfo GPU)] -> Scope GPU
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param (FParamInfo GPU), SubExp) -> Param (FParamInfo GPU))
-> [(Param (FParamInfo GPU), SubExp)] -> [Param (FParamInfo GPU)]
forall a b. (a -> b) -> [a] -> [b]
map (Param (FParamInfo GPU), SubExp) -> Param (FParamInfo GPU)
forall a b. (a, b) -> a
fst [(FParam SOACS, SubExp)]
[(Param (FParamInfo GPU), SubExp)]
merge)) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
        Body GPU
loopbody' <- Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody Body SOACS
loopbody
        Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> (Exp GPU -> IntraGroupM ()) -> Exp GPU -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
Pat (LetDec SOACS)
pat (Exp GPU -> IntraGroupM ()) -> Exp GPU -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
          [(Param (FParamInfo GPU), SubExp)]
-> LoopForm -> Body GPU -> Exp GPU
forall rep.
[(FParam rep, SubExp)] -> LoopForm -> Body rep -> Exp rep
Loop [(FParam SOACS, SubExp)]
[(Param (FParamInfo GPU), SubExp)]
merge LoopForm
form Body GPU
loopbody'
    Match [SubExp]
cond [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
ifdec -> do
      [Case (Body GPU)]
cases' <- (Case (Body SOACS)
 -> BuilderT GPU (RWS () IntraAcc VNameSource) (Case (Body GPU)))
-> [Case (Body SOACS)]
-> BuilderT GPU (RWS () IntraAcc VNameSource) [Case (Body GPU)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Body SOACS -> IntraGroupM (Body GPU))
-> Case (Body SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Case (Body GPU))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Case a -> f (Case b)
traverse Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody) [Case (Body SOACS)]
cases
      Body GPU
defbody' <- Body SOACS -> IntraGroupM (Body GPU)
intraGroupBody Body SOACS
defbody
      Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> (Exp GPU -> IntraGroupM ()) -> Exp GPU -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
Pat (LetDec SOACS)
pat (Exp GPU -> IntraGroupM ()) -> Exp GPU -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        [SubExp]
-> [Case (Body GPU)]
-> Body GPU
-> MatchDec (BranchType GPU)
-> Exp GPU
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp]
cond [Case (Body GPU)]
cases' Body GPU
defbody' MatchDec (BranchType SOACS)
MatchDec (BranchType GPU)
ifdec
    WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam -> do
      [WithAccInput GPU]
inputs' <- (WithAccInput SOACS -> IntraGroupM (WithAccInput GPU))
-> [WithAccInput SOACS]
-> BuilderT GPU (RWS () IntraAcc VNameSource) [WithAccInput GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM WithAccInput SOACS -> IntraGroupM (WithAccInput GPU)
intraGroupWithAccInput [WithAccInput SOACS]
inputs
      Lambda GPU
lam' <- Lambda SOACS -> IntraGroupM (Lambda GPU)
intraGroupLambda Lambda SOACS
lam
      Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> (Exp GPU -> IntraGroupM ()) -> Exp GPU -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
Pat (LetDec SOACS)
pat (Exp GPU -> IntraGroupM ()) -> Exp GPU -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ [WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput GPU]
inputs' Lambda GPU
lam'
    Op Op SOACS
soac
      | Attr
"sequential_outer" Attr -> Attrs -> Bool
`inAttrs` StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux ->
          Stms SOACS -> IntraGroupM ()
intraGroupStms (Stms SOACS -> IntraGroupM ())
-> (Stms SOACS -> Stms SOACS) -> Stms SOACS -> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux))
            (Stms SOACS -> IntraGroupM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms SOACS)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Builder SOACS ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms SOACS)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
Transformer m =>
Pat (LetDec (Rep m)) -> SOAC (Rep m) -> m ()
FOT.transformSOAC Pat (LetDec (Rep (BuilderT SOACS (State VNameSource))))
Pat (LetDec SOACS)
pat Op SOACS
SOAC (Rep (BuilderT SOACS (State VNameSource)))
soac)
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just Lambda SOACS
lam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form -> do
          let loopnest :: LoopNesting
loopnest = Pat Type
-> StmAux () -> SubExp -> [(Param Type, VName)] -> LoopNesting
MapNesting Pat Type
Pat (LetDec SOACS)
pat StmAux ()
StmAux (ExpDec SOACS)
aux SubExp
w ([(Param Type, VName)] -> LoopNesting)
-> [(Param Type, VName)] -> LoopNesting
forall a b. (a -> b) -> a -> b
$ [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) [VName]
arrs
              env :: DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
env =
                DistEnv
                  { distNest :: Nestings
distNest =
                      Nesting -> Nestings
singleNesting (Nesting -> Nestings) -> Nesting -> Nestings
forall a b. (a -> b) -> a -> b
$ Names -> LoopNesting -> Nesting
Nesting Names
forall a. Monoid a => a
mempty LoopNesting
loopnest,
                    distScope :: Scope GPU
distScope =
                      Pat Type -> Scope GPU
forall rep dec. (LetDec rep ~ dec) => Pat dec -> Scope rep
scopeOfPat Pat Type
Pat (LetDec SOACS)
pat
                        Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope SOACS -> Scope GPU
scopeForGPU (Lambda SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Lambda SOACS
lam)
                        Scope GPU -> Scope GPU -> Scope GPU
forall a. Semigroup a => a -> a -> a
<> Scope GPU
scope,
                    distOnInnerMap :: MapLoop
-> DistAcc GPU
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
distOnInnerMap =
                      MapLoop
-> DistAcc GPU
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
MapLoop -> DistAcc rep -> DistNestT rep m (DistAcc rep)
distributeMap,
                    distOnTopLevelStms :: Stms SOACS
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
distOnTopLevelStms =
                      BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
forall rep (m :: * -> *) a.
(LocalScope rep m, DistRep rep) =>
m a -> DistNestT rep m a
liftInner (BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
 -> DistNestT
      GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU))
-> (Stms SOACS
    -> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> Stms SOACS
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntraGroupM ()
-> BuilderT
     GPU
     (RWS () IntraAcc VNameSource)
     (Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
IntraGroupM ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (IntraGroupM ()
 -> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> (Stms SOACS -> IntraGroupM ())
-> Stms SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> IntraGroupM ()
intraGroupStms,
                    distSegLevel :: MkSegLevel GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
distSegLevel = \[SubExp]
minw String
_ ThreadRecommendation
_ -> do
                      IntraGroupM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *) a. Monad m => m a -> BuilderT GPU m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntraGroupM ()
 -> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ())
-> IntraGroupM ()
-> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall a b. (a -> b) -> a -> b
$ [SubExp] -> IntraGroupM ()
parallelMin [SubExp]
minw
                      SegLevel
-> BuilderT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) SegLevel
forall a.
a -> BuilderT GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegLevel
lvl,
                    distOnSOACSStms :: Stm SOACS -> BuilderT GPU (State VNameSource) (Stms GPU)
distOnSOACSStms =
                      Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms GPU -> BuilderT GPU (State VNameSource) (Stms GPU))
-> (Stm SOACS -> Stms GPU)
-> Stm SOACS
-> BuilderT GPU (State VNameSource) (Stms GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Stms GPU
forall rep. Stm rep -> Stms rep
oneStm (Stm GPU -> Stms GPU)
-> (Stm SOACS -> Stm GPU) -> Stm SOACS -> Stms GPU
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm SOACS -> Stm GPU
soacsStmToGPU,
                    distOnSOACSLambda :: Lambda SOACS -> Builder GPU (Lambda GPU)
distOnSOACSLambda =
                      Lambda GPU -> Builder GPU (Lambda GPU)
forall a. a -> BuilderT GPU (State VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda GPU -> Builder GPU (Lambda GPU))
-> (Lambda SOACS -> Lambda GPU)
-> Lambda SOACS
-> Builder GPU (Lambda GPU)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> Lambda GPU
soacsLambdaToGPU
                  }
              acc :: DistAcc GPU
acc =
                DistAcc
                  { distTargets :: Targets
distTargets = Target -> Targets
singleTarget (Pat Type
Pat (LetDec SOACS)
pat, Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam),
                    distStms :: Stms GPU
distStms = Stms GPU
forall a. Monoid a => a
mempty
                  }

          Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
            (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadLogger m, DistRep rep) =>
DistEnv rep m -> DistNestT rep m (DistAcc rep) -> m (Stms rep)
runDistNestT DistEnv GPU (BuilderT GPU (RWS () IntraAcc VNameSource))
env (DistAcc GPU
-> Stms SOACS
-> DistNestT
     GPU (BuilderT GPU (RWS () IntraAcc VNameSource)) (DistAcc GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, LocalScope rep m, DistRep rep) =>
DistAcc rep -> Stms SOACS -> DistNestT rep m (DistAcc rep)
distributeMapBodyStms DistAcc GPU
acc (Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms (Body SOACS -> Stms SOACS) -> Body SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam))
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Scan SOACS]
scans, Lambda SOACS
mapfun) <- ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC ScremaForm SOACS
form,
        -- FIXME: Futhark.CodeGen.ImpGen.GPU.Block.compileGroupOp
        -- cannot handle multiple scan operators yet.
        Scan Lambda SOACS
scanfun [SubExp]
nes <- [Scan SOACS] -> Scan SOACS
forall rep. Buildable rep => [Scan rep] -> Scan rep
singleScan [Scan SOACS]
scans -> do
          let scanfun' :: Lambda GPU
scanfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
scanfun
              mapfun' :: Lambda GPU
mapfun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
mapfun
          Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
            Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segScan SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
pat Certs
forall a. Monoid a => a
mempty SubExp
w [Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
Noncommutative Lambda GPU
scanfun' [SubExp]
nes Shape
forall a. Monoid a => a
mempty] Lambda GPU
mapfun' [VName]
arrs [] []
          [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form)
      | Just ([Reduce SOACS]
reds, Lambda SOACS
map_lam) <- ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall rep. ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC ScremaForm SOACS
form -> do
          let onReduce :: Reduce SOACS -> SegBinOp GPU
onReduce (Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
nes) =
                Commutativity -> Lambda GPU -> [SubExp] -> Shape -> SegBinOp GPU
forall rep.
Commutativity -> Lambda rep -> [SubExp] -> Shape -> SegBinOp rep
SegBinOp Commutativity
comm (Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
red_lam) [SubExp]
nes Shape
forall a. Monoid a => a
mempty
              reds' :: [SegBinOp GPU]
reds' = (Reduce SOACS -> SegBinOp GPU) -> [Reduce SOACS] -> [SegBinOp GPU]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> SegBinOp GPU
onReduce [Reduce SOACS]
reds
              map_lam' :: Lambda GPU
map_lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
map_lam
          Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
            Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat (LetDec GPU)
-> Certs
-> SubExp
-> [SegBinOp GPU]
-> Lambda GPU
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) rep.
(MonadFreshNames m, DistRep rep, HasScope rep m) =>
SegOpLevel rep
-> Pat (LetDec rep)
-> Certs
-> SubExp
-> [SegBinOp rep]
-> Lambda rep
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms rep)
segRed SegOpLevel GPU
SegLevel
lvl Pat (LetDec SOACS)
Pat (LetDec GPU)
pat Certs
forall a. Monoid a => a
mempty SubExp
w [SegBinOp GPU]
reds' Lambda GPU
map_lam' [VName]
arrs [] []
          [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Screma SubExp
w [VName]
arrs ScremaForm SOACS
form) ->
      -- This screma is too complicated for us to immediately do
      -- anything, so split it up and try again.
      (Stm SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntraGroupM ()
intraGroupStm (Stms SOACS -> IntraGroupM ())
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> IntraGroupM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm SOACS -> Stm SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> Stm SOACS -> Stm SOACS
forall rep. Certs -> Stm rep -> Stm rep
certify (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux)) (Stms SOACS -> Stms SOACS)
-> (((), Stms SOACS) -> Stms SOACS)
-> ((), Stms SOACS)
-> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((), Stms SOACS) -> Stms SOACS
forall a b. (a, b) -> b
snd
        (((), Stms SOACS) -> IntraGroupM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
-> Scope SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat
  (LetDec
     (Rep
        (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> SubExp
-> ScremaForm
     (Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> [VName]
-> BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m), Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> ScremaForm (Rep m) -> [VName] -> m ()
dissectScrema Pat
  (LetDec
     (Rep
        (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
Pat (LetDec SOACS)
pat SubExp
w ScremaForm
  (Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
ScremaForm SOACS
form [VName]
arrs) (Scope GPU -> Scope SOACS
scopeForSOACs Scope GPU
scope)
    Op (Hist SubExp
w [VName]
arrs [HistOp SOACS]
ops Lambda SOACS
bucket_fun) -> do
      [HistOp GPU]
ops' <- [HistOp SOACS]
-> (HistOp SOACS
    -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [HistOp SOACS]
ops ((HistOp SOACS
  -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
 -> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU])
-> (HistOp SOACS
    -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> BuilderT GPU (RWS () IntraAcc VNameSource) [HistOp GPU]
forall a b. (a -> b) -> a -> b
$ \(HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes Lambda SOACS
op) -> do
        (Lambda SOACS
op', [SubExp]
nes', Shape
shape) <- Lambda SOACS
-> [SubExp]
-> BuilderT
     GPU (RWS () IntraAcc VNameSource) (Lambda SOACS, [SubExp], Shape)
forall (m :: * -> *).
MonadBuilder m =>
Lambda SOACS -> [SubExp] -> m (Lambda SOACS, [SubExp], Shape)
determineReduceOp Lambda SOACS
op [SubExp]
nes
        let op'' :: Lambda GPU
op'' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
op'
        HistOp GPU
-> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU)
forall a. a -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HistOp GPU
 -> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU))
-> HistOp GPU
-> BuilderT GPU (RWS () IntraAcc VNameSource) (HistOp GPU)
forall a b. (a -> b) -> a -> b
$ Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda GPU
-> HistOp GPU
forall rep.
Shape
-> SubExp
-> [VName]
-> [SubExp]
-> Shape
-> Lambda rep
-> HistOp rep
GPU.HistOp Shape
num_bins SubExp
rf [VName]
dests [SubExp]
nes' Shape
shape Lambda GPU
op''

      let bucket_fun' :: Lambda GPU
bucket_fun' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
bucket_fun
      Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$
        Stms (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
Stms GPU -> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms GPU -> IntraGroupM ())
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
-> IntraGroupM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel GPU
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp GPU]
-> Lambda GPU
-> [VName]
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall rep (m :: * -> *).
(DistRep rep, MonadFreshNames m, HasScope rep m) =>
SegOpLevel rep
-> Pat Type
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp rep]
-> Lambda rep
-> [VName]
-> m (Stms rep)
segHist SegOpLevel GPU
SegLevel
lvl Pat Type
Pat (LetDec SOACS)
pat SubExp
w [] [] [HistOp GPU]
ops' Lambda GPU
bucket_fun' [VName]
arrs
      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Op (Stream SubExp
w [VName]
arrs [SubExp]
accs Lambda SOACS
lam)
      | LParam SOACS
chunk_size_param : [LParam SOACS]
_ <- Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam -> do
          Scope SOACS
types <- (Scope GPU -> Scope SOACS)
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Scope SOACS)
forall a.
(Scope GPU -> a) -> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall rep (m :: * -> *) a.
HasScope rep m =>
(Scope rep -> a) -> m a
asksScope Scope GPU -> Scope SOACS
forall fromrep torep.
SameScope fromrep torep =>
Scope fromrep -> Scope torep
castScope
          ((), Stms SOACS
stream_stms) <-
            BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
-> Scope SOACS
-> BuilderT GPU (RWS () IntraAcc VNameSource) ((), Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT (Pat
  (LetDec
     (Rep
        (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
-> SubExp
-> [SubExp]
-> Lambda
     (Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> [VName]
-> BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)) ()
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
Pat (LetDec (Rep m))
-> SubExp -> [SubExp] -> Lambda (Rep m) -> [VName] -> m ()
sequentialStreamWholeArray Pat
  (LetDec
     (Rep
        (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource)))))
Pat (LetDec SOACS)
pat SubExp
w [SubExp]
accs Lambda
  (Rep (BuilderT SOACS (BuilderT GPU (RWS () IntraAcc VNameSource))))
Lambda SOACS
lam [VName]
arrs) Scope SOACS
types
          let replace :: SubExp -> SubExp
replace (Var VName
v) | VName
v VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
LParam SOACS
chunk_size_param = SubExp
w
              replace SubExp
se = SubExp
se
              replaceSets :: IntraAcc -> IntraAcc
replaceSets (IntraAcc Set [SubExp]
x Set [SubExp]
y Log
log) =
                Set [SubExp] -> Set [SubExp] -> Log -> IntraAcc
IntraAcc (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
x) (([SubExp] -> [SubExp]) -> Set [SubExp] -> Set [SubExp]
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map ((SubExp -> SubExp) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> SubExp
replace) Set [SubExp]
y) Log
log
          (IntraAcc -> IntraAcc) -> IntraGroupM () -> IntraGroupM ()
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor IntraAcc -> IntraAcc
replaceSets (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntraGroupM ()
intraGroupStms Stms SOACS
stream_stms
    Op (Scatter SubExp
w [VName]
ivs Lambda SOACS
lam [(Shape, Int, VName)]
dests) -> do
      VName
write_i <- String -> BuilderT GPU (RWS () IntraAcc VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"write_i"
      SegSpace
space <- [(VName, SubExp)]
-> BuilderT GPU (RWS () IntraAcc VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName
write_i, SubExp
w)]

      let lam' :: Lambda GPU
lam' = Lambda SOACS -> Lambda GPU
soacsLambdaToGPU Lambda SOACS
lam
          krets :: [KernelResult]
krets = do
            (Shape
_a_w, VName
a, [(Result, SubExpRes)]
is_vs) <-
              [(Shape, Int, VName)]
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall array a.
[(Shape, Int, array)] -> [a] -> [(Shape, array, [([a], a)])]
groupScatterResults [(Shape, Int, VName)]
dests (Result -> [(Shape, VName, [(Result, SubExpRes)])])
-> Result -> [(Shape, VName, [(Result, SubExpRes)])]
forall a b. (a -> b) -> a -> b
$ Body GPU -> Result
forall rep. Body rep -> Result
bodyResult (Body GPU -> Result) -> Body GPU -> Result
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'
            let cs :: Certs
cs =
                  ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap ((SubExpRes -> Certs) -> Result -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap SubExpRes -> Certs
resCerts (Result -> Certs)
-> ((Result, SubExpRes) -> Result) -> (Result, SubExpRes) -> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> Result
forall a b. (a, b) -> a
fst) [(Result, SubExpRes)]
is_vs
                    Certs -> Certs -> Certs
forall a. Semigroup a => a -> a -> a
<> ((Result, SubExpRes) -> Certs) -> [(Result, SubExpRes)] -> Certs
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap (SubExpRes -> Certs
resCerts (SubExpRes -> Certs)
-> ((Result, SubExpRes) -> SubExpRes)
-> (Result, SubExpRes)
-> Certs
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Result, SubExpRes) -> SubExpRes
forall a b. (a, b) -> b
snd) [(Result, SubExpRes)]
is_vs
                is_vs' :: [(Slice SubExp, SubExp)]
is_vs' = [([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> DimIndex SubExp) -> Result -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (SubExpRes -> SubExp) -> SubExpRes -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
is, SubExpRes -> SubExp
resSubExp SubExpRes
v) | (Result
is, SubExpRes
v) <- [(Result, SubExpRes)]
is_vs]
            KernelResult -> [KernelResult]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelResult -> [KernelResult]) -> KernelResult -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns Certs
cs VName
a [(Slice SubExp, SubExp)]
is_vs'
          inputs :: [KernelInput]
inputs = do
            (Param Type
p, VName
p_a) <- [Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Param (LParamInfo GPU)]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda GPU
lam') [VName]
ivs
            KernelInput -> [KernelInput]
forall a. a -> [a]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (KernelInput -> [KernelInput]) -> KernelInput -> [KernelInput]
forall a b. (a -> b) -> a -> b
$ VName -> Type -> VName -> [SubExp] -> KernelInput
KernelInput (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p) (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p) VName
p_a [VName -> SubExp
Var VName
write_i]

      Stms GPU
kstms <- BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall (m :: * -> *) somerep rep a.
(MonadFreshNames m, HasScope somerep m, SameScope somerep rep) =>
Builder rep a -> m (Stms rep)
runBuilder_ (BuilderT GPU (State VNameSource) ()
 -> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU))
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (RWS () IntraAcc VNameSource) (Stms GPU)
forall a b. (a -> b) -> a -> b
$
        Scope GPU
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall a.
Scope GPU
-> BuilderT GPU (State VNameSource) a
-> BuilderT GPU (State VNameSource) a
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (SegSpace -> Scope GPU
forall rep. SegSpace -> Scope rep
scopeOfSegSpace SegSpace
space) (BuilderT GPU (State VNameSource) ()
 -> BuilderT GPU (State VNameSource) ())
-> BuilderT GPU (State VNameSource) ()
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ do
          (KernelInput -> BuilderT GPU (State VNameSource) ())
-> [KernelInput] -> BuilderT GPU (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *).
(DistRep (Rep m), MonadBuilder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
          Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep (BuilderT GPU (State VNameSource)))
 -> BuilderT GPU (State VNameSource) ())
-> Stms (Rep (BuilderT GPU (State VNameSource)))
-> BuilderT GPU (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Body (Rep (BuilderT GPU (State VNameSource)))
-> Stms (Rep (BuilderT GPU (State VNameSource)))
forall rep. Body rep -> Stms rep
bodyStms (Body (Rep (BuilderT GPU (State VNameSource)))
 -> Stms (Rep (BuilderT GPU (State VNameSource))))
-> Body (Rep (BuilderT GPU (State VNameSource)))
-> Stms (Rep (BuilderT GPU (State VNameSource)))
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lam'

      Certs -> IntraGroupM () -> IntraGroupM ()
forall a.
Certs
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
-> BuilderT GPU (RWS () IntraAcc VNameSource) a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (StmAux () -> Certs
forall dec. StmAux dec -> Certs
stmAuxCerts StmAux ()
StmAux (ExpDec SOACS)
aux) (IntraGroupM () -> IntraGroupM ())
-> IntraGroupM () -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ do
        let body :: KernelBody GPU
body = BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms [KernelResult]
krets
        Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
Pat (LetDec SOACS)
pat (Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
 -> IntraGroupM ())
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Op (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
forall rep. Op rep -> Exp rep
Op (Op (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
 -> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource))))
-> Op (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> Exp (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel GPU -> HostOp SOAC GPU
forall (op :: * -> *) rep. SegOp SegLevel rep -> HostOp op rep
SegOp (SegOp SegLevel GPU -> HostOp SOAC GPU)
-> SegOp SegLevel GPU -> HostOp SOAC GPU
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace -> [Type] -> KernelBody GPU -> SegOp SegLevel GPU
forall lvl rep.
lvl -> SegSpace -> [Type] -> KernelBody rep -> SegOp lvl rep
SegMap SegLevel
lvl SegSpace
space (Pat Type -> [Type]
forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat Type
Pat (LetDec SOACS)
pat) KernelBody GPU
body

      [SubExp] -> IntraGroupM ()
parallelMin [SubExp
w]
    Exp SOACS
_ ->
      Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
 -> IntraGroupM ())
-> Stm (Rep (BuilderT GPU (RWS () IntraAcc VNameSource)))
-> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stm GPU
soacsStmToGPU Stm SOACS
stm

intraGroupStms :: Stms SOACS -> IntraGroupM ()
intraGroupStms :: Stms SOACS -> IntraGroupM ()
intraGroupStms = (Stm SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm SOACS -> IntraGroupM ()
intraGroupStm

intraGroupParalleliseBody ::
  (MonadFreshNames m, HasScope GPU m) =>
  Body SOACS ->
  m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody :: forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
Body SOACS -> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
intraGroupParalleliseBody Body SOACS
body = do
  (IntraAcc Set [SubExp]
min_ws Set [SubExp]
avail_ws Log
log, Stms GPU
kstms) <-
    IntraGroupM () -> m (IntraAcc, Stms GPU)
forall (m :: * -> *).
(MonadFreshNames m, HasScope GPU m) =>
IntraGroupM () -> m (IntraAcc, Stms GPU)
runIntraGroupM (IntraGroupM () -> m (IntraAcc, Stms GPU))
-> IntraGroupM () -> m (IntraAcc, Stms GPU)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> IntraGroupM ()
intraGroupStms (Stms SOACS -> IntraGroupM ()) -> Stms SOACS -> IntraGroupM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
-> m ([[SubExp]], [[SubExp]], Log, KernelBody GPU)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
min_ws,
      Set [SubExp] -> [[SubExp]]
forall a. Set a -> [a]
S.toList Set [SubExp]
avail_ws,
      Log
log,
      BodyDec GPU -> Stms GPU -> [KernelResult] -> KernelBody GPU
forall rep.
BodyDec rep -> Stms rep -> [KernelResult] -> KernelBody rep
KernelBody () Stms GPU
kstms ([KernelResult] -> KernelBody GPU)
-> [KernelResult] -> KernelBody GPU
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> KernelResult) -> Result -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> KernelResult
ret (Result -> [KernelResult]) -> Result -> [KernelResult]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult Body SOACS
body
    )
  where
    ret :: SubExpRes -> KernelResult
ret (SubExpRes Certs
cs SubExp
se) = ResultManifest -> Certs -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify Certs
cs SubExp
se