{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.KernelBabysitting (babysitKernels) where
import Control.Arrow (first)
import Control.Monad.State.Strict
import Data.Foldable
import Data.List (elemIndex, isPrefixOf, sort)
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.IR
import Futhark.IR.Kernels hiding
( BasicOp,
Body,
Exp,
FParam,
FunDef,
LParam,
Lambda,
PatElem,
Pattern,
Prog,
RetType,
Stm,
)
import Futhark.MonadFreshNames
import Futhark.Pass
import Futhark.Tools
import Futhark.Util
babysitKernels :: Pass Kernels Kernels
babysitKernels :: Pass Kernels Kernels
babysitKernels =
String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
String
"babysit kernels"
String
"Transpose kernel input arrays for better performance."
((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$ (Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall {f :: * -> *}.
MonadFreshNames f =>
Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms
where
onStms :: Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms = do
let m :: BinderT Kernels (State VNameSource) (Stms Kernels)
m = Scope Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels))
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ Map VName (Stm Kernels)
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms Map VName (Stm Kernels)
forall a. Monoid a => a
mempty Stms Kernels
stms
((Stms Kernels, Stms Kernels) -> Stms Kernels)
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stms Kernels, Stms Kernels) -> Stms Kernels
forall a b. (a, b) -> a
fst (f (Stms Kernels, Stms Kernels) -> f (Stms Kernels))
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels))
-> (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels, Stms Kernels)
-> VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BinderT Kernels (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT Kernels (State VNameSource) (Stms Kernels)
m Scope Kernels
forall k a. Map k a
M.empty)
type BabysitM = Binder Kernels
transformStms :: ExpMap -> Stms Kernels -> BabysitM (Stms Kernels)
transformStms :: Map VName (Stm Kernels)
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms Map VName (Stm Kernels)
expmap Stms Kernels
stms = BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource)))))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (Map VName (Stm Kernels)
-> Stm Kernels
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels)))
-> Map VName (Stm Kernels)
-> Stms Kernels
-> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ Map VName (Stm Kernels)
-> Stm Kernels
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels))
transformStm Map VName (Stm Kernels)
expmap Stms Kernels
stms
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody :: Map VName (Stm Kernels) -> Body Kernels -> BabysitM (Body Kernels)
transformBody Map VName (Stm Kernels)
expmap (Body () Stms Kernels
stms Result
res) = do
Stms Kernels
stms' <- Map VName (Stm Kernels)
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms Map VName (Stm Kernels)
expmap Stms Kernels
stms
Body Kernels -> BabysitM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels -> BabysitM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms' Result
res
type ExpMap = M.Map VName (Stm Kernels)
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory :: VName -> Map VName (Stm Kernels) -> Maybe (Maybe [Int])
nonlinearInMemory VName
name Map VName (Stm Kernels)
m =
case VName -> Map VName (Stm Kernels) -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName (Stm Kernels)
m of
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Opaque (Var VName
arr)))) -> VName -> Map VName (Stm Kernels) -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr Map VName (Stm Kernels)
m
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Reshape ShapeChange SubExp
_ VName
arr))) -> VName -> Map VName (Stm Kernels) -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr Map VName (Stm Kernels)
m
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Manifest [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
Just (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
_ (Op (SegOp (SegMap SegLevel
_ SegSpace
_ [Type]
ts KernelBody Kernels
_)))) ->
(PatElemT Type, Type) -> Maybe (Maybe [Int])
forall {shape} {dec} {u}.
(ArrayShape shape, Typed dec) =>
(PatElemT dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear
((PatElemT Type, Type) -> Maybe (Maybe [Int]))
-> Maybe (PatElemT Type, Type) -> Maybe (Maybe [Int])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((PatElemT Type, Type) -> Bool)
-> [(PatElemT Type, Type)] -> Maybe (PatElemT Type, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find
((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
name) (VName -> Bool)
-> ((PatElemT Type, Type) -> VName)
-> (PatElemT Type, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT Type -> VName)
-> ((PatElemT Type, Type) -> PatElemT Type)
-> (PatElemT Type, Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, Type) -> PatElemT Type
forall a b. (a, b) -> a
fst)
([PatElemT Type] -> [Type] -> [(PatElemT Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat) [Type]
ts)
Maybe (Stm Kernels)
_ -> Maybe (Maybe [Int])
forall a. Maybe a
Nothing
where
nonlinear :: (PatElemT dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear (PatElemT dec
pe, TypeBase shape u
t)
| Int
inner_r <- TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t,
Int
inner_r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
let outer_r :: Int
outer_r = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
inner_r
Maybe [Int] -> Maybe (Maybe [Int])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int
inner_r .. Int
inner_r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
outer_r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
inner_r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
| Bool
otherwise = Maybe (Maybe [Int])
forall a. Maybe a
Nothing
transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap
transformStm :: Map VName (Stm Kernels)
-> Stm Kernels
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels))
transformStm Map VName (Stm Kernels)
expmap (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp SegOp SegLevel Kernels
op)))
| SegThread {} <- SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op = do
let mapper :: SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper =
SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
{ mapOnSegOpBody :: KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
mapOnSegOpBody =
Map VName (Stm Kernels)
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody Map VName (Stm Kernels)
expmap (SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op) (SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op)
}
SegOp SegLevel Kernels
op' <- SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
-> SegOp SegLevel Kernels
-> BinderT Kernels (State VNameSource) (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper SegOp SegLevel Kernels
op
let stm' :: Stm Kernels
stm' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel Kernels
op'
Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
stm'
Map VName (Stm Kernels)
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName (Stm Kernels)
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels)))
-> Map VName (Stm Kernels)
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels))
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> Map VName (Stm Kernels)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm Kernels
stm') | VName
name <- PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat] Map VName (Stm Kernels)
-> Map VName (Stm Kernels) -> Map VName (Stm Kernels)
forall a. Semigroup a => a -> a -> a
<> Map VName (Stm Kernels)
expmap
transformStm Map VName (Stm Kernels)
expmap (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) = do
ExpT Kernels
e' <- Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
-> ExpT Kernels -> BabysitM (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (Map VName (Stm Kernels)
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform Map VName (Stm Kernels)
expmap) ExpT Kernels
e
let bnd' :: Stm Kernels
bnd' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e'
Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
bnd'
Map VName (Stm Kernels)
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return (Map VName (Stm Kernels)
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels)))
-> Map VName (Stm Kernels)
-> BinderT Kernels (State VNameSource) (Map VName (Stm Kernels))
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> Map VName (Stm Kernels)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
name, Stm Kernels
bnd') | VName
name <- PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat] Map VName (Stm Kernels)
-> Map VName (Stm Kernels) -> Map VName (Stm Kernels)
forall a. Semigroup a => a -> a -> a
<> Map VName (Stm Kernels)
expmap
transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform :: Map VName (Stm Kernels)
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform Map VName (Stm Kernels)
expmap =
Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper {mapOnBody :: Scope Kernels -> Body Kernels -> BabysitM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> BabysitM (Body Kernels) -> BabysitM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BabysitM (Body Kernels) -> BabysitM (Body Kernels))
-> (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels
-> BabysitM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (Stm Kernels) -> Body Kernels -> BabysitM (Body Kernels)
transformBody Map VName (Stm Kernels)
expmap}
transformKernelBody ::
ExpMap ->
SegLevel ->
SegSpace ->
KernelBody Kernels ->
BabysitM (KernelBody Kernels)
transformKernelBody :: Map VName (Stm Kernels)
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody Map VName (Stm Kernels)
expmap SegLevel
lvl SegSpace
space KernelBody Kernels
kbody = do
Scope Kernels
scope <- BinderT Kernels (State VNameSource) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let thread_gids :: [VName]
thread_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
thread_local :: Names
thread_local = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
thread_gids
free_ker_vars :: Names
free_ker_vars = KernelBody Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody Kernels
kbody Names -> Names -> Names
`namesSubtract` SegSpace -> Names
getKerVariantIds SegSpace
space
SubExp
num_threads <-
String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp
(IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)
(Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl)
(Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
StateT
Replacements
(BinderT Kernels (State VNameSource))
(KernelBody Kernels)
-> Replacements
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT
( Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform
(StateT Replacements (BinderT Kernels (State VNameSource)))
-> KernelBody Kernels
-> StateT
Replacements
(BinderT Kernels (State VNameSource))
(KernelBody Kernels)
forall (f :: * -> *).
(Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes
Names
free_ker_vars
Names
thread_local
(Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space)
(Map VName (Stm Kernels)
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform
(StateT Replacements (BinderT Kernels (State VNameSource)))
forall (m :: * -> *).
MonadBinder m =>
Map VName (Stm Kernels)
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess Map VName (Stm Kernels)
expmap (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) SubExp
num_threads)
KernelBody Kernels
kbody
)
Replacements
forall a. Monoid a => a
mempty
where
getKerVariantIds :: SegSpace -> Names
getKerVariantIds = [VName] -> Names
namesFromList ([VName] -> Names) -> (SegSpace -> [VName]) -> SegSpace -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> (SegSpace -> Map VName (NameInfo Any)) -> SegSpace -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace
type m =
Names ->
(VName -> Bool) ->
(VName -> SubExp -> Bool) ->
(SubExp -> Maybe SubExp) ->
Scope Kernels ->
VName ->
Slice SubExp ->
m (Maybe (VName, Slice SubExp))
traverseKernelBodyArrayIndexes ::
(Applicative f, Monad f) =>
Names ->
Names ->
Scope Kernels ->
ArrayIndexTransform f ->
KernelBody Kernels ->
f (KernelBody Kernels)
traverseKernelBodyArrayIndexes :: forall (f :: * -> *).
(Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes Names
free_ker_vars Names
thread_variant Scope Kernels
outer_scope ArrayIndexTransform f
f (KernelBody () Stms Kernels
kstms [KernelResult]
kres) =
BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ([Stm Kernels] -> Stms Kernels)
-> [Stm Kernels]
-> [KernelResult]
-> KernelBody Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList
([Stm Kernels] -> [KernelResult] -> KernelBody Kernels)
-> f [Stm Kernels] -> f ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM
( (VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm
( VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
forall a. Monoid a => a
mempty Stms Kernels
kstms,
Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
kstms,
Scope Kernels
outer_scope
)
)
(Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
kstms)
f ([KernelResult] -> KernelBody Kernels)
-> f [KernelResult] -> f (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [KernelResult] -> f [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
where
onLambda :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) LambdaT Kernels
lam =
(\Body Kernels
body' -> LambdaT Kernels
lam {lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body'})
(Body Kernels -> LambdaT Kernels)
-> f (Body Kernels) -> f (LambdaT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope') (LambdaT Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam)
where
scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams (LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam)
onBody :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Body BodyDec Kernels
bdec Stms Kernels
stms Result
bres) = do
Stms Kernels
stms' <- [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> f [Stm Kernels] -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance', Map VName SubExp
szsubst', Scope Kernels
scope')) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)
Body Kernels -> f (Body Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body Kernels -> f (Body Kernels))
-> Body Kernels -> f (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec Kernels
bdec Stms Kernels
stms' Result
bres
where
variance' :: VarianceTable
variance' = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
variance Stms Kernels
stms
szsubst' :: Map VName SubExp
szsubst' = Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
stms Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
szsubst
scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms
onStm :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
_) (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (BasicOp (Index VName
arr Slice SubExp
is))) =
Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (ExpT Kernels -> Stm Kernels)
-> (Maybe (VName, Slice SubExp) -> ExpT Kernels)
-> Maybe (VName, Slice SubExp)
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew (Maybe (VName, Slice SubExp) -> Stm Kernels)
-> f (Maybe (VName, Slice SubExp)) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform f
f Names
free_ker_vars VName -> Bool
isThreadLocal VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope Kernels
outer_scope VName
arr Slice SubExp
is
where
oldOrNew :: Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
oldOrNew (Just (VName
arr', Slice SubExp
is')) =
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'
isGidVariant :: VName -> SubExp -> Bool
isGidVariant VName
gid (Var VName
v) =
VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance)
isGidVariant VName
_ SubExp
_ = Bool
False
isThreadLocal :: VName -> Bool
isThreadLocal VName
v =
Names
thread_variant
Names -> Names -> Bool
`namesIntersect` Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance
sizeSubst :: SubExp -> Maybe SubExp
sizeSubst (Constant PrimValue
v) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
sizeSubst (Var VName
v)
| VName
v VName -> Scope Kernels -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope Kernels
outer_scope = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
| Just SubExp
v' <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
szsubst = SubExp -> Maybe SubExp
sizeSubst SubExp
v'
| Bool
otherwise = Maybe SubExp
forall a. Maybe a
Nothing
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec ExpT Kernels
e) =
Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (ExpT Kernels -> Stm Kernels)
-> f (ExpT Kernels) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper Kernels Kernels f -> ExpT Kernels -> f (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope)) ExpT Kernels
e
onOp :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx (OtherOp SOAC Kernels
soac) =
SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> f (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper Kernels Kernels f -> SOAC Kernels -> f (SOAC Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any f
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: LambdaT Kernels -> f (LambdaT Kernels)
mapOnSOACLambda = (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable, Map VName SubExp, Scope Kernels)
ctx} SOAC Kernels
soac
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
_ HostOp Kernels (SOAC Kernels)
op = HostOp Kernels (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp Kernels (SOAC Kernels)
op
mapper :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable, Map VName SubExp, Scope Kernels)
ctx =
Mapper Kernels Kernels f
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
{ mapOnBody :: Scope Kernels -> Body Kernels -> f (Body Kernels)
mapOnBody = (Body Kernels -> f (Body Kernels))
-> Scope Kernels -> Body Kernels -> f (Body Kernels)
forall a b. a -> b -> a
const ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable, Map VName SubExp, Scope Kernels)
ctx),
mapOnOp :: Op Kernels -> f (Op Kernels)
mapOnOp = (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx
}
mkSizeSubsts :: Stms Kernels -> Map VName SubExp
mkSizeSubsts = (Stm Kernels -> Map VName SubExp)
-> Stms Kernels -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm Kernels -> Map VName SubExp
forall {lore} {lore} {op}.
(Op lore ~ HostOp lore op) =>
Stm lore -> Map VName SubExp
mkStmSizeSubst
where
mkStmSizeSubst :: Stm lore -> Map VName SubExp
mkStmSizeSubst (Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (Op (SizeOp (SplitSpace SplitOrdering
_ SubExp
_ SubExp
_ SubExp
elems_per_i)))) =
VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe) SubExp
elems_per_i
mkStmSizeSubst Stm lore
_ = Map VName SubExp
forall a. Monoid a => a
mempty
type Replacements = M.Map (VName, Slice SubExp) VName
ensureCoalescedAccess ::
MonadBinder m =>
ExpMap ->
[(VName, SubExp)] ->
SubExp ->
ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess :: forall (m :: * -> *).
MonadBinder m =>
Map VName (Stm Kernels)
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess
Map VName (Stm Kernels)
expmap
[(VName, SubExp)]
thread_space
SubExp
num_threads
Names
free_ker_vars
VName -> Bool
isThreadLocal
VName -> SubExp -> Bool
isGidVariant
SubExp -> Maybe SubExp
sizeSubst
Scope Kernels
outer_scope
VName
arr
Slice SubExp
slice = do
Maybe VName
seen <- (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName))
-> (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Replacements -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Slice SubExp
slice)
case (Maybe VName
seen, VName -> Bool
isThreadLocal VName
arr, NameInfo Kernels -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo Kernels -> Type)
-> Maybe (NameInfo Kernels) -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Scope Kernels -> Maybe (NameInfo Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Scope Kernels
outer_scope) of
(Just VName
arr', Bool
_, Maybe Type
_) ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
(Maybe VName
Nothing, Bool
False, Just Type
t)
| Just Result
is <- Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t,
Just Result
is' <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is,
Just [Int]
perm <- Result
is' Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
is ->
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> Map VName (Stm Kernels) -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr Map VName (Stm Kernels)
expmap) [Int]
perm VName
arr)
| Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) <- VName -> Map VName (Stm Kernels) -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName (Stm Kernels)
expmap,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
VName
inner_gid <- [VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids,
Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm,
Slice SubExp
slice' <- (Int -> DimIndex SubExp) -> [Int] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp
slice Slice SubExp -> Int -> DimIndex SubExp
forall a. [a] -> Int -> a
!!) [Int]
perm,
DimFix SubExp
inner_ind <- Slice SubExp -> DimIndex SubExp
forall a. [a] -> a
last Slice SubExp
slice',
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
VName -> SubExp -> Bool
isGidVariant VName
inner_gid SubExp
inner_ind ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
| (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
Slice SubExp -> Bool
allDimAreSlice Slice SubExp
rem_slice,
Maybe (Stm Kernels)
Nothing <- VName -> Map VName (Stm Kernels) -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Map VName (Stm Kernels)
expmap,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids Bool -> Bool -> Bool
|| Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
Bool -> Bool
not ([VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids VName -> Names -> Bool
`nameIn` (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
rem_slice)) ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
| (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isThreadLocal (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is) -> do
let perm :: [Int]
perm = Int -> Int -> [Int]
coalescingPermutation (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> Map VName (Stm Kernels) -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr Map VName (Stm Kernels)
expmap) [Int]
perm VName
arr)
| (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Bool) -> Result -> Result -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(==) Result
is (Result -> [Bool]) -> Result -> [Bool]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids,
DimSlice SubExp
offset SubExp
len (Constant PrimValue
stride) : Slice SubExp
_ <- Slice SubExp
rem_slice,
SubExp -> Bool
isThreadLocalSubExp SubExp
offset,
Just {} <- SubExp -> Maybe SubExp
sizeSubst SubExp
len,
PrimValue -> Bool
oneIsh PrimValue
stride -> do
let num_chunks :: PrimExp VName
num_chunks =
if Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is
then TPrimExp Int32 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int32 VName -> PrimExp VName)
-> TPrimExp Int32 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ SubExp -> TPrimExp Int32 VName
pe32 SubExp
num_threads
else
TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
[TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([TPrimExp Int64 VName] -> TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> TPrimExp Int64 VName
forall a b. (a -> b) -> a -> b
$
(SubExp -> TPrimExp Int64 VName)
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 (Result -> [TPrimExp Int64 VName])
-> Result -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Result
thread_gdims
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Int -> SubExp -> PrimExp VName -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Type
t) PrimExp VName
num_chunks VName
arr)
| Just {} <- VName -> Map VName (Stm Kernels) -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr Map VName (Stm Kernels)
expmap ->
case Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice of
Just Result
is
| Just Result
_ <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is ->
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
| Bool
otherwise ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
Maybe Result
_ -> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
(Maybe VName, Bool, Maybe Type)
_ -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
where
([VName]
thread_gids, Result
thread_gdims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
thread_space
replace :: VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace VName
arr' = do
(Replacements -> Replacements) -> StateT Replacements m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Replacements -> Replacements) -> StateT Replacements m ())
-> (Replacements -> Replacements) -> StateT Replacements m ()
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> VName -> Replacements -> Replacements
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, Slice SubExp
slice) VName
arr'
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
isThreadLocalSubExp :: SubExp -> Bool
isThreadLocalSubExp (Var VName
v) = VName -> Bool
isThreadLocal VName
v
isThreadLocalSubExp Constant {} = Bool
False
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice Int32
bs = (Bool, Int32) -> Bool
forall a b. (a, b) -> a
fst ((Bool, Int32) -> Bool)
-> (Slice SubExp -> (Bool, Int32)) -> Slice SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, Int32) -> SubExp -> (Bool, Int32))
-> (Bool, Int32) -> Result -> (Bool, Int32)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
bs) (Result -> (Bool, Int32))
-> (Slice SubExp -> Result) -> Slice SubExp -> (Bool, Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims
where
comb :: (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
x) (Constant (IntValue (Int32Value Int32
d))) = (Int32
d Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
x Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
4, Int32
d Int32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
* Int32
x)
comb (Bool
_, Int32
x) SubExp
_ = (Bool
False, Int32
x)
splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice :: Slice SubExp -> (Result, Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix SubExp
i : Slice SubExp
is) = (Result -> Result)
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (SubExp
i SubExp -> Result -> Result
forall a. a -> [a] -> [a]
:) ((Result, Slice SubExp) -> (Result, Slice SubExp))
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
is
splitSlice Slice SubExp
is = ([], Slice SubExp
is)
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = Bool
True
allDimAreSlice (DimFix SubExp
_ : Slice SubExp
_) = Bool
False
allDimAreSlice (DimIndex SubExp
_ : Slice SubExp
is) = Slice SubExp -> Bool
allDimAreSlice Slice SubExp
is
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes :: Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant Result
tgids Result
is
| (SubExp -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
isCt Result
is =
Maybe Result
forall a. Maybe a
Nothing
| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
free_ker_vars) (Result -> [VName]
subExpVars Result
is) =
Maybe Result
forall a. Maybe a
Nothing
| Result
is Result -> Result -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Result
tgids =
Maybe Result
forall a. Maybe a
Nothing
| Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
tgids),
Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
Var VName
innergid <- Result -> SubExp
forall a. [a] -> a
last Result
tgids,
Int
num_is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& VName -> SubExp -> Bool
isGidVariant VName
innergid (Result -> SubExp
forall a. [a] -> a
last Result
is) =
Result -> Maybe Result
forall a. a -> Maybe a
Just Result
is
| Bool
otherwise =
Result -> Maybe Result
forall a. a -> Maybe a
Just (Result -> Maybe Result) -> Result -> Maybe Result
forall a b. (a -> b) -> a -> b
$ Result -> Result
forall a. [a] -> [a]
reverse (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> (Int, SubExp) -> Result)
-> Result -> [(Int, SubExp)] -> Result
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Result -> (Int, SubExp) -> Result
move (Result -> Result
forall a. [a] -> [a]
reverse Result
is) ([(Int, SubExp)] -> Result) -> [(Int, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 ..] (Result -> Result
forall a. [a] -> [a]
reverse Result
tgids)
where
num_is :: Int
num_is = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is
move :: Result -> (Int, SubExp) -> Result
move Result
is_rev (Int
i, SubExp
tgid)
| Just Int
j <- SubExp -> Result -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex SubExp
tgid Result
is_rev,
Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j,
Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_is =
Int -> Int -> Result -> Result
forall {a} {b} {t}.
(Integral a, Integral b, Show a, Show b, Show t) =>
a -> b -> [t] -> [t]
swap Int
i Int
j Result
is_rev
| Bool
otherwise =
Result
is_rev
swap :: a -> b -> [t] -> [t]
swap a
i b
j [t]
l
| Just t
ix <- a -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [t]
l,
Just t
jx <- b -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth b
j [t]
l =
a -> t -> [t] -> [t]
forall {t} {t}. (Eq t, Num t) => t -> t -> [t] -> [t]
update a
i t
jx ([t] -> [t]) -> [t] -> [t]
forall a b. (a -> b) -> a -> b
$ b -> t -> [t] -> [t]
forall {t} {t}. (Eq t, Num t) => t -> t -> [t] -> [t]
update b
j t
ix [t]
l
| Bool
otherwise =
String -> [t]
forall a. HasCallStack => String -> a
error (String -> [t]) -> String -> [t]
forall a b. (a -> b) -> a -> b
$ String
"coalescedIndexes swap: invalid indices" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (a, b, [t]) -> String
forall a. Show a => a -> String
show (a
i, b
j, [t]
l)
update :: t -> t -> [t] -> [t]
update t
0 t
x (t
_ : [t]
ys) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
ys
update t
i t
x (t
y : [t]
ys) = t
y t -> [t] -> [t]
forall a. a -> [a] -> [a]
: t -> t -> [t] -> [t]
update (t
i t -> t -> t
forall a. Num a => a -> a -> a
-t
1) t
x [t]
ys
update t
_ t
_ [] = String -> [t]
forall a. HasCallStack => String -> a
error String
"coalescedIndexes: update"
isCt :: SubExp -> Bool
isCt :: SubExp -> Bool
isCt (Constant PrimValue
_) = Bool
True
isCt (Var VName
_) = Bool
False
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation Int
num_is Int
rank =
[Int
num_is .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
num_is Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
rearrangeInput ::
MonadBinder m =>
Maybe (Maybe [Int]) ->
[Int] ->
VName ->
m VName
rearrangeInput :: forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just [Int]
current_perm)) [Int]
perm VName
arr
| [Int]
current_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
rearrangeInput Maybe (Maybe [Int])
Nothing [Int]
perm VName
arr
| [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
rearrangeInput (Just Just {}) [Int]
perm VName
arr
| [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr
rearrangeInput Maybe (Maybe [Int])
manifest [Int]
perm VName
arr = do
VName
manifested <- if Maybe (Maybe [Int]) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Maybe [Int])
manifest then VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr else VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_coalesced") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
manifested
rowMajorArray ::
MonadBinder m =>
VName ->
m VName
rowMajorArray :: forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr = do
Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> m Type -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_rowmajor") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] VName
arr
rearrangeSlice ::
MonadBinder m =>
Int ->
SubExp ->
PrimExp VName ->
VName ->
m VName
rearrangeSlice :: forall (m :: * -> *).
MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice Int
d SubExp
w PrimExp VName
num_chunks VName
arr = do
SubExp
num_chunks' <- String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"num_chunks" PrimExp VName
num_chunks
(SubExp
w_padded, SubExp
padding) <- SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
num_chunks'
SubExp
per_chunk <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_chunk" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int64 Safety
Unsafe) SubExp
w_padded SubExp
num_chunks'
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
VName
arr_padded <- SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t
SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk (VName -> String
baseString VName
arr) VName
arr_padded Type
arr_t
where
padArray :: SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t = do
let arr_shape :: Shape
arr_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
padding_shape :: Shape
padding_shape = Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
d Shape
arr_shape SubExp
padding
VName
arr_padding <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padding") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_t) (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
padding_shape)
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padded") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d VName
arr [VName
arr_padding] SubExp
w_padded
rearrange :: SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk String
arr_name VName
arr_padded Type
arr_t = do
let arr_dims :: Result
arr_dims = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
pre_dims :: Result
pre_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
d Result
arr_dims
post_dims :: Result
post_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Result
arr_dims
extradim_shape :: Shape
extradim_shape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ Result
pre_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
num_chunks', SubExp
per_chunk] Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
post_dims
tr_perm :: [Int]
tr_perm = [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
d) ([Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2 .. Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
extradim_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
d] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0])
VName
arr_extradim <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (Result -> ShapeChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
extradim_shape) VName
arr_padded
VName
arr_extradim_tr <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim_tr") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
tr_perm VName
arr_extradim
VName
arr_inv_tr <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
ShapeChange SubExp -> VName -> BasicOp
Reshape
((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion Result
pre_dims ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (SubExp
w_padded SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: Result
post_dims))
VName
arr_extradim_tr
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr_init")
(Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
eSliceArray Int
d VName
arr_inv_tr (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (Exp (Lore m))) -> SubExp -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w)
paddedScanReduceInput ::
MonadBinder m =>
SubExp ->
SubExp ->
m (SubExp, SubExp)
paddedScanReduceInput :: forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
stride = do
SubExp
w_padded <-
String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padded_size"
(ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IntType
-> m (ExpT (Lore m)) -> m (ExpT (Lore m)) -> m (ExpT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf IntType
Int64 (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w) (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
stride)
SubExp
padding <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padding" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) SubExp
w_padded SubExp
w
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
w_padded, SubExp
padding)
type VarianceTable = M.Map VName Names
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
t = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> [Stm Kernels] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
t ([Stm Kernels] -> VarianceTable)
-> (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> VarianceTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
variance Stm Kernels
bnd =
(VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
where
add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)