{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.KernelBabysitting ( babysitKernels )
where
import Control.Arrow (first)
import Control.Monad.State.Strict
import qualified Data.Map.Strict as M
import Data.Foldable
import Data.List (elemIndex, isPrefixOf, sort)
import Data.Maybe
import Futhark.MonadFreshNames
import Futhark.IR
import Futhark.IR.Kernels
hiding (Prog, Body, Stm, Pattern, PatElem,
BasicOp, Exp, Lambda, FunDef, FParam, LParam, RetType)
import Futhark.Tools
import Futhark.Pass
import Futhark.Util
babysitKernels :: Pass Kernels Kernels
babysitKernels :: Pass Kernels Kernels
babysitKernels = String
-> String
-> (Prog Kernels -> PassM (Prog Kernels))
-> Pass Kernels Kernels
forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass String
"babysit kernels"
String
"Transpose kernel input arrays for better performance." ((Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels)
-> (Prog Kernels -> PassM (Prog Kernels)) -> Pass Kernels Kernels
forall a b. (a -> b) -> a -> b
$
(Scope Kernels -> Stms Kernels -> PassM (Stms Kernels))
-> Prog Kernels -> PassM (Prog Kernels)
forall lore.
(Scope lore -> Stms lore -> PassM (Stms lore))
-> Prog lore -> PassM (Prog lore)
intraproceduralTransformation Scope Kernels -> Stms Kernels -> PassM (Stms Kernels)
forall (f :: * -> *).
MonadFreshNames f =>
Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms
where onStms :: Scope Kernels -> Stms Kernels -> f (Stms Kernels)
onStms Scope Kernels
scope Stms Kernels
stms = do
let m :: BinderT Kernels (State VNameSource) (Stms Kernels)
m = Scope Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels))
-> BinderT Kernels (State VNameSource) (Stms Kernels)
-> BinderT Kernels (State VNameSource) (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
forall a. Monoid a => a
mempty Stms Kernels
stms
((Stms Kernels, Stms Kernels) -> Stms Kernels)
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stms Kernels, Stms Kernels) -> Stms Kernels
forall a b. (a, b) -> a
fst (f (Stms Kernels, Stms Kernels) -> f (Stms Kernels))
-> f (Stms Kernels, Stms Kernels) -> f (Stms Kernels)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels))
-> (VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource))
-> f (Stms Kernels, Stms Kernels)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms Kernels, Stms Kernels)
-> VNameSource -> ((Stms Kernels, Stms Kernels), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (BinderT Kernels (State VNameSource) (Stms Kernels)
-> Scope Kernels -> State VNameSource (Stms Kernels, Stms Kernels)
forall (m :: * -> *) lore a.
MonadFreshNames m =>
BinderT lore m a -> Scope lore -> m (a, Stms lore)
runBinderT BinderT Kernels (State VNameSource) (Stms Kernels)
m Scope Kernels
forall k a. Map k a
M.empty)
type BabysitM = Binder Kernels
transformStms :: ExpMap -> Stms Kernels -> BabysitM (Stms Kernels)
transformStms :: ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
expmap Stms Kernels
stms = BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource))))
forall (m :: * -> *) a. MonadBinder m => m a -> m (Stms (Lore m))
collectStms_ (BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource)))))
-> BinderT Kernels (State VNameSource) ()
-> BinderT
Kernels
(State VNameSource)
(Stms (Lore (BinderT Kernels (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ (ExpMap
-> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> Stms Kernels -> BinderT Kernels (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ExpMap -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap
transformStm ExpMap
expmap Stms Kernels
stms
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody :: ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody ExpMap
expmap (Body () Stms Kernels
stms Result
res) = do
Stms Kernels
stms' <- ExpMap
-> Stms Kernels
-> BinderT Kernels (State VNameSource) (Stms Kernels)
transformStms ExpMap
expmap Stms Kernels
stms
Body Kernels -> BabysitM (Body Kernels)
forall (m :: * -> *) a. Monad m => a -> m a
return (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels -> BabysitM (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body () Stms Kernels
stms' Result
res
type ExpMap = M.Map VName (Stm Kernels)
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory :: VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
name ExpMap
m =
case VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name ExpMap
m of
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Opaque (Var VName
arr)))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Reshape ShapeChange SubExp
_ VName
arr))) -> VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
m
Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Manifest [Int]
perm VName
_))) -> Maybe [Int] -> Maybe (Maybe [Int])
forall a. a -> Maybe a
Just (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
Just (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
_ (Op (SegOp (SegMap _ _ ts _)))) ->
(PatElemT Type, Type) -> Maybe (Maybe [Int])
forall shape dec u.
(ArrayShape shape, Typed dec) =>
(PatElemT dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear ((PatElemT Type, Type) -> Maybe (Maybe [Int]))
-> Maybe (PatElemT Type, Type) -> Maybe (Maybe [Int])
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((PatElemT Type, Type) -> Bool)
-> [(PatElemT Type, Type)] -> Maybe (PatElemT Type, Type)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
==VName
name) (VName -> Bool)
-> ((PatElemT Type, Type) -> VName)
-> (PatElemT Type, Type)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName (PatElemT Type -> VName)
-> ((PatElemT Type, Type) -> PatElemT Type)
-> (PatElemT Type, Type)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PatElemT Type, Type) -> PatElemT Type
forall a b. (a, b) -> a
fst)
([PatElemT Type] -> [Type] -> [(PatElemT Type, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern Kernels
pat) [Type]
ts)
Maybe (Stm Kernels)
_ -> Maybe (Maybe [Int])
forall a. Maybe a
Nothing
where nonlinear :: (PatElemT dec, TypeBase shape u) -> Maybe (Maybe [Int])
nonlinear (PatElemT dec
pe, TypeBase shape u
t)
| Int
inner_r <- TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t, Int
inner_r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = do
let outer_r :: Int
outer_r = Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (PatElemT dec -> Type
forall dec. Typed dec => PatElemT dec -> Type
patElemType PatElemT dec
pe) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
inner_r
Maybe [Int] -> Maybe (Maybe [Int])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [Int] -> Maybe (Maybe [Int]))
-> Maybe [Int] -> Maybe (Maybe [Int])
forall a b. (a -> b) -> a -> b
$ [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just ([Int] -> Maybe [Int]) -> [Int] -> Maybe [Int]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ [Int
inner_r..Int
inner_rInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
outer_rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
inner_rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
| Bool
otherwise = Maybe (Maybe [Int])
forall a. Maybe a
Nothing
transformStm :: ExpMap -> Stm Kernels -> BabysitM ExpMap
transformStm :: ExpMap -> Stm Kernels -> BinderT Kernels (State VNameSource) ExpMap
transformStm ExpMap
expmap (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (Op (SegOp op))) = do
let mapper :: SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper = SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lvl lore.
Monad m =>
SegOpMapper lvl lore lore m
identitySegOpMapper
{ mapOnSegOpBody :: KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
mapOnSegOpBody =
ExpMap
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody ExpMap
expmap (SegOp SegLevel Kernels -> SegLevel
forall lvl lore. SegOp lvl lore -> lvl
segLevel SegOp SegLevel Kernels
op) (SegOp SegLevel Kernels -> SegSpace
forall lvl lore. SegOp lvl lore -> SegSpace
segSpace SegOp SegLevel Kernels
op)
}
SegOp SegLevel Kernels
op' <- SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
-> SegOp SegLevel Kernels
-> BinderT Kernels (State VNameSource) (SegOp SegLevel Kernels)
forall (m :: * -> *) lvl flore tlore.
(Applicative m, Monad m) =>
SegOpMapper lvl flore tlore m
-> SegOp lvl flore -> m (SegOp lvl tlore)
mapSegOpM SegOpMapper
SegLevel Kernels Kernels (BinderT Kernels (State VNameSource))
mapper SegOp SegLevel Kernels
op
let stm' :: Stm Kernels
stm' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux (ExpT Kernels -> Stm Kernels) -> ExpT Kernels -> Stm Kernels
forall a b. (a -> b) -> a -> b
$ Op Kernels -> ExpT Kernels
forall lore. Op lore -> ExpT lore
Op (Op Kernels -> ExpT Kernels) -> Op Kernels -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ SegOp SegLevel Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel Kernels
op'
Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
stm'
ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpMap -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (VName
name, Stm Kernels
stm') | VName
name <- PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat ] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap
transformStm ExpMap
expmap (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e) = do
ExpT Kernels
e' <- Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
-> ExpT Kernels -> BabysitM (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM (ExpMap
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform ExpMap
expmap) ExpT Kernels
e
let bnd' :: Stm Kernels
bnd' = Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
aux ExpT Kernels
e'
Stm (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) ()
forall (m :: * -> *). MonadBinder m => Stm (Lore m) -> m ()
addStm Stm (Lore (BinderT Kernels (State VNameSource)))
Stm Kernels
bnd'
ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall (m :: * -> *) a. Monad m => a -> m a
return (ExpMap -> BinderT Kernels (State VNameSource) ExpMap)
-> ExpMap -> BinderT Kernels (State VNameSource) ExpMap
forall a b. (a -> b) -> a -> b
$ [(VName, Stm Kernels)] -> ExpMap
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [ (VName
name, Stm Kernels
bnd') | VName
name <- PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern Kernels
pat ] ExpMap -> ExpMap -> ExpMap
forall a. Semigroup a => a -> a -> a
<> ExpMap
expmap
transform :: ExpMap -> Mapper Kernels Kernels BabysitM
transform :: ExpMap
-> Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
transform ExpMap
expmap =
Mapper Kernels Kernels (BinderT Kernels (State VNameSource))
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels -> Body Kernels -> BabysitM (Body Kernels)
mapOnBody = \Scope Kernels
scope -> Scope Kernels -> BabysitM (Body Kernels) -> BabysitM (Body Kernels)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope Kernels
scope (BabysitM (Body Kernels) -> BabysitM (Body Kernels))
-> (Body Kernels -> BabysitM (Body Kernels))
-> Body Kernels
-> BabysitM (Body Kernels)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpMap -> Body Kernels -> BabysitM (Body Kernels)
transformBody ExpMap
expmap }
transformKernelBody :: ExpMap -> SegLevel -> SegSpace -> KernelBody Kernels
-> BabysitM (KernelBody Kernels)
transformKernelBody :: ExpMap
-> SegLevel
-> SegSpace
-> KernelBody Kernels
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
transformKernelBody ExpMap
expmap SegLevel
lvl SegSpace
space KernelBody Kernels
kbody = do
Scope Kernels
scope <- BinderT Kernels (State VNameSource) (Scope Kernels)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let thread_gids :: [VName]
thread_gids = ((VName, SubExp) -> VName) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> VName
forall a b. (a, b) -> a
fst ([(VName, SubExp)] -> [VName]) -> [(VName, SubExp)] -> [VName]
forall a b. (a -> b) -> a -> b
$ SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space
thread_local :: Names
thread_local = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat SegSpace
space VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
thread_gids
free_ker_vars :: Names
free_ker_vars = KernelBody Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn KernelBody Kernels
kbody Names -> Names -> Names
`namesSubtract` SegSpace -> Names
getKerVariantIds SegSpace
space
SubExp
num_threads <- String
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"num_threads" (Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp)
-> Exp (Lore (BinderT Kernels (State VNameSource)))
-> BinderT Kernels (State VNameSource) SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int32 Overflow
OverflowUndef)
(Count NumGroups SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count NumGroups SubExp -> SubExp)
-> Count NumGroups SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count NumGroups SubExp
segNumGroups SegLevel
lvl) (Count GroupSize SubExp -> SubExp
forall u e. Count u e -> e
unCount (Count GroupSize SubExp -> SubExp)
-> Count GroupSize SubExp -> SubExp
forall a b. (a -> b) -> a -> b
$ SegLevel -> Count GroupSize SubExp
segGroupSize SegLevel
lvl)
StateT
Replacements
(BinderT Kernels (State VNameSource))
(KernelBody Kernels)
-> Replacements
-> BinderT Kernels (State VNameSource) (KernelBody Kernels)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform
(StateT Replacements (BinderT Kernels (State VNameSource)))
-> KernelBody Kernels
-> StateT
Replacements
(BinderT Kernels (State VNameSource))
(KernelBody Kernels)
forall (f :: * -> *).
(Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes
Names
free_ker_vars
Names
thread_local
(Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> SegSpace -> Scope Kernels
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space)
(ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform
(StateT Replacements (BinderT Kernels (State VNameSource)))
forall (m :: * -> *).
MonadBinder m =>
ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap (SegSpace -> [(VName, SubExp)]
unSegSpace SegSpace
space) SubExp
num_threads)
KernelBody Kernels
kbody)
Replacements
forall a. Monoid a => a
mempty
where getKerVariantIds :: SegSpace -> Names
getKerVariantIds = [VName] -> Names
namesFromList ([VName] -> Names) -> (SegSpace -> [VName]) -> SegSpace -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName (NameInfo Any) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (NameInfo Any) -> [VName])
-> (SegSpace -> Map VName (NameInfo Any)) -> SegSpace -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegSpace -> Map VName (NameInfo Any)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace
type m =
Names ->
(VName -> Bool) ->
(VName -> SubExp -> Bool)->
(SubExp -> Maybe SubExp) ->
Scope Kernels ->
VName -> Slice SubExp -> m (Maybe (VName, Slice SubExp))
traverseKernelBodyArrayIndexes :: (Applicative f, Monad f) =>
Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes :: Names
-> Names
-> Scope Kernels
-> ArrayIndexTransform f
-> KernelBody Kernels
-> f (KernelBody Kernels)
traverseKernelBodyArrayIndexes Names
free_ker_vars Names
thread_variant Scope Kernels
outer_scope ArrayIndexTransform f
f (KernelBody () Stms Kernels
kstms [KernelResult]
kres) =
BodyDec Kernels
-> Stms Kernels -> [KernelResult] -> KernelBody Kernels
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms Kernels -> [KernelResult] -> KernelBody Kernels)
-> ([Stm Kernels] -> Stms Kernels)
-> [Stm Kernels]
-> [KernelResult]
-> KernelBody Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> [KernelResult] -> KernelBody Kernels)
-> f [Stm Kernels] -> f ([KernelResult] -> KernelBody Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
(Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
forall a. Monoid a => a
mempty Stms Kernels
kstms,
Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
kstms,
Scope Kernels
outer_scope)) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
kstms) f ([KernelResult] -> KernelBody Kernels)
-> f [KernelResult] -> f (KernelBody Kernels)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*>
[KernelResult] -> f [KernelResult]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [KernelResult]
kres
where onLambda :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) LambdaT Kernels
lam =
(\Body Kernels
body' -> LambdaT Kernels
lam { lambdaBody :: Body Kernels
lambdaBody = Body Kernels
body' }) (Body Kernels -> LambdaT Kernels)
-> f (Body Kernels) -> f (LambdaT Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
(VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope') (LambdaT Kernels -> Body Kernels
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT Kernels
lam)
where scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> [Param Type] -> Scope Kernels
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams (LambdaT Kernels -> [LParam Kernels]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT Kernels
lam)
onBody :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Body BodyDec Kernels
bdec Stms Kernels
stms Result
bres) = do
Stms Kernels
stms' <- [Stm Kernels] -> Stms Kernels
forall lore. [Stm lore] -> Stms lore
stmsFromList ([Stm Kernels] -> Stms Kernels)
-> f [Stm Kernels] -> f (Stms Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm Kernels -> f (Stm Kernels))
-> [Stm Kernels] -> f [Stm Kernels]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance', Map VName SubExp
szsubst', Scope Kernels
scope')) (Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms Kernels
stms)
Body Kernels -> f (Body Kernels)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body Kernels -> f (Body Kernels))
-> Body Kernels -> f (Body Kernels)
forall a b. (a -> b) -> a -> b
$ BodyDec Kernels -> Stms Kernels -> Result -> Body Kernels
forall lore. BodyDec lore -> Stms lore -> Result -> BodyT lore
Body BodyDec Kernels
bdec Stms Kernels
stms' Result
bres
where variance' :: VarianceTable
variance' = VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
variance Stms Kernels
stms
szsubst' :: Map VName SubExp
szsubst' = Stms Kernels -> Map VName SubExp
mkSizeSubsts Stms Kernels
stms Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
szsubst
scope' :: Scope Kernels
scope' = Scope Kernels
scope Scope Kernels -> Scope Kernels -> Scope Kernels
forall a. Semigroup a => a -> a -> a
<> Stms Kernels -> Scope Kernels
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms Kernels
stms
onStm :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Stm Kernels -> f (Stm Kernels)
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
_) (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (BasicOp (Index VName
arr Slice SubExp
is))) =
Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (ExpT Kernels -> Stm Kernels)
-> (Maybe (VName, Slice SubExp) -> ExpT Kernels)
-> Maybe (VName, Slice SubExp)
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew (Maybe (VName, Slice SubExp) -> Stm Kernels)
-> f (Maybe (VName, Slice SubExp)) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ArrayIndexTransform f
f Names
free_ker_vars VName -> Bool
isThreadLocal VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope Kernels
outer_scope VName
arr Slice SubExp
is
where oldOrNew :: Maybe (VName, Slice SubExp) -> ExpT Kernels
oldOrNew Maybe (VName, Slice SubExp)
Nothing =
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr Slice SubExp
is
oldOrNew (Just (VName
arr', Slice SubExp
is')) =
BasicOp -> ExpT Kernels
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT Kernels) -> BasicOp -> ExpT Kernels
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr' Slice SubExp
is'
isGidVariant :: VName -> SubExp -> Bool
isGidVariant VName
gid (Var VName
v) =
VName
gid VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v Bool -> Bool -> Bool
|| VName -> Names -> Bool
nameIn VName
gid (Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance)
isGidVariant VName
_ SubExp
_ = Bool
False
isThreadLocal :: VName -> Bool
isThreadLocal VName
v =
Names
thread_variant Names -> Names -> Bool
`namesIntersect`
Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault (VName -> Names
oneName VName
v) VName
v VarianceTable
variance
sizeSubst :: SubExp -> Maybe SubExp
sizeSubst (Constant PrimValue
v) = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
sizeSubst (Var VName
v)
| VName
v VName -> Scope Kernels -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Scope Kernels
outer_scope = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (SubExp -> Maybe SubExp) -> SubExp -> Maybe SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
| Just SubExp
v' <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
szsubst = SubExp -> Maybe SubExp
sizeSubst SubExp
v'
| Bool
otherwise = Maybe SubExp
forall a. Maybe a
Nothing
onStm (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope) (Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec ExpT Kernels
e) =
Pattern Kernels
-> StmAux (ExpDec Kernels) -> ExpT Kernels -> Stm Kernels
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern Kernels
pat StmAux (ExpDec Kernels)
dec (ExpT Kernels -> Stm Kernels)
-> f (ExpT Kernels) -> f (Stm Kernels)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper Kernels Kernels f -> ExpT Kernels -> f (ExpT Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable
variance, Map VName SubExp
szsubst, Scope Kernels
scope)) ExpT Kernels
e
onOp :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx (OtherOp SOAC Kernels
soac) =
SOAC Kernels -> HostOp Kernels (SOAC Kernels)
forall lore op. op -> HostOp lore op
OtherOp (SOAC Kernels -> HostOp Kernels (SOAC Kernels))
-> f (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SOACMapper Kernels Kernels f -> SOAC Kernels -> f (SOAC Kernels)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any f
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper{ mapOnSOACLambda :: LambdaT Kernels -> f (LambdaT Kernels)
mapOnSOACLambda = (VarianceTable, Map VName SubExp, Scope Kernels)
-> LambdaT Kernels -> f (LambdaT Kernels)
onLambda (VarianceTable, Map VName SubExp, Scope Kernels)
ctx } SOAC Kernels
soac
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
_ HostOp Kernels (SOAC Kernels)
op = HostOp Kernels (SOAC Kernels) -> f (HostOp Kernels (SOAC Kernels))
forall (m :: * -> *) a. Monad m => a -> m a
return HostOp Kernels (SOAC Kernels)
op
mapper :: (VarianceTable, Map VName SubExp, Scope Kernels)
-> Mapper Kernels Kernels f
mapper (VarianceTable, Map VName SubExp, Scope Kernels)
ctx = Mapper Kernels Kernels f
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper { mapOnBody :: Scope Kernels -> Body Kernels -> f (Body Kernels)
mapOnBody = (Body Kernels -> f (Body Kernels))
-> Scope Kernels -> Body Kernels -> f (Body Kernels)
forall a b. a -> b -> a
const ((VarianceTable, Map VName SubExp, Scope Kernels)
-> Body Kernels -> f (Body Kernels)
onBody (VarianceTable, Map VName SubExp, Scope Kernels)
ctx)
, mapOnOp :: Op Kernels -> f (Op Kernels)
mapOnOp = (VarianceTable, Map VName SubExp, Scope Kernels)
-> HostOp Kernels (SOAC Kernels)
-> f (HostOp Kernels (SOAC Kernels))
onOp (VarianceTable, Map VName SubExp, Scope Kernels)
ctx }
mkSizeSubsts :: Stms Kernels -> Map VName SubExp
mkSizeSubsts = (Stm Kernels -> Map VName SubExp)
-> Stms Kernels -> Map VName SubExp
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Stm Kernels -> Map VName SubExp
forall lore lore op.
(Op lore ~ HostOp lore op) =>
Stm lore -> Map VName SubExp
mkStmSizeSubst
where mkStmSizeSubst :: Stm lore -> Map VName SubExp
mkStmSizeSubst (Let (Pattern [] [PatElemT (LetDec lore)
pe]) StmAux (ExpDec lore)
_ (Op (SizeOp (SplitSpace _ _ _ elems_per_i)))) =
VName -> SubExp -> Map VName SubExp
forall k a. k -> a -> Map k a
M.singleton (PatElemT (LetDec lore) -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT (LetDec lore)
pe) SubExp
elems_per_i
mkStmSizeSubst Stm lore
_ = Map VName SubExp
forall a. Monoid a => a
mempty
type Replacements = M.Map (VName, Slice SubExp) VName
ensureCoalescedAccess :: MonadBinder m =>
ExpMap
-> [(VName,SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess :: ExpMap
-> [(VName, SubExp)]
-> SubExp
-> ArrayIndexTransform (StateT Replacements m)
ensureCoalescedAccess ExpMap
expmap [(VName, SubExp)]
thread_space SubExp
num_threads Names
free_ker_vars VName -> Bool
isThreadLocal
VName -> SubExp -> Bool
isGidVariant SubExp -> Maybe SubExp
sizeSubst Scope Kernels
outer_scope VName
arr Slice SubExp
slice = do
Maybe VName
seen <- (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName))
-> (Replacements -> Maybe VName)
-> StateT Replacements m (Maybe VName)
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Replacements -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (VName
arr, Slice SubExp
slice)
case (Maybe VName
seen, VName -> Bool
isThreadLocal VName
arr, NameInfo Kernels -> Type
forall t. Typed t => t -> Type
typeOf (NameInfo Kernels -> Type)
-> Maybe (NameInfo Kernels) -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> Scope Kernels -> Maybe (NameInfo Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr Scope Kernels
outer_scope) of
(Just VName
arr', Bool
_, Maybe Type
_) ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
(Maybe VName
Nothing, Bool
False, Just Type
t)
| Just Result
is <- Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t,
Just Result
is' <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is,
Just [Int]
perm <- Result
is' Result -> Result -> Maybe [Int]
forall a. Eq a => [a] -> [a] -> Maybe [Int]
`isPermutationOf` Result
is ->
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)
| Just (Let Pattern Kernels
_ StmAux (ExpDec Kernels)
_ (BasicOp (Rearrange [Int]
perm VName
_))) <- VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
perm,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
VName
inner_gid <- [VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids,
Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
slice Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm,
Slice SubExp
slice' <- (Int -> DimIndex SubExp) -> [Int] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map (Slice SubExp
slice Slice SubExp -> Int -> DimIndex SubExp
forall a. [a] -> Int -> a
!!) [Int]
perm,
DimFix SubExp
inner_ind <- Slice SubExp -> DimIndex SubExp
forall a. [a] -> a
last Slice SubExp
slice',
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids,
VName -> SubExp -> Bool
isGidVariant VName
inner_gid SubExp
inner_ind ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
| (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
Slice SubExp -> Bool
allDimAreSlice Slice SubExp
rem_slice,
Maybe (Stm Kernels)
Nothing <- VName -> ExpMap -> Maybe (Stm Kernels)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
arr ExpMap
expmap,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
Bool -> Bool
not ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
thread_gids Bool -> Bool -> Bool
|| Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
Bool -> Bool
not ([VName] -> VName
forall a. [a] -> a
last [VName]
thread_gids VName -> Names -> Bool
`nameIn` (Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Slice SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
rem_slice)) ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
| (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Slice SubExp
rem_slice,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int32 -> Slice SubExp -> Bool
tooSmallSlice (PrimType -> Int32
forall a. Num a => PrimType -> a
primByteSize (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t)) Slice SubExp
rem_slice,
Result
is Result -> Result -> Bool
forall a. Eq a => a -> a -> Bool
/= (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) [VName]
thread_gids) Bool -> Bool -> Bool
|| Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
thread_gids,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isThreadLocal (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Result -> Names
forall a. FreeIn a => a -> Names
freeIn Result
is) -> do
let perm :: [Int]
perm = Int -> Int -> [Int]
coalescingPermutation (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> [Int]) -> Int -> [Int]
forall a b. (a -> b) -> a -> b
$ Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap) [Int]
perm VName
arr)
| (Result
is, Slice SubExp
rem_slice) <- Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
slice,
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> SubExp -> Bool) -> Result -> Result -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(==) Result
is (Result -> [Bool]) -> Result -> [Bool]
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids,
DimSlice SubExp
offset SubExp
len (Constant PrimValue
stride):Slice SubExp
_ <- Slice SubExp
rem_slice,
SubExp -> Bool
isThreadLocalSubExp SubExp
offset,
Just {} <- SubExp -> Maybe SubExp
sizeSubst SubExp
len,
PrimValue -> Bool
oneIsh PrimValue
stride -> do
let num_chunks :: PrimExp VName
num_chunks = if Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is
then PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32 SubExp
num_threads
else IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
coerceIntPrimExp IntType
Int32 (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$
[PrimExp VName] -> PrimExp VName
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (SubExp -> PrimExp VName) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
int32) (Result -> [PrimExp VName]) -> Result -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$
Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Result
thread_gdims
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Int -> SubExp -> PrimExp VName -> VName -> m VName
forall (m :: * -> *).
MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize (Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is) Type
t) PrimExp VName
num_chunks VName
arr)
| Just{} <- VName -> ExpMap -> Maybe (Maybe [Int])
nonlinearInMemory VName
arr ExpMap
expmap ->
case Slice SubExp -> Maybe Result
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice of
Just Result
is | Just Result
_ <- Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant ((VName -> SubExp) -> [VName] -> Result
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
thread_gids) Result
is ->
VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
| Bool
otherwise ->
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
Maybe Result
_ -> VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace (VName -> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> StateT Replacements m VName
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m VName -> StateT Replacements m VName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr)
(Maybe VName, Bool, Maybe Type)
_ -> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (VName, Slice SubExp)
forall a. Maybe a
Nothing
where ([VName]
thread_gids, Result
thread_gdims) = [(VName, SubExp)] -> ([VName], Result)
forall a b. [(a, b)] -> ([a], [b])
unzip [(VName, SubExp)]
thread_space
replace :: VName -> StateT Replacements m (Maybe (VName, Slice SubExp))
replace VName
arr' = do
(Replacements -> Replacements) -> StateT Replacements m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Replacements -> Replacements) -> StateT Replacements m ())
-> (Replacements -> Replacements) -> StateT Replacements m ()
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> VName -> Replacements -> Replacements
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert (VName
arr, Slice SubExp
slice) VName
arr'
Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp)))
-> Maybe (VName, Slice SubExp)
-> StateT Replacements m (Maybe (VName, Slice SubExp))
forall a b. (a -> b) -> a -> b
$ (VName, Slice SubExp) -> Maybe (VName, Slice SubExp)
forall a. a -> Maybe a
Just (VName
arr', Slice SubExp
slice)
isThreadLocalSubExp :: SubExp -> Bool
isThreadLocalSubExp (Var VName
v) = VName -> Bool
isThreadLocal VName
v
isThreadLocalSubExp Constant{} = Bool
False
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice :: Int32 -> Slice SubExp -> Bool
tooSmallSlice Int32
bs = (Bool, Int32) -> Bool
forall a b. (a, b) -> a
fst ((Bool, Int32) -> Bool)
-> (Slice SubExp -> (Bool, Int32)) -> Slice SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Bool, Int32) -> SubExp -> (Bool, Int32))
-> (Bool, Int32) -> Result -> (Bool, Int32)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True,Int32
bs) (Result -> (Bool, Int32))
-> (Slice SubExp -> Result) -> Slice SubExp -> (Bool, Int32)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice SubExp -> Result
forall d. Slice d -> [d]
sliceDims
where comb :: (Bool, Int32) -> SubExp -> (Bool, Int32)
comb (Bool
True, Int32
x) (Constant (IntValue (Int32Value Int32
d))) = (Int32
dInt32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
*Int32
x Int32 -> Int32 -> Bool
forall a. Ord a => a -> a -> Bool
< Int32
4, Int32
dInt32 -> Int32 -> Int32
forall a. Num a => a -> a -> a
*Int32
x)
comb (Bool
_, Int32
x) SubExp
_ = (Bool
False, Int32
x)
splitSlice :: Slice SubExp -> ([SubExp], Slice SubExp)
splitSlice :: Slice SubExp -> (Result, Slice SubExp)
splitSlice [] = ([], [])
splitSlice (DimFix SubExp
i:Slice SubExp
is) = (Result -> Result)
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (SubExp
iSubExp -> Result -> Result
forall a. a -> [a] -> [a]
:) ((Result, Slice SubExp) -> (Result, Slice SubExp))
-> (Result, Slice SubExp) -> (Result, Slice SubExp)
forall a b. (a -> b) -> a -> b
$ Slice SubExp -> (Result, Slice SubExp)
splitSlice Slice SubExp
is
splitSlice Slice SubExp
is = ([], Slice SubExp
is)
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice :: Slice SubExp -> Bool
allDimAreSlice [] = Bool
True
allDimAreSlice (DimFix SubExp
_:Slice SubExp
_) = Bool
False
allDimAreSlice (DimIndex SubExp
_:Slice SubExp
is) = Slice SubExp -> Bool
allDimAreSlice Slice SubExp
is
coalescedIndexes :: Names -> (VName -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> Maybe [SubExp]
coalescedIndexes :: Names
-> (VName -> SubExp -> Bool) -> Result -> Result -> Maybe Result
coalescedIndexes Names
free_ker_vars VName -> SubExp -> Bool
isGidVariant Result
tgids Result
is
| (SubExp -> Bool) -> Result -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
isCt Result
is =
Maybe Result
forall a. Maybe a
Nothing
| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
free_ker_vars) (Result -> [VName]
subExpVars Result
is) =
Maybe Result
forall a. Maybe a
Nothing
| Result
is Result -> Result -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` Result
tgids =
Maybe Result
forall a. Maybe a
Nothing
| Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
tgids),
Bool -> Bool
not (Result -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Result
is),
Var VName
innergid <- Result -> SubExp
forall a. [a] -> a
last Result
tgids,
Int
num_is Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& VName -> SubExp -> Bool
isGidVariant VName
innergid (Result -> SubExp
forall a. [a] -> a
last Result
is) =
Result -> Maybe Result
forall a. a -> Maybe a
Just Result
is
| Bool
otherwise =
Result -> Maybe Result
forall a. a -> Maybe a
Just (Result -> Maybe Result) -> Result -> Maybe Result
forall a b. (a -> b) -> a -> b
$ Result -> Result
forall a. [a] -> [a]
reverse (Result -> Result) -> Result -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> (Int, SubExp) -> Result)
-> Result -> [(Int, SubExp)] -> Result
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Result -> (Int, SubExp) -> Result
move (Result -> Result
forall a. [a] -> [a]
reverse Result
is) ([(Int, SubExp)] -> Result) -> [(Int, SubExp)] -> Result
forall a b. (a -> b) -> a -> b
$ [Int] -> Result -> [(Int, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0..] (Result -> Result
forall a. [a] -> [a]
reverse Result
tgids)
where num_is :: Int
num_is = Result -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Result
is
move :: Result -> (Int, SubExp) -> Result
move Result
is_rev (Int
i, SubExp
tgid)
| Just Int
j <- SubExp -> Result -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex SubExp
tgid Result
is_rev, Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
j, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
num_is =
Int -> Int -> Result -> Result
forall a b t.
(Integral a, Integral b, Show a, Show b, Show t) =>
a -> b -> [t] -> [t]
swap Int
i Int
j Result
is_rev
| Bool
otherwise =
Result
is_rev
swap :: a -> b -> [t] -> [t]
swap a
i b
j [t]
l
| Just t
ix <- a -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth a
i [t]
l,
Just t
jx <- b -> [t] -> Maybe t
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth b
j [t]
l =
a -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update a
i t
jx ([t] -> [t]) -> [t] -> [t]
forall a b. (a -> b) -> a -> b
$ b -> t -> [t] -> [t]
forall t t. (Eq t, Num t) => t -> t -> [t] -> [t]
update b
j t
ix [t]
l
| Bool
otherwise =
String -> [t]
forall a. HasCallStack => String -> a
error (String -> [t]) -> String -> [t]
forall a b. (a -> b) -> a -> b
$ String
"coalescedIndexes swap: invalid indices" String -> String -> String
forall a. [a] -> [a] -> [a]
++ (a, b, [t]) -> String
forall a. Show a => a -> String
show (a
i, b
j, [t]
l)
update :: t -> t -> [t] -> [t]
update t
0 t
x (t
_:[t]
ys) = t
x t -> [t] -> [t]
forall a. a -> [a] -> [a]
: [t]
ys
update t
i t
x (t
y:[t]
ys) = t
y t -> [t] -> [t]
forall a. a -> [a] -> [a]
: t -> t -> [t] -> [t]
update (t
it -> t -> t
forall a. Num a => a -> a -> a
-t
1) t
x [t]
ys
update t
_ t
_ [] = String -> [t]
forall a. HasCallStack => String -> a
error String
"coalescedIndexes: update"
isCt :: SubExp -> Bool
isCt :: SubExp -> Bool
isCt (Constant PrimValue
_) = Bool
True
isCt (Var VName
_) = Bool
False
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation :: Int -> Int -> [Int]
coalescingPermutation Int
num_is Int
rank =
[Int
num_is..Int
rankInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0..Int
num_isInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
rearrangeInput :: MonadBinder m =>
Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput :: Maybe (Maybe [Int]) -> [Int] -> VName -> m VName
rearrangeInput (Just (Just [Int]
current_perm)) [Int]
perm VName
arr
| [Int]
current_perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
rearrangeInput Maybe (Maybe [Int])
Nothing [Int]
perm VName
arr
| [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
rearrangeInput (Just Just{}) [Int]
perm VName
arr
| [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int]
perm = VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr
rearrangeInput Maybe (Maybe [Int])
manifest [Int]
perm VName
arr = do
VName
manifested <- if Maybe (Maybe [Int]) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (Maybe [Int])
manifest then VName -> m VName
forall (m :: * -> *). MonadBinder m => VName -> m VName
rowMajorArray VName
arr else VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
arr
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_coalesced") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
manifested
rowMajorArray :: MonadBinder m =>
VName -> m VName
rowMajorArray :: VName -> m VName
rowMajorArray VName
arr = do
Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> m Type -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_rowmajor") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int
0..Int
rankInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] VName
arr
rearrangeSlice :: MonadBinder m =>
Int -> SubExp -> PrimExp VName -> VName
-> m VName
rearrangeSlice :: Int -> SubExp -> PrimExp VName -> VName -> m VName
rearrangeSlice Int
d SubExp
w PrimExp VName
num_chunks VName
arr = do
SubExp
num_chunks' <- String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"num_chunks" PrimExp VName
num_chunks
(SubExp
w_padded, SubExp
padding) <- SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
num_chunks'
SubExp
per_chunk <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"per_chunk" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SQuot IntType
Int32 Safety
Unsafe) SubExp
w_padded SubExp
num_chunks'
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
VName
arr_padded <- SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t
SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk (VName -> String
baseString VName
arr) VName
arr_padded Type
arr_t
where padArray :: SubExp -> SubExp -> Type -> m VName
padArray SubExp
w_padded SubExp
padding Type
arr_t = do
let arr_shape :: Shape
arr_shape = Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
arr_t
padding_shape :: Shape
padding_shape = Int -> Shape -> SubExp -> Shape
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
d Shape
arr_shape SubExp
padding
VName
arr_padding <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padding") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ PrimType -> Result -> BasicOp
Scratch (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_t) (Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
padding_shape)
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
arr String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_padded") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Int -> VName -> [VName] -> SubExp -> BasicOp
Concat Int
d VName
arr [VName
arr_padding] SubExp
w_padded
rearrange :: SubExp -> SubExp -> SubExp -> String -> VName -> Type -> m VName
rearrange SubExp
num_chunks' SubExp
w_padded SubExp
per_chunk String
arr_name VName
arr_padded Type
arr_t = do
let arr_dims :: Result
arr_dims = Type -> Result
forall u. TypeBase Shape u -> Result
arrayDims Type
arr_t
pre_dims :: Result
pre_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
d Result
arr_dims
post_dims :: Result
post_dims = Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Result
arr_dims
extradim_shape :: Shape
extradim_shape = Result -> Shape
forall d. [d] -> ShapeBase d
Shape (Result -> Shape) -> Result -> Shape
forall a b. (a -> b) -> a -> b
$ Result
pre_dims Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ [SubExp
num_chunks', SubExp
per_chunk] Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
post_dims
tr_perm :: [Int]
tr_perm = [Int
0..Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
d) ([Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
2..Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
extradim_shapeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
d] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0])
VName
arr_extradim <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (Result -> ShapeChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Shape -> Result
forall d. ShapeBase d -> [d]
shapeDims Shape
extradim_shape) VName
arr_padded
VName
arr_extradim_tr <-
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_extradim_tr") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Manifest [Int]
tr_perm VName
arr_extradim
VName
arr_inv_tr <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr") (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> VName -> BasicOp
Reshape ((SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimCoercion Result
pre_dims ShapeChange SubExp -> ShapeChange SubExp -> ShapeChange SubExp
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimChange SubExp) -> Result -> ShapeChange SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimChange SubExp
forall d. d -> DimChange d
DimNew (SubExp
w_padded SubExp -> Result -> Result
forall a. a -> [a] -> [a]
: Result
post_dims))
VName
arr_extradim_tr
String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (String
arr_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_inv_tr_init") (Exp (Lore m) -> m VName) -> m (Exp (Lore m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Int
-> VName
-> m (Exp (Lore m))
-> m (Exp (Lore m))
-> m (Exp (Lore m))
eSliceArray Int
d VName
arr_inv_tr (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp (SubExp -> m (Exp (Lore m))) -> SubExp -> m (Exp (Lore m))
forall a b. (a -> b) -> a -> b
$ Int32 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int32
0::Int32)) (SubExp -> m (Exp (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w)
paddedScanReduceInput :: MonadBinder m =>
SubExp -> SubExp
-> m (SubExp, SubExp)
paddedScanReduceInput :: SubExp -> SubExp -> m (SubExp, SubExp)
paddedScanReduceInput SubExp
w SubExp
stride = do
SubExp
w_padded <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padded_size" (ExpT (Lore m) -> m SubExp) -> m (ExpT (Lore m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<
IntType
-> m (ExpT (Lore m)) -> m (ExpT (Lore m)) -> m (ExpT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
IntType -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eRoundToMultipleOf IntType
Int32 (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
w) (SubExp -> m (ExpT (Lore m))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
stride)
SubExp
padding <- String -> ExpT (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"padding" (ExpT (Lore m) -> m SubExp) -> ExpT (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int32 Overflow
OverflowUndef) SubExp
w_padded SubExp
w
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
w_padded, SubExp
padding)
type VarianceTable = M.Map VName Names
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms :: VarianceTable -> Stms Kernels -> VarianceTable
varianceInStms VarianceTable
t = (VarianceTable -> Stm Kernels -> VarianceTable)
-> VarianceTable -> [Stm Kernels] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
t ([Stm Kernels] -> VarianceTable)
-> (Stms Kernels -> [Stm Kernels]) -> Stms Kernels -> VarianceTable
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms Kernels -> [Stm Kernels]
forall lore. Stms lore -> [Stm lore]
stmsToList
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm :: VarianceTable -> Stm Kernels -> VarianceTable
varianceInStm VarianceTable
variance Stm Kernels
bnd =
(VarianceTable -> VName -> VarianceTable)
-> VarianceTable -> [VName] -> VarianceTable
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' VarianceTable -> VName -> VarianceTable
add VarianceTable
variance ([VName] -> VarianceTable) -> [VName] -> VarianceTable
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName]) -> PatternT Type -> [VName]
forall a b. (a -> b) -> a -> b
$ Stm Kernels -> Pattern Kernels
forall lore. Stm lore -> Pattern lore
stmPattern Stm Kernels
bnd
where add :: VarianceTable -> VName -> VarianceTable
add VarianceTable
variance' VName
v = VName -> Names -> VarianceTable -> VarianceTable
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Names
binding_variance VarianceTable
variance'
look :: VarianceTable -> VName -> Names
look VarianceTable
variance' VName
v = VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names -> VName -> VarianceTable -> Names
forall k a. Ord k => a -> k -> Map k a -> a
M.findWithDefault Names
forall a. Monoid a => a
mempty VName
v VarianceTable
variance'
binding_variance :: Names
binding_variance = [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map (VarianceTable -> VName -> Names
look VarianceTable
variance) ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Stm Kernels -> Names
forall a. FreeIn a => a -> Names
freeIn Stm Kernels
bnd)