module Futhark.Optimise.TileLoops.Shared
  ( TileM,
    segMap2D,
    segMap3D,
    segScatter2D,
    VarianceTable,
    varianceInStms,
    isTileableRedomap,
  )
where

import Control.Monad.Reader
import Control.Monad.State
import Data.List (foldl', zip4)
import qualified Data.Map as M
import Futhark.IR.Kernels
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename

type TileM = ReaderT (Scope Kernels) (State VNameSource)

segMap2D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp) -> -- (dim_x, dim_y)
  ( (VName, VName) -> -- f
    Binder Kernels [SubExp]
  ) ->
  Binder Kernels [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_y, SubExp
dim_x) (VName, VName) -> Binder Kernels [SubExp]
f = do
  VName
ltid_xx <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_yy <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_yy, SubExp
dim_y), (VName
ltid_xx, SubExp
dim_x)]

  (([Type]
ts, [SubExp]
res), Stms Kernels
stms) <- Binder Kernels ([Type], [SubExp])
-> BinderT
     Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], [SubExp])
 -> BinderT
      Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels))
-> Binder Kernels ([Type], [SubExp])
-> BinderT
     Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
res <- (VName, VName) -> Binder Kernels [SubExp]
f (VName
ltid_yy, VName
ltid_xx)
    [Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> [SubExp] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType [SubExp]
res
    ([Type], [SubExp]) -> Binder Kernels ([Type], [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, [SubExp]
res)

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
    Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
      SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) [SubExp]
res

segMap3D ::
  String -> -- desc
  SegLevel -> -- lvl
  ResultManifest -> -- manifest
  (SubExp, SubExp, SubExp) -> -- (dim_z, dim_y, dim_x)
  ( (VName, VName, VName) -> -- f
    Binder Kernels [SubExp]
  ) ->
  Binder Kernels [VName]
segMap3D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap3D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_z, SubExp
dim_y, SubExp
dim_x) (VName, VName, VName) -> Binder Kernels [SubExp]
f = do
  VName
ltid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  VName
ltid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_z <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_z"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_z, SubExp
dim_z), (VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]

  (([Type]
ts, [SubExp]
res), Stms Kernels
stms) <- Binder Kernels ([Type], [SubExp])
-> BinderT
     Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], [SubExp])
 -> BinderT
      Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels))
-> Binder Kernels ([Type], [SubExp])
-> BinderT
     Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    [SubExp]
res <- (VName, VName, VName) -> Binder Kernels [SubExp]
f (VName
ltid_z, VName
ltid_y, VName
ltid_x)
    [Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> [SubExp] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType [SubExp]
res
    ([Type], [SubExp]) -> Binder Kernels ([Type], [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, [SubExp]
res)

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
    Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
      SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
        SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) [SubExp]
res

segScatter2D ::
  String -> -- desc
  SubExp -> -- arr_size
  VName ->
  SegLevel -> -- lvl
  (SubExp, SubExp) -> -- (dim_y, dim_x)
  ((VName, VName) -> Binder Kernels (SubExp, SubExp)) -> -- f
  Binder Kernels [VName]
segScatter2D :: String
-> SubExp
-> VName
-> SegLevel
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels (SubExp, SubExp))
-> Binder Kernels [VName]
segScatter2D String
desc SubExp
arr_size VName
updt_arr SegLevel
lvl (SubExp
dim_x, SubExp
dim_y) (VName, VName) -> Binder Kernels (SubExp, SubExp)
f = do
  VName
ltid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
  VName
ltid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
  VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
  let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_x, SubExp
dim_x), (VName
ltid_y, SubExp
dim_y)]

  ((Type
t_v, SubExp
res_v, SubExp
res_i), Stms Kernels
stms) <- Binder Kernels (Type, SubExp, SubExp)
-> BinderT
     Kernels (State VNameSource) ((Type, SubExp, SubExp), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (Type, SubExp, SubExp)
 -> BinderT
      Kernels (State VNameSource) ((Type, SubExp, SubExp), Stms Kernels))
-> Binder Kernels (Type, SubExp, SubExp)
-> BinderT
     Kernels (State VNameSource) ((Type, SubExp, SubExp), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
    (SubExp
res_v, SubExp
res_i) <- (VName, VName) -> Binder Kernels (SubExp, SubExp)
f (VName
ltid_x, VName
ltid_y)
    Type
t_v <- SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
res_v
    (Type, SubExp, SubExp) -> Binder Kernels (Type, SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t_v, SubExp
res_v, SubExp
res_i)

  let ret :: KernelResult
ret = Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
arr_size]) VName
updt_arr [([SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
  let body :: KernelBody Kernels
body = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms [KernelResult
ret]

  String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
    -> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
segspace [Type
t_v] KernelBody Kernels
body

-- | The variance table keeps a mapping from a variable name
-- (something produced by a 'Stm') to the kernel thread indices
-- that name depends on.  If a variable is not present in this table,
-- that means it is bound outside the kernel (and so can be considered
-- invariant to all dimensions).
type VarianceTable = M.Map VName Names

isTileableRedomap ::
  Stm Kernels ->
  Maybe
    ( SubExp,
      [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
    )
isTileableRedomap :: Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
isTileableRedomap Stm Kernels
stm
  | Op (OtherOp (Screma w arrs form)) <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm,
    Just ([Reduce Kernels]
reds, Lambda Kernels
map_lam) <- ScremaForm Kernels -> Maybe ([Reduce Kernels], Lambda Kernels)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm Kernels
form,
    Reduce Commutativity
red_comm Lambda Kernels
red_lam [SubExp]
red_nes <- [Reduce Kernels] -> Reduce Kernels
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce Kernels]
reds,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam,
    Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam, -- No mapout arrays.
    Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs),
    (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam,
    (Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam =
    (SubExp, [VName],
 (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda Kernels
red_lam, [SubExp]
red_nes, Lambda Kernels
map_lam))
  | Bool
otherwise =
    Maybe
  (SubExp, [VName],
   (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
forall a. Maybe a
Nothing

defVarianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm VarianceTable
variance Stm Kernels
bnd =
  (VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
forall k. Ord k => Map k Names -> k -> Map k Names
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
  where
    add :: Map k Names -> k -> Map k Names
add Map k Names
variance' k
v = k -> Names -> Map k Names -> Map k Names
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Names
binding_variance Map k Names
variance'
    look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
    binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)

-- just in case you need the Screma being treated differently than
-- by default; previously Cosmin had to enhance it when dealing with stream.
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
v0 bnd :: Stm Kernels
bnd@(Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (Op (OtherOp Screma {})))
  | Just (SubExp
_, [VName]
arrs, (Commutativity
_, Lambda Kernels
red_lam, [SubExp]
red_nes, Lambda Kernels
map_lam)) <- Stm Kernels
-> Maybe
     (SubExp, [VName],
      (Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
isTileableRedomap Stm Kernels
bnd =
    let v :: VarianceTable
v = VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm Kernels
bnd
        red_ps :: [LParam Kernels]
red_ps = Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam
        map_ps :: [LParam Kernels]
map_ps = Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam
        card_red :: Int
card_red = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes
        acc_lam_f :: [Param Type]
acc_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
        arr_lam_f :: [Param Type]
arr_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
        stm_lam :: Stms Kernels
stm_lam = BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
map_lam) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
red_lam)

        f :: VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
vacc (VName
v_a, VName
v_fm, VName
v_fr_acc, VName
v_fr_var) =
          let vrc :: Names
vrc = VName -> Names
oneName VName
v_a Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v_a VarianceTable
vacc
              vacc' :: VarianceTable
vacc' = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fm Names
vrc VarianceTable
vacc
              vrc' :: Names
vrc' = VName -> Names
oneName VName
v_fm Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc
           in VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_acc (VName -> Names
oneName VName
v_fr_var Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc') (VarianceTable -> VarianceTable) -> VarianceTable -> VarianceTable
forall a b. (a -> b) -> a -> b
$ VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_var Names
vrc' VarianceTable
vacc'

        v' :: VarianceTable
v' =
          (VarianceTable -> (VName, VName, VName, VName) -> VarianceTable)
-> VarianceTable -> [(VName, VName, VName, VName)] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
v ([(VName, VName, VName, VName)] -> VarianceTable)
-> [(VName, VName, VName, VName)] -> VarianceTable
forall a b. (a -> b) -> a -> b
$
            [VName]
-> [VName] -> [VName] -> [VName] -> [(VName, VName, VName, VName)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
arrs ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_ps) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_lam_f) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_lam_f)
     in VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
v' Stms Kernels
stm_lam
varianceInStm VarianceTable
v0 Stm Kernels
bnd = VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm Kernels
bnd

varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> Stms Kernels -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm