module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
(
analyseFunDef,
analyseConsts,
hostOnlyFunDefs,
MigrationTable,
MigrationStatus (..),
shouldMoveStm,
shouldMove,
usedOnHost,
statusOf,
)
where
import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader qualified as R
import Control.Monad.Trans.State.Strict ()
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first, second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.IntSet qualified as IS
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe (fromJust, fromMaybe, isJust, isNothing)
import Data.Sequence qualified as SQ
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Futhark.Error
import Futhark.IR.GPU
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph
( EdgeType (..),
Edges (..),
Id,
IdSet,
Result (..),
Routing (..),
Vertex (..),
)
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph qualified as MG
data MigrationStatus
=
MoveToDevice
|
UsedOnHost
|
StayOnHost
deriving (MigrationStatus -> MigrationStatus -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: MigrationStatus -> MigrationStatus -> Bool
$c/= :: MigrationStatus -> MigrationStatus -> Bool
== :: MigrationStatus -> MigrationStatus -> Bool
$c== :: MigrationStatus -> MigrationStatus -> Bool
Eq, Eq MigrationStatus
MigrationStatus -> MigrationStatus -> Bool
MigrationStatus -> MigrationStatus -> Ordering
MigrationStatus -> MigrationStatus -> MigrationStatus
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmin :: MigrationStatus -> MigrationStatus -> MigrationStatus
max :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmax :: MigrationStatus -> MigrationStatus -> MigrationStatus
>= :: MigrationStatus -> MigrationStatus -> Bool
$c>= :: MigrationStatus -> MigrationStatus -> Bool
> :: MigrationStatus -> MigrationStatus -> Bool
$c> :: MigrationStatus -> MigrationStatus -> Bool
<= :: MigrationStatus -> MigrationStatus -> Bool
$c<= :: MigrationStatus -> MigrationStatus -> Bool
< :: MigrationStatus -> MigrationStatus -> Bool
$c< :: MigrationStatus -> MigrationStatus -> Bool
compare :: MigrationStatus -> MigrationStatus -> Ordering
$ccompare :: MigrationStatus -> MigrationStatus -> Ordering
Ord, Id -> MigrationStatus -> ShowS
[MigrationStatus] -> ShowS
MigrationStatus -> String
forall a.
(Id -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MigrationStatus] -> ShowS
$cshowList :: [MigrationStatus] -> ShowS
show :: MigrationStatus -> String
$cshow :: MigrationStatus -> String
showsPrec :: Id -> MigrationStatus -> ShowS
$cshowsPrec :: Id -> MigrationStatus -> ShowS
Show)
newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus)
instance Semigroup MigrationTable where
MigrationTable IntMap MigrationStatus
a <> :: MigrationTable -> MigrationTable -> MigrationTable
<> MigrationTable IntMap MigrationStatus
b = IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus
a forall a. IntMap a -> IntMap a -> IntMap a
`IM.union` IntMap MigrationStatus
b)
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf VName
n (MigrationTable IntMap MigrationStatus
mt) =
forall a. a -> Maybe a -> a
fromMaybe MigrationStatus
StayOnHost forall a b. (a -> b) -> a -> b
$ forall a. Id -> IntMap a -> Maybe a
IM.lookup (VName -> Id
baseTag VName
n) IntMap MigrationStatus
mt
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slice))) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
movedOperand Slice SubExp
slice
where
movedOperand :: SubExp -> Bool
movedOperand (Var VName
op) = VName -> MigrationTable -> MigrationStatus
statusOf VName
op MigrationTable
mt forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
movedOperand SubExp
_ = Bool
False
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ Apply {}) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Match [SubExp]
cond [Case (Body GPU)]
_ Body GPU
_ MatchDec (BranchType GPU)
_)) MigrationTable
mt =
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice) forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> MigrationTable -> MigrationStatus
`statusOf` MigrationTable
mt)) forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
cond
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (DoLoop [(FParam GPU, SubExp)]
_ (WhileLoop VName
n) Body GPU
_)) MigrationTable
mt =
VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm Stm GPU
_ MigrationTable
_ = Bool
False
shouldMove :: VName -> MigrationTable -> Bool
shouldMove :: VName -> MigrationTable -> Bool
shouldMove VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt forall a. Eq a => a -> a -> Bool
/= MigrationStatus
MoveToDevice
type HostOnlyFuns = Set Name
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs =
let names :: [Name]
names = forall a b. (a -> b) -> [a] -> [b]
map forall {k} (rep :: k). FunDef rep -> Name
funDefName [FunDef GPU]
funs
call_map :: Map Name (Maybe HostOnlyFuns)
call_map = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names (forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Maybe HostOnlyFuns
checkFunDef [FunDef GPU]
funs)
in forall a. Ord a => [a] -> Set a
S.fromList [Name]
names forall a. Ord a => Set a -> Set a -> Set a
\\ forall {a}. Map Name a -> HostOnlyFuns
keysToSet (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
call_map)
where
keysToSet :: Map Name a -> HostOnlyFuns
keysToSet = forall a. Eq a => [a] -> Set a
S.fromAscList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall k a. Map k a -> [k]
M.keys
removeHostOnly :: Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
cm =
let (Map Name (Maybe HostOnlyFuns)
host_only, Map Name (Maybe HostOnlyFuns)
cm') = forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition forall {a}. Maybe a -> Bool
isHostOnly Map Name (Maybe HostOnlyFuns)
cm
in if forall k a. Map k a -> Bool
M.null Map Name (Maybe HostOnlyFuns)
host_only
then Map Name (Maybe HostOnlyFuns)
cm'
else Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly forall a b. (a -> b) -> a -> b
$ forall a b k. (a -> b) -> Map k a -> Map k b
M.map (forall {a}. Ord a => Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls forall a b. (a -> b) -> a -> b
$ forall {a}. Map Name a -> HostOnlyFuns
keysToSet Map Name (Maybe HostOnlyFuns)
host_only) Map Name (Maybe HostOnlyFuns)
cm'
isHostOnly :: Maybe a -> Bool
isHostOnly = forall {a}. Maybe a -> Bool
isNothing
checkCalls :: Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls Set a
hostOnlyFuns (Just Set a
calls)
| Set a
hostOnlyFuns forall a. Ord a => Set a -> Set a -> Bool
`S.disjoint` Set a
calls =
forall a. a -> Maybe a
Just Set a
calls
checkCalls Set a
_ Maybe (Set a)
_ =
forall a. Maybe a
Nothing
checkFunDef :: FunDef GPU -> Maybe (Set Name)
checkFunDef :: FunDef GPU -> Maybe HostOnlyFuns
checkFunDef FunDef GPU
fun = do
[Param DeclType] -> Maybe ()
checkFParams (forall {k} (rep :: k). FunDef rep -> [FParam rep]
funDefParams FunDef GPU
fun)
forall {u}. [TypeBase ExtShape u] -> Maybe ()
checkRetTypes (forall {k} (rep :: k). FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fun)
Body GPU -> Maybe HostOnlyFuns
checkBody (forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPU
fun)
where
hostOnly :: Maybe a
hostOnly = forall a. Maybe a
Nothing
ok :: Maybe ()
ok = forall a. a -> Maybe a
Just ()
check :: (a -> Bool) -> t a -> Maybe ()
check a -> Bool
isArr t a
as = if forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any a -> Bool
isArr t a
as then forall a. Maybe a
hostOnly else Maybe ()
ok
checkFParams :: [Param DeclType] -> Maybe ()
checkFParams = forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check forall t. Typed t => t -> Bool
isArray
checkLParams :: [(FParam GPU, b)] -> Maybe ()
checkLParams = forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check (forall t. Typed t => t -> Bool
isArray forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst)
checkRetTypes :: [TypeBase ExtShape u] -> Maybe ()
checkRetTypes = forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType
checkPats :: [PatElem (LetDec GPU)] -> Maybe ()
checkPats = forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check forall t. Typed t => t -> Bool
isArray
checkLoopForm :: LoopForm rep -> Maybe ()
checkLoopForm (ForLoop VName
_ IntType
_ SubExp
_ ((LParam rep, VName)
_ : [(LParam rep, VName)]
_)) = forall a. Maybe a
hostOnly
checkLoopForm LoopForm rep
_ = Maybe ()
ok
checkBody :: Body GPU -> Maybe HostOnlyFuns
checkBody = Stms GPU -> Maybe HostOnlyFuns
checkStms forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Stms rep
bodyStms
checkStms :: Stms GPU -> Maybe HostOnlyFuns
checkStms Stms GPU
stms = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Stm GPU -> Maybe HostOnlyFuns
checkStm Stms GPU
stms
checkStm :: Stm GPU -> Maybe HostOnlyFuns
checkStm (Let (Pat [PatElem (LetDec GPU)]
pats) StmAux (ExpDec GPU)
_ Exp GPU
e) = [PatElem (LetDec GPU)] -> Maybe ()
checkPats [PatElem (LetDec GPU)]
pats forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> Maybe HostOnlyFuns
checkExp Exp GPU
e
checkExp :: Exp GPU -> Maybe HostOnlyFuns
checkExp (BasicOp (Index VName
_ Slice SubExp
_)) = forall a. Maybe a
hostOnly
checkExp (WithAcc [WithAccInput GPU]
_ Lambda GPU
_) = forall a. Maybe a
hostOnly
checkExp (Op Op GPU
_) = forall a. Maybe a
hostOnly
checkExp (Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) = forall a. a -> Maybe a
Just (forall a. a -> Set a
S.singleton Name
fn)
checkExp (Match [SubExp]
_ [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_) =
forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Body GPU -> Maybe HostOnlyFuns
checkBody (Body GPU
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPU)]
cases)
checkExp (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
forall {b}. [(FParam GPU, b)] -> Maybe ()
checkLParams [(FParam GPU, SubExp)]
params
forall {k} {rep :: k}. LoopForm rep -> Maybe ()
checkLoopForm LoopForm GPU
lform
Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
body
checkExp BasicOp {} = forall a. a -> Maybe a
Just forall a. Set a
S.empty
type HostUsage = [Id]
nameToId :: VName -> Id
nameToId :: VName -> Id
nameToId = VName -> Id
baseTag
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts =
let usage :: [Id]
usage = forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey (forall {t}. Typed t => Names -> [Id] -> VName -> t -> [Id]
f forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn [FunDef GPU]
funs) [] (forall {k} (rep :: k) a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts)
in HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
consts
where
f :: Names -> [Id] -> VName -> t -> [Id]
f Names
free [Id]
usage VName
n t
t
| forall t. Typed t => t -> Bool
isScalar t
t,
VName
n VName -> Names -> Bool
`nameIn` Names
free =
VName -> Id
nameToId VName
n forall a. a -> [a] -> [a]
: [Id]
usage
| Bool
otherwise =
[Id]
usage
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd =
let body :: Body GPU
body = forall {k} (rep :: k). FunDef rep -> Body rep
funDefBody FunDef GPU
fd
usage :: [Id]
usage = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' forall {shape} {u}. [Id] -> (SubExpRes, TypeBase shape u) -> [Id]
f [] forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body) (forall {k} (rep :: k). FunDef rep -> [RetType rep]
funDefRetType FunDef GPU
fd)
stms :: Stms GPU
stms = forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body
in HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
stms
where
f :: [Id] -> (SubExpRes, TypeBase shape u) -> [Id]
f [Id]
usage (SubExpRes Certs
_ (Var VName
n), TypeBase shape u
t) | forall shape u. TypeBase shape u -> Bool
isScalarType TypeBase shape u
t = VName -> Id
nameToId VName
n forall a. a -> [a] -> [a]
: [Id]
usage
f [Id]
usage (SubExpRes, TypeBase shape u)
_ = [Id]
usage
analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable
analyseStms :: HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Id]
_) = HostOnlyFuns -> [Id] -> Stms GPU -> (Graph, Sources, [Id])
buildGraph HostOnlyFuns
hof [Id]
usage Stms GPU
stms
([Id]
routed, [Id]
unrouted) = Sources
srcs
([Id]
_, Graph
g') = forall m. [Id] -> Graph m -> ([Id], Graph m)
MG.routeMany [Id]
unrouted Graph
g
f :: ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands), Visited ())
st' = forall m a.
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> (a, Visited ())
-> EdgeType
-> Id
-> (a, Visited ())
MG.fold Graph
g' forall {m}.
(Operands, Operands, Operands)
-> EdgeType -> Vertex m -> (Operands, Operands, Operands)
visit ((Operands, Operands, Operands), Visited ())
st' EdgeType
Normal
st :: ((Operands, Operands, Operands), Visited ())
st = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands)
initial, forall a. Visited a
MG.none) [Id]
unrouted
(Operands
vr, Operands
vn, Operands
tn) = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands), Visited ())
st [Id]
routed
in
IntMap MigrationStatus -> MigrationTable
MigrationTable forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IM.unions
[ forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) Operands
vr,
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) Operands
vn,
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (forall a b. a -> b -> a
const MigrationStatus
UsedOnHost) Operands
tn
]
where
initial :: (Operands, Operands, Operands)
initial = (Operands
IS.empty, Operands
IS.empty, Operands
IS.empty)
visit :: (Operands, Operands, Operands)
-> EdgeType -> Vertex m -> (Operands, Operands, Operands)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Reversed Vertex m
v =
let vr' :: Operands
vr' = Id -> Operands -> Operands
IS.insert (forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
vr
in (Operands
vr', Operands
vn, Operands
tn)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Normal v :: Vertex m
v@Vertex {vertexRouting :: forall m. Vertex m -> Routing
vertexRouting = Routing
NoRoute} =
let vn' :: Operands
vn' = Id -> Operands -> Operands
IS.insert (forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
vn
in (Operands
vr, Operands
vn', Operands
tn)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Normal Vertex m
v =
let tn' :: Operands
tn' = Id -> Operands -> Operands
IS.insert (forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
tn
in (Operands
vr, Operands
vn, Operands
tn')
isScalar :: Typed t => t -> Bool
isScalar :: forall t. Typed t => t -> Bool
isScalar = forall shape u. TypeBase shape u -> Bool
isScalarType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Typed t => t -> Type
typeOf
isScalarType :: TypeBase shape u -> Bool
isScalarType :: forall shape u. TypeBase shape u -> Bool
isScalarType (Prim PrimType
Unit) = Bool
False
isScalarType (Prim PrimType
_) = Bool
True
isScalarType TypeBase shape u
_ = Bool
False
isArray :: Typed t => t -> Bool
isArray :: forall t. Typed t => t -> Bool
isArray = forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Typed t => t -> Type
typeOf
isArrayType :: ArrayShape shape => TypeBase shape u -> Bool
isArrayType :: forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType = (Id
0 <) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall shape u. ArrayShape shape => TypeBase shape u -> Id
arrayRank
buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks)
buildGraph :: HostOnlyFuns -> [Id] -> Stms GPU -> (Graph, Sources, [Id])
buildGraph HostOnlyFuns
hof [Id]
usage Stms GPU
stms =
let (Graph
g, Sources
srcs, [Id]
sinks) = forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Id])
execGrapher HostOnlyFuns
hof (Stms GPU -> Grapher ()
graphStms Stms GPU
stms)
g' :: Graph
g' = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall m. Id -> Graph m -> Graph m
MG.connectToSink) Graph
g [Id]
usage
in (Graph
g', Sources
srcs, [Id]
sinks)
graphBody :: Body GPU -> Grapher ()
graphBody :: Body GPU -> Grapher ()
graphBody Body GPU
body = do
let res_ops :: Operands
res_ops = Names -> Operands
namesIntSet forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
BodyStats
body_stats <-
forall a. Grapher a -> Grapher BodyStats
captureBodyStats forall a b. (a -> b) -> a -> b
$
forall a. Grapher a -> Grapher a
incBodyDepthFor (Stms GPU -> Grapher ()
graphStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Operands -> Grapher ()
tellOperands Operands
res_ops)
Id
body_depth <- (Id
1 +) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT State (Reader Env) Id
getBodyDepth
let host_only :: Bool
host_only = Id -> Operands -> Bool
IS.member Id
body_depth (BodyStats -> Operands
bodyHostOnlyParents BodyStats
body_stats)
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
hops' :: Operands
hops' = Id -> Operands -> Operands
IS.delete Id
body_depth (BodyStats -> Operands
bodyHostOnlyParents BodyStats
stats)
stats' :: BodyStats
stats' = if Bool
host_only then BodyStats
stats {bodyHostOnly :: Bool
bodyHostOnly = Bool
True} else BodyStats
stats
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats' {bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
hops'}}
graphStms :: Stms GPU -> Grapher ()
graphStms :: Stms GPU -> Grapher ()
graphStms = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> Grapher ()
graphStm
graphStm :: Stm GPU -> Grapher ()
graphStm :: Stm GPU -> Grapher ()
graphStm Stm GPU
stm = do
let bs :: [Binding]
bs = Stm GPU -> [Binding]
boundBy Stm GPU
stm
let e :: Exp GPU
e = forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm
case Exp GPU
e of
BasicOp (SubExp SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (Opaque OpaqueOp
_ SubExp
se) -> do
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
BasicOp (ArrayLit [SubExp]
arr Type
t)
| forall t. Typed t => t -> Bool
isScalar Type
t,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall {a}. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Maybe VName
subExpVar) [SubExp]
arr ->
Binding -> Grapher ()
graphAutoMove (forall {a}. [a] -> a
one [Binding]
bs)
BasicOp UnOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp BinOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp CmpOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp ConvOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp Assert {} ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
_ Slice SubExp
slice)
| forall {d}. Slice d -> Bool
isFixed Slice SubExp
slice ->
Binding -> Grapher ()
graphRead (forall {a}. [a] -> a
one [Binding]
bs)
BasicOp {}
| [(Id
_, Type
t)] <- [Binding]
bs,
[SubExp]
dims <- forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t,
[SubExp]
dims forall a. Eq a => a -> a -> Bool
/= [],
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims ->
[Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
BasicOp (Index VName
arr Slice SubExp
s) -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (forall d. Slice d -> [d]
sliceDims Slice SubExp
s) Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Update Safety
_ VName
arr Slice SubExp
slice SubExp
_)
| forall {d}. Slice d -> Bool
isFixed Slice SubExp
slice -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatIndex VName
arr FlatSlice SubExp
s) -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
s) Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (FlatUpdate VName
arr FlatSlice SubExp
_ VName
_) -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Scratch PrimType
_ [SubExp]
s) ->
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [SubExp]
s Exp GPU
e
BasicOp (Reshape ReshapeKind
_ ShapeBase SubExp
s VName
arr) -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
s) Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rearrange [Id]
_ VName
arr) -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp (Rotate [SubExp]
_ VName
arr) -> do
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
BasicOp ArrayLit {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Update {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Concat {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Copy {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Manifest {} ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Iota {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp Replicate {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
BasicOp UpdateAcc {} ->
Binding -> Exp GPU -> Grapher ()
graphUpdateAcc (forall {a}. [a] -> a
one [Binding]
bs) Exp GPU
e
Apply Name
fn [(SubExp, Diet)]
_ [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_ ->
Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e
Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_ ->
[Binding]
-> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch [Binding]
bs [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody
DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body ->
[Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [Binding]
bs [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body
WithAcc [WithAccInput GPU]
inputs Lambda GPU
f ->
[Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f
Op GPUBody {} ->
Grapher ()
tellGPUBody
Op Op GPU
_ ->
Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
where
one :: [a] -> a
one [a
x] = a
x
one [a]
_ = forall a. String -> a
compilerBugS String
"Type error: unexpected number of pattern elements."
isFixed :: Slice d -> Bool
isFixed = forall {a}. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. Slice d -> Maybe [d]
sliceIndices
graphInefficientReturn :: t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn t SubExp
new_dims Exp GPU
e = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
hostSize t SubExp
new_dims
Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges Edges
ToSink
hostSize :: SubExp -> Grapher ()
hostSize (Var VName
n) = VName -> Grapher ()
hostSizeVar VName
n
hostSize SubExp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
hostSizeVar :: VName -> Grapher ()
hostSizeVar = Id -> Grapher ()
requiredOnHost forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Id
nameToId
boundBy :: Stm GPU -> [Binding]
boundBy :: Stm GPU -> [Binding]
boundBy = forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem VName
n Type
t) -> (VName -> Id
nameToId VName
n, Type
t)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. Pat dec -> [PatElem dec]
patElems forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e = do
Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e
let edges :: Edges
edges = [Id] -> Edges
MG.declareEdges (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [Binding]
bs)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Operands -> Bool
IS.null Operands
ops) (forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Binding -> Grapher ()
addVertex [Binding]
bs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> Operands -> Grapher ()
addEdges Edges
edges Operands
ops)
graphRead :: Binding -> Grapher ()
graphRead :: Binding -> Grapher ()
graphRead Binding
b = do
Binding -> Grapher ()
addSource Binding
b
Grapher ()
tellRead
graphAutoMove :: Binding -> Grapher ()
graphAutoMove :: Binding -> Grapher ()
graphAutoMove =
Binding -> Grapher ()
addSource
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e = do
Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e
Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
ops
Grapher ()
tellHostOnly
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc Binding
b Exp GPU
e | (Id
_, Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) <- Binding
b =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let accs :: IntMap [Delayed]
accs = State -> IntMap [Delayed]
stateUpdateAccs State
st
accs' :: IntMap [Delayed]
accs' = forall a. (Maybe a -> Maybe a) -> Id -> IntMap a -> IntMap a
IM.alter Maybe [Delayed] -> Maybe [Delayed]
add (VName -> Id
nameToId VName
a) IntMap [Delayed]
accs
in State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
accs'}
where
add :: Maybe [Delayed] -> Maybe [Delayed]
add Maybe [Delayed]
Nothing = forall a. a -> Maybe a
Just [(Binding
b, Exp GPU
e)]
add (Just [Delayed]
xs) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ (Binding
b, Exp GPU
e) forall a. a -> [a] -> [a]
: [Delayed]
xs
graphUpdateAcc Binding
_ Exp GPU
_ =
forall a. String -> a
compilerBugS
String
"Type error: UpdateAcc did not produce accumulator typed value."
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e = do
Bool
hof <- Name -> Grapher Bool
isHostOnlyFun Name
fn
if Bool
hof
then Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
else [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
graphMatch :: [Binding] -> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch :: [Binding]
-> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch [Binding]
bs [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody = do
Bool
body_host_only <-
forall a. Grapher a -> Grapher a
incForkDepthFor forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any BodyStats -> Bool
bodyHostOnly
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a. Grapher a -> Grapher BodyStats
captureBodyStats forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Grapher ()
graphBody) (Body GPU
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map forall body. Case body -> body
caseBody [Case (Body GPU)]
cases)
let branch_results :: [[SubExp]]
branch_results = forall {k} {rep :: k}. Body rep -> [SubExp]
results Body GPU
defbody forall a. a -> [a] -> [a]
: forall a b. (a -> b) -> [a] -> [b]
map (forall {k} {rep :: k}. Body rep -> [SubExp]
results forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
Bool
may_copy_results <- [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches [Binding]
bs [[SubExp]]
branch_results
let may_migrate :: Bool
may_migrate = Bool -> Bool
not Bool
body_host_only Bool -> Bool -> Bool
&& Bool
may_copy_results
Operands
cond_id <-
if Bool
may_migrate
then forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
ses
else do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Id -> Grapher ()
connectToSink forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Id
nameToId) ([SubExp] -> [VName]
subExpVars [SubExp]
ses)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
Operands -> Grapher ()
tellOperands Operands
cond_id
[Operands]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Operands -> [SubExp] -> Grapher Operands
comb Operands
cond_id) forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
L.transpose [[SubExp]]
branch_results
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> Operands -> Grapher ()
createNode) (forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [Operands]
ret)
where
results :: Body rep -> [SubExp]
results = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Body rep -> Result
bodyResult
comb :: Operands -> [SubExp] -> Grapher Operands
comb Operands
ci [SubExp]
a = (Operands
ci <>) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars (forall a. Ord a => [a] -> Set a
S.fromList forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
a)
type ReachableBindings = IdSet
type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings)
type NonExhausted = [Id]
type LoopValue = (Binding, Id, SubExp, SubExp)
graphLoop ::
[Binding] ->
[(FParam GPU, SubExp)] ->
LoopForm GPU ->
Body GPU ->
Grapher ()
graphLoop :: [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm GPU -> Body GPU -> Grapher ()
graphLoop [] [(FParam GPU, SubExp)]
_ LoopForm GPU
_ Body GPU
_ =
forall a. String -> a
compilerBugS String
"Loop statement bound no variable; should have been eliminated."
graphLoop (Binding
b : [Binding]
bs) [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body = do
Graph
g <- Grapher Graph
getGraph
BodyStats
stats <- forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Id
subgraphId forall a. Id -> Grapher a -> Grapher a
`graphIdFor` Grapher ()
graphTheLoop)
let args :: [SubExp]
args = forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> b
snd [(FParam GPU, SubExp)]
params
let results :: [SubExp]
results = forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
Bool
may_copy_results <- [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches (Binding
b forall a. a -> [a] -> [a]
: [Binding]
bs) [[SubExp]
args, [SubExp]
results]
let may_migrate :: Bool
may_migrate = Bool -> Bool
not (BodyStats -> Bool
bodyHostOnly BodyStats
stats) Bool -> Bool -> Bool
&& Bool
may_copy_results
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
may_migrate forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ (Var VName
n) [(LParam GPU, VName)]
_ -> Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
n)
WhileLoop VName
n
| (Binding
_, Id
p, SubExp
_, SubExp
res) <- VName -> (Binding, Id, SubExp, SubExp)
loopValueFor VName
n -> do
Id -> Grapher ()
connectToSink Id
p
case SubExp
res of
Var VName
v -> Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
v)
SubExp
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
LoopForm GPU
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding, Id, SubExp, SubExp) -> Grapher ()
mergeLoopParam [(Binding, Id, SubExp, SubExp)]
loopValues
[Id]
srcs <- Id -> Grapher [Id]
routeSubgraph Id
subgraphId
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Binding, Id, SubExp, SubExp)]
loopValues forall a b. (a -> b) -> a -> b
$ \(Binding
bnd, Id
p, SubExp
_, SubExp
_) -> Binding -> Operands -> Grapher ()
createNode Binding
bnd (Id -> Operands
IS.singleton Id
p)
Graph
g' <- Grapher Graph
getGraph
let (Operands
dbs, ReachableBindingsCache
rbc) = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Graph
-> (Operands, ReachableBindingsCache)
-> Id
-> (Operands, ReachableBindingsCache)
deviceBindings Graph
g') (Operands
IS.empty, forall a. Visited a
MG.none) [Id]
srcs
(Sources -> Sources) -> Grapher ()
modifySources forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Operands -> [Id]
IS.toList Operands
dbs <>)
let ops :: Operands
ops = (Id -> Bool) -> Operands -> Operands
IS.filter (forall m. Id -> Graph m -> Bool
`MG.member` Graph
g) (BodyStats -> Operands
bodyOperands BodyStats
stats)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ReachableBindingsCache -> Id -> Grapher ReachableBindingsCache
connectOperand ReachableBindingsCache
rbc (Operands -> [Id]
IS.elems Operands
ops)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
may_migrate forall a b. (a -> b) -> a -> b
$ case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
_ ->
SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
n forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges (Operands -> Maybe Operands -> Edges
ToNodes Operands
bindings forall a. Maybe a
Nothing)
WhileLoop VName
n
| (Binding
_, Id
_, SubExp
arg, SubExp
_) <- VName -> (Binding, Id, SubExp, SubExp)
loopValueFor VName
n ->
SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
arg forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges (Operands -> Maybe Operands -> Edges
ToNodes Operands
bindings forall a. Maybe a
Nothing)
where
subgraphId :: Id
subgraphId :: Id
subgraphId = forall a b. (a, b) -> a
fst Binding
b
loopValues :: [LoopValue]
loopValues :: [(Binding, Id, SubExp, SubExp)]
loopValues =
let tmp :: [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp = forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Binding
b forall a. a -> [a] -> [a]
: [Binding]
bs) [(FParam GPU, SubExp)]
params (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
tmp' :: [(Binding, Id, SubExp, SubExp)]
tmp' = forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> [a] -> [b]
map [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp forall a b. (a -> b) -> a -> b
$
\(Binding
bnd, (Param DeclType
p, SubExp
arg), SubExpRes
res) ->
let i :: Id
i = VName -> Id
nameToId (forall dec. Param dec -> VName
paramName Param DeclType
p)
in (Binding
bnd, Id
i, SubExp
arg, SubExpRes -> SubExp
resSubExp SubExpRes
res)
in forall a. (a -> Bool) -> [a] -> [a]
filter (\((Id
_, Type
t), Id
_, SubExp
_, SubExp
_) -> forall t. Typed t => t -> Bool
isScalar Type
t) [(Binding, Id, SubExp, SubExp)]
tmp'
bindings :: IdSet
bindings :: Operands
bindings = [Id] -> Operands
IS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (\((Id
i, Type
_), Id
_, SubExp
_, SubExp
_) -> Id
i) [(Binding, Id, SubExp, SubExp)]
loopValues
loopValueFor :: VName -> LoopValue
loopValueFor :: VName -> (Binding, Id, SubExp, SubExp)
loopValueFor VName
n =
forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Binding
_, Id
p, SubExp
_, SubExp
_) -> Id
p forall a. Eq a => a -> a -> Bool
== VName -> Id
nameToId VName
n) [(Binding, Id, SubExp, SubExp)]
loopValues
graphTheLoop :: Grapher ()
graphTheLoop :: Grapher ()
graphTheLoop = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {a} {d}. ((a, Type), Id, SubExp, d) -> Grapher ()
graphParam [(Binding, Id, SubExp, SubExp)]
loopValues
case LoopForm GPU
lform of
ForLoop VName
_ IntType
_ SubExp
n [(LParam GPU, VName)]
elems -> do
SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
n forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Operands -> Grapher ()
tellOperands
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {dec}. Typed dec => (Param dec, VName) -> Grapher ()
graphForInElem [(LParam GPU, VName)]
elems
WhileLoop VName
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Body GPU -> Grapher ()
graphBody Body GPU
body
where
graphForInElem :: (Param dec, VName) -> Grapher ()
graphForInElem (Param dec
p, VName
arr) = do
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall t. Typed t => t -> Bool
isScalar Param dec
p) forall a b. (a -> b) -> a -> b
$ Binding -> Grapher ()
addSource (VName -> Id
nameToId forall a b. (a -> b) -> a -> b
$ forall dec. Param dec -> VName
paramName Param dec
p, forall t. Typed t => t -> Type
typeOf Param dec
p)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall t. Typed t => t -> Bool
isArray Param dec
p) forall a b. (a -> b) -> a -> b
$ (VName -> Id
nameToId (forall dec. Param dec -> VName
paramName Param dec
p), forall t. Typed t => t -> Type
typeOf Param dec
p) Binding -> VName -> Grapher ()
`reuses` VName
arr
graphParam :: ((a, Type), Id, SubExp, d) -> Grapher ()
graphParam ((a
_, Type
t), Id
p, SubExp
arg, d
_) =
do
Binding -> Grapher ()
addVertex (Id
p, Type
t)
Operands
ops <- SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
arg
Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge Id
p) Operands
ops
mergeLoopParam :: LoopValue -> Grapher ()
mergeLoopParam :: (Binding, Id, SubExp, SubExp) -> Grapher ()
mergeLoopParam (Binding
_, Id
p, SubExp
_, SubExp
res)
| Var VName
n <- SubExp
res,
Id
ret <- VName -> Id
nameToId VName
n,
Id
ret forall a. Eq a => a -> a -> Bool
/= Id
p =
Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge Id
p) (Id -> Operands
IS.singleton Id
ret)
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
deviceBindings ::
Graph ->
(ReachableBindings, ReachableBindingsCache) ->
Id ->
(ReachableBindings, ReachableBindingsCache)
deviceBindings :: Graph
-> (Operands, ReachableBindingsCache)
-> Id
-> (Operands, ReachableBindingsCache)
deviceBindings Graph
g (Operands
rb, ReachableBindingsCache
rbc) Id
i =
let (Result Operands
r, ReachableBindingsCache
rbc') = forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Id
-> (Result a, Visited (Result a))
MG.reduce Graph
g Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Id
i
in case Result Operands
r of
Produced Operands
rb' -> (Operands
rb forall a. Semigroup a => a -> a -> a
<> Operands
rb', ReachableBindingsCache
rbc')
Result Operands
_ ->
forall a. String -> a
compilerBugS
String
"Migration graph sink could be reached from source after it\
\ had been attempted routed."
bindingReach ::
ReachableBindings ->
EdgeType ->
Vertex Meta ->
ReachableBindings
bindingReach :: Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach Operands
rb EdgeType
_ Vertex Meta
v
| Id
i <- forall m. Vertex m -> Id
vertexId Vertex Meta
v,
Id -> Operands -> Bool
IS.member Id
i Operands
bindings =
Id -> Operands -> Operands
IS.insert Id
i Operands
rb
| Bool
otherwise =
Operands
rb
connectOperand ::
ReachableBindingsCache ->
Id ->
Grapher ReachableBindingsCache
connectOperand :: ReachableBindingsCache -> Id -> Grapher ReachableBindingsCache
connectOperand ReachableBindingsCache
cache Id
op = do
Graph
g <- Grapher Graph
getGraph
case forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
op Graph
g of
Maybe (Vertex Meta)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
Just Vertex Meta
v ->
case forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v of
Edges
ToSink -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
ToNodes Operands
es Maybe Operands
Nothing -> Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> Grapher ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Id
op Operands
es
ToNodes Operands
_ (Just Operands
nx) -> Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> Grapher ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Id
op Operands
nx
where
connectOp ::
Graph ->
ReachableBindingsCache ->
Id ->
IdSet ->
Grapher ReachableBindingsCache
connectOp :: Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> Grapher ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
rbc Id
i Operands
es = do
let (Result Operands
res, [Id]
nx, ReachableBindingsCache
rbc') = Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
IS.empty, [], ReachableBindingsCache
rbc) (Operands -> [Id]
IS.elems Operands
es)
case Result Operands
res of
Result Operands
FoundSink -> Id -> Grapher ()
connectToSink Id
i
Produced Operands
rb -> (Graph -> Graph) -> Grapher ()
modifyGraph forall a b. (a -> b) -> a -> b
$ forall m. (Vertex m -> Vertex m) -> Id -> Graph m -> Graph m
MG.adjust ([Id] -> Operands -> Vertex Meta -> Vertex Meta
updateEdges [Id]
nx Operands
rb) Id
i
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
rbc'
updateEdges ::
NonExhausted ->
ReachableBindings ->
Vertex Meta ->
Vertex Meta
updateEdges :: [Id] -> Operands -> Vertex Meta -> Vertex Meta
updateEdges [Id]
nx Operands
rb Vertex Meta
v
| ToNodes Operands
es Maybe Operands
_ <- forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v =
let nx' :: Operands
nx' = [Id] -> Operands
IS.fromList [Id]
nx
es' :: Edges
es' = Operands -> Maybe Operands -> Edges
ToNodes (Operands
rb forall a. Semigroup a => a -> a -> a
<> Operands
es) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Operands
rb forall a. Semigroup a => a -> a -> a
<> Operands
nx')
in Vertex Meta
v {vertexEdges :: Edges
vertexEdges = Edges
es'}
| Bool
otherwise = Vertex Meta
v
findBindings ::
Graph ->
(ReachableBindings, NonExhausted, ReachableBindingsCache) ->
[Id] ->
(MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache)
findBindings :: Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
_ (Operands
rb, [Id]
nx, ReachableBindingsCache
rbc) [] =
(forall a. a -> Result a
Produced Operands
rb, [Id]
nx, ReachableBindingsCache
rbc)
findBindings Graph
g (Operands
rb, [Id]
nx, ReachableBindingsCache
rbc) (Id
i : [Id]
is)
| Just Vertex Meta
v <- forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i Graph
g,
Just Id
gid <- Meta -> Maybe Id
metaGraphId (forall m. Vertex m -> m
vertexMeta Vertex Meta
v),
Id
gid forall a. Eq a => a -> a -> Bool
== Id
subgraphId
=
let (Result Operands
res, ReachableBindingsCache
rbc') = forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Id
-> (Result a, Visited (Result a))
MG.reduce Graph
g Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Id
i
in case Result Operands
res of
Result Operands
FoundSink -> (forall a. Result a
FoundSink, [], ReachableBindingsCache
rbc')
Produced Operands
rb' -> Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
rb forall a. Semigroup a => a -> a -> a
<> Operands
rb', [Id]
nx, ReachableBindingsCache
rbc') [Id]
is
| Bool
otherwise =
Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
rb, Id
i forall a. a -> [a] -> [a]
: [Id]
nx, ReachableBindingsCache
rbc) [Id]
is
graphWithAcc ::
[Binding] ->
[WithAccInput GPU] ->
Lambda GPU ->
Grapher ()
graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> Grapher ()
graphBody (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {shape} {u} {a} {b}.
(TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f) [WithAccInput GPU]
inputs
let arrs :: [SubExp]
arrs = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(ShapeBase SubExp
_, [VName]
as, Maybe (Lambda GPU, [SubExp])
_) -> forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as) [WithAccInput GPU]
inputs
let res :: Result
res = forall a. Id -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Id
length [WithAccInput GPU]
inputs) (forall {k} (rep :: k). Body rep -> Result
bodyResult forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Bool
_ <- [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs ([SubExp]
arrs forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp Result
res)
[Operands]
ret <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> Grapher Operands
onlyGraphedScalarSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) Result
res
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> Operands -> Grapher ()
createNode) forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall a. Id -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Id
length [SubExp]
arrs) [Binding]
bs) [Operands]
ret
where
graph :: (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph (Acc VName
a ShapeBase SubExp
_ [Type]
types u
_, (a
_, b
_, Maybe (Lambda GPU, [SubExp])
comb)) = do
let i :: Id
i = VName -> Id
nameToId VName
a
[Delayed]
delayed <- forall a. a -> Maybe a -> a
fromMaybe [] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (forall a. Id -> IntMap a -> Maybe a
IM.lookup Id
i forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap [Delayed]
stateUpdateAccs)
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = forall a. Id -> IntMap a -> IntMap a
IM.delete Id
i (State -> IntMap [Delayed]
stateUpdateAccs State
st)}
Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Id
i [Type]
types (forall a b. (a, b) -> a
fst forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Lambda GPU, [SubExp])
comb) [Delayed]
delayed
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
connectSubExpToSink forall a b. (a -> b) -> a -> b
$ forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] forall a b. (a, b) -> b
snd Maybe (Lambda GPU, [SubExp])
comb
graph (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
_ =
forall a. String -> a
compilerBugS String
"Type error: WithAcc expression did not return accumulator."
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Id
i [Type]
_ Maybe (Lambda GPU)
_ [] = Binding -> Grapher ()
addSource (Id
i, forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
graphAcc Id
i [Type]
types Maybe (Lambda GPU)
op [Delayed]
delayed = do
Env
env <- Grapher Env
ask
State
st <- forall (m :: * -> *) s. Monad m => StateT s m s
get
let lambda :: Lambda GPU
lambda = forall a. a -> Maybe a -> a
fromMaybe (forall {k} (rep :: k).
[LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [] (forall {k} (rep :: k).
BodyDec rep -> Stms rep -> Result -> Body rep
Body () forall a. Seq a
SQ.empty []) []) Maybe (Lambda GPU)
op
let m :: Grapher ()
m = Body GPU -> Grapher ()
graphBody (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
lambda)
let stats :: BodyStats
stats = forall r a. Reader r a -> r -> a
R.runReader (forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher ()
m) State
st) Env
env
let host_only :: Bool
host_only = BodyStats -> Bool
bodyHostOnly BodyStats
stats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHasGPUBody BodyStats
stats
let does_read :: Bool
does_read = BodyStats -> Bool
bodyReads BodyStats
stats Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall t. Typed t => t -> Bool
isScalar [Type]
types
Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands (forall {k} (rep :: k). [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [] Lambda GPU
lambda)
case (Bool
host_only, Bool
does_read) of
(Bool
True, Bool
_) -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Exp GPU -> Grapher ()
graphHostOnly forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [Delayed]
delayed
Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
ops
(Bool
_, Bool
True) -> do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding -> Grapher ()
graphAutoMove forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [Delayed]
delayed
Binding -> Grapher ()
addSource (Id
i, forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
(Bool, Bool)
_ -> do
Binding -> Operands -> Grapher ()
createNode (Id
i, forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) Operands
ops
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Delayed]
delayed forall a b. (a -> b) -> a -> b
$
\(Binding
b, Exp GPU
e) -> Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Binding -> Operands -> Grapher ()
createNode Binding
b forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Operands -> Operands
IS.insert Id
i
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e =
let is :: Operands
is = forall a b. (a, b) -> a
fst forall a b. (a -> b) -> a -> b
$ forall s a. State s a -> s -> s
execState (Exp GPU -> StateT (Operands, Set VName) Identity ()
collect Exp GPU
e) forall {a}. (Operands, Set a)
initial
in Operands -> Operands -> Operands
IS.intersection Operands
is forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Operands
getGraphedScalars
where
initial :: (Operands, Set a)
initial = (Operands
IS.empty, forall a. Set a
S.empty)
captureName :: VName -> StateT (p Operands c) m ()
captureName VName
n = forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first forall a b. (a -> b) -> a -> b
$ Id -> Operands -> Operands
IS.insert (VName -> Id
nameToId VName
n)
captureAcc :: a -> StateT (p a (Set a)) m ()
captureAcc a
a = forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Set a
S.insert a
a
collectFree :: a -> StateT (p Operands c) m ()
collectFree a
x = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName (Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn a
x)
collect :: Exp GPU -> StateT (Operands, Set VName) Identity ()
collect b :: Exp GPU
b@BasicOp {} =
forall {k} {m :: * -> *} {p :: * -> * -> *} {rep :: k} {c}.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p Operands c) m ()
collectBasic Exp GPU
b
collect (Apply Name
_ [(SubExp, Diet)]
params [RetType GPU]
_ (Safety, SrcLoc, [SrcLoc])
_) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
params
collect (Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_) = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
ses
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Body GPU -> StateT (Operands, Set VName) Identity ()
collectBody forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
Body GPU -> StateT (Operands, Set VName) Identity ()
collectBody Body GPU
defbody
collect (DoLoop [(FParam GPU, SubExp)]
params LoopForm GPU
lform Body GPU
body) = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> b
snd) [(FParam GPU, SubExp)]
params
forall {k} {m :: * -> *} {p :: * -> * -> *} {rep :: k} {c}.
(Monad m, Bifunctor p) =>
LoopForm rep -> StateT (p Operands c) m ()
collectLForm LoopForm GPU
lform
Body GPU -> StateT (Operands, Set VName) Identity ()
collectBody Body GPU
body
collect (WithAcc [WithAccInput GPU]
accs Lambda GPU
f) =
[WithAccInput GPU]
-> Lambda GPU -> StateT (Operands, Set VName) Identity ()
collectWithAcc [WithAccInput GPU]
accs Lambda GPU
f
collect (Op Op GPU
op) =
forall {k} {a} {rep :: k} {c}.
FreeIn a =>
HostOp rep a -> StateT (Operands, c) Identity ()
collectHostOp Op GPU
op
collectBasic :: Exp rep -> StateT (p Operands c) m ()
collectBasic (BasicOp (Update Safety
_ VName
_ Slice SubExp
slice SubExp
_)) =
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree Slice SubExp
slice
collectBasic (BasicOp (Replicate ShapeBase SubExp
shape SubExp
_)) =
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree ShapeBase SubExp
shape
collectBasic Exp rep
e' =
forall {k} (m :: * -> *) (rep :: k).
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (forall {k} (m :: * -> *) (rep :: k). Monad m => Walker rep m
identityWalker {walkOnSubExp :: SubExp -> StateT (p Operands c) m ()
walkOnSubExp = forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp}) Exp rep
e'
collectSubExp :: SubExp -> StateT (p Operands c) m ()
collectSubExp (Var VName
n) = forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName VName
n
collectSubExp SubExp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectBody :: Body GPU -> StateT (Operands, Set VName) Identity ()
collectBody Body GPU
body = do
Stms GPU -> StateT (Operands, Set VName) Identity ()
collectStms (forall {k} (rep :: k). Body rep -> Stms rep
bodyStms Body GPU
body)
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree (forall {k} (rep :: k). Body rep -> Result
bodyResult Body GPU
body)
collectStms :: Stms GPU -> StateT (Operands, Set VName) Identity ()
collectStms = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> StateT (Operands, Set VName) Identity ()
collectStm
collectStm :: Stm GPU -> StateT (Operands, Set VName) Identity ()
collectStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ Exp GPU
ua)
| BasicOp UpdateAcc {} <- Exp GPU
ua,
Pat [PatElem (LetDec GPU)
pe] <- Pat (LetDec GPU)
pat,
Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_ <- forall t. Typed t => t -> Type
typeOf PatElem (LetDec GPU)
pe =
forall {m :: * -> *} {p :: * -> * -> *} {a} {a}.
(Monad m, Bifunctor p, Ord a) =>
a -> StateT (p a (Set a)) m ()
captureAcc VName
a forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall {k} {m :: * -> *} {p :: * -> * -> *} {rep :: k} {c}.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p Operands c) m ()
collectBasic Exp GPU
ua
collectStm Stm GPU
stm = Exp GPU -> StateT (Operands, Set VName) Identity ()
collect (forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm GPU
stm)
collectLForm :: LoopForm rep -> StateT (p Operands c) m ()
collectLForm (ForLoop VName
_ IntType
_ SubExp
b [(LParam rep, VName)]
_) = forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp SubExp
b
collectLForm (WhileLoop VName
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectWithAcc :: [WithAccInput GPU]
-> Lambda GPU -> StateT (Operands, Set VName) Identity ()
collectWithAcc [WithAccInput GPU]
inputs Lambda GPU
f = do
Body GPU -> StateT (Operands, Set VName) Identity ()
collectBody (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
Set VName
used_accs <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets forall a b. (a, b) -> b
snd
let accs :: [Type]
accs = forall a. Id -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Id
length [WithAccInput GPU]
inputs) (forall {k} (rep :: k). Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f)
let used :: [Bool]
used = forall a b. (a -> b) -> [a] -> [b]
map (\(Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) -> forall a. Ord a => a -> Set a -> Bool
S.member VName
a Set VName
used_accs) [Type]
accs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool, WithAccInput GPU)
-> StateT (Operands, Set VName) Identity ()
collectAcc (forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
used [WithAccInput GPU]
inputs)
collectAcc :: (Bool, WithAccInput GPU)
-> StateT (Operands, Set VName) Identity ()
collectAcc (Bool
_, (ShapeBase SubExp
_, [VName]
_, Maybe (Lambda GPU, [SubExp])
Nothing)) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectAcc (Bool
used, (ShapeBase SubExp
_, [VName]
_, Just (Lambda GPU
op, [SubExp]
nes))) = do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
used forall a b. (a -> b) -> a -> b
$ Body GPU -> StateT (Operands, Set VName) Identity ()
collectBody (forall {k} (rep :: k). Lambda rep -> Body rep
lambdaBody Lambda GPU
op)
collectHostOp :: HostOp rep a -> StateT (Operands, c) Identity ()
collectHostOp (SegOp (SegMap SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_)) = do
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
collectHostOp (SegOp (SegRed SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {m :: * -> *} {p :: * -> * -> *} {rep :: k} {c}.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegScan SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {m :: * -> *} {p :: * -> * -> *} {rep :: k} {c}.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp [SegBinOp rep]
ops
collectHostOp (SegOp (SegHist SegLevel
lvl SegSpace
sp [HistOp rep]
ops [Type]
_ KernelBody rep
_)) = do
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {k} {m :: * -> *} {p :: * -> * -> *} {rep :: k} {c}.
(Monad m, Bifunctor p) =>
HistOp rep -> StateT (p Operands c) m ()
collectHistOp [HistOp rep]
ops
collectHostOp (SizeOp SizeOp
op) = forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree SizeOp
op
collectHostOp (OtherOp a
op) = forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree a
op
collectHostOp GPUBody {} = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
collectSegLevel :: SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. FreeIn a => a -> Names
freeIn
collectSegSpace :: SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
space =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)
collectSegBinOp :: SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp (SegBinOp Commutativity
_ Lambda rep
_ [SubExp]
nes ShapeBase SubExp
_) =
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
collectHistOp :: HistOp rep -> StateT (p Operands c) m ()
collectHistOp (HistOp ShapeBase SubExp
_ SubExp
rf [VName]
_ [SubExp]
nes ShapeBase SubExp
_ Lambda rep
_) = do
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp SubExp
rf
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
createNode :: Binding -> Operands -> Grapher ()
createNode :: Binding -> Operands -> Grapher ()
createNode Binding
b Operands
ops =
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Operands -> Bool
IS.null Operands
ops) (Binding -> Grapher ()
addVertex Binding
b forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> a
fst Binding
b) Operands
ops)
addVertex :: Binding -> Grapher ()
addVertex :: Binding -> Grapher ()
addVertex (Id
i, Type
t) = do
Meta
meta <- Grapher Meta
getMeta
let v :: Vertex Meta
v = forall m. Id -> m -> Vertex m
MG.vertex Id
i Meta
meta
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall t. Typed t => t -> Bool
isScalar Type
t) forall a b. (a -> b) -> a -> b
$ (Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Id -> Operands -> Operands
IS.insert Id
i)
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall t. Typed t => t -> Bool
isArray Type
t) forall a b. (a -> b) -> a -> b
$ Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Meta -> Id
metaBodyDepth Meta
meta)
(Graph -> Graph) -> Grapher ()
modifyGraph (forall m. Vertex m -> Graph m -> Graph m
MG.insert Vertex Meta
v)
addSource :: Binding -> Grapher ()
addSource :: Binding -> Grapher ()
addSource Binding
b = do
Binding -> Grapher ()
addVertex Binding
b
(Sources -> Sources) -> Grapher ()
modifySources forall a b. (a -> b) -> a -> b
$ forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (forall a b. (a, b) -> a
fst Binding
b :)
addEdges :: Edges -> IdSet -> Grapher ()
addEdges :: Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph forall a b. (a -> b) -> a -> b
$ \Graph
g -> forall a. (a -> Id -> a) -> a -> Operands -> a
IS.foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall m. Id -> Graph m -> Graph m
MG.connectToSink) Graph
g Operands
is
(Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Operands -> Operands -> Operands
`IS.difference` Operands
is)
addEdges Edges
es Operands
is = do
(Graph -> Graph) -> Grapher ()
modifyGraph forall a b. (a -> b) -> a -> b
$ \Graph
g -> forall a. (a -> Id -> a) -> a -> Operands -> a
IS.foldl' (forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a b. (a -> b) -> a -> b
$ forall m. Edges -> Id -> Graph m -> Graph m
MG.addEdges Edges
es) Graph
g Operands
is
Operands -> Grapher ()
tellOperands Operands
is
requiredOnHost :: Id -> Grapher ()
requiredOnHost :: Id -> Grapher ()
requiredOnHost Id
i = do
Maybe (Vertex Meta)
mv <- forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Graph
getGraph
case Maybe (Vertex Meta)
mv of
Maybe (Vertex Meta)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
Just Vertex Meta
v -> do
Id -> Grapher ()
connectToSink Id
i
Id -> Grapher ()
tellHostOnlyParent (Meta -> Id
metaBodyDepth forall a b. (a -> b) -> a -> b
$ forall m. Vertex m -> m
vertexMeta Vertex Meta
v)
connectToSink :: Id -> Grapher ()
connectToSink :: Id -> Grapher ()
connectToSink Id
i = do
(Graph -> Graph) -> Grapher ()
modifyGraph (forall m. Id -> Graph m -> Graph m
MG.connectToSink Id
i)
(Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Id -> Operands -> Operands
IS.delete Id
i)
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink (Var VName
n) = Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
n)
connectSubExpToSink SubExp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph Id
si = do
State
st <- forall (m :: * -> *) s. Monad m => StateT s m s
get
let g :: Graph
g = State -> Graph
stateGraph State
st
let ([Id]
routed, [Id]
unrouted) = State -> Sources
stateSources State
st
let ([Id]
gsrcs, [Id]
unrouted') = forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Id -> Graph -> Id -> Bool
inSubGraph Id
si Graph
g) [Id]
unrouted
let ([Id]
sinks, Graph
g') = forall m. [Id] -> Graph m -> ([Id], Graph m)
MG.routeMany [Id]
gsrcs Graph
g
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put forall a b. (a -> b) -> a -> b
$
State
st
{ stateGraph :: Graph
stateGraph = Graph
g',
stateSources :: Sources
stateSources = ([Id]
gsrcs forall a. [a] -> [a] -> [a]
++ [Id]
routed, [Id]
unrouted'),
stateSinks :: [Id]
stateSinks = [Id]
sinks forall a. [a] -> [a] -> [a]
++ State -> [Id]
stateSinks State
st
}
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Id]
gsrcs
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph Id
si Graph
g Id
i
| Just Vertex Meta
v <- forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i Graph
g,
Just Id
mgi <- Meta -> Maybe Id
metaGraphId (forall m. Vertex m -> m
vertexMeta Vertex Meta
v) =
Id
si forall a. Eq a => a -> a -> Bool
== Id
mgi
inSubGraph Id
_ Graph
_ Id
_ = Bool
False
reuses :: Binding -> VName -> Grapher ()
reuses :: Binding -> VName -> Grapher ()
reuses (Id
i, Type
t) VName
n
| forall t. Typed t => t -> Bool
isArray Type
t =
do
Maybe Id
body_depth <- VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n
case Maybe Id
body_depth of
Just Id
bd -> Id -> Id -> Grapher ()
recordCopyableMemory Id
i Id
bd
Maybe Id
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp Binding
b (Var VName
n) = Binding
b Binding -> VName -> Grapher ()
`reuses` VName
n
reusesSubExp Binding
_ SubExp
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs [SubExp]
res = do
Id
body_depth <- Meta -> Id
metaBodyDepth forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Id -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Id
body_depth) Bool
True (forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [SubExp]
res)
where
reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse :: Id -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Id
body_depth Bool
onlyCopyable (Binding
b, SubExp
se)
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd Binding
b) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Id
i, Type
t) <- Binding
b,
forall t. Typed t => t -> Bool
isArray Type
t,
Var VName
n <- SubExp
se =
do
Maybe Id
res_body_depth <- VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n
case Maybe Id
res_body_depth of
Just Id
inner -> do
Id -> Id -> Grapher ()
recordCopyableMemory Id
i (forall a. Ord a => a -> a -> a
min Id
body_depth Id
inner)
let returns_free_var :: Bool
returns_free_var = Id
inner forall a. Ord a => a -> a -> Bool
<= Id
body_depth
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
Maybe Id
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches [Binding]
bs [[SubExp]]
seses = do
Id
body_depth <- Meta -> Id
metaBodyDepth forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Id -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse Id
body_depth) Bool
True forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs forall a b. (a -> b) -> a -> b
$ forall a. [[a]] -> [[a]]
L.transpose [[SubExp]]
seses
where
reuse :: Int -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse :: Id -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse Id
body_depth Bool
onlyCopyable (Binding
b, [SubExp]
ses)
| forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd Binding
b) =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
| (Id
i, Type
t) <- Binding
b,
forall t. Typed t => t -> Bool
isArray Type
t,
Just [VName]
ns <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
ses = do
[Maybe Id]
body_depths <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> Grapher (Maybe Id)
outermostCopyableArray [VName]
ns
case forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe Id]
body_depths of
Just [Id]
bds -> do
let inner :: Id
inner = forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [Id]
bds
Id -> Id -> Grapher ()
recordCopyableMemory Id
i (forall a. Ord a => a -> a -> a
min Id
body_depth Id
inner)
let returns_free_var :: Bool
returns_free_var = Id
inner forall a. Ord a => a -> a -> Bool
<= Id
body_depth
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
Maybe [Id]
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
type Grapher = StateT State (R.Reader Env)
data Env = Env
{
Env -> HostOnlyFuns
envHostOnlyFuns :: HostOnlyFuns,
Env -> Meta
envMeta :: Meta
}
type BodyDepth = Int
data Meta = Meta
{
Meta -> Id
metaForkDepth :: Int,
Meta -> Id
metaBodyDepth :: BodyDepth,
Meta -> Maybe Id
metaGraphId :: Maybe Id
}
type Operands = IdSet
data BodyStats = BodyStats
{
BodyStats -> Bool
bodyHostOnly :: Bool,
BodyStats -> Bool
bodyHasGPUBody :: Bool,
BodyStats -> Bool
bodyReads :: Bool,
BodyStats -> Operands
bodyOperands :: Operands,
BodyStats -> Operands
bodyHostOnlyParents :: IS.IntSet
}
instance Semigroup BodyStats where
(BodyStats Bool
ho1 Bool
gb1 Bool
r1 Operands
o1 Operands
hop1) <> :: BodyStats -> BodyStats -> BodyStats
<> (BodyStats Bool
ho2 Bool
gb2 Bool
r2 Operands
o2 Operands
hop2) =
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
ho1 Bool -> Bool -> Bool
|| Bool
ho2,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
gb1 Bool -> Bool -> Bool
|| Bool
gb2,
bodyReads :: Bool
bodyReads = Bool
r1 Bool -> Bool -> Bool
|| Bool
r2,
bodyOperands :: Operands
bodyOperands = Operands -> Operands -> Operands
IS.union Operands
o1 Operands
o2,
bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands -> Operands -> Operands
IS.union Operands
hop1 Operands
hop2
}
instance Monoid BodyStats where
mempty :: BodyStats
mempty =
BodyStats
{ bodyHostOnly :: Bool
bodyHostOnly = Bool
False,
bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
False,
bodyReads :: Bool
bodyReads = Bool
False,
bodyOperands :: Operands
bodyOperands = Operands
IS.empty,
bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
IS.empty
}
type Graph = MG.Graph Meta
type Sources = ([Id], [Id])
type Sinks = [Id]
type Delayed = (Binding, Exp GPU)
type Binding = (Id, Type)
type CopyableMemoryMap = IM.IntMap BodyDepth
data State = State
{
State -> Graph
stateGraph :: Graph,
State -> Operands
stateGraphedScalars :: IdSet,
State -> Sources
stateSources :: Sources,
State -> [Id]
stateSinks :: Sinks,
State -> IntMap [Delayed]
stateUpdateAccs :: IM.IntMap [Delayed],
State -> CopyableMemoryMap
stateCopyableMemory :: CopyableMemoryMap,
State -> BodyStats
stateStats :: BodyStats
}
execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks)
execGrapher :: forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Id])
execGrapher HostOnlyFuns
hof Grapher a
m =
let s :: State
s = forall r a. Reader r a -> r -> a
R.runReader (forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Grapher a
m State
st) Env
env
in (State -> Graph
stateGraph State
s, State -> Sources
stateSources State
s, State -> [Id]
stateSinks State
s)
where
env :: Env
env =
Env
{ envHostOnlyFuns :: HostOnlyFuns
envHostOnlyFuns = HostOnlyFuns
hof,
envMeta :: Meta
envMeta =
Meta
{ metaForkDepth :: Id
metaForkDepth = Id
0,
metaBodyDepth :: Id
metaBodyDepth = Id
0,
metaGraphId :: Maybe Id
metaGraphId = forall a. Maybe a
Nothing
}
}
st :: State
st =
State
{ stateGraph :: Graph
stateGraph = forall m. Graph m
MG.empty,
stateGraphedScalars :: Operands
stateGraphedScalars = Operands
IS.empty,
stateSources :: Sources
stateSources = ([], []),
stateSinks :: [Id]
stateSinks = [],
stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = forall a. IntMap a
IM.empty,
stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = forall a. IntMap a
IM.empty,
stateStats :: BodyStats
stateStats = forall a. Monoid a => a
mempty
}
local :: (Env -> Env) -> Grapher a -> Grapher a
local :: forall a. (Env -> Env) -> Grapher a -> Grapher a
local Env -> Env
f = forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT (forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
R.local Env -> Env
f)
ask :: Grapher Env
ask :: Grapher Env
ask = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask
asks :: (Env -> a) -> Grapher a
asks :: forall a. (Env -> a) -> Grapher a
asks = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks
tellHostOnly :: Grapher ()
tellHostOnly :: Grapher ()
tellHostOnly =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHostOnly :: Bool
bodyHostOnly = Bool
True}}
tellGPUBody :: Grapher ()
tellGPUBody :: Grapher ()
tellGPUBody =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
True}}
tellRead :: Grapher ()
tellRead :: Grapher ()
tellRead =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyReads :: Bool
bodyReads = Bool
True}}
tellOperands :: IdSet -> Grapher ()
tellOperands :: Operands -> Grapher ()
tellOperands Operands
is =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
operands :: Operands
operands = BodyStats -> Operands
bodyOperands BodyStats
stats
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyOperands :: Operands
bodyOperands = Operands
operands forall a. Semigroup a => a -> a -> a
<> Operands
is}}
tellHostOnlyParent :: BodyDepth -> Grapher ()
tellHostOnlyParent :: Id -> Grapher ()
tellHostOnlyParent Id
body_depth =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st ->
let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
parents :: Operands
parents = BodyStats -> Operands
bodyHostOnlyParents BodyStats
stats
parents' :: Operands
parents' = Id -> Operands -> Operands
IS.insert Id
body_depth Operands
parents
in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
parents'}}
getGraph :: Grapher Graph
getGraph :: Grapher Graph
getGraph = forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Graph
stateGraph
getGraphedScalars :: Grapher IdSet
getGraphedScalars :: Grapher Operands
getGraphedScalars = forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Operands
stateGraphedScalars
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory = forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> CopyableMemoryMap
stateCopyableMemory
outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth)
outermostCopyableArray :: VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n = forall a. Id -> IntMap a -> Maybe a
IM.lookup (VName -> Id
nameToId VName
n) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher CopyableMemoryMap
getCopyableMemory
onlyGraphedScalars :: Foldable t => t VName -> Grapher IdSet
onlyGraphedScalars :: forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars t VName
vs = do
let is :: Operands
is = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Operands
s VName
n -> Id -> Operands -> Operands
IS.insert (VName -> Id
nameToId VName
n) Operands
s) Operands
IS.empty t VName
vs
Operands -> Operands -> Operands
IS.intersection Operands
is forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Operands
getGraphedScalars
onlyGraphedScalar :: VName -> Grapher IdSet
onlyGraphedScalar :: VName -> Grapher Operands
onlyGraphedScalar VName
n = do
let i :: Id
i = VName -> Id
nameToId VName
n
Operands
gss <- Grapher Operands
getGraphedScalars
if Id -> Operands -> Bool
IS.member Id
i Operands
gss
then forall (f :: * -> *) a. Applicative f => a -> f a
pure (Id -> Operands
IS.singleton Id
i)
else forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet
onlyGraphedScalarSubExp :: SubExp -> Grapher Operands
onlyGraphedScalarSubExp (Constant PrimValue
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
onlyGraphedScalarSubExp (Var VName
n) = VName -> Grapher Operands
onlyGraphedScalar VName
n
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph Graph -> Graph
f =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraph :: Graph
stateGraph = Graph -> Graph
f (State -> Graph
stateGraph State
st)}
modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher ()
modifyGraphedScalars :: (Operands -> Operands) -> Grapher ()
modifyGraphedScalars Operands -> Operands
f =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraphedScalars :: Operands
stateGraphedScalars = Operands -> Operands
f (State -> Operands
stateGraphedScalars State
st)}
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory CopyableMemoryMap -> CopyableMemoryMap
f =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap -> CopyableMemoryMap
f (State -> CopyableMemoryMap
stateCopyableMemory State
st)}
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources Sources -> Sources
f =
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateSources :: Sources
stateSources = Sources -> Sources
f (State -> Sources
stateSources State
st)}
recordCopyableMemory :: Id -> BodyDepth -> Grapher ()
recordCopyableMemory :: Id -> Id -> Grapher ()
recordCopyableMemory Id
i Id
bd =
(CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory (forall a. Id -> a -> IntMap a -> IntMap a
IM.insert Id
i Id
bd)
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor :: forall a. Grapher a -> Grapher a
incForkDepthFor =
forall a. (Env -> Env) -> Grapher a -> Grapher a
local forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
fork_depth :: Id
fork_depth = Meta -> Id
metaForkDepth Meta
meta
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaForkDepth :: Id
metaForkDepth = Id
fork_depth forall a. Num a => a -> a -> a
+ Id
1}}
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor :: forall a. Grapher a -> Grapher a
incBodyDepthFor =
forall a. (Env -> Env) -> Grapher a -> Grapher a
local forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
body_depth :: Id
body_depth = Meta -> Id
metaBodyDepth Meta
meta
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaBodyDepth :: Id
metaBodyDepth = Id
body_depth forall a. Num a => a -> a -> a
+ Id
1}}
graphIdFor :: Id -> Grapher a -> Grapher a
graphIdFor :: forall a. Id -> Grapher a -> Grapher a
graphIdFor Id
i =
forall a. (Env -> Env) -> Grapher a -> Grapher a
local forall a b. (a -> b) -> a -> b
$ \Env
env ->
let meta :: Meta
meta = Env -> Meta
envMeta Env
env
in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaGraphId :: Maybe Id
metaGraphId = forall a. a -> Maybe a
Just Id
i}}
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats :: forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher a
m = do
BodyStats
stats <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = forall a. Monoid a => a
mempty}
a
_ <- Grapher a
m
BodyStats
stats' <- forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
stats forall a. Semigroup a => a -> a -> a
<> BodyStats
stats'}
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyStats
stats'
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun Name
fn = forall a. (Env -> a) -> Grapher a
asks forall a b. (a -> b) -> a -> b
$ forall a. Ord a => a -> Set a -> Bool
S.member Name
fn forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> HostOnlyFuns
envHostOnlyFuns
getMeta :: Grapher Meta
getMeta :: Grapher Meta
getMeta = forall a. (Env -> a) -> Grapher a
asks Env -> Meta
envMeta
getBodyDepth :: Grapher BodyDepth
getBodyDepth :: StateT State (Reader Env) Id
getBodyDepth = forall a. (Env -> a) -> Grapher a
asks (Meta -> Id
metaBodyDepth forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Meta
envMeta)