module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
(
analyseFunDef,
analyseConsts,
hostOnlyFunDefs,
MigrationTable,
MigrationStatus (..),
shouldMoveStm,
shouldMove,
usedOnHost,
statusOf,
)
where
import Control.Monad
import Control.Monad.Trans.Class
import qualified Control.Monad.Trans.Reader as R
import Control.Monad.Trans.State.Strict ()
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first, second)
import Data.Foldable
import qualified Data.IntMap.Strict as IM
import qualified Data.IntSet as IS
import qualified Data.Map.Strict as M
import Data.Maybe (fromJust, fromMaybe, isJust, isNothing)
import qualified Data.Sequence as SQ
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Futhark.Error
import Futhark.IR.GPU
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph
( EdgeType (..),
Edges (..),
Id,
IdSet,
Result (..),
Routing (..),
Vertex (..),
)
import qualified Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph as MG
data MigrationStatus
=
MoveToDevice
|
UsedOnHost
|
StayOnHost
deriving (MigrationStatus -> MigrationStatus -> Bool
(MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> Eq MigrationStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MigrationStatus -> MigrationStatus -> Bool
$c/= :: MigrationStatus -> MigrationStatus -> Bool
== :: MigrationStatus -> MigrationStatus -> Bool
$c== :: MigrationStatus -> MigrationStatus -> Bool
Eq, Eq MigrationStatus
Eq MigrationStatus
-> (MigrationStatus -> MigrationStatus -> Ordering)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> Ord MigrationStatus
MigrationStatus -> MigrationStatus -> Bool
MigrationStatus -> MigrationStatus -> Ordering
MigrationStatus -> MigrationStatus -> MigrationStatus
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmin :: MigrationStatus -> MigrationStatus -> MigrationStatus
max :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmax :: MigrationStatus -> MigrationStatus -> MigrationStatus
>= :: MigrationStatus -> MigrationStatus -> Bool
$c>= :: MigrationStatus -> MigrationStatus -> Bool
> :: MigrationStatus -> MigrationStatus -> Bool
$c> :: MigrationStatus -> MigrationStatus -> Bool
<= :: MigrationStatus -> MigrationStatus -> Bool
$c<= :: MigrationStatus -> MigrationStatus -> Bool
< :: MigrationStatus -> MigrationStatus -> Bool
$c< :: MigrationStatus -> MigrationStatus -> Bool
compare :: MigrationStatus -> MigrationStatus -> Ordering
$ccompare :: MigrationStatus -> MigrationStatus -> Ordering
$cp1Ord :: Eq MigrationStatus
Ord, Int -> MigrationStatus -> ShowS
[MigrationStatus] -> ShowS
MigrationStatus -> String
(Int -> MigrationStatus -> ShowS)
-> (MigrationStatus -> String)
-> ([MigrationStatus] -> ShowS)
-> Show MigrationStatus
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MigrationStatus] -> ShowS
$cshowList :: [MigrationStatus] -> ShowS
show :: MigrationStatus -> String
$cshow :: MigrationStatus -> String
showsPrec :: Int -> MigrationStatus -> ShowS
$cshowsPrec :: Int -> MigrationStatus -> ShowS
Show)
newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus)
instance Semigroup MigrationTable where
MigrationTable IntMap MigrationStatus
a <> :: MigrationTable -> MigrationTable -> MigrationTable
<> MigrationTable IntMap MigrationStatus
b = IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus
a IntMap MigrationStatus
-> IntMap MigrationStatus -> IntMap MigrationStatus
forall a. IntMap a -> IntMap a -> IntMap a
`IM.union` IntMap MigrationStatus
b)
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf VName
n (MigrationTable IntMap MigrationStatus
mt) =
MigrationStatus -> Maybe MigrationStatus -> MigrationStatus
forall a. a -> Maybe a -> a
fromMaybe MigrationStatus
StayOnHost (Maybe MigrationStatus -> MigrationStatus)
-> Maybe MigrationStatus -> MigrationStatus
forall a b. (a -> b) -> a -> b
$ Int -> IntMap MigrationStatus -> Maybe MigrationStatus
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
baseTag VName
n) IntMap MigrationStatus
mt
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slice))) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice Bool -> Bool -> Bool
|| (SubExp -> Bool) -> Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
movedOperand Slice SubExp
slice
where
movedOperand :: SubExp -> Bool
movedOperand (Var VName
op) = VName -> MigrationTable -> MigrationStatus
statusOf VName
op MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
movedOperand SubExp
_ = Bool
False
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ Apply {}) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (If (Var VName
n) Body GPU
_ Body GPU
_ IfDec (BranchType GPU)
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (WhileLoop VName
n) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm Stm GPU
_ MigrationTable
_ = Bool
False
shouldMove :: VName -> MigrationTable -> Bool
shouldMove :: VName -> MigrationTable -> Bool
shouldMove VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
MoveToDevice
type HostOnlyFuns = Set Name
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs =
let names :: [Name]
names = (FunDef GPU -> Name) -> [FunDef GPU] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Name
forall rep. FunDef rep -> Name
funDefName [FunDef GPU]
funs
call_map :: Map Name (Maybe HostOnlyFuns)
call_map = [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns))
-> [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ [Name] -> [Maybe HostOnlyFuns] -> [(Name, Maybe HostOnlyFuns)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names ((FunDef GPU -> Maybe HostOnlyFuns)
-> [FunDef GPU] -> [Maybe HostOnlyFuns]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Maybe HostOnlyFuns
checkFunDef [FunDef GPU]
funs)
in [Name] -> HostOnlyFuns
forall a. Ord a => [a] -> Set a
S.fromList [Name]
names HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Ord a => Set a -> Set a -> Set a
\\ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall a. Map Name a -> HostOnlyFuns
keysToSet (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
call_map)
where
keysToSet :: Map Name a -> HostOnlyFuns
keysToSet = [Name] -> HostOnlyFuns
forall a. Eq a => [a] -> Set a
S.fromAscList ([Name] -> HostOnlyFuns)
-> (Map Name a -> [Name]) -> Map Name a -> HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name a -> [Name]
forall k a. Map k a -> [k]
M.keys
removeHostOnly :: Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
cm =
let (Map Name (Maybe HostOnlyFuns)
host_only, Map Name (Maybe HostOnlyFuns)
cm') = (Maybe HostOnlyFuns -> Bool)
-> Map Name (Maybe HostOnlyFuns)
-> (Map Name (Maybe HostOnlyFuns), Map Name (Maybe HostOnlyFuns))
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition Maybe HostOnlyFuns -> Bool
forall a. Maybe a -> Bool
isHostOnly Map Name (Maybe HostOnlyFuns)
cm
in if Map Name (Maybe HostOnlyFuns) -> Bool
forall k a. Map k a -> Bool
M.null Map Name (Maybe HostOnlyFuns)
host_only
then Map Name (Maybe HostOnlyFuns)
cm'
else Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns))
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ (Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a. Ord a => Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall a. Map Name a -> HostOnlyFuns
keysToSet Map Name (Maybe HostOnlyFuns)
host_only) Map Name (Maybe HostOnlyFuns)
cm'
isHostOnly :: Maybe a -> Bool
isHostOnly = Maybe a -> Bool
forall a. Maybe a -> Bool
isNothing
checkCalls :: Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls Set a
hostOnlyFuns (Just Set a
calls)
| Set a
hostOnlyFuns Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.disjoint` Set a
calls =
Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
calls
checkCalls Set a
_ Maybe (Set a)
_ =
Maybe (Set a)
forall a. Maybe a
Nothing
checkFunDef :: FunDef GPU -> Maybe (Set Name)
checkFunDef :: FunDef GPU -> Maybe HostOnlyFuns
checkFunDef FunDef GPU
fun = do
[Param DeclType] -> Maybe ()
checkFParams (FunDef GPU -> [FParam GPU]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef GPU
fun)
[TypeBase ExtShape Uniqueness] -> Maybe ()
forall u. [TypeBase ExtShape u] -> Maybe ()
checkRetTypes (FunDef GPU -> [RetType GPU]
forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fun)
Body GPU -> Maybe HostOnlyFuns
checkBody (FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fun)
where
hostOnly :: Maybe a
hostOnly = Maybe a
forall a. Maybe a
Nothing
ok :: Maybe ()
ok = () -> Maybe ()
forall a. a -> Maybe a
Just ()
check :: (a -> Bool) -> t a -> Maybe ()
check a -> Bool
isArr t a
as = if (a -> Bool) -> t a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any a -> Bool
isArr t a
as then Maybe ()
forall a. Maybe a
hostOnly else Maybe ()
ok
checkFParams :: [Param DeclType] -> Maybe ()
checkFParams = (Param DeclType -> Bool) -> [Param DeclType] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray
checkLParams :: [(Param DeclType, b)] -> Maybe ()
checkLParams = ((Param DeclType, b) -> Bool) -> [(Param DeclType, b)] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check (Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray (Param DeclType -> Bool)
-> ((Param DeclType, b) -> Param DeclType)
-> (Param DeclType, b)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, b) -> Param DeclType
forall a b. (a, b) -> a
fst)
checkRetTypes :: [TypeBase ExtShape u] -> Maybe ()
checkRetTypes = (TypeBase ExtShape u -> Bool) -> [TypeBase ExtShape u] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check TypeBase ExtShape u -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType
checkPats :: [PatElem Type] -> Maybe ()
checkPats = (PatElem Type -> Bool) -> [PatElem Type] -> Maybe ()
forall (t :: * -> *) a.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check PatElem Type -> Bool
forall t. Typed t => t -> Bool
isArray
checkLoopForm :: LoopForm rep -> Maybe ()
checkLoopForm (ForLoop VName
_ IntType
_ SubExp
_ ((LParam rep, VName)
_ : [(LParam rep, VName)]
_)) = Maybe ()
forall a. Maybe a
hostOnly
checkLoopForm LoopForm rep
_ = Maybe ()
ok
checkBody :: Body GPU -> Maybe HostOnlyFuns
checkBody = Stms GPU -> Maybe HostOnlyFuns
checkStms (Stms GPU -> Maybe HostOnlyFuns)
-> (Body GPU -> Stms GPU) -> Body GPU -> Maybe HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms
checkStms :: Stms GPU -> Maybe HostOnlyFuns
checkStms Stms GPU
stms = Seq HostOnlyFuns -> HostOnlyFuns
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (Seq HostOnlyFuns -> HostOnlyFuns)
-> Maybe (Seq HostOnlyFuns) -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Maybe HostOnlyFuns)
-> Stms GPU -> Maybe (Seq HostOnlyFuns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> Maybe HostOnlyFuns
checkStm Stms GPU
stms
checkStm :: Stm GPU -> Maybe HostOnlyFuns
checkStm (Let (Pat [PatElem (LetDec GPU)]
pats) StmAux (ExpDec GPU)
_ Exp GPU
e) = [PatElem Type] -> Maybe ()
checkPats [PatElem Type]
[PatElem (LetDec GPU)]
pats Maybe () -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> Maybe HostOnlyFuns
checkExp Exp GPU
e
checkExp :: Exp GPU -> Maybe HostOnlyFuns
checkExp (BasicOp (Index VName
_ Slice SubExp
_)) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (WithAcc [WithAccInput GPU]
_ Lambda GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (Op Op GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
checkExp (Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just (Name -> HostOnlyFuns
forall a. a -> Set a
S.singleton Name
fn)
checkExp (If SubExp
_ Body GPU
tbranch Body GPU
fbranch IfDec (BranchType GPU)
_) = do
HostOnlyFuns
calls1 <- Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
tbranch
HostOnlyFuns
calls2 <- Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
fbranch
HostOnlyFuns -> Maybe HostOnlyFuns
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HostOnlyFuns
calls1 HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Semigroup a => a -> a -> a
<> HostOnlyFuns
calls2)
checkExp (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
[(Param DeclType, SubExp)] -> Maybe ()
forall b. [(Param DeclType, b)] -> Maybe ()
checkLParams [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
LoopForm GPU -> Maybe ()
forall rep. LoopForm rep -> Maybe ()
checkLoopForm LoopForm GPU
lform
Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
body
checkExp Exp GPU
_ = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just HostOnlyFuns
forall a. Set a
S.empty
type HostUsage = [Id]
nameToId :: VName -> Id
nameToId :: VName -> Int
nameToId = VName -> Int
baseTag
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts =
let usage :: [Int]
usage = ([Int] -> VName -> NameInfo GPU -> [Int])
-> [Int] -> Map VName (NameInfo GPU) -> [Int]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey (Names -> [Int] -> VName -> NameInfo GPU -> [Int]
forall t. Typed t => Names -> [Int] -> VName -> t -> [Int]
f (Names -> [Int] -> VName -> NameInfo GPU -> [Int])
-> Names -> [Int] -> VName -> NameInfo GPU -> [Int]
forall a b. (a -> b) -> a -> b
$ [FunDef GPU] -> Names
forall a. FreeIn a => a -> Names
freeIn [FunDef GPU]
funs) [] (Stms GPU -> Map VName (NameInfo GPU)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts)
in HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
consts
where
f :: Names -> [Int] -> VName -> t -> [Int]
f Names
free [Int]
usage VName
n t
t
| t -> Bool
forall t. Typed t => t -> Bool
isScalar t
t,
VName
n VName -> Names -> Bool
`nameIn` Names
free =
VName -> Int
nameToId VName
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
usage
| Bool
otherwise =
[Int]
usage
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd =
let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
usage :: [Int]
usage = ([Int] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Int])
-> [Int] -> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Int] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Int]
forall shape u. [Int] -> (SubExpRes, TypeBase shape u) -> [Int]
f [] ([(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int])
-> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Int]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
-> [TypeBase ExtShape Uniqueness]
-> [(SubExpRes, TypeBase ExtShape Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body) (FunDef GPU -> [RetType GPU]
forall rep. FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fd)
stms :: Stms GPU
stms = Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body
in HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
stms
where
f :: [Int] -> (SubExpRes, TypeBase shape u) -> [Int]
f [Int]
usage (SubExpRes Certs
_ (Var VName
n), TypeBase shape u
t) | TypeBase shape u -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType TypeBase shape u
t = VName -> Int
nameToId VName
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
usage
f [Int]
usage (SubExpRes, TypeBase shape u)
_ = [Int]
usage
analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable
analyseStms :: HostOnlyFuns -> [Int] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Int]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Int]
_) = HostOnlyFuns -> [Int] -> Stms GPU -> (Graph, Sources, [Int])
buildGraph HostOnlyFuns
hof [Int]
usage Stms GPU
stms
([Int]
routed, [Int]
unrouted) = Sources
srcs
([Int]
_, Graph
g') = [Int] -> Graph -> ([Int], Graph)
forall m. [Int] -> Graph m -> ([Int], Graph m)
MG.routeMany [Int]
unrouted Graph
g
f :: ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet), Visited ())
st' = Graph
-> ((IntSet, IntSet, IntSet)
-> EdgeType -> Vertex Meta -> (IntSet, IntSet, IntSet))
-> ((IntSet, IntSet, IntSet), Visited ())
-> EdgeType
-> Int
-> ((IntSet, IntSet, IntSet), Visited ())
forall m a.
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> (a, Visited ())
-> EdgeType
-> Int
-> (a, Visited ())
MG.fold Graph
g' (IntSet, IntSet, IntSet)
-> EdgeType -> Vertex Meta -> (IntSet, IntSet, IntSet)
forall m.
(IntSet, IntSet, IntSet)
-> EdgeType -> Vertex m -> (IntSet, IntSet, IntSet)
visit ((IntSet, IntSet, IntSet), Visited ())
st' EdgeType
Normal
st :: ((IntSet, IntSet, IntSet), Visited ())
st = (((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ()))
-> ((IntSet, IntSet, IntSet), Visited ())
-> [Int]
-> ((IntSet, IntSet, IntSet), Visited ())
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet)
initial, Visited ()
forall a. Visited a
MG.none) [Int]
unrouted
(IntSet
vr, IntSet
vn, IntSet
tn) = ((IntSet, IntSet, IntSet), Visited ()) -> (IntSet, IntSet, IntSet)
forall a b. (a, b) -> a
fst (((IntSet, IntSet, IntSet), Visited ())
-> (IntSet, IntSet, IntSet))
-> ((IntSet, IntSet, IntSet), Visited ())
-> (IntSet, IntSet, IntSet)
forall a b. (a -> b) -> a -> b
$ (((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ()))
-> ((IntSet, IntSet, IntSet), Visited ())
-> [Int]
-> ((IntSet, IntSet, IntSet), Visited ())
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((IntSet, IntSet, IntSet), Visited ())
-> Int -> ((IntSet, IntSet, IntSet), Visited ())
f ((IntSet, IntSet, IntSet), Visited ())
st [Int]
routed
in
IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus -> MigrationTable)
-> IntMap MigrationStatus -> MigrationTable
forall a b. (a -> b) -> a -> b
$
[IntMap MigrationStatus] -> IntMap MigrationStatus
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IM.unions
[ (Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) IntSet
vr,
(Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) IntSet
vn,
(Int -> MigrationStatus) -> IntSet -> IntMap MigrationStatus
forall a. (Int -> a) -> IntSet -> IntMap a
IM.fromSet (MigrationStatus -> Int -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
UsedOnHost) IntSet
tn
]
where
initial :: (IntSet, IntSet, IntSet)
initial = (IntSet
IS.empty, IntSet
IS.empty, IntSet
IS.empty)
visit :: (IntSet, IntSet, IntSet)
-> EdgeType -> Vertex m -> (IntSet, IntSet, IntSet)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Reversed Vertex m
v =
let vr' :: IntSet
vr' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
vr
in (IntSet
vr', IntSet
vn, IntSet
tn)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Normal v :: Vertex m
v@Vertex {vertexRouting :: forall m. Vertex m -> Routing
vertexRouting = Routing
NoRoute} =
let vn' :: IntSet
vn' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
vn
in (IntSet
vr, IntSet
vn', IntSet
tn)
visit (IntSet
vr, IntSet
vn, IntSet
tn) EdgeType
Normal Vertex m
v =
let tn' :: IntSet
tn' = Int -> IntSet -> IntSet
IS.insert (Vertex m -> Int
forall m. Vertex m -> Int
vertexId Vertex m
v) IntSet
tn
in (IntSet
vr, IntSet
vn, IntSet
tn')
isScalar :: Typed t => t -> Bool
isScalar :: t -> Bool
isScalar = Type -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf
isScalarType :: TypeBase shape u -> Bool
isScalarType :: TypeBase shape u -> Bool
isScalarType (Prim PrimType
Unit) = Bool
False
isScalarType (Prim PrimType
_) = Bool
True
isScalarType TypeBase shape u
_ = Bool
False
isArray :: Typed t => t -> Bool
isArray :: t -> Bool
isArray = Type -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf
isArrayType :: ArrayShape shape => TypeBase shape u -> Bool
isArrayType :: TypeBase shape u -> Bool
isArrayType = (Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<) (Int -> Bool)
-> (TypeBase shape u -> Int) -> TypeBase shape u -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank
buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks)
buildGraph :: HostOnlyFuns -> [Int] -> Stms GPU -> (Graph, Sources, [Int])
buildGraph HostOnlyFuns
hof [Int]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Int]
sinks) = HostOnlyFuns -> Grapher () -> (Graph, Sources, [Int])
forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Int])
execGrapher HostOnlyFuns
hof (Stms GPU -> Grapher ()
graphStms Stms GPU
stms)
g' :: Graph
g' = (Graph -> Int -> Graph) -> Graph -> [Int] -> Graph
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink) Graph
g [Int]
usage
in (Graph
g', Sources
srcs, [Int]
sinks)
graphBody :: Body GPU -> Grapher ()
graphBody :: Body GPU -> Grapher ()
graphBody Body GPU
body = do
let res_ops :: IntSet
res_ops = Names -> IntSet
namesIntSet (Names -> IntSet) -> Names -> IntSet
forall a b. (a -> b) -> a -> b
$ [SubExpRes] -> Names
forall a. FreeIn a => a -> Names
freeIn (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
BodyStats
body_stats <-
Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> Grapher () -> Grapher BodyStats
forall a b. (a -> b) -> a -> b
$
Grapher () -> Grapher ()
forall a. Grapher a -> Grapher a
incBodyDepthFor (Stms GPU -> Grapher ()
graphStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IntSet -> Grapher ()
tellOperands IntSet
res_ops)
Int
body_depth <- (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+) (Int -> Int)
-> StateT State (Reader Env) Int -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT State (Reader Env) Int
getBodyDepth
let host_only :: Bool
host_only = Int -> IntSet -> Bool
IS.member Int
body_depth (BodyStats -> IntSet
bodyHostOnlyParents BodyStats
body_stats)
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
hops' :: IntSet
hops' = Int -> IntSet -> IntSet
IS.delete Int
body_depth (BodyStats -> IntSet
bodyHostOnlyParents BodyStats
stats)
stats' :: BodyStats
stats' = if Bool
host_only then BodyStats
stats {bodyHostOnly :: Bool
bodyHostOnly = Bool
True} else BodyStats
stats
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats' {bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
hops'}}
graphStms :: Stms GPU -> Grapher ()
graphStms :: Stms GPU -> Grapher ()
graphStms = (Stm GPU -> Grapher ()) -> Stms GPU -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> Grapher ()
graphStm
graphStm :: Stm GPU -> Grapher ()
graphStm :: Stm GPU -> Grapher ()
graphStm Stm GPU
stm = do
let bs :: [Binding]
bs = Stm GPU -> [Binding]
boundBy Stm GPU
stm
let e :: Exp GPU
e = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
case Exp GPU
e of
BasicOp (SubExp SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (Opaque OpaqueOp
_ SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (ArrayLit [SubExp]
arr Type
t)
| Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t,
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool) -> (SubExp -> Maybe VName) -> SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Maybe VName
subExpVar) [SubExp]
arr ->
Binding -> Grapher ()
graphAutoMove ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs)
BasicOp UnOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp BinOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp CmpOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp ConvOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp Assert {} ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
_ Slice SubExp
slice)
| Slice SubExp -> Bool
forall d. Slice d -> Bool
isFixed Slice SubExp
slice ->
Binding -> Grapher ()
graphRead ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs)
BasicOp {}
| [(Int
_, Type
t)] <- [Binding]
bs,
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t,
[SubExp]
dims [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [],
(SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
arr Slice SubExp
s) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
s) Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Update Safety
_ VName
arr Slice SubExp
_ SubExp
_) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatIndex VName
arr FlatSlice SubExp
s) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
s) Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatUpdate VName
arr FlatSlice SubExp
_ VName
_) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Scratch PrimType
_ [SubExp]
s) ->
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [SubExp]
s Exp GPU
e
BasicOp (Reshape ShapeChange SubExp
s VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
s) Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rearrange [Int]
_ VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rotate [SubExp]
_ VName
arr) -> do
[SubExp] -> Exp GPU -> Grapher ()
forall (t :: * -> *).
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
[Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp ArrayLit {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Concat {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Copy {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Manifest {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Iota {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Replicate {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp UpdateAcc {} ->
Binding -> Exp GPU -> Grapher ()
graphUpdateAcc ([Binding] -> Binding
forall p. [p] -> p
one [Binding]
bs) Exp GPU
e
Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_ ->
Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e
If SubExp
cond Body GPU
tbody Body GPU
fbody IfDec (BranchType GPU)
_ ->
[Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf [Binding]
bs SubExp
cond Body GPU
tbody Body GPU
fbody
DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body ->
[Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [Binding]
bs [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body
WithAcc [WithAccInput GPU]
inputs Lambda GPU
f ->
[Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f
Op GPUBody {} ->
Grapher ()
tellGPUBody
Op Op GPU
_ ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
where
one :: [p] -> p
one [p
x] = p
x
one [p]
_ = String -> p
forall a. String -> a
compilerBugS String
"Type error: unexpected number of pattern elements."
isFixed :: Slice d -> Bool
isFixed = Maybe [d] -> Bool
forall a. Maybe a -> Bool
isJust (Maybe [d] -> Bool) -> (Slice d -> Maybe [d]) -> Slice d -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice d -> Maybe [d]
forall d. Slice d -> Maybe [d]
sliceIndices
graphInefficientReturn :: t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn t SubExp
new_dims Exp GPU
e = do
(SubExp -> Grapher ()) -> t SubExp -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
hostSize t SubExp
new_dims
Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink
hostSize :: SubExp -> Grapher ()
hostSize (Var VName
n) = VName -> Grapher ()
hostSizeVar VName
n
hostSize SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
hostSizeVar :: VName -> Grapher ()
hostSizeVar = Int -> Grapher ()
requiredOnHost (Int -> Grapher ()) -> (VName -> Int) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Int
nameToId
boundBy :: Stm GPU -> [Binding]
boundBy :: Stm GPU -> [Binding]
boundBy = (PatElem Type -> Binding) -> [PatElem Type] -> [Binding]
forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem VName
n Type
t) -> (VName -> Int
nameToId VName
n, Type
t)) ([PatElem Type] -> [Binding])
-> (Stm GPU -> [PatElem Type]) -> Stm GPU -> [Binding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
forall rep. Stm rep -> Pat (LetDec rep)
stmPat
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e = do
IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e
let edges :: Edges
edges = [Int] -> Edges
MG.declareEdges ((Binding -> Int) -> [Binding] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Binding -> Int
forall a b. (a, b) -> a
fst [Binding]
bs)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntSet -> Bool
IS.null IntSet
ops) ((Binding -> Grapher ()) -> [Binding] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Binding -> Grapher ()
addVertex [Binding]
bs Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> IntSet -> Grapher ()
addEdges Edges
edges IntSet
ops)
graphRead :: Binding -> Grapher ()
graphRead :: Binding -> Grapher ()
graphRead Binding
b = do
Binding -> Grapher ()
addSource Binding
b
Grapher ()
tellRead
graphAutoMove :: Binding -> Grapher ()
graphAutoMove :: Binding -> Grapher ()
graphAutoMove =
Binding -> Grapher ()
addSource
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e = do
IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e
Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
ops
Grapher ()
tellHostOnly
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc Binding
b Exp GPU
e | (Int
_, Acc VName
a Shape
_ [Type]
_ NoUniqueness
_) <- Binding
b =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let accs :: IntMap [Delayed]
accs = State -> IntMap [Delayed]
stateUpdateAccs State
st
accs' :: IntMap [Delayed]
accs' = (Maybe [Delayed] -> Maybe [Delayed])
-> Int -> IntMap [Delayed] -> IntMap [Delayed]
forall a. (Maybe a -> Maybe a) -> Int -> IntMap a -> IntMap a
IM.alter Maybe [Delayed] -> Maybe [Delayed]
add (VName -> Int
nameToId VName
a) IntMap [Delayed]
accs
in State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
accs'}
where
add :: Maybe [Delayed] -> Maybe [Delayed]
add Maybe [Delayed]
Nothing = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just [(Binding
b, Exp GPU
e)]
add (Just [Delayed]
xs) = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just ([Delayed] -> Maybe [Delayed]) -> [Delayed] -> Maybe [Delayed]
forall a b. (a -> b) -> a -> b
$ (Binding
b, Exp GPU
e) Delayed -> [Delayed] -> [Delayed]
forall a. a -> [a] -> [a]
: [Delayed]
xs
graphUpdateAcc Binding
_ Exp GPU
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS
String
"Type error: UpdateAcc did not produce accumulator typed value."
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e = do
Bool
hof <- Name -> Grapher Bool
isHostOnlyFun Name
fn
if Bool
hof
then Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
else [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf :: [Binding] -> SubExp -> Body GPU -> Body GPU -> Grapher ()
graphIf [Binding]
bs SubExp
cond Body GPU
tbody Body GPU
fbody = do
Bool
body_host_only <-
Grapher Bool -> Grapher Bool
forall a. Grapher a -> Grapher a
incForkDepthFor
( do
BodyStats
tstats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Body GPU -> Grapher ()
graphBody Body GPU
tbody)
BodyStats
fstats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Body GPU -> Grapher ()
graphBody Body GPU
fbody)
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Grapher Bool) -> Bool -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ BodyStats -> Bool
bodyHostOnly BodyStats
tstats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHostOnly BodyStats
fstats
)
Bool
may_copy_results <- [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches [Binding]
bs (Body GPU -> [SubExp]
forall rep. Body rep -> [SubExp]
results Body GPU
tbody) (Body GPU -> [SubExp]
forall rep. Body rep -> [SubExp]
results Body GPU
fbody)
let may_migrate :: Bool
may_migrate = Bool -> Bool
not Bool
body_host_only Bool -> Bool -> Bool
&& Bool
may_copy_results
IntSet
cond_id <- case (Bool
may_migrate, SubExp
cond) of
(Bool
False, Var VName
n) ->
Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n) Grapher () -> Grapher IntSet -> Grapher IntSet
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
(Bool
True, Var VName
n) -> VName -> Grapher IntSet
onlyGraphedScalar VName
n
(Bool
_, SubExp
_) -> IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
IntSet -> Grapher ()
tellOperands IntSet
cond_id
[IntSet]
ret <- (SubExpRes -> SubExpRes -> Grapher IntSet)
-> [SubExpRes] -> [SubExpRes] -> StateT State (Reader Env) [IntSet]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (IntSet -> SubExpRes -> SubExpRes -> Grapher IntSet
comb IntSet
cond_id) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
tbody) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
fbody)
((Binding, IntSet) -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> IntSet -> Grapher ())
-> (Binding, IntSet) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> IntSet -> Grapher ()
createNode) ([Binding] -> [IntSet] -> [(Binding, IntSet)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [IntSet]
ret)
where
results :: Body rep -> [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Body rep -> [SubExpRes]) -> Body rep -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult
comb :: IntSet -> SubExpRes -> SubExpRes -> Grapher IntSet
comb IntSet
ci SubExpRes
a SubExpRes
b = (IntSet
ci IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<>) (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> Grapher IntSet
forall (t :: * -> *). Foldable t => t VName -> Grapher IntSet
onlyGraphedScalars (SubExpRes -> Set VName
toSet SubExpRes
a Set VName -> Set VName -> Set VName
forall a. Semigroup a => a -> a -> a
<> SubExpRes -> Set VName
toSet SubExpRes
b)
toSet :: SubExpRes -> Set VName
toSet (SubExpRes Certs
_ (Var VName
n)) = VName -> Set VName
forall a. a -> Set a
S.singleton VName
n
toSet SubExpRes
_ = Set VName
forall a. Set a
S.empty
type ReachableBindings = IdSet
type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings)
type NonExhausted = [Id]
type LoopValue = (Binding, Id, SubExp, SubExp)
graphLoop ::
[Binding] ->
[(FParam GPU, SubExp)] ->
LoopForm GPU ->
Body GPU ->
Grapher ()
graphLoop :: [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [] [(FParam GPU, SubExp)]
_ LoopForm GPU
_ Body GPU
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS String
"Loop statement bound no variable; should have been eliminated."
graphLoop (Binding
b : [Binding]
bs) [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body = do
Graph
g0 <- Grapher Graph
getGraph
BodyStats
stats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Int
subgraphId Int -> Grapher () -> Grapher ()
forall a. Int -> Grapher a -> Grapher a
`graphIdFor` Grapher ()
graphTheLoop)
let args :: [SubExp]
args = ((Param DeclType, SubExp) -> SubExp)
-> [(Param DeclType, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
let results :: [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
Bool
may_copy_results <- [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [SubExp]
args [SubExp]
results
let may_migrate :: Bool
may_migrate = Bool -> Bool
not (BodyStats -> Bool
bodyHostOnly BodyStats
stats) Bool -> Bool -> Bool
&& Bool
may_copy_results
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_ -> Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
WhileLoop VName
n -> Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
LoopForm GPU
_ -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Graph
g1 <- Grapher Graph
getGraph
(LoopValue -> Grapher ()) -> [LoopValue] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Graph -> LoopValue -> Grapher ()
mergeLoopParam Graph
g1) [LoopValue]
loopValues
[Int]
srcs <- Int -> Grapher [Int]
routeSubgraph Int
subgraphId
[LoopValue] -> (LoopValue -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [LoopValue]
loopValues ((LoopValue -> Grapher ()) -> Grapher ())
-> (LoopValue -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \(Binding
bnd, Int
p, SubExp
_, SubExp
_) -> Binding -> IntSet -> Grapher ()
createNode Binding
bnd (Int -> IntSet
IS.singleton Int
p)
Graph
g2 <- Grapher Graph
getGraph
let (IntSet
dbs, ReachableBindingsCache
rbc) = ((IntSet, ReachableBindingsCache)
-> Int -> (IntSet, ReachableBindingsCache))
-> (IntSet, ReachableBindingsCache)
-> [Int]
-> (IntSet, ReachableBindingsCache)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Graph
-> (IntSet, ReachableBindingsCache)
-> Int
-> (IntSet, ReachableBindingsCache)
deviceBindings Graph
g2) (IntSet
IS.empty, ReachableBindingsCache
forall a. Visited a
MG.none) [Int]
srcs
(Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> Sources -> Sources
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (IntSet -> [Int]
IS.toList IntSet
dbs [Int] -> [Int] -> [Int]
forall a. Semigroup a => a -> a -> a
<>)
let ops :: IntSet
ops = (Int -> Bool) -> IntSet -> IntSet
IS.filter (Int -> Graph -> Bool
forall m. Int -> Graph m -> Bool
`MG.member` Graph
g0) (BodyStats -> IntSet
bodyOperands BodyStats
stats)
(ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache)
-> ReachableBindingsCache -> [Int] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
rbc (IntSet -> [Int]
IS.elems IntSet
ops)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
_ ->
SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
n Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges (IntSet -> Maybe IntSet -> Edges
ToNodes IntSet
bindings Maybe IntSet
forall a. Maybe a
Nothing)
WhileLoop VName
n
| (Binding
_, Int
_, SubExp
arg, SubExp
_) <- VName -> LoopValue
loopValueFor VName
n ->
SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
arg Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> IntSet -> Grapher ()
addEdges (IntSet -> Maybe IntSet -> Edges
ToNodes IntSet
bindings Maybe IntSet
forall a. Maybe a
Nothing)
where
subgraphId :: Id
subgraphId :: Int
subgraphId = Binding -> Int
forall a b. (a, b) -> a
fst Binding
b
loopValues :: [LoopValue]
loopValues :: [LoopValue]
loopValues =
let tmp :: [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp = [Binding]
-> [(Param DeclType, SubExp)]
-> [SubExpRes]
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
tmp' :: [LoopValue]
tmp' = (((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [(Binding, (Param DeclType, SubExp), SubExpRes)] -> [LoopValue])
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [(Binding, (Param DeclType, SubExp), SubExpRes)] -> [LoopValue]
forall a b. (a -> b) -> [a] -> [b]
map [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp (((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue])
-> ((Binding, (Param DeclType, SubExp), SubExpRes) -> LoopValue)
-> [LoopValue]
forall a b. (a -> b) -> a -> b
$
\(Binding
bnd, (Param DeclType
p, SubExp
arg), SubExpRes
res) ->
let i :: Int
i = VName -> Int
nameToId (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
in (Binding
bnd, Int
i, SubExp
arg, SubExpRes -> SubExp
resSubExp SubExpRes
res)
in (LoopValue -> Bool) -> [LoopValue] -> [LoopValue]
forall a. (a -> Bool) -> [a] -> [a]
filter (\((Int
_, Type
t), Int
_, SubExp
_, SubExp
_) -> Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) [LoopValue]
tmp'
bindings :: IdSet
bindings :: IntSet
bindings = [Int] -> IntSet
IS.fromList ([Int] -> IntSet) -> [Int] -> IntSet
forall a b. (a -> b) -> a -> b
$ (LoopValue -> Int) -> [LoopValue] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (\((Int
i, Type
_), Int
_, SubExp
_, SubExp
_) -> Int
i) [LoopValue]
loopValues
loopValueFor :: VName -> LoopValue
loopValueFor :: VName -> LoopValue
loopValueFor VName
n =
Maybe LoopValue -> LoopValue
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe LoopValue -> LoopValue) -> Maybe LoopValue -> LoopValue
forall a b. (a -> b) -> a -> b
$ (LoopValue -> Bool) -> [LoopValue] -> Maybe LoopValue
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Binding
_, Int
p, SubExp
_, SubExp
_) -> Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Int
nameToId VName
n) [LoopValue]
loopValues
graphTheLoop :: Grapher ()
graphTheLoop :: Grapher ()
graphTheLoop = do
(LoopValue -> Grapher ()) -> [LoopValue] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ LoopValue -> Grapher ()
forall a d. ((a, Type), Int, SubExp, d) -> Grapher ()
graphParam [LoopValue]
loopValues
case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
elems -> do
SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
n Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= IntSet -> Grapher ()
tellOperands
((Param Type, VName) -> Grapher ())
-> [(Param Type, VName)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Param Type, VName) -> Grapher ()
forall dec. Typed dec => (Param dec, VName) -> Grapher ()
graphForInElem [(Param Type, VName)]
[(LParam GPU, VName)]
elems
WhileLoop VName
_ -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Body GPU -> Grapher ()
graphBody Body GPU
body
where
graphForInElem :: (Param dec, VName) -> Grapher ()
graphForInElem (Param dec
p, VName
arr) = do
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param dec -> Bool
forall t. Typed t => t -> Bool
isScalar Param dec
p) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Binding -> Grapher ()
addSource (VName -> Int
nameToId (VName -> Int) -> VName -> Int
forall a b. (a -> b) -> a -> b
$ Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p, Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Param dec -> Bool
forall t. Typed t => t -> Bool
isArray Param dec
p) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (VName -> Int
nameToId (Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
p), Param dec -> Type
forall t. Typed t => t -> Type
typeOf Param dec
p) Binding -> VName -> Grapher ()
`reuses` VName
arr
graphParam :: ((a, Type), Int, SubExp, d) -> Grapher ()
graphParam ((a
_, Type
t), Int
p, SubExp
arg, d
_) =
do
Binding -> Grapher ()
addVertex (Int
p, Type
t)
IntSet
ops <- SubExp -> Grapher IntSet
onlyGraphedScalarSubExp SubExp
arg
Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge Int
p) IntSet
ops
mergeLoopParam :: Graph -> LoopValue -> Grapher ()
mergeLoopParam :: Graph -> LoopValue -> Grapher ()
mergeLoopParam Graph
g (Binding
_, Int
p, SubExp
_, SubExp
res)
| Var VName
n <- SubExp
res,
Int
ret <- VName -> Int
nameToId VName
n,
Int
ret Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
p =
if Int -> Graph -> Bool
forall m. Int -> Graph m -> Bool
MG.isSinkConnected Int
p Graph
g
then Int -> Grapher ()
connectToSink Int
ret
else Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge Int
p) (Int -> IntSet
IS.singleton Int
ret)
| Bool
otherwise =
() -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
deviceBindings ::
Graph ->
(ReachableBindings, ReachableBindingsCache) ->
Id ->
(ReachableBindings, ReachableBindingsCache)
deviceBindings :: Graph
-> (IntSet, ReachableBindingsCache)
-> Int
-> (IntSet, ReachableBindingsCache)
deviceBindings Graph
g (IntSet
rb, ReachableBindingsCache
rbc) Int
i =
let (Result IntSet
r, ReachableBindingsCache
rbc') = Graph
-> (IntSet -> EdgeType -> Vertex Meta -> IntSet)
-> ReachableBindingsCache
-> EdgeType
-> Int
-> (Result IntSet, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Int
-> (Result a, Visited (Result a))
MG.reduce Graph
g IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Int
i
in case Result IntSet
r of
Produced IntSet
rb' -> (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
rb', ReachableBindingsCache
rbc')
Result IntSet
_ ->
String -> (IntSet, ReachableBindingsCache)
forall a. String -> a
compilerBugS
String
"Migration graph sink could be reached from source after it\
\ had been attempted routed."
bindingReach ::
ReachableBindings ->
EdgeType ->
Vertex Meta ->
ReachableBindings
bindingReach :: IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach IntSet
rb EdgeType
_ Vertex Meta
v
| Int
i <- Vertex Meta -> Int
forall m. Vertex m -> Int
vertexId Vertex Meta
v,
Int -> IntSet -> Bool
IS.member Int
i IntSet
bindings =
Int -> IntSet -> IntSet
IS.insert Int
i IntSet
rb
| Bool
otherwise =
IntSet
rb
connectOperand ::
ReachableBindingsCache ->
Id ->
Grapher ReachableBindingsCache
connectOperand :: ReachableBindingsCache
-> Int -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
cache Int
op = do
Graph
g <- Grapher Graph
getGraph
case Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
op Graph
g of
Maybe (Vertex Meta)
Nothing -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
Just Vertex Meta
v ->
case Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v of
Edges
ToSink -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
ToNodes IntSet
es Maybe IntSet
Nothing -> Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Int
op IntSet
es
ToNodes IntSet
_ (Just IntSet
nx) -> Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Int
op IntSet
nx
where
connectOp ::
Graph ->
ReachableBindingsCache ->
Id ->
IdSet ->
Grapher ReachableBindingsCache
connectOp :: Graph
-> ReachableBindingsCache
-> Int
-> IntSet
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
rbc Int
i IntSet
es = do
let (Result IntSet
res, [Int]
nx, ReachableBindingsCache
rbc') = Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
IS.empty, [], ReachableBindingsCache
rbc) (IntSet -> [Int]
IS.elems IntSet
es)
case Result IntSet
res of
Result IntSet
FoundSink -> Int -> Grapher ()
connectToSink Int
i
Produced IntSet
rb -> (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Vertex Meta -> Vertex Meta) -> Int -> Graph -> Graph
forall m. (Vertex m -> Vertex m) -> Int -> Graph m -> Graph m
MG.adjust ([Int] -> IntSet -> Vertex Meta -> Vertex Meta
updateEdges [Int]
nx IntSet
rb) Int
i
ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
rbc'
updateEdges ::
NonExhausted ->
ReachableBindings ->
Vertex Meta ->
Vertex Meta
updateEdges :: [Int] -> IntSet -> Vertex Meta -> Vertex Meta
updateEdges [Int]
nx IntSet
rb Vertex Meta
v
| ToNodes IntSet
es Maybe IntSet
_ <- Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v =
let nx' :: IntSet
nx' = [Int] -> IntSet
IS.fromList [Int]
nx
es' :: Edges
es' = IntSet -> Maybe IntSet -> Edges
ToNodes (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
es) (Maybe IntSet -> Edges) -> Maybe IntSet -> Edges
forall a b. (a -> b) -> a -> b
$ IntSet -> Maybe IntSet
forall a. a -> Maybe a
Just (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
nx')
in Vertex Meta
v {vertexEdges :: Edges
vertexEdges = Edges
es'}
| Bool
otherwise = Vertex Meta
v
findBindings ::
Graph ->
(ReachableBindings, NonExhausted, ReachableBindingsCache) ->
[Id] ->
(MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache)
findBindings :: Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
_ (IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc) [] =
(IntSet -> Result IntSet
forall a. a -> Result a
Produced IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc)
findBindings Graph
g (IntSet
rb, [Int]
nx, ReachableBindingsCache
rbc) (Int
i : [Int]
is)
| Just Vertex Meta
v <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i Graph
g,
Just Int
gid <- Meta -> Maybe Int
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v),
Int
gid Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
subgraphId
=
let (Result IntSet
res, ReachableBindingsCache
rbc') = Graph
-> (IntSet -> EdgeType -> Vertex Meta -> IntSet)
-> ReachableBindingsCache
-> EdgeType
-> Int
-> (Result IntSet, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Int
-> (Result a, Visited (Result a))
MG.reduce Graph
g IntSet -> EdgeType -> Vertex Meta -> IntSet
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Int
i
in case Result IntSet
res of
Result IntSet
FoundSink -> (Result IntSet
forall a. Result a
FoundSink, [], ReachableBindingsCache
rbc')
Produced IntSet
rb' -> Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
rb IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
rb', [Int]
nx, ReachableBindingsCache
rbc') [Int]
is
| Bool
otherwise =
Graph
-> (IntSet, [Int], ReachableBindingsCache)
-> [Int]
-> (Result IntSet, [Int], ReachableBindingsCache)
findBindings Graph
g (IntSet
rb, Int
i Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
nx, ReachableBindingsCache
rbc) [Int]
is
graphWithAcc ::
[Binding] ->
[WithAccInput GPU] ->
Lambda GPU ->
Grapher ()
graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
((Type, WithAccInput GPU) -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type, WithAccInput GPU) -> Grapher ()
forall shape u a b.
(TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph ([(Type, WithAccInput GPU)] -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Type] -> [WithAccInput GPU] -> [(Type, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f) [WithAccInput GPU]
inputs
let arrs :: [SubExp]
arrs = (WithAccInput GPU -> [SubExp]) -> [WithAccInput GPU] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda GPU, [SubExp])
_) -> (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as) [WithAccInput GPU]
inputs
let res :: [SubExpRes]
res = Int -> [SubExpRes] -> [SubExpRes]
forall a. Int -> [a] -> [a]
drop ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPU -> [SubExpRes]) -> Body GPU -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Bool
_ <- [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs ([SubExp]
arrs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
res)
[IntSet]
ret <- (SubExpRes -> Grapher IntSet)
-> [SubExpRes] -> StateT State (Reader Env) [IntSet]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Grapher IntSet
onlyGraphedScalarSubExp (SubExp -> Grapher IntSet)
-> (SubExpRes -> SubExp) -> SubExpRes -> Grapher IntSet
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) [SubExpRes]
res
((Binding, IntSet) -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> IntSet -> Grapher ())
-> (Binding, IntSet) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> IntSet -> Grapher ()
createNode) ([(Binding, IntSet)] -> Grapher ())
-> [(Binding, IntSet)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Binding] -> [IntSet] -> [(Binding, IntSet)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [Binding] -> [Binding]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
arrs) [Binding]
bs) [IntSet]
ret
where
graph :: (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph (Acc VName
a Shape
_ [Type]
types u
_, (a
_, b
_, Maybe (Lambda GPU, [SubExp])
comb)) = do
let i :: Int
i = VName -> Int
nameToId VName
a
[Delayed]
delayed <- [Delayed] -> Maybe [Delayed] -> [Delayed]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Delayed] -> [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
-> StateT State (Reader Env) [Delayed]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Maybe [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Int -> IntMap [Delayed] -> Maybe [Delayed]
forall a. Int -> IntMap a -> Maybe a
IM.lookup Int
i (IntMap [Delayed] -> Maybe [Delayed])
-> (State -> IntMap [Delayed]) -> State -> Maybe [Delayed]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap [Delayed]
stateUpdateAccs)
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = Int -> IntMap [Delayed] -> IntMap [Delayed]
forall a. Int -> IntMap a -> IntMap a
IM.delete Int
i (State -> IntMap [Delayed]
stateUpdateAccs State
st)}
Int -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Int
i [Type]
types ((Lambda GPU, [SubExp]) -> Lambda GPU
forall a b. (a, b) -> a
fst ((Lambda GPU, [SubExp]) -> Lambda GPU)
-> Maybe (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Lambda GPU, [SubExp])
comb) [Delayed]
delayed
(SubExp -> Grapher ()) -> [SubExp] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
connectSubExpToSink ([SubExp] -> Grapher ()) -> [SubExp] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> ((Lambda GPU, [SubExp]) -> [SubExp])
-> Maybe (Lambda GPU, [SubExp])
-> [SubExp]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (Lambda GPU, [SubExp]) -> [SubExp]
forall a b. (a, b) -> b
snd Maybe (Lambda GPU, [SubExp])
comb
graph (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
_ =
String -> Grapher ()
forall a. String -> a
compilerBugS String
"Type error: WithAcc expression did not return accumulator."
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc :: Int -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Int
i [Type]
_ Maybe (Lambda GPU)
_ [] = Binding -> Grapher ()
addSource (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
graphAcc Int
i [Type]
types Maybe (Lambda GPU)
op [Delayed]
delayed = do
Env
env <- Grapher Env
ask
State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let lambda :: Lambda GPU
lambda = Lambda GPU -> Maybe (Lambda GPU) -> Lambda GPU
forall a. a -> Maybe a -> a
fromMaybe ([LParam GPU] -> Body GPU -> [Type] -> Lambda GPU
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [] (BodyDec GPU -> Stms GPU -> [SubExpRes] -> Body GPU
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty []) []) Maybe (Lambda GPU)
op
let m :: Grapher ()
m = Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lambda)
let stats :: BodyStats
stats = Reader Env BodyStats -> Env -> BodyStats
forall r a. Reader r a -> r -> a
R.runReader (Grapher BodyStats -> State -> Reader Env BodyStats
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher ()
m) State
st) Env
env
let host_only :: Bool
host_only = BodyStats -> Bool
bodyHostOnly BodyStats
stats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHasGPUBody BodyStats
stats
let does_read :: Bool
does_read = BodyStats -> Bool
bodyReads BodyStats
stats Bool -> Bool -> Bool
|| (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall t. Typed t => t -> Bool
isScalar [Type]
types
IntSet
ops <- Exp GPU -> Grapher IntSet
graphedScalarOperands ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [] Lambda GPU
lambda)
case (Bool
host_only, Bool
does_read) of
(Bool
True, Bool
_) -> do
(Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Exp GPU -> Grapher ()
graphHostOnly (Exp GPU -> Grapher ())
-> (Delayed -> Exp GPU) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Exp GPU
forall a b. (a, b) -> b
snd) [Delayed]
delayed
Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
ops
(Bool
_, Bool
True) -> do
(Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding -> Grapher ()
graphAutoMove (Binding -> Grapher ())
-> (Delayed -> Binding) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Binding
forall a b. (a, b) -> a
fst) [Delayed]
delayed
Binding -> Grapher ()
addSource (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
(Bool, Bool)
_ -> do
Binding -> IntSet -> Grapher ()
createNode (Int
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) IntSet
ops
[Delayed] -> (Delayed -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Delayed]
delayed ((Delayed -> Grapher ()) -> Grapher ())
-> (Delayed -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$
\(Binding
b, Exp GPU
e) -> Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e Grapher IntSet -> (IntSet -> Grapher ()) -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Binding -> IntSet -> Grapher ()
createNode Binding
b (IntSet -> Grapher ())
-> (IntSet -> IntSet) -> IntSet -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IntSet -> IntSet
IS.insert Int
i
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands :: Exp GPU -> Grapher IntSet
graphedScalarOperands Exp GPU
e =
let is :: IntSet
is = (IntSet, Set VName) -> IntSet
forall a b. (a, b) -> a
fst ((IntSet, Set VName) -> IntSet) -> (IntSet, Set VName) -> IntSet
forall a b. (a -> b) -> a -> b
$ State (IntSet, Set VName) ()
-> (IntSet, Set VName) -> (IntSet, Set VName)
forall s a. State s a -> s -> s
execState (Exp GPU -> State (IntSet, Set VName) ()
collect Exp GPU
e) (IntSet, Set VName)
forall a. (IntSet, Set a)
initial
in IntSet -> IntSet -> IntSet
IS.intersection IntSet
is (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher IntSet
getGraphedScalars
where
initial :: (IntSet, Set a)
initial = (IntSet
IS.empty, Set a
forall a. Set a
S.empty)
captureName :: VName -> StateT (p IntSet c) m ()
captureName VName
n = (p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ())
-> (p IntSet c -> p IntSet c) -> StateT (p IntSet c) m ()
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> p IntSet c -> p IntSet c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((IntSet -> IntSet) -> p IntSet c -> p IntSet c)
-> (IntSet -> IntSet) -> p IntSet c -> p IntSet c
forall a b. (a -> b) -> a -> b
$ Int -> IntSet -> IntSet
IS.insert (VName -> Int
nameToId VName
n)
captureAcc :: a -> StateT (p a (Set a)) m ()
captureAcc a
a = (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ())
-> (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall a b. (a -> b) -> a -> b
$ (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Set a -> Set a) -> p a (Set a) -> p a (Set a))
-> (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall a b. (a -> b) -> a -> b
$ a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert a
a
collectFree :: a -> StateT (p IntSet c) m ()
collectFree a
x = (VName -> StateT (p IntSet c) m ())
-> [VName] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
VName -> StateT (p IntSet c) m ()
captureName (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)
collect :: Exp GPU -> State (IntSet, Set VName) ()
collect b :: Exp GPU
b@BasicOp {} =
Exp GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p IntSet c) m ()
collectBasic Exp GPU
b
collect (Apply Name
_ [(SubExp, Diet)]
params [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) =
((SubExp, Diet) -> State (IntSet, Set VName) ())
-> [(SubExp, Diet)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SubExp -> State (IntSet, Set VName) ())
-> ((SubExp, Diet) -> SubExp)
-> (SubExp, Diet)
-> State (IntSet, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
params
collect (If SubExp
cond Body GPU
tbranch Body GPU
fbranch IfDec (BranchType GPU)
_) =
SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
cond State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
tbranch State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
fbranch
collect (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
((FParam GPU, SubExp) -> State (IntSet, Set VName) ())
-> [(FParam GPU, SubExp)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SubExp -> State (IntSet, Set VName) ())
-> ((FParam GPU, SubExp) -> SubExp)
-> (FParam GPU, SubExp)
-> State (IntSet, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(FParam GPU, SubExp)]
params
LoopForm GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
LoopForm rep -> StateT (p IntSet c) m ()
collectLForm LoopForm GPU
lform
Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
body
collect (WithAcc [WithAccInput GPU]
accs Lambda GPU
f) =
[WithAccInput GPU] -> Lambda GPU -> State (IntSet, Set VName) ()
collectWithAcc [WithAccInput GPU]
accs Lambda GPU
f
collect (Op Op GPU
op) =
HostOp GPU (SOAC GPU) -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a rep c.
(Monad m, Bifunctor p, FreeIn a) =>
HostOp rep a -> StateT (p IntSet c) m ()
collectHostOp Op GPU
HostOp GPU (SOAC GPU)
op
collectBasic :: Exp rep -> StateT (p IntSet c) m ()
collectBasic (BasicOp (Update Safety
_ VName
_ Slice SubExp
slice SubExp
_)) =
Slice SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree Slice SubExp
slice
collectBasic (BasicOp (Replicate Shape
shape SubExp
_)) =
Shape -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree Shape
shape
collectBasic Exp rep
e' =
Walker rep (StateT (p IntSet c) m)
-> Exp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Walker rep (StateT (p IntSet c) m)
forall (m :: * -> *) rep. Monad m => Walker rep m
identityWalker {walkOnSubExp :: SubExp -> StateT (p IntSet c) m ()
walkOnSubExp = SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp}) Exp rep
e'
collectSubExp :: SubExp -> StateT (p IntSet c) m ()
collectSubExp (Var VName
n) = VName -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
VName -> StateT (p IntSet c) m ()
captureName VName
n
collectSubExp SubExp
_ = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectBody :: Body GPU -> State (IntSet, Set VName) ()
collectBody Body GPU
body = do
Stms GPU -> State (IntSet, Set VName) ()
collectStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
[SubExpRes] -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
collectStms :: Stms GPU -> State (IntSet, Set VName) ()
collectStms = (Stm GPU -> State (IntSet, Set VName) ())
-> Stms GPU -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> State (IntSet, Set VName) ()
collectStm
collectStm :: Stm GPU -> State (IntSet, Set VName) ()
collectStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ Exp GPU
ua)
| BasicOp UpdateAcc {} <- Exp GPU
ua,
Pat [PatElem (LetDec GPU)
pe] <- Pat (LetDec GPU)
pat,
Acc VName
a Shape
_ [Type]
_ NoUniqueness
_ <- PatElem (LetDec GPU) -> Type
forall t. Typed t => t -> Type
typeOf PatElem (LetDec GPU)
pe =
VName -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) a a.
(Monad m, Bifunctor p, Ord a) =>
a -> StateT (p a (Set a)) m ()
captureAcc VName
a State (IntSet, Set VName) ()
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p IntSet c) m ()
collectBasic Exp GPU
ua
collectStm Stm GPU
stm = Exp GPU -> State (IntSet, Set VName) ()
collect (Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm)
collectLForm :: LoopForm rep -> StateT (p IntSet c) m ()
collectLForm (ForLoop VName
_ IntType
_ SubExp
b [(LParam rep, VName)]
_) = SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
b
collectLForm (WhileLoop VName
_) = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectWithAcc :: [WithAccInput GPU] -> Lambda GPU -> State (IntSet, Set VName) ()
collectWithAcc [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> State (IntSet, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Set VName
used_accs <- ((IntSet, Set VName) -> Set VName)
-> StateT (IntSet, Set VName) Identity (Set VName)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (IntSet, Set VName) -> Set VName
forall a b. (a, b) -> b
snd
let accs :: [Type]
accs = Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
take ([WithAccInput GPU] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput GPU]
inputs) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f)
let used :: [Bool]
used = (Type -> Bool) -> [Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\(Acc VName
a Shape
_ [Type]
_ NoUniqueness
_) -> VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
a Set VName
used_accs) [Type]
accs
((Bool, WithAccInput GPU) -> State (IntSet, Set VName) ())
-> [(Bool, WithAccInput GPU)] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool, WithAccInput GPU) -> State (IntSet, Set VName) ()
collectAcc ([Bool] -> [WithAccInput GPU] -> [(Bool, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
used [WithAccInput GPU]
inputs)
collectAcc :: (Bool, WithAccInput GPU) -> State (IntSet, Set VName) ()
collectAcc (Bool
_, (Shape
_, [VName]
_, Maybe (Lambda GPU, [SubExp])
Nothing)) = () -> State (IntSet, Set VName) ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectAcc (Bool
used, (Shape
_, [VName]
_, Just (Lambda GPU
op, [SubExp]
nes))) = do
(SubExp -> State (IntSet, Set VName) ())
-> [SubExp] -> State (IntSet, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (IntSet, Set VName) ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
Bool
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
used (State (IntSet, Set VName) () -> State (IntSet, Set VName) ())
-> State (IntSet, Set VName) () -> State (IntSet, Set VName) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> State (IntSet, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op)
collectHostOp :: HostOp rep a -> StateT (p IntSet c) m ()
collectHostOp (SegOp (SegMap SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
collectHostOp (SegOp (SegRed SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
(SegBinOp rep -> StateT (p IntSet c) m ())
-> [SegBinOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegScan SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
(SegBinOp rep -> StateT (p IntSet c) m ())
-> [SegBinOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegHist SegLevel
lvl SegSpace
sp [HistOp rep]
ops [Type]
_ KernelBody rep
_)) = do
SegLevel -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegLevel -> StateT (p IntSet c) m ()
collectSegLevel SegLevel
lvl
SegSpace -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
sp
(HistOp rep -> StateT (p IntSet c) m ())
-> [HistOp rep] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ HistOp rep -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) rep c.
(Monad m, Bifunctor p) =>
HistOp rep -> StateT (p IntSet c) m ()
collectHistOp [HistOp rep]
ops
collectHostOp (SizeOp SizeOp
op) = SizeOp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree SizeOp
op
collectHostOp (OtherOp a
op) = a -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) a c.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p IntSet c) m ()
collectFree a
op
collectHostOp GPUBody {} = () -> StateT (p IntSet c) m ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectSegLevel :: SegLevel -> StateT (p IntSet c) m ()
collectSegLevel (SegThread (Count SubExp
num) (Count SubExp
size) SegVirt
_) =
SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
num StateT (p IntSet c) m ()
-> StateT (p IntSet c) m () -> StateT (p IntSet c) m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
size
collectSegLevel (SegGroup (Count SubExp
num) (Count SubExp
size) SegVirt
_) =
SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
num StateT (p IntSet c) m ()
-> StateT (p IntSet c) m () -> StateT (p IntSet c) m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
size
collectSegSpace :: SegSpace -> StateT (p IntSet c) m ()
collectSegSpace SegSpace
space =
(SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
collectSegBinOp :: SegBinOp rep -> StateT (p IntSet c) m ()
collectSegBinOp (SegBinOp Commutativity
_ Lambda rep
_ [SubExp]
nes Shape
_) =
(SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
collectHistOp :: HistOp rep -> StateT (p IntSet c) m ()
collectHistOp (HistOp Shape
_ SubExp
rf [VName]
_ [SubExp]
nes Shape
_ Lambda rep
_) = do
SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp SubExp
rf
(SubExp -> StateT (p IntSet c) m ())
-> [SubExp] -> StateT (p IntSet c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p IntSet c) m ()
forall (m :: * -> *) (p :: * -> * -> *) c.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p IntSet c) m ()
collectSubExp [SubExp]
nes
createNode :: Binding -> Operands -> Grapher ()
createNode :: Binding -> IntSet -> Grapher ()
createNode Binding
b IntSet
ops =
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (IntSet -> Bool
IS.null IntSet
ops) (Binding -> Grapher ()
addVertex Binding
b Grapher () -> Grapher () -> Grapher ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> IntSet -> Grapher ()
addEdges (Int -> Edges
MG.oneEdge (Int -> Edges) -> Int -> Edges
forall a b. (a -> b) -> a -> b
$ Binding -> Int
forall a b. (a, b) -> a
fst Binding
b) IntSet
ops)
addVertex :: Binding -> Grapher ()
addVertex :: Binding -> Grapher ()
addVertex (Int
i, Type
t) = do
Meta
meta <- Grapher Meta
getMeta
let v :: Vertex Meta
v = Int -> Meta -> Vertex Meta
forall m. Int -> m -> Vertex m
MG.vertex Int
i Meta
meta
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (Int -> IntSet -> IntSet
IS.insert Int
i)
Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Meta -> Int
metaBodyDepth Meta
meta)
(Graph -> Graph) -> Grapher ()
modifyGraph (Vertex Meta -> Graph -> Graph
forall m. Vertex m -> Graph m -> Graph m
MG.insert Vertex Meta
v)
addSource :: Binding -> Grapher ()
addSource :: Binding -> Grapher ()
addSource Binding
b = do
Binding -> Grapher ()
addVertex Binding
b
(Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Int] -> [Int]) -> Sources -> Sources
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Binding -> Int
forall a b. (a, b) -> a
fst Binding
b Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:)
addEdges :: Edges -> IdSet -> Grapher ()
addEdges :: Edges -> IntSet -> Grapher ()
addEdges Edges
ToSink IntSet
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Int -> Graph) -> Graph -> IntSet -> Graph
forall a. (a -> Int -> a) -> a -> IntSet -> a
IS.foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink) Graph
g IntSet
is
(IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (IntSet -> IntSet -> IntSet
`IS.difference` IntSet
is)
addEdges Edges
es IntSet
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Int -> Graph) -> Graph -> IntSet -> Graph
forall a. (a -> Int -> a) -> a -> IntSet -> a
IS.foldl' ((Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Int -> Graph -> Graph) -> Graph -> Int -> Graph)
-> (Int -> Graph -> Graph) -> Graph -> Int -> Graph
forall a b. (a -> b) -> a -> b
$ Edges -> Int -> Graph -> Graph
forall m. Edges -> Int -> Graph m -> Graph m
MG.addEdges Edges
es) Graph
g IntSet
is
IntSet -> Grapher ()
tellOperands IntSet
is
requiredOnHost :: Id -> Grapher ()
requiredOnHost :: Int -> Grapher ()
requiredOnHost Int
i = do
Maybe (Vertex Meta)
mv <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i (Graph -> Maybe (Vertex Meta))
-> Grapher Graph -> StateT State (Reader Env) (Maybe (Vertex Meta))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Graph
getGraph
case Maybe (Vertex Meta)
mv of
Maybe (Vertex Meta)
Nothing -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just Vertex Meta
v -> do
Int -> Grapher ()
connectToSink Int
i
Int -> Grapher ()
tellHostOnlyParent (Meta -> Int
metaBodyDepth (Meta -> Int) -> Meta -> Int
forall a b. (a -> b) -> a -> b
$ Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v)
connectToSink :: Id -> Grapher ()
connectToSink :: Int -> Grapher ()
connectToSink Int
i = do
(Graph -> Graph) -> Grapher ()
modifyGraph (Int -> Graph -> Graph
forall m. Int -> Graph m -> Graph m
MG.connectToSink Int
i)
(IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars (Int -> IntSet -> IntSet
IS.delete Int
i)
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink (Var VName
n) = Int -> Grapher ()
connectToSink (VName -> Int
nameToId VName
n)
connectSubExpToSink SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph :: Int -> Grapher [Int]
routeSubgraph Int
si = do
State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
let g :: Graph
g = State -> Graph
stateGraph State
st
let ([Int]
routed, [Int]
unrouted) = State -> Sources
stateSources State
st
let ([Int]
gsrcs, [Int]
unrouted') = (Int -> Bool) -> [Int] -> Sources
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Int -> Graph -> Int -> Bool
inSubGraph Int
si Graph
g) [Int]
unrouted
let ([Int]
sinks, Graph
g') = [Int] -> Graph -> ([Int], Graph)
forall m. [Int] -> Graph m -> ([Int], Graph m)
MG.routeMany [Int]
gsrcs Graph
g
State -> Grapher ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State -> Grapher ()) -> State -> Grapher ()
forall a b. (a -> b) -> a -> b
$
State
st
{ stateGraph :: Graph
stateGraph = Graph
g',
stateSources :: Sources
stateSources = ([Int]
gsrcs [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int]
routed, [Int]
unrouted'),
stateSinks :: [Int]
stateSinks = [Int]
sinks [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ State -> [Int]
stateSinks State
st
}
[Int] -> Grapher [Int]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Int]
gsrcs
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph :: Int -> Graph -> Int -> Bool
inSubGraph Int
si Graph
g Int
i
| Just Vertex Meta
v <- Int -> Graph -> Maybe (Vertex Meta)
forall m. Int -> Graph m -> Maybe (Vertex m)
MG.lookup Int
i Graph
g,
Just Int
mgi <- Meta -> Maybe Int
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v) =
Int
si Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mgi
inSubGraph Int
_ Graph
_ Int
_ = Bool
False
reuses :: Binding -> VName -> Grapher ()
reuses :: Binding -> VName -> Grapher ()
reuses (Int
i, Type
t) VName
n
| Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t =
do
Maybe Int
body_depth <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n
case Maybe Int
body_depth of
Just Int
bd -> Int -> Int -> Grapher ()
recordCopyableMemory Int
i Int
bd
Maybe Int
Nothing -> () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
| Bool
otherwise =
() -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp Binding
b (Var VName
n) = Binding
b Binding -> VName -> Grapher ()
`reuses` VName
n
reusesSubExp Binding
_ SubExp
_ = () -> Grapher ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs [SubExp]
res = do
Int
body_depth <- Meta -> Int
metaBodyDepth (Meta -> Int) -> Grapher Meta -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
(Bool -> (Binding, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Int
body_depth) Bool
True ([Binding] -> [SubExp] -> [(Binding, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [SubExp]
res)
where
reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Int
body_depth Bool
onlyCopyable (Binding
b, SubExp
se)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Int
i, Type
t) <- Binding
b,
Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
Var VName
n <- SubExp
se =
do
Maybe Int
res_body_depth <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n
case Maybe Int
res_body_depth of
Just Int
inner -> do
Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
body_depth Int
inner)
let returns_free_var :: Bool
returns_free_var = Int
inner Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
body_depth
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
Maybe Int
_ ->
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches :: [Binding] -> [SubExp] -> [SubExp] -> Grapher Bool
reusesBranches [Binding]
bs [SubExp]
b1 [SubExp]
b2 = do
Int
body_depth <- Meta -> Int
metaBodyDepth (Meta -> Int) -> Grapher Meta -> StateT State (Reader Env) Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
(Bool -> (Binding, SubExp, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse Int
body_depth) Bool
True ([(Binding, SubExp, SubExp)] -> Grapher Bool)
-> [(Binding, SubExp, SubExp)] -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ [Binding] -> [SubExp] -> [SubExp] -> [(Binding, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Binding]
bs [SubExp]
b1 [SubExp]
b2
where
reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse :: Int -> Bool -> (Binding, SubExp, SubExp) -> Grapher Bool
reuse Int
body_depth Bool
onlyCopyable (Binding
b, SubExp
se1, SubExp
se2)
| (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Int
i, Type
t) <- Binding
b,
Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
Var VName
n1 <- SubExp
se1,
Var VName
n2 <- SubExp
se2 =
do
Maybe Int
body_depth_1 <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n1
Maybe Int
body_depth_2 <- VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n2
case (Maybe Int
body_depth_1, Maybe Int
body_depth_2) of
(Just Int
bd1, Just Int
bd2) -> do
let inner :: Int
inner = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
bd1 Int
bd2
Int -> Int -> Grapher ()
recordCopyableMemory Int
i (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
body_depth Int
inner)
let returns_free_var :: Bool
returns_free_var = Int
inner Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
body_depth
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
(Maybe Int, Maybe Int)
_ ->
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
Bool -> Grapher Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
type Grapher = StateT State (R.Reader Env)
data Env = Env
{
Env -> HostOnlyFuns
envHostOnlyFuns :: HostOnlyFuns,
Env -> Meta
envMeta :: Meta
}
type BodyDepth = Int
data Meta = Meta
{
Meta -> Int
metaForkDepth :: Int,
Meta -> Int
metaBodyDepth :: BodyDepth,
Meta -> Maybe Int
metaGraphId :: Maybe Id
}
type Operands = IdSet
data BodyStats = BodyStats
{
BodyStats -> Bool
bodyHostOnly :: Bool,
BodyStats -> Bool
bodyHasGPUBody :: Bool,
BodyStats -> Bool
bodyReads :: Bool,
BodyStats -> IntSet
bodyOperands :: Operands,
BodyStats -> IntSet
bodyHostOnlyParents :: IS.IntSet
}
instance Semigroup BodyStats where
(BodyStats Bool
ho1 Bool
gb1 Bool
r1 IntSet
o1 IntSet
hop1) <> :: BodyStats -> BodyStats -> BodyStats
<> (BodyStats Bool
ho2 Bool
gb2 Bool
r2 IntSet
o2 IntSet
hop2) =
BodyStats :: Bool -> Bool -> Bool -> IntSet -> IntSet -> BodyStats
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
ho1 Bool -> Bool -> Bool
|| Bool
ho2,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
gb1 Bool -> Bool -> Bool
|| Bool
gb2,
bodyReads :: Bool
bodyReads = Bool
r1 Bool -> Bool -> Bool
|| Bool
r2,
bodyOperands :: IntSet
bodyOperands = IntSet -> IntSet -> IntSet
IS.union IntSet
o1 IntSet
o2,
bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet -> IntSet -> IntSet
IS.union IntSet
hop1 IntSet
hop2
}
instance Monoid BodyStats where
mempty :: BodyStats
mempty =
BodyStats :: Bool -> Bool -> Bool -> IntSet -> IntSet -> BodyStats
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
False,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
False,
bodyReads :: Bool
bodyReads = Bool
False,
bodyOperands :: IntSet
bodyOperands = IntSet
IS.empty,
bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
IS.empty
}
type Graph = MG.Graph Meta
type Sources = ([Id], [Id])
type Sinks = [Id]
type Delayed = (Binding, Exp GPU)
type Binding = (Id, Type)
type CopyableMemoryMap = IM.IntMap BodyDepth
data State = State
{
State -> Graph
stateGraph :: Graph,
State -> IntSet
stateGraphedScalars :: IdSet,
State -> Sources
stateSources :: Sources,
State -> [Int]
stateSinks :: Sinks,
State -> IntMap [Delayed]
stateUpdateAccs :: IM.IntMap [Delayed],
State -> CopyableMemoryMap
stateCopyableMemory :: CopyableMemoryMap,
State -> BodyStats
stateStats :: BodyStats
}
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks)
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, [Int])
execGrapher HostOnlyFuns
hof Grapher a
m =
let s :: State
s = Reader Env State -> Env -> State
forall r a. Reader r a -> r -> a
R.runReader (Grapher a -> State -> Reader Env State
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Grapher a
m State
st) Env
env
in (State -> Graph
stateGraph State
s, State -> Sources
stateSources State
s, State -> [Int]
stateSinks State
s)
where
env :: Env
env =
Env :: HostOnlyFuns -> Meta -> Env
Env
{ envHostOnlyFuns :: HostOnlyFuns
envHostOnlyFuns = HostOnlyFuns
hof,
envMeta :: Meta
envMeta =
Meta :: Int -> Int -> Maybe Int -> Meta
Meta
{ metaForkDepth :: Int
metaForkDepth = Int
0,
metaBodyDepth :: Int
metaBodyDepth = Int
0,
metaGraphId :: Maybe Int
metaGraphId = Maybe Int
forall a. Maybe a
Nothing
}
}
st :: State
st =
State :: Graph
-> IntSet
-> Sources
-> [Int]
-> IntMap [Delayed]
-> CopyableMemoryMap
-> BodyStats
-> State
State
{ stateGraph :: Graph
stateGraph = Graph
forall m. Graph m
MG.empty,
stateGraphedScalars :: IntSet
stateGraphedScalars = IntSet
IS.empty,
stateSources :: Sources
stateSources = ([], []),
stateSinks :: [Int]
stateSinks = [],
stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
forall a. IntMap a
IM.empty,
stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap
forall a. IntMap a
IM.empty,
stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty
}
local :: (Env -> Env) -> Grapher a -> Grapher a
local :: (Env -> Env) -> Grapher a -> Grapher a
local Env -> Env
f = (ReaderT Env Identity (a, State)
-> ReaderT Env Identity (a, State))
-> Grapher a -> Grapher a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((Env -> Env)
-> ReaderT Env Identity (a, State)
-> ReaderT Env Identity (a, State)
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
R.local Env -> Env
f)
ask :: Grapher Env
ask :: Grapher Env
ask = ReaderT Env Identity Env -> Grapher Env
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift ReaderT Env Identity Env
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask
asks :: (Env -> a) -> Grapher a
asks :: (Env -> a) -> Grapher a
asks = ReaderT Env Identity a -> Grapher a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ReaderT Env Identity a -> Grapher a)
-> ((Env -> a) -> ReaderT Env Identity a)
-> (Env -> a)
-> Grapher a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Env -> a) -> ReaderT Env Identity a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks
tellHostOnly :: Grapher ()
tellHostOnly :: Grapher ()
tellHostOnly =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHostOnly :: Bool
bodyHostOnly = Bool
True}}
tellGPUBody :: Grapher ()
tellGPUBody :: Grapher ()
tellGPUBody =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
True}}
tellRead :: Grapher ()
tellRead :: Grapher ()
tellRead =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyReads :: Bool
bodyReads = Bool
True}}
tellOperands :: IdSet -> Grapher ()
tellOperands :: IntSet -> Grapher ()
tellOperands IntSet
is =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
operands :: IntSet
operands = BodyStats -> IntSet
bodyOperands BodyStats
stats
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyOperands :: IntSet
bodyOperands = IntSet
operands IntSet -> IntSet -> IntSet
forall a. Semigroup a => a -> a -> a
<> IntSet
is}}
tellHostOnlyParent :: BodyDepth -> Grapher ()
tellHostOnlyParent :: Int -> Grapher ()
tellHostOnlyParent Int
body_depth =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
parents :: IntSet
parents = BodyStats -> IntSet
bodyHostOnlyParents BodyStats
stats
parents' :: IntSet
parents' = Int -> IntSet -> IntSet
IS.insert Int
body_depth IntSet
parents
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyHostOnlyParents :: IntSet
bodyHostOnlyParents = IntSet
parents'}}
getGraph :: Grapher Graph
getGraph :: Grapher Graph
getGraph = (State -> Graph) -> Grapher Graph
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Graph
stateGraph
getGraphedScalars :: Grapher IdSet
getGraphedScalars :: Grapher IntSet
getGraphedScalars = (State -> IntSet) -> Grapher IntSet
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> IntSet
stateGraphedScalars
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory = (State -> CopyableMemoryMap) -> Grapher CopyableMemoryMap
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> CopyableMemoryMap
stateCopyableMemory
outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth)
outermostCopyableArray :: VName -> Grapher (Maybe Int)
outermostCopyableArray VName
n = Int -> CopyableMemoryMap -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup (VName -> Int
nameToId VName
n) (CopyableMemoryMap -> Maybe Int)
-> Grapher CopyableMemoryMap -> Grapher (Maybe Int)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher CopyableMemoryMap
getCopyableMemory
onlyGraphedScalars :: Foldable t => t VName -> Grapher IdSet
onlyGraphedScalars :: t VName -> Grapher IntSet
onlyGraphedScalars t VName
vs = do
let is :: IntSet
is = (IntSet -> VName -> IntSet) -> IntSet -> t VName -> IntSet
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\IntSet
s VName
n -> Int -> IntSet -> IntSet
IS.insert (VName -> Int
nameToId VName
n) IntSet
s) IntSet
IS.empty t VName
vs
IntSet -> IntSet -> IntSet
IS.intersection IntSet
is (IntSet -> IntSet) -> Grapher IntSet -> Grapher IntSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher IntSet
getGraphedScalars
onlyGraphedScalar :: VName -> Grapher IdSet
onlyGraphedScalar :: VName -> Grapher IntSet
onlyGraphedScalar VName
n = do
let i :: Int
i = VName -> Int
nameToId VName
n
IntSet
gss <- Grapher IntSet
getGraphedScalars
if Int -> IntSet -> Bool
IS.member Int
i IntSet
gss
then IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Int -> IntSet
IS.singleton Int
i)
else IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet
onlyGraphedScalarSubExp :: SubExp -> Grapher IntSet
onlyGraphedScalarSubExp (Constant PrimValue
_) = IntSet -> Grapher IntSet
forall (f :: * -> *) a. Applicative f => a -> f a
pure IntSet
IS.empty
onlyGraphedScalarSubExp (Var VName
n) = VName -> Grapher IntSet
onlyGraphedScalar VName
n
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph Graph -> Graph
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraph :: Graph
stateGraph = Graph -> Graph
f (State -> Graph
stateGraph State
st)}
modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher ()
modifyGraphedScalars :: (IntSet -> IntSet) -> Grapher ()
modifyGraphedScalars IntSet -> IntSet
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraphedScalars :: IntSet
stateGraphedScalars = IntSet -> IntSet
f (State -> IntSet
stateGraphedScalars State
st)}
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory CopyableMemoryMap -> CopyableMemoryMap
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap -> CopyableMemoryMap
f (State -> CopyableMemoryMap
stateCopyableMemory State
st)}
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources Sources -> Sources
f =
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateSources :: Sources
stateSources = Sources -> Sources
f (State -> Sources
stateSources State
st)}
recordCopyableMemory :: Id -> BodyDepth -> Grapher ()
recordCopyableMemory :: Int -> Int -> Grapher ()
recordCopyableMemory Int
i Int
bd =
(CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory (Int -> Int -> CopyableMemoryMap -> CopyableMemoryMap
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
i Int
bd)
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
fork_depth :: Int
fork_depth = Meta -> Int
metaForkDepth Meta
meta
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaForkDepth :: Int
metaForkDepth = Int
fork_depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}}
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
body_depth :: Int
body_depth = Meta -> Int
metaBodyDepth Meta
meta
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaBodyDepth :: Int
metaBodyDepth = Int
body_depth Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1}}
graphIdFor :: Id -> Grapher a -> Grapher a
graphIdFor :: Int -> Grapher a -> Grapher a
graphIdFor Int
i =
(Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaGraphId :: Maybe Int
metaGraphId = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i}}
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats Grapher a
m = do
BodyStats
stats <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty}
a
_ <- Grapher a
m
BodyStats
stats' <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
(State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
stats BodyStats -> BodyStats -> BodyStats
forall a. Semigroup a => a -> a -> a
<> BodyStats
stats'}
BodyStats -> Grapher BodyStats
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyStats
stats'
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun Name
fn = (Env -> Bool) -> Grapher Bool
forall a. (Env -> a) -> Grapher a
asks ((Env -> Bool) -> Grapher Bool) -> (Env -> Bool) -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ Name -> HostOnlyFuns -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
fn (HostOnlyFuns -> Bool) -> (Env -> HostOnlyFuns) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> HostOnlyFuns
envHostOnlyFuns
getMeta :: Grapher Meta
getMeta :: Grapher Meta
getMeta = (Env -> Meta) -> Grapher Meta
forall a. (Env -> a) -> Grapher a
asks Env -> Meta
envMeta
getBodyDepth :: Grapher BodyDepth
getBodyDepth :: StateT State (Reader Env) Int
getBodyDepth = (Env -> Int) -> StateT State (Reader Env) Int
forall a. (Env -> a) -> Grapher a
asks (Meta -> Int
metaBodyDepth (Meta -> Int) -> (Env -> Meta) -> Env -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Meta
envMeta)