module Futhark.Optimise.TileLoops.Shared
( TileM,
segMap2D,
segMap3D,
segScatter2D,
VarianceTable,
varianceInStms,
isTileableRedomap,
)
where
import Control.Monad.Reader
import Control.Monad.State
import Data.List (foldl', zip4)
import qualified Data.Map as M
import Futhark.IR.Kernels
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
type TileM = ReaderT (Scope Kernels) (State VNameSource)
segMap2D ::
String ->
SegLevel ->
ResultManifest ->
(SubExp, SubExp) ->
( (VName, VName) ->
Binder Kernels [SubExp]
) ->
Binder Kernels [VName]
segMap2D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap2D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_y, SubExp
dim_x) (VName, VName) -> Binder Kernels [SubExp]
f = do
VName
ltid_xx <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
VName
ltid_yy <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_yy, SubExp
dim_y), (VName
ltid_xx, SubExp
dim_x)]
(([Type]
ts, [SubExp]
res), Stms Kernels
stms) <- Binder Kernels ([Type], [SubExp])
-> BinderT
Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], [SubExp])
-> BinderT
Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels))
-> Binder Kernels ([Type], [SubExp])
-> BinderT
Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
res <- (VName, VName) -> Binder Kernels [SubExp]
f (VName
ltid_yy, VName
ltid_xx)
[Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> [SubExp] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType [SubExp]
res
([Type], [SubExp]) -> Binder Kernels ([Type], [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, [SubExp]
res)
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) [SubExp]
res
segMap3D ::
String ->
SegLevel ->
ResultManifest ->
(SubExp, SubExp, SubExp) ->
( (VName, VName, VName) ->
Binder Kernels [SubExp]
) ->
Binder Kernels [VName]
segMap3D :: String
-> SegLevel
-> ResultManifest
-> (SubExp, SubExp, SubExp)
-> ((VName, VName, VName) -> Binder Kernels [SubExp])
-> Binder Kernels [VName]
segMap3D String
desc SegLevel
lvl ResultManifest
manifest (SubExp
dim_z, SubExp
dim_y, SubExp
dim_x) (VName, VName, VName) -> Binder Kernels [SubExp]
f = do
VName
ltid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
VName
ltid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
VName
ltid_z <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_z"
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_z, SubExp
dim_z), (VName
ltid_y, SubExp
dim_y), (VName
ltid_x, SubExp
dim_x)]
(([Type]
ts, [SubExp]
res), Stms Kernels
stms) <- Binder Kernels ([Type], [SubExp])
-> BinderT
Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels ([Type], [SubExp])
-> BinderT
Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels))
-> Binder Kernels ([Type], [SubExp])
-> BinderT
Kernels (State VNameSource) (([Type], [SubExp]), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
res <- (VName, VName, VName) -> Binder Kernels [SubExp]
f (VName
ltid_z, VName
ltid_y, VName
ltid_x)
[Type]
ts <- (SubExp -> BinderT Kernels (State VNameSource) Type)
-> [SubExp] -> BinderT Kernels (State VNameSource) [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType [SubExp]
res
([Type], [SubExp]) -> Binder Kernels ([Type], [SubExp])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Type]
ts, [SubExp]
res)
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$
Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$
SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
segspace [Type]
ts (KernelBody Kernels -> SegOp SegLevel Kernels)
-> KernelBody Kernels -> SegOp SegLevel Kernels
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms ([KernelResult] -> KernelBody Kernels)
-> [KernelResult] -> KernelBody Kernels
forall a b. (a -> b) -> a -> b
$ (SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
manifest) [SubExp]
res
segScatter2D ::
String ->
SubExp ->
VName ->
SegLevel ->
(SubExp, SubExp) ->
((VName, VName) -> Binder Kernels (SubExp, SubExp)) ->
Binder Kernels [VName]
segScatter2D :: String
-> SubExp
-> VName
-> SegLevel
-> (SubExp, SubExp)
-> ((VName, VName) -> Binder Kernels (SubExp, SubExp))
-> Binder Kernels [VName]
segScatter2D String
desc SubExp
arr_size VName
updt_arr SegLevel
lvl (SubExp
dim_x, SubExp
dim_y) (VName, VName) -> Binder Kernels (SubExp, SubExp)
f = do
VName
ltid_x <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_x"
VName
ltid_y <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_y"
VName
ltid_flat <- String -> BinderT Kernels (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ltid_flat"
let segspace :: SegSpace
segspace = VName -> [(VName, SubExp)] -> SegSpace
SegSpace VName
ltid_flat [(VName
ltid_x, SubExp
dim_x), (VName
ltid_y, SubExp
dim_y)]
((Type
t_v, SubExp
res_v, SubExp
res_i), Stms Kernels
stms) <- Binder Kernels (Type, SubExp, SubExp)
-> BinderT
Kernels (State VNameSource) ((Type, SubExp, SubExp), Stms Kernels)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder Kernels (Type, SubExp, SubExp)
-> BinderT
Kernels (State VNameSource) ((Type, SubExp, SubExp), Stms Kernels))
-> Binder Kernels (Type, SubExp, SubExp)
-> BinderT
Kernels (State VNameSource) ((Type, SubExp, SubExp), Stms Kernels)
forall a b. (a -> b) -> a -> b
$ do
(SubExp
res_v, SubExp
res_i) <- (VName, VName) -> Binder Kernels (SubExp, SubExp)
f (VName
ltid_x, VName
ltid_y)
Type
t_v <- SubExp -> BinderT Kernels (State VNameSource) Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
res_v
(Type, SubExp, SubExp) -> Binder Kernels (Type, SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Type
t_v, SubExp
res_v, SubExp
res_i)
let ret :: KernelResult
ret = Shape -> VName -> [(Slice SubExp, SubExp)] -> KernelResult
WriteReturns ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp
arr_size]) VName
updt_arr [([SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
res_i], SubExp
res_v)]
let body :: KernelBody Kernels
body = BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () Stms Kernels
stms [KernelResult
ret]
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> Binder Kernels [VName]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [VName]
letTupExp String
desc (ExpT Kernels -> Binder Kernels [VName])
-> (ExpT Kernels
-> BinderT Kernels (State VNameSource) (ExpT Kernels))
-> ExpT Kernels
-> Binder Kernels [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< ExpT Kernels -> BinderT Kernels (State VNameSource) (ExpT Kernels)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Exp lore -> m (Exp lore)
renameExp (ExpT Kernels -> Binder Kernels [VName])
-> ExpT Kernels -> Binder Kernels [VName]
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp (SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels))
-> SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall a b. (a -> b) -> a -> b
$ SegLevel
-> SegSpace
-> [Type]
-> KernelBody Kernels
-> SegOp SegLevel Kernels
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegLevel
lvl SegSpace
segspace [Type
t_v] KernelBody Kernels
body
type VarianceTable = M.Map VName Names
isTileableRedomap ::
Stm Kernels ->
Maybe
( SubExp,
[VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels)
)
isTileableRedomap :: Stm Kernels
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
isTileableRedomap Stm Kernels
stm
| Op (OtherOp (Screma SubExp
w [VName]
arrs ScremaForm Kernels
form)) <- Stm Kernels -> ExpT Kernels
forall lore. Stm lore -> Exp lore
stmExp Stm Kernels
stm,
Just ([Reduce Kernels]
reds, Lambda Kernels
map_lam) <- ScremaForm Kernels -> Maybe ([Reduce Kernels], Lambda Kernels)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm Kernels
form,
Reduce Commutativity
red_comm Lambda Kernels
red_lam [SubExp]
red_nes <- [Reduce Kernels] -> Reduce Kernels
forall lore. Bindable lore => [Reduce lore] -> Reduce lore
singleReduce [Reduce Kernels]
reds,
(Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam,
(Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (Type -> Type) -> (Param Type -> Type) -> Param Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam,
Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam [Type] -> [Type] -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
red_lam,
Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
arrs),
(Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda Kernels
map_lam,
(Param Type -> Bool) -> [Param Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType (Type -> Bool) -> (Param Type -> Type) -> Param Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType) ([Param Type] -> Bool) -> [Param Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam =
(SubExp, [VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
forall a. a -> Maybe a
Just (SubExp
w, [VName]
arrs, (Commutativity
red_comm, Lambda Kernels
red_lam, [SubExp]
red_nes, Lambda Kernels
map_lam))
| Bool
otherwise =
Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
forall a. Maybe a
Nothing
defVarianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm VarianceTable
variance Stm Kernels
bnd =
(VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
forall {k}. Ord k => Map k Names -> k -> Map k Names
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
where
add :: Map k Names -> k -> Map k Names
add Map k Names
variance' k
v = k -> Names -> Map k Names -> Map k Names
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Names
binding_variance Map k Names
variance'
look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
v0 bnd :: Stm Kernels
bnd@(Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (Op (OtherOp Screma {})))
| Just (SubExp
_, [VName]
arrs, (Commutativity
_, Lambda Kernels
red_lam, [SubExp]
red_nes, Lambda Kernels
map_lam)) <- Stm Kernels
-> Maybe
(SubExp, [VName],
(Commutativity, Lambda Kernels, [SubExp], Lambda Kernels))
isTileableRedomap Stm Kernels
bnd =
let v :: VarianceTable
v = VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm Kernels
bnd
red_ps :: [LParam Kernels]
red_ps = Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
red_lam
map_ps :: [LParam Kernels]
map_ps = Lambda Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda Kernels
map_lam
card_red :: Int
card_red = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes
acc_lam_f :: [Param Type]
acc_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
arr_lam_f :: [Param Type]
arr_lam_f = Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop (Int
card_red Int -> Int -> Int
forall a. Integral a => a -> a -> a
`quot` Int
2) [Param Type]
red_ps
stm_lam :: Stms Kernels
stm_lam = BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
map_lam) Stms Kernels -> Stms Kernels -> Stms Kernels
forall a. Semigroup a => a -> a -> a
<> BodyT Kernels -> Stms Kernels
forall lore. BodyT lore -> Stms lore
bodyStms (Lambda Kernels -> BodyT Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda Kernels
red_lam)
f :: VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
vacc (VName
v_a, VName
v_fm, VName
v_fr_acc, VName
v_fr_var) =
let vrc :: Names
vrc = VName -> Names
oneName VName
v_a Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v_a VarianceTable
vacc
vacc' :: VarianceTable
vacc' = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fm Names
vrc VarianceTable
vacc
vrc' :: Names
vrc' = VName -> Names
oneName VName
v_fm Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc
in VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_acc (VName -> Names
oneName VName
v_fr_var Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
vrc') (VarianceTable -> VarianceTable) -> VarianceTable -> VarianceTable
forall a b. (a -> b) -> a -> b
$ VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v_fr_var Names
vrc' VarianceTable
vacc'
v' :: VarianceTable
v' =
(VarianceTable -> (VName, VName, VName, VName) -> VarianceTable)
-> VarianceTable -> [(VName, VName, VName, VName)] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> (VName, VName, VName, VName) -> VarianceTable
f VarianceTable
v ([(VName, VName, VName, VName)] -> VarianceTable)
-> [(VName, VName, VName, VName)] -> VarianceTable
forall a b. (a -> b) -> a -> b
$
[VName]
-> [VName] -> [VName] -> [VName] -> [(VName, VName, VName, VName)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [VName]
arrs ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_ps) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
acc_lam_f) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
arr_lam_f)
in VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
v' Stms Kernels
stm_lam
varianceInStm VarianceTable
v0 Stm Kernels
bnd = VarianceTable -> Stm Kernels -> VarianceTable
defVarianceInStm VarianceTable
v0 Stm Kernels
bnd
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> Stms Kernels -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm