{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Pass.ExtractKernels.BlockedKernel
( DistLore,
MkSegLevel,
ThreadRecommendation (..),
segRed,
nonSegRed,
segScan,
segHist,
segMap,
mapKernel,
KernelInput (..),
readKernelInput,
mkSegSpace,
dummyDim,
)
where
import Control.Monad
import Control.Monad.Writer
import Data.List ()
import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.IR.SegOp
import Futhark.MonadFreshNames
import Futhark.Tools
import Futhark.Transform.Rename
import Prelude hiding (quot)
type DistLore lore =
( Bindable lore,
HasSegOp lore,
BinderOps lore,
LetDec lore ~ Type,
ExpDec lore ~ (),
BodyDec lore ~ (),
CanBeAliased (Op lore)
)
data ThreadRecommendation = ManyThreads | NoRecommendation SegVirt
type MkSegLevel lore m =
[SubExp] -> String -> ThreadRecommendation -> BinderT lore m (SegOpLevel lore)
mkSegSpace :: MonadFreshNames m => [(VName, SubExp)] -> m SegSpace
mkSegSpace :: forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
dims = VName -> [(VName, SubExp)] -> SegSpace
SegSpace (VName -> [(VName, SubExp)] -> SegSpace)
-> m VName -> m ([(VName, SubExp)] -> SegSpace)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"phys_tid" m ([(VName, SubExp)] -> SegSpace)
-> m [(VName, SubExp)] -> m SegSpace
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(VName, SubExp)] -> m [(VName, SubExp)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, SubExp)]
dims
prepareRedOrScan ::
(MonadBinder m, DistLore (Lore m)) =>
SubExp ->
Lambda (Lore m) ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (SegSpace, KernelBody (Lore m))
prepareRedOrScan :: forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
SubExp
-> Lambda (Lore m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Lore m))
prepareRedOrScan SubExp
w Lambda (Lore m)
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = do
VName
gtid <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> m SegSpace)
-> [(VName, SubExp)] -> m SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
w)]
KernelBody (Lore m)
kbody <- (([KernelResult], Stms (Lore m)) -> KernelBody (Lore m))
-> m ([KernelResult], Stms (Lore m)) -> m (KernelBody (Lore m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms (Lore m) -> KernelBody (Lore m))
-> ([KernelResult], Stms (Lore m)) -> KernelBody (Lore m)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms (Lore m) -> [KernelResult] -> KernelBody (Lore m))
-> [KernelResult] -> Stms (Lore m) -> KernelBody (Lore m)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (BodyDec (Lore m)
-> Stms (Lore m) -> [KernelResult] -> KernelBody (Lore m)
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ()))) (m ([KernelResult], Stms (Lore m)) -> m (KernelBody (Lore m)))
-> m ([KernelResult], Stms (Lore m)) -> m (KernelBody (Lore m))
forall a b. (a -> b) -> a -> b
$
Binder (Lore m) [KernelResult] -> m ([KernelResult], Stms (Lore m))
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder (Lore m) [KernelResult]
-> m ([KernelResult], Stms (Lore m)))
-> Binder (Lore m) [KernelResult]
-> m ([KernelResult], Stms (Lore m))
forall a b. (a -> b) -> a -> b
$
Scope (Lore m)
-> Binder (Lore m) [KernelResult] -> Binder (Lore m) [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope (Lore m)
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder (Lore m) [KernelResult] -> Binder (Lore m) [KernelResult])
-> Binder (Lore m) [KernelResult] -> Binder (Lore m) [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
(KernelInput -> BinderT (Lore m) (State VNameSource) ())
-> [KernelInput] -> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
[(Param Type, VName)]
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda (Lore m) -> [LParam (Lore m)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Lore m)
map_lam) [VName]
arrs) (((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ())
-> ((Param Type, VName) -> BinderT (Lore m) (State VNameSource) ())
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
Type
arr_t <- VName -> BinderT (Lore m) (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[VName]
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ())
-> Exp (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT (Lore m)) -> BasicOp -> ExpT (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
(SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult])
-> BinderT (Lore m) (State VNameSource) [SubExp]
-> Binder (Lore m) [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT (Lore m) (State VNameSource)))
-> BinderT (Lore m) (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Lambda (Lore m) -> BodyT (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
map_lam)
(SegSpace, KernelBody (Lore m))
-> m (SegSpace, KernelBody (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegSpace
space, KernelBody (Lore m)
kbody)
segRed ::
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore ->
Pattern lore ->
SubExp ->
[SegBinOp lore] ->
Lambda lore ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms lore)
segRed :: forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel lore
lvl Pattern lore
pat SubExp
w [SegBinOp lore]
ops Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody lore
kbody) <- SubExp
-> Lambda (Lore (BinderT lore (State VNameSource)))
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
lore
(State VNameSource)
(SegSpace, KernelBody (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
SubExp
-> Lambda (Lore m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Lore m))
prepareRedOrScan SubExp
w Lambda lore
Lambda (Lore (BinderT lore (State VNameSource)))
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$
Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$
SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegRed SegOpLevel lore
lvl SegSpace
kspace [SegBinOp lore]
ops (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
map_lam) KernelBody lore
kbody
segScan ::
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore ->
Pattern lore ->
SubExp ->
[SegBinOp lore] ->
Lambda lore ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms lore)
segScan :: forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segScan SegOpLevel lore
lvl Pattern lore
pat SubExp
w [SegBinOp lore]
ops Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody lore
kbody) <- SubExp
-> Lambda (Lore (BinderT lore (State VNameSource)))
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
lore
(State VNameSource)
(SegSpace, KernelBody (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
SubExp
-> Lambda (Lore m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Lore m))
prepareRedOrScan SubExp
w Lambda lore
Lambda (Lore (BinderT lore (State VNameSource)))
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$
Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$
SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
SegOpLevel lore
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [SegBinOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegScan SegOpLevel lore
lvl SegSpace
kspace [SegBinOp lore]
ops (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
map_lam) KernelBody lore
kbody
segMap ::
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore ->
Pattern lore ->
SubExp ->
Lambda lore ->
[VName] ->
[(VName, SubExp)] ->
[KernelInput] ->
m (Stms lore)
segMap :: forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segMap SegOpLevel lore
lvl Pattern lore
pat SubExp
w Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps = Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
(SegSpace
kspace, KernelBody lore
kbody) <- SubExp
-> Lambda (Lore (BinderT lore (State VNameSource)))
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT
lore
(State VNameSource)
(SegSpace, KernelBody (Lore (BinderT lore (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, DistLore (Lore m)) =>
SubExp
-> Lambda (Lore m)
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (SegSpace, KernelBody (Lore m))
prepareRedOrScan SubExp
w Lambda lore
Lambda (Lore (BinderT lore (State VNameSource)))
map_lam [VName]
arrs [(VName, SubExp)]
ispace [KernelInput]
inps
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$
Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$
SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$
SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
kspace (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
map_lam) KernelBody lore
kbody
dummyDim ::
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) ->
m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim :: forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) -> m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim Pattern (Lore m)
pat = do
let addDummyDim :: Type -> Type
addDummyDim Type
t = Type
t Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1
PatternT Type
pat' <- (Type -> Type) -> PatternT Type -> PatternT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
addDummyDim (PatternT Type -> PatternT Type)
-> m (PatternT Type) -> m (PatternT Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PatternT Type -> m (PatternT Type)
forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
PatternT dec -> m (PatternT dec)
renamePattern PatternT Type
Pattern (Lore m)
pat
VName
dummy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"dummy"
let ispace :: [(VName, SubExp)]
ispace = [(VName
dummy, IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)]
(PatternT Type, [(VName, SubExp)], m ())
-> m (PatternT Type, [(VName, SubExp)], m ())
forall (m :: * -> *) a. Monad m => a -> m a
return
( PatternT Type
pat',
[(VName, SubExp)]
ispace,
[(VName, VName)] -> ((VName, VName) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
pat') (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern (Lore m)
pat)) (((VName, VName) -> m ()) -> m ())
-> ((VName, VName) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) -> do
Type
from_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
from
[VName] -> Exp (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
to] (Exp (Lore m) -> m ()) -> Exp (Lore m) -> m ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
from (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice Type
from_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0]
)
nonSegRed ::
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore ->
Pattern lore ->
SubExp ->
[SegBinOp lore] ->
Lambda lore ->
[VName] ->
m (Stms lore)
nonSegRed :: forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
nonSegRed SegOpLevel lore
lvl Pattern lore
pat SubExp
w [SegBinOp lore]
ops Lambda lore
map_lam [VName]
arrs = Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
(PatternT Type
pat', [(VName, SubExp)]
ispace, Binder lore ()
read_dummy) <- Pattern (Lore (BinderT lore (State VNameSource)))
-> BinderT
lore
(State VNameSource)
(Pattern (Lore (BinderT lore (State VNameSource))),
[(VName, SubExp)], Binder lore ())
forall (m :: * -> *).
(MonadFreshNames m, MonadBinder m, DistLore (Lore m)) =>
Pattern (Lore m) -> m (Pattern (Lore m), [(VName, SubExp)], m ())
dummyDim Pattern lore
Pattern (Lore (BinderT lore (State VNameSource)))
pat
Stms lore -> Binder lore ()
forall (m :: * -> *). MonadBinder m => Stms (Lore m) -> m ()
addStms (Stms lore -> Binder lore ())
-> BinderT lore (State VNameSource) (Stms lore) -> Binder lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> BinderT lore (State VNameSource) (Stms lore)
forall (m :: * -> *) lore.
(MonadFreshNames m, DistLore lore, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [SegBinOp lore]
-> Lambda lore
-> [VName]
-> [(VName, SubExp)]
-> [KernelInput]
-> m (Stms lore)
segRed SegOpLevel lore
lvl PatternT Type
Pattern lore
pat' SubExp
w [SegBinOp lore]
ops Lambda lore
map_lam [VName]
arrs [(VName, SubExp)]
ispace []
Binder lore ()
read_dummy
segHist ::
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore ->
Pattern lore ->
SubExp ->
[(VName, SubExp)] ->
[KernelInput] ->
[HistOp lore] ->
Lambda lore ->
[VName] ->
m (Stms lore)
segHist :: forall lore (m :: * -> *).
(DistLore lore, MonadFreshNames m, HasScope lore m) =>
SegOpLevel lore
-> Pattern lore
-> SubExp
-> [(VName, SubExp)]
-> [KernelInput]
-> [HistOp lore]
-> Lambda lore
-> [VName]
-> m (Stms lore)
segHist SegOpLevel lore
lvl Pattern lore
pat SubExp
arr_w [(VName, SubExp)]
ispace [KernelInput]
inps [HistOp lore]
ops Lambda lore
lam [VName]
arrs = Binder lore () -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore () -> m (Stms lore))
-> Binder lore () -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ do
VName
gtid <- String -> BinderT lore (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"gtid"
SegSpace
space <- [(VName, SubExp)] -> BinderT lore (State VNameSource) SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace ([(VName, SubExp)] -> BinderT lore (State VNameSource) SegSpace)
-> [(VName, SubExp)] -> BinderT lore (State VNameSource) SegSpace
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)]
ispace [(VName, SubExp)] -> [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(VName
gtid, SubExp
arr_w)]
KernelBody lore
kbody <- (([KernelResult], Stms lore) -> KernelBody lore)
-> BinderT lore (State VNameSource) ([KernelResult], Stms lore)
-> BinderT lore (State VNameSource) (KernelBody lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([KernelResult] -> Stms lore -> KernelBody lore)
-> ([KernelResult], Stms lore) -> KernelBody lore
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ((Stms lore -> [KernelResult] -> KernelBody lore)
-> [KernelResult] -> Stms lore -> KernelBody lore
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Stms lore -> [KernelResult] -> KernelBody lore)
-> [KernelResult] -> Stms lore -> KernelBody lore)
-> (Stms lore -> [KernelResult] -> KernelBody lore)
-> [KernelResult]
-> Stms lore
-> KernelBody lore
forall a b. (a -> b) -> a -> b
$ BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody ())) (BinderT lore (State VNameSource) ([KernelResult], Stms lore)
-> BinderT lore (State VNameSource) (KernelBody lore))
-> BinderT lore (State VNameSource) ([KernelResult], Stms lore)
-> BinderT lore (State VNameSource) (KernelBody lore)
forall a b. (a -> b) -> a -> b
$
Binder lore [KernelResult]
-> BinderT lore (State VNameSource) ([KernelResult], Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder lore [KernelResult]
-> BinderT lore (State VNameSource) ([KernelResult], Stms lore))
-> Binder lore [KernelResult]
-> BinderT lore (State VNameSource) ([KernelResult], Stms lore)
forall a b. (a -> b) -> a -> b
$
Scope lore
-> Binder lore [KernelResult] -> Binder lore [KernelResult]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope (SegSpace -> Scope lore
forall lore. SegSpace -> Scope lore
scopeOfSegSpace SegSpace
space) (Binder lore [KernelResult] -> Binder lore [KernelResult])
-> Binder lore [KernelResult] -> Binder lore [KernelResult]
forall a b. (a -> b) -> a -> b
$ do
(KernelInput -> Binder lore ()) -> [KernelInput] -> Binder lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ KernelInput -> Binder lore ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inps
[(Param Type, VName)]
-> ((Param Type, VName) -> Binder lore ()) -> Binder lore ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [VName] -> [(Param Type, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam) [VName]
arrs) (((Param Type, VName) -> Binder lore ()) -> Binder lore ())
-> ((Param Type, VName) -> Binder lore ()) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) -> do
Type
arr_t <- VName -> BinderT lore (State VNameSource) Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
arr
[VName]
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
gtid]
(SubExp -> KernelResult) -> [SubExp] -> [KernelResult]
forall a b. (a -> b) -> [a] -> [b]
map (ResultManifest -> SubExp -> KernelResult
Returns ResultManifest
ResultMaySimplify) ([SubExp] -> [KernelResult])
-> BinderT lore (State VNameSource) [SubExp]
-> Binder lore [KernelResult]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Lore (BinderT lore (State VNameSource)))
-> BinderT lore (State VNameSource) [SubExp]
forall (m :: * -> *). MonadBinder m => Body (Lore m) -> m [SubExp]
bodyBind (Lambda lore -> BodyT lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
Pattern (Lore (BinderT lore (State VNameSource)))
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (BinderT lore (State VNameSource)))
pat (Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ())
-> Exp (Lore (BinderT lore (State VNameSource))) -> Binder lore ()
forall a b. (a -> b) -> a -> b
$ Op lore -> ExpT lore
forall lore. Op lore -> ExpT lore
Op (Op lore -> ExpT lore) -> Op lore -> ExpT lore
forall a b. (a -> b) -> a -> b
$ SegOp (SegOpLevel lore) lore -> Op lore
forall lore.
HasSegOp lore =>
SegOp (SegOpLevel lore) lore -> Op lore
segOp (SegOp (SegOpLevel lore) lore -> Op lore)
-> SegOp (SegOpLevel lore) lore -> Op lore
forall a b. (a -> b) -> a -> b
$ SegOpLevel lore
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl
-> SegSpace
-> [HistOp lore]
-> [Type]
-> KernelBody lore
-> SegOp lvl lore
SegHist SegOpLevel lore
lvl SegSpace
space [HistOp lore]
ops (Lambda lore -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType Lambda lore
lam) KernelBody lore
kbody
mapKernelSkeleton ::
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
[(VName, SubExp)] ->
[KernelInput] ->
m (SegSpace, Stms lore)
mapKernelSkeleton :: forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms lore)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs = do
Stms lore
read_input_bnds <- Binder lore [()] -> m (Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (Binder lore [()] -> m (Stms lore))
-> Binder lore [()] -> m (Stms lore)
forall a b. (a -> b) -> a -> b
$ (KernelInput -> BinderT lore (State VNameSource) ())
-> [KernelInput] -> Binder lore [()]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernelInput -> BinderT lore (State VNameSource) ()
forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput [KernelInput]
inputs
SegSpace
space <- [(VName, SubExp)] -> m SegSpace
forall (m :: * -> *).
MonadFreshNames m =>
[(VName, SubExp)] -> m SegSpace
mkSegSpace [(VName, SubExp)]
ispace
(SegSpace, Stms lore) -> m (SegSpace, Stms lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegSpace
space, Stms lore
read_input_bnds)
mapKernel ::
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m ->
[(VName, SubExp)] ->
[KernelInput] ->
[Type] ->
KernelBody lore ->
m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel :: forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
MkSegLevel lore m
-> [(VName, SubExp)]
-> [KernelInput]
-> [Type]
-> KernelBody lore
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
mapKernel MkSegLevel lore m
mk_lvl [(VName, SubExp)]
ispace [KernelInput]
inputs [Type]
rts (KernelBody () Stms lore
kstms [KernelResult]
krets) = BinderT lore m (SegOp (SegOpLevel lore) lore)
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
BinderT lore m a -> m (a, Stms lore)
runBinderT' (BinderT lore m (SegOp (SegOpLevel lore) lore)
-> m (SegOp (SegOpLevel lore) lore, Stms lore))
-> BinderT lore m (SegOp (SegOpLevel lore) lore)
-> m (SegOp (SegOpLevel lore) lore, Stms lore)
forall a b. (a -> b) -> a -> b
$ do
(SegSpace
space, Stms lore
read_input_stms) <- [(VName, SubExp)]
-> [KernelInput] -> BinderT lore m (SegSpace, Stms lore)
forall lore (m :: * -> *).
(DistLore lore, HasScope lore m, MonadFreshNames m) =>
[(VName, SubExp)] -> [KernelInput] -> m (SegSpace, Stms lore)
mapKernelSkeleton [(VName, SubExp)]
ispace [KernelInput]
inputs
let kbody' :: KernelBody lore
kbody' = BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
forall lore.
BodyDec lore -> Stms lore -> [KernelResult] -> KernelBody lore
KernelBody () (Stms lore
read_input_stms Stms lore -> Stms lore -> Stms lore
forall a. Semigroup a => a -> a -> a
<> Stms lore
kstms) [KernelResult]
krets
let r :: ThreadRecommendation
r = if (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType [Type]
rts then ThreadRecommendation
ManyThreads else SegVirt -> ThreadRecommendation
NoRecommendation SegVirt
SegVirt
SegOpLevel lore
lvl <- MkSegLevel lore m
mk_lvl (((VName, SubExp) -> SubExp) -> [(VName, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(VName, SubExp)]
ispace) String
"segmap" ThreadRecommendation
r
SegOp (SegOpLevel lore) lore
-> BinderT lore m (SegOp (SegOpLevel lore) lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp (SegOpLevel lore) lore
-> BinderT lore m (SegOp (SegOpLevel lore) lore))
-> SegOp (SegOpLevel lore) lore
-> BinderT lore m (SegOp (SegOpLevel lore) lore)
forall a b. (a -> b) -> a -> b
$ SegOpLevel lore
-> SegSpace
-> [Type]
-> KernelBody lore
-> SegOp (SegOpLevel lore) lore
forall lvl lore.
lvl -> SegSpace -> [Type] -> KernelBody lore -> SegOp lvl lore
SegMap SegOpLevel lore
lvl SegSpace
space [Type]
rts KernelBody lore
kbody'
data KernelInput = KernelInput
{ KernelInput -> VName
kernelInputName :: VName,
KernelInput -> Type
kernelInputType :: Type,
KernelInput -> VName
kernelInputArray :: VName,
KernelInput -> [SubExp]
kernelInputIndices :: [SubExp]
}
deriving (Int -> KernelInput -> ShowS
[KernelInput] -> ShowS
KernelInput -> String
(Int -> KernelInput -> ShowS)
-> (KernelInput -> String)
-> ([KernelInput] -> ShowS)
-> Show KernelInput
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernelInput] -> ShowS
$cshowList :: [KernelInput] -> ShowS
show :: KernelInput -> String
$cshow :: KernelInput -> String
showsPrec :: Int -> KernelInput -> ShowS
$cshowsPrec :: Int -> KernelInput -> ShowS
Show)
readKernelInput ::
(DistLore (Lore m), MonadBinder m) =>
KernelInput ->
m ()
readKernelInput :: forall (m :: * -> *).
(DistLore (Lore m), MonadBinder m) =>
KernelInput -> m ()
readKernelInput KernelInput
inp = do
let pe :: PatElemT Type
pe = VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem (KernelInput -> VName
kernelInputName KernelInput
inp) (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ KernelInput -> Type
kernelInputType KernelInput
inp
Type
arr_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType (VName -> m Type) -> VName -> m Type
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
Pattern (Lore m) -> ExpT (Lore m) -> m ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [PatElemT Type
pe]) (ExpT (Lore m) -> m ())
-> (BasicOp -> ExpT (Lore m)) -> BasicOp -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> ExpT (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> m ()) -> BasicOp -> m ()
forall a b. (a -> b) -> a -> b
$
case Type
arr_t of
Acc {} ->
SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ KernelInput -> VName
kernelInputArray KernelInput
inp
Type
_ ->
VName -> Slice SubExp -> BasicOp
Index (KernelInput -> VName
kernelInputArray KernelInput
inp) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice Type
arr_t (Slice SubExp -> Slice SubExp) -> Slice SubExp -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix ([SubExp] -> Slice SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ KernelInput -> [SubExp]
kernelInputIndices KernelInput
inp