-- |
-- This module implements program analysis to determine which program statements
-- the "Futhark.Optimise.ReduceDeviceSyncs" pass should move into 'GPUBody' kernels
-- to reduce blocking memory transfers between host and device. The results of
-- the analysis is encoded into a 'MigrationTable' which can be queried.
--
-- To reduce blocking scalar reads the module constructs a data flow
-- dependency graph of program variables (see
-- "Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph") in which
-- it finds a minimum vertex cut that separates array reads of scalars
-- from transitive usage that cannot or should not be migrated to
-- device.
--
-- The variables of each partition are assigned a 'MigrationStatus' that states
-- whether the computation of those variables should be moved to device or
-- remain on host. Due to how the graph is built and the vertex cut is found all
-- variables bound by a single statement will belong to the same partition.
--
-- The vertex cut contains all variables that will reside in device memory but
-- are required by host operations. These variables must be read from device
-- memory and cannot be reduced further in number merely by migrating
-- statements (subject to the accuracy of the graph model). The model is built
-- to reduce the worst-case number of scalar reads; an optimal migration of
-- statements depends on runtime data.
--
-- Blocking scalar writes are reduced by either turning such writes into
-- asynchronous kernels, as is done with scalar array literals and accumulator
-- updates, or by transforming host-device writing into device-device copying.
--
-- For details on how the graph is constructed and how the vertex cut is found,
-- see the master thesis "Reducing Synchronous GPU Memory Transfers" by Philip
-- Børgesen (2022).
module Futhark.Optimise.ReduceDeviceSyncs.MigrationTable
  ( -- * Analysis
    analyseFunDef,
    analyseConsts,
    hostOnlyFunDefs,

    -- * Types
    MigrationTable,
    MigrationStatus (..),

    -- * Query

    -- | These functions all assume that no parent statement should be migrated.
    -- That is @shouldMoveStm stm mt@ should return @False@ for every statement
    -- @stm@ with a body that a queried 'VName' or 'Stm' is nested within,
    -- otherwise the query result may be invalid.
    shouldMoveStm,
    shouldMove,
    usedOnHost,
    statusOf,
  )
where

import Control.Monad
import Control.Monad.Trans.Class
import Control.Monad.Trans.Reader qualified as R
import Control.Monad.Trans.State.Strict ()
import Control.Monad.Trans.State.Strict hiding (State)
import Data.Bifunctor (first, second)
import Data.Foldable
import Data.IntMap.Strict qualified as IM
import Data.IntSet qualified as IS
import Data.List qualified as L
import Data.Map.Strict qualified as M
import Data.Maybe (fromMaybe, isJust, isNothing)
import Data.Sequence qualified as SQ
import Data.Set (Set, (\\))
import Data.Set qualified as S
import Futhark.Error
import Futhark.IR.GPU
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph
  ( EdgeType (..),
    Edges (..),
    Id,
    IdSet,
    Result (..),
    Routing (..),
    Vertex (..),
  )
import Futhark.Optimise.ReduceDeviceSyncs.MigrationTable.Graph qualified as MG

--------------------------------------------------------------------------------
--                              MIGRATION TABLES                              --
--------------------------------------------------------------------------------

-- | Where the value bound by a name should be computed.
data MigrationStatus
  = -- | The statement that computes the value should be moved to device.
    -- No host usage of the value will be left after the migration.
    MoveToDevice
  | -- | As 'MoveToDevice' but host usage of the value will remain after
    -- migration.
    UsedOnHost
  | -- | The statement that computes the value should remain on host.
    StayOnHost
  deriving (MigrationStatus -> MigrationStatus -> Bool
(MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> Eq MigrationStatus
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: MigrationStatus -> MigrationStatus -> Bool
== :: MigrationStatus -> MigrationStatus -> Bool
$c/= :: MigrationStatus -> MigrationStatus -> Bool
/= :: MigrationStatus -> MigrationStatus -> Bool
Eq, Eq MigrationStatus
Eq MigrationStatus
-> (MigrationStatus -> MigrationStatus -> Ordering)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> Bool)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> (MigrationStatus -> MigrationStatus -> MigrationStatus)
-> Ord MigrationStatus
MigrationStatus -> MigrationStatus -> Bool
MigrationStatus -> MigrationStatus -> Ordering
MigrationStatus -> MigrationStatus -> MigrationStatus
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: MigrationStatus -> MigrationStatus -> Ordering
compare :: MigrationStatus -> MigrationStatus -> Ordering
$c< :: MigrationStatus -> MigrationStatus -> Bool
< :: MigrationStatus -> MigrationStatus -> Bool
$c<= :: MigrationStatus -> MigrationStatus -> Bool
<= :: MigrationStatus -> MigrationStatus -> Bool
$c> :: MigrationStatus -> MigrationStatus -> Bool
> :: MigrationStatus -> MigrationStatus -> Bool
$c>= :: MigrationStatus -> MigrationStatus -> Bool
>= :: MigrationStatus -> MigrationStatus -> Bool
$cmax :: MigrationStatus -> MigrationStatus -> MigrationStatus
max :: MigrationStatus -> MigrationStatus -> MigrationStatus
$cmin :: MigrationStatus -> MigrationStatus -> MigrationStatus
min :: MigrationStatus -> MigrationStatus -> MigrationStatus
Ord, Id -> MigrationStatus -> ShowS
[MigrationStatus] -> ShowS
MigrationStatus -> String
(Id -> MigrationStatus -> ShowS)
-> (MigrationStatus -> String)
-> ([MigrationStatus] -> ShowS)
-> Show MigrationStatus
forall a.
(Id -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Id -> MigrationStatus -> ShowS
showsPrec :: Id -> MigrationStatus -> ShowS
$cshow :: MigrationStatus -> String
show :: MigrationStatus -> String
$cshowList :: [MigrationStatus] -> ShowS
showList :: [MigrationStatus] -> ShowS
Show)

-- | Identifies
--
--     (1) which statements should be moved from host to device to reduce the
--         worst case number of blocking memory transfers.
--
--     (2) which migrated variables that still will be used on the host after
--         all such statements have been moved.
newtype MigrationTable = MigrationTable (IM.IntMap MigrationStatus)

instance Semigroup MigrationTable where
  MigrationTable IntMap MigrationStatus
a <> :: MigrationTable -> MigrationTable -> MigrationTable
<> MigrationTable IntMap MigrationStatus
b = IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus
a IntMap MigrationStatus
-> IntMap MigrationStatus -> IntMap MigrationStatus
forall a. IntMap a -> IntMap a -> IntMap a
`IM.union` IntMap MigrationStatus
b)

-- | Where should the value bound by this name be computed?
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf :: VName -> MigrationTable -> MigrationStatus
statusOf VName
n (MigrationTable IntMap MigrationStatus
mt) =
  MigrationStatus -> Maybe MigrationStatus -> MigrationStatus
forall a. a -> Maybe a -> a
fromMaybe MigrationStatus
StayOnHost (Maybe MigrationStatus -> MigrationStatus)
-> Maybe MigrationStatus -> MigrationStatus
forall a b. (a -> b) -> a -> b
$ Id -> IntMap MigrationStatus -> Maybe MigrationStatus
forall a. Id -> IntMap a -> Maybe a
IM.lookup (VName -> Id
baseTag VName
n) IntMap MigrationStatus
mt

-- | Should this whole statement be moved from host to device?
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm :: Stm GPU -> MigrationTable -> Bool
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp (Index VName
_ Slice SubExp
slice))) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice Bool -> Bool -> Bool
|| (SubExp -> Bool) -> Slice SubExp -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any SubExp -> Bool
movedOperand Slice SubExp
slice
  where
    movedOperand :: SubExp -> Bool
movedOperand (Var VName
op) = VName -> MigrationTable -> MigrationStatus
statusOf VName
op MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
    movedOperand SubExp
_ = Bool
False
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ (BasicOp BasicOp
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let (Pat ((PatElem VName
n LetDec GPU
_) : [PatElem (LetDec GPU)]
_)) StmAux (ExpDec GPU)
_ Apply {}) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Match [SubExp]
cond [Case (Body GPU)]
_ Body GPU
_ MatchDec (BranchType GPU)
_)) MigrationTable
mt =
  (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice) (MigrationStatus -> Bool)
-> (VName -> MigrationStatus) -> VName -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> MigrationTable -> MigrationStatus
`statusOf` MigrationTable
mt)) ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
cond
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Loop [(FParam GPU, SubExp)]
_ (ForLoop VName
_ IntType
_ (Var VName
n)) Body GPU
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
shouldMoveStm (Let Pat (LetDec GPU)
_ StmAux (ExpDec GPU)
_ (Loop [(FParam GPU, SubExp)]
_ (WhileLoop VName
n) Body GPU
_)) MigrationTable
mt =
  VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
== MigrationStatus
MoveToDevice
-- BasicOp and Apply statements might not bind any variables (shouldn't happen).
-- If statements might use a constant branch condition.
-- For loop statements might use a constant number of iterations.
-- HostOp statements cannot execute on device.
-- WithAcc statements are never moved in their entirety.
shouldMoveStm Stm GPU
_ MigrationTable
_ = Bool
False

-- | Should the value bound by this name be computed on device?
shouldMove :: VName -> MigrationTable -> Bool
shouldMove :: VName -> MigrationTable -> Bool
shouldMove VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
StayOnHost

-- | Will the value bound by this name be used on host?
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost :: VName -> MigrationTable -> Bool
usedOnHost VName
n MigrationTable
mt = VName -> MigrationTable -> MigrationStatus
statusOf VName
n MigrationTable
mt MigrationStatus -> MigrationStatus -> Bool
forall a. Eq a => a -> a -> Bool
/= MigrationStatus
MoveToDevice

--------------------------------------------------------------------------------
--                         HOST-ONLY FUNCTION ANALYSIS                        --
--------------------------------------------------------------------------------

-- | Identifies top-level function definitions that cannot be run on the
-- device. The application of any such function is host-only.
type HostOnlyFuns = Set Name

-- | Returns the names of all top-level functions that cannot be called from the
-- device. The evaluation of such a function is host-only.
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs :: [FunDef GPU] -> HostOnlyFuns
hostOnlyFunDefs [FunDef GPU]
funs =
  let names :: [Name]
names = (FunDef GPU -> Name) -> [FunDef GPU] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Name
forall rep. FunDef rep -> Name
funDefName [FunDef GPU]
funs
      call_map :: Map Name (Maybe HostOnlyFuns)
call_map = [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns))
-> [(Name, Maybe HostOnlyFuns)] -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ [Name] -> [Maybe HostOnlyFuns] -> [(Name, Maybe HostOnlyFuns)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Name]
names ((FunDef GPU -> Maybe HostOnlyFuns)
-> [FunDef GPU] -> [Maybe HostOnlyFuns]
forall a b. (a -> b) -> [a] -> [b]
map FunDef GPU -> Maybe HostOnlyFuns
checkFunDef [FunDef GPU]
funs)
   in [Name] -> HostOnlyFuns
forall a. Ord a => [a] -> Set a
S.fromList [Name]
names HostOnlyFuns -> HostOnlyFuns -> HostOnlyFuns
forall a. Ord a => Set a -> Set a -> Set a
\\ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall {a}. Map Name a -> HostOnlyFuns
keysToSet (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
call_map)
  where
    keysToSet :: Map Name a -> HostOnlyFuns
keysToSet = [Name] -> HostOnlyFuns
forall a. Eq a => [a] -> Set a
S.fromAscList ([Name] -> HostOnlyFuns)
-> (Map Name a -> [Name]) -> Map Name a -> HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Name a -> [Name]
forall k a. Map k a -> [k]
M.keys

    removeHostOnly :: Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly Map Name (Maybe HostOnlyFuns)
cm =
      let (Map Name (Maybe HostOnlyFuns)
host_only, Map Name (Maybe HostOnlyFuns)
cm') = (Maybe HostOnlyFuns -> Bool)
-> Map Name (Maybe HostOnlyFuns)
-> (Map Name (Maybe HostOnlyFuns), Map Name (Maybe HostOnlyFuns))
forall a k. (a -> Bool) -> Map k a -> (Map k a, Map k a)
M.partition Maybe HostOnlyFuns -> Bool
forall {a}. Maybe a -> Bool
isHostOnly Map Name (Maybe HostOnlyFuns)
cm
       in if Map Name (Maybe HostOnlyFuns) -> Bool
forall k a. Map k a -> Bool
M.null Map Name (Maybe HostOnlyFuns)
host_only
            then Map Name (Maybe HostOnlyFuns)
cm'
            else Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
removeHostOnly (Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns))
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b. (a -> b) -> a -> b
$ (Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> Map Name (Maybe HostOnlyFuns) -> Map Name (Maybe HostOnlyFuns)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall {a}. Ord a => Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls (HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns)
-> HostOnlyFuns -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ Map Name (Maybe HostOnlyFuns) -> HostOnlyFuns
forall {a}. Map Name a -> HostOnlyFuns
keysToSet Map Name (Maybe HostOnlyFuns)
host_only) Map Name (Maybe HostOnlyFuns)
cm'

    isHostOnly :: Maybe a -> Bool
isHostOnly = Maybe a -> Bool
forall {a}. Maybe a -> Bool
isNothing

    -- A function that calls a host-only function is itself host-only.
    checkCalls :: Set a -> Maybe (Set a) -> Maybe (Set a)
checkCalls Set a
hostOnlyFuns (Just Set a
calls)
      | Set a
hostOnlyFuns Set a -> Set a -> Bool
forall a. Ord a => Set a -> Set a -> Bool
`S.disjoint` Set a
calls =
          Set a -> Maybe (Set a)
forall a. a -> Maybe a
Just Set a
calls
    checkCalls Set a
_ Maybe (Set a)
_ =
      Maybe (Set a)
forall a. Maybe a
Nothing

-- | 'checkFunDef' returns 'Nothing' if this function definition uses arrays or
-- HostOps. Otherwise it returns the names of all applied functions, which may
-- include user defined functions that could turn out to be host-only.
checkFunDef :: FunDef GPU -> Maybe (Set Name)
checkFunDef :: FunDef GPU -> Maybe HostOnlyFuns
checkFunDef FunDef GPU
fun = do
  [Param DeclType] -> Maybe ()
checkFParams ([Param DeclType] -> Maybe ()) -> [Param DeclType] -> Maybe ()
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> [FParam GPU]
forall rep. FunDef rep -> [FParam rep]
funDefParams FunDef GPU
fun
  [TypeBase ExtShape Uniqueness] -> Maybe ()
forall {u}. [TypeBase ExtShape u] -> Maybe ()
checkRetTypes ([TypeBase ExtShape Uniqueness] -> Maybe ())
-> [TypeBase ExtShape Uniqueness] -> Maybe ()
forall a b. (a -> b) -> a -> b
$ ((TypeBase ExtShape Uniqueness, RetAls)
 -> TypeBase ExtShape Uniqueness)
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness, RetAls)
-> TypeBase ExtShape Uniqueness
forall a b. (a, b) -> a
fst ([(TypeBase ExtShape Uniqueness, RetAls)]
 -> [TypeBase ExtShape Uniqueness])
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> [(RetType GPU, RetAls)]
forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef GPU
fun
  Body GPU -> Maybe HostOnlyFuns
checkBody (Body GPU -> Maybe HostOnlyFuns) -> Body GPU -> Maybe HostOnlyFuns
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fun
  where
    hostOnly :: Maybe a
hostOnly = Maybe a
forall a. Maybe a
Nothing
    ok :: Maybe ()
ok = () -> Maybe ()
forall a. a -> Maybe a
Just ()
    check :: (a -> Bool) -> t a -> Maybe ()
check a -> Bool
isArr t a
as = if (a -> Bool) -> t a -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any a -> Bool
isArr t a
as then Maybe ()
forall a. Maybe a
hostOnly else Maybe ()
ok

    checkFParams :: [Param DeclType] -> Maybe ()
checkFParams = (Param DeclType -> Bool) -> [Param DeclType] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check Param DeclType -> Bool
forall t. Typed t => t -> Bool
isArray

    checkLParams :: [(FParam GPU, b)] -> Maybe ()
checkLParams = ((FParam GPU, b) -> Bool) -> [(FParam GPU, b)] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check (FParam GPU -> Bool
forall t. Typed t => t -> Bool
isArray (FParam GPU -> Bool)
-> ((FParam GPU, b) -> FParam GPU) -> (FParam GPU, b) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, b) -> FParam GPU
forall a b. (a, b) -> a
fst)

    checkRetTypes :: [TypeBase ExtShape u] -> Maybe ()
checkRetTypes = (TypeBase ExtShape u -> Bool) -> [TypeBase ExtShape u] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check TypeBase ExtShape u -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType

    checkPats :: [PatElem (LetDec GPU)] -> Maybe ()
checkPats = (PatElem (LetDec GPU) -> Bool)
-> [PatElem (LetDec GPU)] -> Maybe ()
forall {t :: * -> *} {a}.
Foldable t =>
(a -> Bool) -> t a -> Maybe ()
check PatElem (LetDec GPU) -> Bool
forall t. Typed t => t -> Bool
isArray

    checkBody :: Body GPU -> Maybe HostOnlyFuns
checkBody = Stms GPU -> Maybe HostOnlyFuns
checkStms (Stms GPU -> Maybe HostOnlyFuns)
-> (Body GPU -> Stms GPU) -> Body GPU -> Maybe HostOnlyFuns
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms

    checkStms :: Stms GPU -> Maybe HostOnlyFuns
checkStms Stms GPU
stms = Seq HostOnlyFuns -> HostOnlyFuns
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (Seq HostOnlyFuns -> HostOnlyFuns)
-> Maybe (Seq HostOnlyFuns) -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm GPU -> Maybe HostOnlyFuns)
-> Stms GPU -> Maybe (Seq HostOnlyFuns)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Seq a -> m (Seq b)
mapM Stm GPU -> Maybe HostOnlyFuns
checkStm Stms GPU
stms

    checkStm :: Stm GPU -> Maybe HostOnlyFuns
checkStm (Let (Pat [PatElem (LetDec GPU)]
pats) StmAux (ExpDec GPU)
_ Exp GPU
e) = [PatElem (LetDec GPU)] -> Maybe ()
checkPats [PatElem (LetDec GPU)]
pats Maybe () -> Maybe HostOnlyFuns -> Maybe HostOnlyFuns
forall a b. Maybe a -> Maybe b -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> Maybe HostOnlyFuns
checkExp Exp GPU
e

    -- Any expression that produces an array is caught by checkPats
    checkExp :: Exp GPU -> Maybe HostOnlyFuns
checkExp (BasicOp (Index VName
_ Slice SubExp
_)) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
    checkExp (WithAcc [WithAccInput GPU]
_ Lambda GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
    checkExp (Op Op GPU
_) = Maybe HostOnlyFuns
forall a. Maybe a
hostOnly
    checkExp (Apply Name
fn [(SubExp, Diet)]
_ [(RetType GPU, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_) = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just (Name -> HostOnlyFuns
forall a. a -> Set a
S.singleton Name
fn)
    checkExp (Match [SubExp]
_ [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_) =
      [HostOnlyFuns] -> HostOnlyFuns
forall a. Monoid a => [a] -> a
mconcat ([HostOnlyFuns] -> HostOnlyFuns)
-> Maybe [HostOnlyFuns] -> Maybe HostOnlyFuns
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Body GPU -> Maybe HostOnlyFuns)
-> [Body GPU] -> Maybe [HostOnlyFuns]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Body GPU -> Maybe HostOnlyFuns
checkBody (Body GPU
defbody Body GPU -> [Body GPU] -> [Body GPU]
forall a. a -> [a] -> [a]
: (Case (Body GPU) -> Body GPU) -> [Case (Body GPU)] -> [Body GPU]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody [Case (Body GPU)]
cases)
    checkExp (Loop [(FParam GPU, SubExp)]
params LoopForm
_ Body GPU
body) = do
      [(FParam GPU, SubExp)] -> Maybe ()
forall {b}. [(FParam GPU, b)] -> Maybe ()
checkLParams [(FParam GPU, SubExp)]
params
      Body GPU -> Maybe HostOnlyFuns
checkBody Body GPU
body
    checkExp BasicOp {} = HostOnlyFuns -> Maybe HostOnlyFuns
forall a. a -> Maybe a
Just HostOnlyFuns
forall a. Set a
S.empty

--------------------------------------------------------------------------------
--                             MIGRATION ANALYSIS                             --
--------------------------------------------------------------------------------

-- | HostUsage identifies scalar variables that are used on host.
type HostUsage = [Id]

nameToId :: VName -> Id
nameToId :: VName -> Id
nameToId = VName -> Id
baseTag

-- | Analyses top-level constants.
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts :: HostOnlyFuns -> [FunDef GPU] -> Stms GPU -> MigrationTable
analyseConsts HostOnlyFuns
hof [FunDef GPU]
funs Stms GPU
consts =
  let usage :: [Id]
usage = ([Id] -> VName -> NameInfo GPU -> [Id])
-> [Id] -> Map VName (NameInfo GPU) -> [Id]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
M.foldlWithKey (Names -> [Id] -> VName -> NameInfo GPU -> [Id]
forall {t}. Typed t => Names -> [Id] -> VName -> t -> [Id]
f (Names -> [Id] -> VName -> NameInfo GPU -> [Id])
-> Names -> [Id] -> VName -> NameInfo GPU -> [Id]
forall a b. (a -> b) -> a -> b
$ [FunDef GPU] -> Names
forall a. FreeIn a => a -> Names
freeIn [FunDef GPU]
funs) [] (Stms GPU -> Map VName (NameInfo GPU)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms GPU
consts)
   in HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
consts
  where
    f :: Names -> [Id] -> VName -> t -> [Id]
f Names
free [Id]
usage VName
n t
t
      | t -> Bool
forall t. Typed t => t -> Bool
isScalar t
t,
        VName
n VName -> Names -> Bool
`nameIn` Names
free =
          VName -> Id
nameToId VName
n Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
usage
      | Bool
otherwise =
          [Id]
usage

-- | Analyses a top-level function definition.
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef :: HostOnlyFuns -> FunDef GPU -> MigrationTable
analyseFunDef HostOnlyFuns
hof FunDef GPU
fd =
  let body :: Body GPU
body = FunDef GPU -> Body GPU
forall rep. FunDef rep -> Body rep
funDefBody FunDef GPU
fd
      usage :: [Id]
usage = ([Id] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Id])
-> [Id] -> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Id]
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [Id] -> (SubExpRes, TypeBase ExtShape Uniqueness) -> [Id]
forall {shape} {u}. [Id] -> (SubExpRes, TypeBase shape u) -> [Id]
f [] ([(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Id])
-> [(SubExpRes, TypeBase ExtShape Uniqueness)] -> [Id]
forall a b. (a -> b) -> a -> b
$ [SubExpRes]
-> [TypeBase ExtShape Uniqueness]
-> [(SubExpRes, TypeBase ExtShape Uniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body) (((TypeBase ExtShape Uniqueness, RetAls)
 -> TypeBase ExtShape Uniqueness)
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase ExtShape Uniqueness, RetAls)
-> TypeBase ExtShape Uniqueness
forall a b. (a, b) -> a
fst ([(TypeBase ExtShape Uniqueness, RetAls)]
 -> [TypeBase ExtShape Uniqueness])
-> [(TypeBase ExtShape Uniqueness, RetAls)]
-> [TypeBase ExtShape Uniqueness]
forall a b. (a -> b) -> a -> b
$ FunDef GPU -> [(RetType GPU, RetAls)]
forall rep. FunDef rep -> [(RetType rep, RetAls)]
funDefRetType FunDef GPU
fd)
      stms :: Stms GPU
stms = Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body
   in HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
stms
  where
    f :: [Id] -> (SubExpRes, TypeBase shape u) -> [Id]
f [Id]
usage (SubExpRes Certs
_ (Var VName
n), TypeBase shape u
t) | TypeBase shape u -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType TypeBase shape u
t = VName -> Id
nameToId VName
n Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
usage
    f [Id]
usage (SubExpRes, TypeBase shape u)
_ = [Id]
usage

-- | Analyses statements. The 'HostUsage' list identifies which bound scalar
-- variables that subsequently may be used on host. All free variables such as
-- constants and function parameters are assumed to reside on host.
analyseStms :: HostOnlyFuns -> HostUsage -> Stms GPU -> MigrationTable
analyseStms :: HostOnlyFuns -> [Id] -> Stms GPU -> MigrationTable
analyseStms HostOnlyFuns
hof [Id]
usage Stms GPU
stms =
  let (Graph
g, Sources
srcs, [Id]
_) = HostOnlyFuns -> [Id] -> Stms GPU -> (Graph, Sources, [Id])
buildGraph HostOnlyFuns
hof [Id]
usage Stms GPU
stms
      ([Id]
routed, [Id]
unrouted) = Sources
srcs
      ([Id]
_, Graph
g') = [Id] -> Graph -> ([Id], Graph)
forall m. [Id] -> Graph m -> ([Id], Graph m)
MG.routeMany [Id]
unrouted Graph
g -- hereby routed
      f :: ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands), Visited ())
st' = Graph
-> ((Operands, Operands, Operands)
    -> EdgeType -> Vertex Meta -> (Operands, Operands, Operands))
-> ((Operands, Operands, Operands), Visited ())
-> EdgeType
-> Id
-> ((Operands, Operands, Operands), Visited ())
forall m a.
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> (a, Visited ())
-> EdgeType
-> Id
-> (a, Visited ())
MG.fold Graph
g' (Operands, Operands, Operands)
-> EdgeType -> Vertex Meta -> (Operands, Operands, Operands)
forall {m}.
(Operands, Operands, Operands)
-> EdgeType -> Vertex m -> (Operands, Operands, Operands)
visit ((Operands, Operands, Operands), Visited ())
st' EdgeType
Normal
      st :: ((Operands, Operands, Operands), Visited ())
st = (((Operands, Operands, Operands), Visited ())
 -> Id -> ((Operands, Operands, Operands), Visited ()))
-> ((Operands, Operands, Operands), Visited ())
-> [Id]
-> ((Operands, Operands, Operands), Visited ())
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands)
initial, Visited ()
forall a. Visited a
MG.none) [Id]
unrouted
      (Operands
vr, Operands
vn, Operands
tn) = ((Operands, Operands, Operands), Visited ())
-> (Operands, Operands, Operands)
forall a b. (a, b) -> a
fst (((Operands, Operands, Operands), Visited ())
 -> (Operands, Operands, Operands))
-> ((Operands, Operands, Operands), Visited ())
-> (Operands, Operands, Operands)
forall a b. (a -> b) -> a -> b
$ (((Operands, Operands, Operands), Visited ())
 -> Id -> ((Operands, Operands, Operands), Visited ()))
-> ((Operands, Operands, Operands), Visited ())
-> [Id]
-> ((Operands, Operands, Operands), Visited ())
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Operands, Operands, Operands), Visited ())
-> Id -> ((Operands, Operands, Operands), Visited ())
f ((Operands, Operands, Operands), Visited ())
st [Id]
routed
   in -- TODO: Delay reads into (deeper) branches

      IntMap MigrationStatus -> MigrationTable
MigrationTable (IntMap MigrationStatus -> MigrationTable)
-> IntMap MigrationStatus -> MigrationTable
forall a b. (a -> b) -> a -> b
$
        [IntMap MigrationStatus] -> IntMap MigrationStatus
forall (f :: * -> *) a. Foldable f => f (IntMap a) -> IntMap a
IM.unions
          [ (Id -> MigrationStatus) -> Operands -> IntMap MigrationStatus
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (MigrationStatus -> Id -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) Operands
vr,
            (Id -> MigrationStatus) -> Operands -> IntMap MigrationStatus
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (MigrationStatus -> Id -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
MoveToDevice) Operands
vn,
            -- Read by host if not reached by a reversed edge
            (Id -> MigrationStatus) -> Operands -> IntMap MigrationStatus
forall a. (Id -> a) -> Operands -> IntMap a
IM.fromSet (MigrationStatus -> Id -> MigrationStatus
forall a b. a -> b -> a
const MigrationStatus
UsedOnHost) Operands
tn
          ]
  where
    -- 1) Visited by reversed edge.
    -- 2) Visited by normal edge, no route.
    -- 3) Visited by normal edge, had route; will potentially be read by host.
    initial :: (Operands, Operands, Operands)
initial = (Operands
IS.empty, Operands
IS.empty, Operands
IS.empty)

    visit :: (Operands, Operands, Operands)
-> EdgeType -> Vertex m -> (Operands, Operands, Operands)
visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Reversed Vertex m
v =
      let vr' :: Operands
vr' = Id -> Operands -> Operands
IS.insert (Vertex m -> Id
forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
vr
       in (Operands
vr', Operands
vn, Operands
tn)
    visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Normal v :: Vertex m
v@Vertex {vertexRouting :: forall m. Vertex m -> Routing
vertexRouting = Routing
NoRoute} =
      let vn' :: Operands
vn' = Id -> Operands -> Operands
IS.insert (Vertex m -> Id
forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
vn
       in (Operands
vr, Operands
vn', Operands
tn)
    visit (Operands
vr, Operands
vn, Operands
tn) EdgeType
Normal Vertex m
v =
      let tn' :: Operands
tn' = Id -> Operands -> Operands
IS.insert (Vertex m -> Id
forall m. Vertex m -> Id
vertexId Vertex m
v) Operands
tn
       in (Operands
vr, Operands
vn, Operands
tn')

--------------------------------------------------------------------------------
--                                TYPE HELPERS                                --
--------------------------------------------------------------------------------

isScalar :: (Typed t) => t -> Bool
isScalar :: forall t. Typed t => t -> Bool
isScalar = Type -> Bool
forall shape u. TypeBase shape u -> Bool
isScalarType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf

isScalarType :: TypeBase shape u -> Bool
isScalarType :: forall shape u. TypeBase shape u -> Bool
isScalarType (Prim PrimType
Unit) = Bool
False
isScalarType (Prim PrimType
_) = Bool
True
isScalarType TypeBase shape u
_ = Bool
False

isArray :: (Typed t) => t -> Bool
isArray :: forall t. Typed t => t -> Bool
isArray = Type -> Bool
forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType (Type -> Bool) -> (t -> Type) -> t -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf

isArrayType :: (ArrayShape shape) => TypeBase shape u -> Bool
isArrayType :: forall shape u. ArrayShape shape => TypeBase shape u -> Bool
isArrayType = (Id
0 <) (Id -> Bool)
-> (TypeBase shape u -> Id) -> TypeBase shape u -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> Id
forall shape u. ArrayShape shape => TypeBase shape u -> Id
arrayRank

--------------------------------------------------------------------------------
--                               GRAPH BUILDING                               --
--------------------------------------------------------------------------------

buildGraph :: HostOnlyFuns -> HostUsage -> Stms GPU -> (Graph, Sources, Sinks)
buildGraph :: HostOnlyFuns -> [Id] -> Stms GPU -> (Graph, Sources, [Id])
buildGraph HostOnlyFuns
hof [Id]
usage Stms GPU
stms =
  let (Graph
g, Sources
srcs, [Id]
sinks) = HostOnlyFuns -> Grapher () -> (Graph, Sources, [Id])
forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Id])
execGrapher HostOnlyFuns
hof (Stms GPU -> Grapher ()
graphStms Stms GPU
stms)
      g' :: Graph
g' = (Graph -> Id -> Graph) -> Graph -> [Id] -> Graph
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> Graph -> Graph
forall m. Id -> Graph m -> Graph m
MG.connectToSink) Graph
g [Id]
usage
   in (Graph
g', Sources
srcs, [Id]
sinks)

-- | Graph a body.
graphBody :: Body GPU -> Grapher ()
graphBody :: Body GPU -> Grapher ()
graphBody Body GPU
body = do
  let res_ops :: Operands
res_ops = Names -> Operands
namesIntSet (Names -> Operands) -> Names -> Operands
forall a b. (a -> b) -> a -> b
$ [SubExpRes] -> Names
forall a. FreeIn a => a -> Names
freeIn (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
  BodyStats
body_stats <-
    Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> Grapher () -> Grapher BodyStats
forall a b. (a -> b) -> a -> b
$
      Grapher () -> Grapher ()
forall a. Grapher a -> Grapher a
incBodyDepthFor (Stms GPU -> Grapher ()
graphStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body) Grapher () -> Grapher () -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> StateT State (Reader Env) b -> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Operands -> Grapher ()
tellOperands Operands
res_ops)

  Id
body_depth <- (Id
1 +) (Id -> Id)
-> StateT State (Reader Env) Id -> StateT State (Reader Env) Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT State (Reader Env) Id
getBodyDepth
  let host_only :: Bool
host_only = Id -> Operands -> Bool
IS.member Id
body_depth (BodyStats -> Operands
bodyHostOnlyParents BodyStats
body_stats)
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
        hops' :: Operands
hops' = Id -> Operands -> Operands
IS.delete Id
body_depth (BodyStats -> Operands
bodyHostOnlyParents BodyStats
stats)
        -- If body contains a variable that is required on host the parent
        -- statement that contains this body cannot be migrated as a whole.
        stats' :: BodyStats
stats' = if Bool
host_only then BodyStats
stats {bodyHostOnly :: Bool
bodyHostOnly = Bool
True} else BodyStats
stats
     in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats' {bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
hops'}}

-- | Graph multiple statements.
graphStms :: Stms GPU -> Grapher ()
graphStms :: Stms GPU -> Grapher ()
graphStms = (Stm GPU -> Grapher ()) -> Stms GPU -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> Grapher ()
graphStm

-- | Graph a single statement.
graphStm :: Stm GPU -> Grapher ()
graphStm :: Stm GPU -> Grapher ()
graphStm Stm GPU
stm = do
  let bs :: [Binding]
bs = Stm GPU -> [Binding]
boundBy Stm GPU
stm
  let e :: Exp GPU
e = Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm
  -- IMPORTANT! It is generally assumed that all scalars within types and
  -- shapes are present on host. Any expression of a type wherein one of its
  -- scalar operands appears must therefore ensure that that scalar operand is
  -- marked as a size variable (see the 'hostSize' function).
  case Exp GPU
e of
    BasicOp (SubExp SubExp
se) -> do
      [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
    BasicOp (Opaque OpaqueOp
_ SubExp
se) -> do
      [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> SubExp -> Grapher ()
`reusesSubExp` SubExp
se
    BasicOp (ArrayLit [SubExp]
arr Type
t)
      | Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t,
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Maybe VName -> Bool
forall {a}. Maybe a -> Bool
isJust (Maybe VName -> Bool) -> (SubExp -> Maybe VName) -> SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Maybe VName
subExpVar) [SubExp]
arr ->
          -- Migrating an array literal with free variables saves a write for
          -- every scalar it contains. Under some backends the compiler
          -- generates asynchronous writes for scalar constants but otherwise
          -- each write will be synchronous. If all scalars are constants then
          -- the compiler generates more efficient code that copies static
          -- device memory.
          Binding -> Grapher ()
graphAutoMove ([Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs)
    BasicOp UnOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp BinOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp CmpOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp ConvOp {} -> [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp Assert {} ->
      -- == OpenCL =============================================================
      --
      -- The next read after the execution of a kernel containing an assertion
      -- will be made asynchronous, followed by an asynchronous read to check
      -- if any assertion failed. The runtime will then block for all enqueued
      -- operations to finish.
      --
      -- Since an assertion only binds a certificate of unit type, an assertion
      -- cannot increase the number of (read) synchronizations that occur. In
      -- this regard it is free to migrate. The synchronization that does occur
      -- is however (presumably) more expensive as the pipeline of GPU work will
      -- be flushed.
      --
      -- Since this cost is difficult to quantify and amortize over assertion
      -- migration candidates (cost depends on ordering of kernels and reads) we
      -- assume it is insignificant. This will likely hold for a system where
      -- multiple threads or processes schedules GPU work, as system-wide
      -- throughput only will decrease if the GPU utilization decreases as a
      -- result.
      --
      -- == CUDA ===============================================================
      --
      -- Under the CUDA backend every read is synchronous and is followed by
      -- a full synchronization that blocks for all enqueued operations to
      -- finish. If any enqueued kernel contained an assertion, another
      -- synchronous read is then made to check if an assertion failed.
      --
      -- Migrating an assertion to save a read may thus introduce new reads, and
      -- the total number of reads can hence either decrease, remain the same,
      -- or even increase, subject to the ordering of reads and kernels that
      -- perform assertions.
      --
      -- Since it is possible to implement the same failure checking scheme as
      -- OpenCL using asynchronous reads (and doing so would be a good idea!)
      -- we consider this to be acceptable.
      --
      -- TODO: Implement the OpenCL failure checking scheme under CUDA. This
      --       should reduce the number of synchronizations per read to one.
      [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    BasicOp (Index VName
_ Slice SubExp
slice)
      | Slice SubExp -> Bool
forall {d}. Slice d -> Bool
isFixed Slice SubExp
slice ->
          Binding -> Grapher ()
graphRead ([Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs)
    BasicOp {}
      | [(Id
_, Type
t)] <- [Binding]
bs,
        [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
t,
        [SubExp]
dims [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [], -- i.e. produces an array
        (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
dims ->
          -- An expression that produces an array that only contains a single
          -- primitive value is as efficient to compute and copy as a scalar,
          -- and introduces no size variables.
          --
          -- This is an exception to the inefficiency rules that comes next.
          [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e
    -- Expressions with a cost sublinear to the size of their result arrays are
    -- risky to migrate as we cannot guarantee that their results are not
    -- returned from a GPUBody, which always copies its return values. Since
    -- this would make the effective asymptotic cost of such statements linear
    -- we block them from being migrated on their own.
    --
    -- The parent statement of an enclosing body may still be migrated as a
    -- whole given that each of its returned arrays either
    --   1) is backed by memory used by a migratable statement within its body.
    --   2) contains just a single element.
    -- An array matching either criterion is denoted "copyable memory" because
    -- the asymptotic cost of copying it is less than or equal to the statement
    -- that produced it. This makes the parent of statements with sublinear cost
    -- safe to migrate.
    BasicOp (Index VName
arr Slice SubExp
s) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
s) Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Update Safety
_ VName
arr Slice SubExp
slice SubExp
_)
      | Slice SubExp -> Bool
forall {d}. Slice d -> Bool
isFixed Slice SubExp
slice -> do
          [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
          [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (FlatIndex VName
arr FlatSlice SubExp
s) -> do
      -- Migrating a FlatIndex leads to a memory allocation error.
      --
      -- TODO: Fix FlatIndex memory allocation error.
      --
      -- Can be replaced with 'graphHostOnly e' to disable migration.
      -- A fix can be verified by enabling tests/migration/reuse2_flatindex.fut
      [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (FlatSlice SubExp -> [SubExp]
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice SubExp
s) Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (FlatUpdate VName
arr FlatSlice SubExp
_ VName
_) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Scratch PrimType
_ [SubExp]
s) ->
      -- Migrating a Scratch leads to a memory allocation error.
      --
      -- TODO: Fix Scratch memory allocation error.
      --
      -- Can be replaced with 'graphHostOnly e' to disable migration.
      -- A fix can be verified by enabling tests/migration/reuse4_scratch.fut
      [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [SubExp]
s Exp GPU
e
    BasicOp (Reshape ReshapeKind
_ ShapeBase SubExp
s VName
arr) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
s) Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    BasicOp (Rearrange [Id]
_ VName
arr) -> do
      [SubExp] -> Exp GPU -> Grapher ()
forall {t :: * -> *}.
Foldable t =>
t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn [] Exp GPU
e
      [Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs Binding -> VName -> Grapher ()
`reuses` VName
arr
    -- Expressions with a cost linear to the size of their result arrays are
    -- inefficient to migrate into GPUBody kernels as such kernels are single-
    -- threaded. For sufficiently large arrays the cost may exceed what is saved
    -- by avoiding reads. We therefore also block these from being migrated,
    -- as well as their parents.
    BasicOp ArrayLit {} ->
      -- An array literal purely of primitive constants can be hoisted out to be
      -- a top-level constant, unless it is to be returned or consumed.
      -- Otherwise its runtime implementation will copy a precomputed static
      -- array and thus behave like a 'Copy'.
      -- Whether the rows are primitive constants or arrays, without any scalar
      -- variable operands such ArrayLit cannot directly prevent a scalar read.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Update {} ->
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Concat {} ->
      -- Is unlikely to prevent a scalar read as the only SubExp operand in
      -- practice is a computation of host-only size variables.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Manifest {} ->
      -- Takes no scalar operands so cannot directly prevent a scalar read.
      -- It is introduced as part of the BlkRegTiling kernel optimization and
      -- is thus unlikely to prevent the migration of a parent which was not
      -- already blocked by some host-only operation.
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Iota {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    BasicOp Replicate {} -> Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    -- END
    BasicOp UpdateAcc {} ->
      Binding -> Exp GPU -> Grapher ()
graphUpdateAcc ([Binding] -> Binding
forall {a}. [a] -> a
one [Binding]
bs) Exp GPU
e
    Apply Name
fn [(SubExp, Diet)]
_ [(RetType GPU, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_ ->
      Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e
    Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_ ->
      [Binding]
-> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch [Binding]
bs [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody
    Loop [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body ->
      [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Grapher ()
graphLoop [Binding]
bs [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body
    WithAcc [WithAccInput GPU]
inputs Lambda GPU
f ->
      [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f
    Op GPUBody {} ->
      -- A GPUBody can be migrated into a parent GPUBody by replacing it with
      -- its body statements and binding its return values inside 'ArrayLit's.
      Grapher ()
tellGPUBody
    Op Op GPU
_ ->
      Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
  where
    one :: [a] -> a
one [a
x] = a
x
    one [a]
_ = String -> a
forall a. String -> a
compilerBugS String
"Type error: unexpected number of pattern elements."

    isFixed :: Slice d -> Bool
isFixed = Maybe [d] -> Bool
forall {a}. Maybe a -> Bool
isJust (Maybe [d] -> Bool) -> (Slice d -> Maybe [d]) -> Slice d -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Slice d -> Maybe [d]
forall d. Slice d -> Maybe [d]
sliceIndices

    -- new_dims may introduce new size variables which must be present on host
    -- when this expression is evaluated.
    graphInefficientReturn :: t SubExp -> Exp GPU -> Grapher ()
graphInefficientReturn t SubExp
new_dims Exp GPU
e = do
      (SubExp -> Grapher ()) -> t SubExp -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
hostSize t SubExp
new_dims
      Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges Edges
ToSink

    hostSize :: SubExp -> Grapher ()
hostSize (Var VName
n) = VName -> Grapher ()
hostSizeVar VName
n
    hostSize SubExp
_ = () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    hostSizeVar :: VName -> Grapher ()
hostSizeVar = Id -> Grapher ()
requiredOnHost (Id -> Grapher ()) -> (VName -> Id) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Id
nameToId

-- | Bindings for all pattern elements bound by a statement.
boundBy :: Stm GPU -> [Binding]
boundBy :: Stm GPU -> [Binding]
boundBy = (PatElem Type -> Binding) -> [PatElem Type] -> [Binding]
forall a b. (a -> b) -> [a] -> [b]
map (\(PatElem VName
n Type
t) -> (VName -> Id
nameToId VName
n, Type
t)) ([PatElem Type] -> [Binding])
-> (Stm GPU -> [PatElem Type]) -> Stm GPU -> [Binding]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems (Pat Type -> [PatElem Type])
-> (Stm GPU -> Pat Type) -> Stm GPU -> [PatElem Type]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm GPU -> Pat Type
Stm GPU -> Pat (LetDec GPU)
forall rep. Stm rep -> Pat (LetDec rep)
stmPat

-- | Graph a statement which in itself neither reads scalars from device memory
-- nor forces such scalars to be available on host. Such statement can be moved
-- to device to eliminate the host usage of its operands which transitively may
-- depend on a scalar device read.
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple :: [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e = do
  -- Only add vertices to the graph if they have a transitive dependency to
  -- an array read. Transitive dependencies through variables connected to
  -- sinks do not count.
  Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e
  let edges :: Edges
edges = [Id] -> Edges
MG.declareEdges ((Binding -> Id) -> [Binding] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map Binding -> Id
forall a b. (a, b) -> a
fst [Binding]
bs)
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Operands -> Bool
IS.null Operands
ops) ((Binding -> Grapher ()) -> [Binding] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Binding -> Grapher ()
addVertex [Binding]
bs Grapher () -> Grapher () -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> StateT State (Reader Env) b -> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> Operands -> Grapher ()
addEdges Edges
edges Operands
ops)

-- | Graph a statement that reads a scalar from device memory.
graphRead :: Binding -> Grapher ()
graphRead :: Binding -> Grapher ()
graphRead Binding
b = do
  -- Operands are not important as the source will block routes through b.
  Binding -> Grapher ()
addSource Binding
b
  Grapher ()
tellRead

-- | Graph a statement that always should be moved to device.
graphAutoMove :: Binding -> Grapher ()
graphAutoMove :: Binding -> Grapher ()
graphAutoMove =
  -- Operands are not important as the source will block routes through b.
  Binding -> Grapher ()
addSource

-- | Graph a statement that is unfit for execution in a GPUBody and thus must
-- be executed on host, requiring all its operands to be made available there.
-- Parent statements of enclosing bodies are also blocked from being migrated.
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly :: Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e = do
  -- Connect the vertices of all operands to sinks to mark that they are
  -- required on host. Transitive reads that they depend upon can be delayed
  -- no further, and any parent statements cannot be migrated.
  Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e
  Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
ops
  Grapher ()
tellHostOnly

-- | Graph an 'UpdateAcc' statement.
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc :: Binding -> Exp GPU -> Grapher ()
graphUpdateAcc Binding
b Exp GPU
e | (Id
_, Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) <- Binding
b =
  -- The actual graphing is delayed to the corrensponding 'WithAcc' parent.
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let accs :: IntMap [Delayed]
accs = State -> IntMap [Delayed]
stateUpdateAccs State
st
        accs' :: IntMap [Delayed]
accs' = (Maybe [Delayed] -> Maybe [Delayed])
-> Id -> IntMap [Delayed] -> IntMap [Delayed]
forall a. (Maybe a -> Maybe a) -> Id -> IntMap a -> IntMap a
IM.alter Maybe [Delayed] -> Maybe [Delayed]
add (VName -> Id
nameToId VName
a) IntMap [Delayed]
accs
     in State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
accs'}
  where
    add :: Maybe [Delayed] -> Maybe [Delayed]
add Maybe [Delayed]
Nothing = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just [(Binding
b, Exp GPU
e)]
    add (Just [Delayed]
xs) = [Delayed] -> Maybe [Delayed]
forall a. a -> Maybe a
Just ([Delayed] -> Maybe [Delayed]) -> [Delayed] -> Maybe [Delayed]
forall a b. (a -> b) -> a -> b
$ (Binding
b, Exp GPU
e) Delayed -> [Delayed] -> [Delayed]
forall a. a -> [a] -> [a]
: [Delayed]
xs
graphUpdateAcc Binding
_ Exp GPU
_ =
  String -> Grapher ()
forall a. String -> a
compilerBugS
    String
"Type error: UpdateAcc did not produce accumulator typed value."

-- | Graph a function application.
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply :: Name -> [Binding] -> Exp GPU -> Grapher ()
graphApply Name
fn [Binding]
bs Exp GPU
e = do
  Bool
hof <- Name -> Grapher Bool
isHostOnlyFun Name
fn
  if Bool
hof
    then Exp GPU -> Grapher ()
graphHostOnly Exp GPU
e
    else [Binding] -> Exp GPU -> Grapher ()
graphSimple [Binding]
bs Exp GPU
e

-- | Graph a Match statement.
graphMatch :: [Binding] -> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch :: [Binding]
-> [SubExp] -> [Case (Body GPU)] -> Body GPU -> Grapher ()
graphMatch [Binding]
bs [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody = do
  Bool
body_host_only <-
    Grapher Bool -> Grapher Bool
forall a. Grapher a -> Grapher a
incForkDepthFor (Grapher Bool -> Grapher Bool) -> Grapher Bool -> Grapher Bool
forall a b. (a -> b) -> a -> b
$
      (BodyStats -> Bool) -> [BodyStats] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any BodyStats -> Bool
bodyHostOnly
        ([BodyStats] -> Bool)
-> StateT State (Reader Env) [BodyStats] -> Grapher Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Body GPU -> Grapher BodyStats)
-> [Body GPU] -> StateT State (Reader Env) [BodyStats]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Grapher () -> Grapher BodyStats)
-> (Body GPU -> Grapher ()) -> Body GPU -> Grapher BodyStats
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body GPU -> Grapher ()
graphBody) (Body GPU
defbody Body GPU -> [Body GPU] -> [Body GPU]
forall a. a -> [a] -> [a]
: (Case (Body GPU) -> Body GPU) -> [Case (Body GPU)] -> [Body GPU]
forall a b. (a -> b) -> [a] -> [b]
map Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody [Case (Body GPU)]
cases)

  let branch_results :: [[SubExp]]
branch_results = Body GPU -> [SubExp]
forall {rep}. Body rep -> [SubExp]
results Body GPU
defbody [SubExp] -> [[SubExp]] -> [[SubExp]]
forall a. a -> [a] -> [a]
: (Case (Body GPU) -> [SubExp]) -> [Case (Body GPU)] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
map (Body GPU -> [SubExp]
forall {rep}. Body rep -> [SubExp]
results (Body GPU -> [SubExp])
-> (Case (Body GPU) -> Body GPU) -> Case (Body GPU) -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases

  -- Record aliases for copyable memory backing returned arrays.
  Bool
may_copy_results <- [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches [Binding]
bs [[SubExp]]
branch_results
  let may_migrate :: Bool
may_migrate = Bool -> Bool
not Bool
body_host_only Bool -> Bool -> Bool
&& Bool
may_copy_results

  Operands
cond_id <-
    if Bool
may_migrate
      then [VName] -> Grapher Operands
forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars ([VName] -> Grapher Operands) -> [VName] -> Grapher Operands
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
ses
      else do
        -- The migration status of the condition is what determines
        -- whether the statement may be migrated as a whole or
        -- not. See 'shouldMoveStm'.
        (VName -> Grapher ()) -> [VName] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Id -> Grapher ()
connectToSink (Id -> Grapher ()) -> (VName -> Id) -> VName -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Id
nameToId) ([SubExp] -> [VName]
subExpVars [SubExp]
ses)
        Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty

  Operands -> Grapher ()
tellOperands Operands
cond_id

  -- Connect branch results to bound variables to allow delaying reads out of
  -- branches. It might also be beneficial to move the whole statement to
  -- device, to avoid reading the branch condition value. This must be balanced
  -- against the need to read the values bound by the if statement.
  --
  -- By connecting the branch condition to each variable bound by the statement
  -- the condition will only stay on device if
  --
  --   (1) the if statement is not required on host, based on the statements
  --       within its body.
  --
  --   (2) no additional reads will be required to use the if statement bound
  --       variables should the whole statement be migrated.
  --
  -- If the condition is migrated to device and stays there, then the if
  -- statement must necessarily execute on device.
  --
  -- While the graph model built by this module generally migrates no more
  -- statements than necessary to obtain a minimum vertex cut, the branches
  -- of if statements are subject to an inaccuracy. Specifically model is not
  -- strong enough to capture their mutual exclusivity and thus encodes that
  -- both branches are taken. While this does not affect the resulting number
  -- of host-device reads it means that some reads may needlessly be delayed
  -- out of branches. The overhead as measured on futhark-benchmarks appears
  -- to be neglible though.
  [Operands]
ret <- ([SubExp] -> Grapher Operands)
-> [[SubExp]] -> StateT State (Reader Env) [Operands]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Operands -> [SubExp] -> Grapher Operands
comb Operands
cond_id) ([[SubExp]] -> StateT State (Reader Env) [Operands])
-> [[SubExp]] -> StateT State (Reader Env) [Operands]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [[a]] -> [[a]]
L.transpose [[SubExp]]
branch_results
  ((Binding, Operands) -> Grapher ())
-> [(Binding, Operands)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> Operands -> Grapher ())
-> (Binding, Operands) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> Operands -> Grapher ()
createNode) ([Binding] -> [Operands] -> [(Binding, Operands)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [Operands]
ret)
  where
    results :: Body rep -> [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp ([SubExpRes] -> [SubExp])
-> (Body rep -> [SubExpRes]) -> Body rep -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Body rep -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult

    comb :: Operands -> [SubExp] -> Grapher Operands
comb Operands
ci [SubExp]
a = (Operands
ci <>) (Operands -> Operands) -> Grapher Operands -> Grapher Operands
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Set VName -> Grapher Operands
forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars ([VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName) -> [VName] -> Set VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName]
subExpVars [SubExp]
a)

-----------------------------------------------------
-- These type aliases are only used by 'graphLoop' --
-----------------------------------------------------
type ReachableBindings = IdSet

type ReachableBindingsCache = MG.Visited (MG.Result ReachableBindings)

type NonExhausted = [Id]

type LoopValue = (Binding, Id, SubExp, SubExp)

-----------------------------------------------------
-----------------------------------------------------

-- | Graph a loop statement.
graphLoop ::
  [Binding] ->
  [(FParam GPU, SubExp)] ->
  LoopForm ->
  Body GPU ->
  Grapher ()
graphLoop :: [Binding]
-> [(FParam GPU, SubExp)] -> LoopForm -> Body GPU -> Grapher ()
graphLoop [] [(FParam GPU, SubExp)]
_ LoopForm
_ Body GPU
_ =
  -- We expect each loop to bind a value or be eliminated.
  String -> Grapher ()
forall a. String -> a
compilerBugS String
"Loop statement bound no variable; should have been eliminated."
graphLoop (Binding
b : [Binding]
bs) [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body = do
  -- Graph loop params and body while capturing statistics.
  Graph
g <- Grapher Graph
getGraph
  BodyStats
stats <- Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats (Id
subgraphId Id -> Grapher () -> Grapher ()
forall a. Id -> Grapher a -> Grapher a
`graphIdFor` Grapher ()
graphTheLoop)

  -- Record aliases for copyable memory backing returned arrays.
  -- Does the loop return any arrays which prevent it from being migrated?
  let args :: [SubExp]
args = ((Param DeclType, SubExp) -> SubExp)
-> [(Param DeclType, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> SubExp
forall a b. (a, b) -> b
snd [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params
  let results :: [SubExp]
results = (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
  Bool
may_copy_results <- [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [[SubExp]
args, [SubExp]
results]

  -- Connect the loop condition to a sink if the loop cannot be migrated,
  -- ensuring that it will be available to the host. The migration status
  -- of the condition is what determines whether the loop may be migrated
  -- as a whole or not. See 'shouldMoveStm'.
  let may_migrate :: Bool
may_migrate = Bool -> Bool
not (BodyStats -> Bool
bodyHostOnly BodyStats
stats) Bool -> Bool -> Bool
&& Bool
may_copy_results
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm
lform of
    ForLoop VName
_ IntType
_ (Var VName
n) -> Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
n)
    WhileLoop VName
n
      | Just (Binding
_, Id
p, SubExp
_, SubExp
res) <- VName -> Maybe (Binding, Id, SubExp, SubExp)
loopValueFor VName
n -> do
          Id -> Grapher ()
connectToSink Id
p
          case SubExp
res of
            Var VName
v -> Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
v)
            SubExp
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    LoopForm
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

  -- Connect graphed return values to their loop parameters.
  ((Binding, Id, SubExp, SubExp) -> Grapher ())
-> [(Binding, Id, SubExp, SubExp)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding, Id, SubExp, SubExp) -> Grapher ()
mergeLoopParam [(Binding, Id, SubExp, SubExp)]
loopValues

  -- Route the sources within the loop body in isolation.
  -- The loop graph must not be altered after this point.
  [Id]
srcs <- Id -> Grapher [Id]
routeSubgraph Id
subgraphId

  -- Graph the variables bound by the statement.
  [(Binding, Id, SubExp, SubExp)]
-> ((Binding, Id, SubExp, SubExp) -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [(Binding, Id, SubExp, SubExp)]
loopValues (((Binding, Id, SubExp, SubExp) -> Grapher ()) -> Grapher ())
-> ((Binding, Id, SubExp, SubExp) -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \(Binding
bnd, Id
p, SubExp
_, SubExp
_) -> Binding -> Operands -> Grapher ()
createNode Binding
bnd (Id -> Operands
IS.singleton Id
p)

  -- If a device read is delayed from one iteration to the next the
  -- corresponding variables bound by the statement must be treated as
  -- sources.
  Graph
g' <- Grapher Graph
getGraph
  let (Operands
dbs, ReachableBindingsCache
rbc) = ((Operands, ReachableBindingsCache)
 -> Id -> (Operands, ReachableBindingsCache))
-> (Operands, ReachableBindingsCache)
-> [Id]
-> (Operands, ReachableBindingsCache)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (Graph
-> (Operands, ReachableBindingsCache)
-> Id
-> (Operands, ReachableBindingsCache)
deviceBindings Graph
g') (Operands
IS.empty, ReachableBindingsCache
forall a. Visited a
MG.none) [Id]
srcs
  (Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Id] -> [Id]) -> Sources -> Sources
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Operands -> [Id]
IS.toList Operands
dbs <>)

  -- Connect operands to sinks if they can reach a sink within the loop.
  -- Otherwise connect them to the loop bound variables that they can
  -- reach and exhaust their normal entry edges into the loop.
  -- This means a read can be delayed through a loop but not into it if
  -- that would increase the number of reads done by any given iteration.
  let ops :: Operands
ops = (Id -> Bool) -> Operands -> Operands
IS.filter (Id -> Graph -> Bool
forall m. Id -> Graph m -> Bool
`MG.member` Graph
g) (BodyStats -> Operands
bodyOperands BodyStats
stats)
  (ReachableBindingsCache
 -> Id -> StateT State (Reader Env) ReachableBindingsCache)
-> ReachableBindingsCache -> [Id] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ ReachableBindingsCache
-> Id -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
rbc (Operands -> [Id]
IS.elems Operands
ops)

  -- It might be beneficial to move the whole loop to device, to avoid
  -- reading the (initial) loop condition value. This must be balanced
  -- against the need to read the values bound by the loop statement.
  --
  -- For more details see the similar description for if statements.
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
may_migrate (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ case LoopForm
lform of
    ForLoop VName
_ IntType
_ SubExp
n ->
      SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
n Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges (Operands -> Maybe Operands -> Edges
ToNodes Operands
bindings Maybe Operands
forall a. Maybe a
Nothing)
    WhileLoop VName
n
      | Just (Binding
_, Id
_, SubExp
arg, SubExp
_) <- VName -> Maybe (Binding, Id, SubExp, SubExp)
loopValueFor VName
n ->
          SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
arg Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Edges -> Operands -> Grapher ()
addEdges (Operands -> Maybe Operands -> Edges
ToNodes Operands
bindings Maybe Operands
forall a. Maybe a
Nothing)
    LoopForm
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  where
    subgraphId :: Id
    subgraphId :: Id
subgraphId = Binding -> Id
forall a b. (a, b) -> a
fst Binding
b

    loopValues :: [LoopValue]
    loopValues :: [(Binding, Id, SubExp, SubExp)]
loopValues =
      let tmp :: [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp = [Binding]
-> [(Param DeclType, SubExp)]
-> [SubExpRes]
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Binding
b Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bs) [(Param DeclType, SubExp)]
[(FParam GPU, SubExp)]
params (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
          tmp' :: [(Binding, Id, SubExp, SubExp)]
tmp' = (((Binding, (Param DeclType, SubExp), SubExpRes)
  -> (Binding, Id, SubExp, SubExp))
 -> [(Binding, (Param DeclType, SubExp), SubExpRes)]
 -> [(Binding, Id, SubExp, SubExp)])
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> ((Binding, (Param DeclType, SubExp), SubExpRes)
    -> (Binding, Id, SubExp, SubExp))
-> [(Binding, Id, SubExp, SubExp)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Binding, (Param DeclType, SubExp), SubExpRes)
 -> (Binding, Id, SubExp, SubExp))
-> [(Binding, (Param DeclType, SubExp), SubExpRes)]
-> [(Binding, Id, SubExp, SubExp)]
forall a b. (a -> b) -> [a] -> [b]
map [(Binding, (Param DeclType, SubExp), SubExpRes)]
tmp (((Binding, (Param DeclType, SubExp), SubExpRes)
  -> (Binding, Id, SubExp, SubExp))
 -> [(Binding, Id, SubExp, SubExp)])
-> ((Binding, (Param DeclType, SubExp), SubExpRes)
    -> (Binding, Id, SubExp, SubExp))
-> [(Binding, Id, SubExp, SubExp)]
forall a b. (a -> b) -> a -> b
$
            \(Binding
bnd, (Param DeclType
p, SubExp
arg), SubExpRes
res) ->
              let i :: Id
i = VName -> Id
nameToId (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p)
               in (Binding
bnd, Id
i, SubExp
arg, SubExpRes -> SubExp
resSubExp SubExpRes
res)
       in ((Binding, Id, SubExp, SubExp) -> Bool)
-> [(Binding, Id, SubExp, SubExp)]
-> [(Binding, Id, SubExp, SubExp)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\((Id
_, Type
t), Id
_, SubExp
_, SubExp
_) -> Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) [(Binding, Id, SubExp, SubExp)]
tmp'

    bindings :: IdSet
    bindings :: Operands
bindings = [Id] -> Operands
IS.fromList ([Id] -> Operands) -> [Id] -> Operands
forall a b. (a -> b) -> a -> b
$ ((Binding, Id, SubExp, SubExp) -> Id)
-> [(Binding, Id, SubExp, SubExp)] -> [Id]
forall a b. (a -> b) -> [a] -> [b]
map (\((Id
i, Type
_), Id
_, SubExp
_, SubExp
_) -> Id
i) [(Binding, Id, SubExp, SubExp)]
loopValues

    loopValueFor :: VName -> Maybe (Binding, Id, SubExp, SubExp)
loopValueFor VName
n =
      ((Binding, Id, SubExp, SubExp) -> Bool)
-> [(Binding, Id, SubExp, SubExp)]
-> Maybe (Binding, Id, SubExp, SubExp)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (\(Binding
_, Id
p, SubExp
_, SubExp
_) -> Id
p Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> Id
nameToId VName
n) [(Binding, Id, SubExp, SubExp)]
loopValues

    graphTheLoop :: Grapher ()
    graphTheLoop :: Grapher ()
graphTheLoop = do
      ((Binding, Id, SubExp, SubExp) -> Grapher ())
-> [(Binding, Id, SubExp, SubExp)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding, Id, SubExp, SubExp) -> Grapher ()
forall {a} {d}. ((a, Type), Id, SubExp, d) -> Grapher ()
graphParam [(Binding, Id, SubExp, SubExp)]
loopValues

      -- For simplicity we do not currently track memory reuse through merge
      -- parameters. A parameter does not simply reuse the memory of its
      -- argument; it must also consider the iteration return value, which in
      -- turn may depend on other merge parameters.
      --
      -- Situations that would benefit from this tracking is unlikely to occur
      -- at the time of writing, and if it occurs current compiler limitations
      -- will prevent successful compilation.
      -- Specifically it requires the merge parameter argument to reuse memory
      -- from an array literal, and both it and the loop must occur within an
      -- if statement branch. Array literals are generally hoisted out of if
      -- statements however, and when they are not, a memory allocation error
      -- occurs.
      --
      -- TODO: Track memory reuse through merge parameters.

      case LoopForm
lform of
        ForLoop VName
_ IntType
_ SubExp
n ->
          SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
n Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Operands -> Grapher ()
tellOperands
        WhileLoop VName
_ -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
      Body GPU -> Grapher ()
graphBody Body GPU
body
      where
        graphParam :: ((a, Type), Id, SubExp, d) -> Grapher ()
graphParam ((a
_, Type
t), Id
p, SubExp
arg, d
_) =
          do
            -- It is unknown whether a read can be delayed via the parameter
            -- from one iteration to the next, so we have to create a vertex
            -- even if the initial value never depends on a read.
            Binding -> Grapher ()
addVertex (Id
p, Type
t)
            Operands
ops <- SubExp -> Grapher Operands
onlyGraphedScalarSubExp SubExp
arg
            Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge Id
p) Operands
ops

    mergeLoopParam :: LoopValue -> Grapher ()
    mergeLoopParam :: (Binding, Id, SubExp, SubExp) -> Grapher ()
mergeLoopParam (Binding
_, Id
p, SubExp
_, SubExp
res)
      | Var VName
n <- SubExp
res,
        Id
ret <- VName -> Id
nameToId VName
n,
        Id
ret Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
/= Id
p =
          Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge Id
p) (Id -> Operands
IS.singleton Id
ret)
      | Bool
otherwise =
          () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    deviceBindings ::
      Graph ->
      (ReachableBindings, ReachableBindingsCache) ->
      Id ->
      (ReachableBindings, ReachableBindingsCache)
    deviceBindings :: Graph
-> (Operands, ReachableBindingsCache)
-> Id
-> (Operands, ReachableBindingsCache)
deviceBindings Graph
g (Operands
rb, ReachableBindingsCache
rbc) Id
i =
      let (Result Operands
r, ReachableBindingsCache
rbc') = Graph
-> (Operands -> EdgeType -> Vertex Meta -> Operands)
-> ReachableBindingsCache
-> EdgeType
-> Id
-> (Result Operands, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Id
-> (Result a, Visited (Result a))
MG.reduce Graph
g Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Id
i
       in case Result Operands
r of
            Produced Operands
rb' -> (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
rb', ReachableBindingsCache
rbc')
            Result Operands
_ ->
              String -> (Operands, ReachableBindingsCache)
forall a. String -> a
compilerBugS
                String
"Migration graph sink could be reached from source after it\
                \ had been attempted routed."
    bindingReach ::
      ReachableBindings ->
      EdgeType ->
      Vertex Meta ->
      ReachableBindings
    bindingReach :: Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach Operands
rb EdgeType
_ Vertex Meta
v
      | Id
i <- Vertex Meta -> Id
forall m. Vertex m -> Id
vertexId Vertex Meta
v,
        Id -> Operands -> Bool
IS.member Id
i Operands
bindings =
          Id -> Operands -> Operands
IS.insert Id
i Operands
rb
      | Bool
otherwise =
          Operands
rb
    connectOperand ::
      ReachableBindingsCache ->
      Id ->
      Grapher ReachableBindingsCache
    connectOperand :: ReachableBindingsCache
-> Id -> StateT State (Reader Env) ReachableBindingsCache
connectOperand ReachableBindingsCache
cache Id
op = do
      Graph
g <- Grapher Graph
getGraph
      case Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
op Graph
g of
        Maybe (Vertex Meta)
Nothing -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
        Just Vertex Meta
v ->
          case Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v of
            Edges
ToSink -> ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
cache
            ToNodes Operands
es Maybe Operands
Nothing -> Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Id
op Operands
es
            ToNodes Operands
_ (Just Operands
nx) -> Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
cache Id
op Operands
nx
      where
        connectOp ::
          Graph ->
          ReachableBindingsCache ->
          Id -> -- operand id
          IdSet -> -- its edges
          Grapher ReachableBindingsCache
        connectOp :: Graph
-> ReachableBindingsCache
-> Id
-> Operands
-> StateT State (Reader Env) ReachableBindingsCache
connectOp Graph
g ReachableBindingsCache
rbc Id
i Operands
es = do
          let (Result Operands
res, [Id]
nx, ReachableBindingsCache
rbc') = Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
IS.empty, [], ReachableBindingsCache
rbc) (Operands -> [Id]
IS.elems Operands
es)
          case Result Operands
res of
            Result Operands
FoundSink -> Id -> Grapher ()
connectToSink Id
i
            Produced Operands
rb -> (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Vertex Meta -> Vertex Meta) -> Id -> Graph -> Graph
forall m. (Vertex m -> Vertex m) -> Id -> Graph m -> Graph m
MG.adjust ([Id] -> Operands -> Vertex Meta -> Vertex Meta
updateEdges [Id]
nx Operands
rb) Id
i
          ReachableBindingsCache
-> StateT State (Reader Env) ReachableBindingsCache
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ReachableBindingsCache
rbc'
        updateEdges ::
          NonExhausted ->
          ReachableBindings ->
          Vertex Meta ->
          Vertex Meta
        updateEdges :: [Id] -> Operands -> Vertex Meta -> Vertex Meta
updateEdges [Id]
nx Operands
rb Vertex Meta
v
          | ToNodes Operands
es Maybe Operands
_ <- Vertex Meta -> Edges
forall m. Vertex m -> Edges
vertexEdges Vertex Meta
v =
              let nx' :: Operands
nx' = [Id] -> Operands
IS.fromList [Id]
nx
                  es' :: Edges
es' = Operands -> Maybe Operands -> Edges
ToNodes (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
es) (Maybe Operands -> Edges) -> Maybe Operands -> Edges
forall a b. (a -> b) -> a -> b
$ Operands -> Maybe Operands
forall a. a -> Maybe a
Just (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
nx')
               in Vertex Meta
v {vertexEdges :: Edges
vertexEdges = Edges
es'}
          | Bool
otherwise = Vertex Meta
v
        findBindings ::
          Graph ->
          (ReachableBindings, NonExhausted, ReachableBindingsCache) ->
          [Id] -> -- current non-exhausted edges
          (MG.Result ReachableBindings, NonExhausted, ReachableBindingsCache)
        findBindings :: Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
_ (Operands
rb, [Id]
nx, ReachableBindingsCache
rbc) [] =
          (Operands -> Result Operands
forall a. a -> Result a
Produced Operands
rb, [Id]
nx, ReachableBindingsCache
rbc)
        findBindings Graph
g (Operands
rb, [Id]
nx, ReachableBindingsCache
rbc) (Id
i : [Id]
is)
          | Just Vertex Meta
v <- Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i Graph
g,
            Just Id
gid <- Meta -> Maybe Id
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v),
            Id
gid Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
subgraphId -- only search the subgraph
            =
              let (Result Operands
res, ReachableBindingsCache
rbc') = Graph
-> (Operands -> EdgeType -> Vertex Meta -> Operands)
-> ReachableBindingsCache
-> EdgeType
-> Id
-> (Result Operands, ReachableBindingsCache)
forall a m.
Monoid a =>
Graph m
-> (a -> EdgeType -> Vertex m -> a)
-> Visited (Result a)
-> EdgeType
-> Id
-> (Result a, Visited (Result a))
MG.reduce Graph
g Operands -> EdgeType -> Vertex Meta -> Operands
bindingReach ReachableBindingsCache
rbc EdgeType
Normal Id
i
               in case Result Operands
res of
                    Result Operands
FoundSink -> (Result Operands
forall a. Result a
FoundSink, [], ReachableBindingsCache
rbc')
                    Produced Operands
rb' -> Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
rb Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
rb', [Id]
nx, ReachableBindingsCache
rbc') [Id]
is
          | Bool
otherwise =
              -- don't exhaust
              Graph
-> (Operands, [Id], ReachableBindingsCache)
-> [Id]
-> (Result Operands, [Id], ReachableBindingsCache)
findBindings Graph
g (Operands
rb, Id
i Id -> [Id] -> [Id]
forall a. a -> [a] -> [a]
: [Id]
nx, ReachableBindingsCache
rbc) [Id]
is

-- | Graph a 'WithAcc' statement.
graphWithAcc ::
  [Binding] ->
  [WithAccInput GPU] ->
  Lambda GPU ->
  Grapher ()
graphWithAcc :: [Binding] -> [WithAccInput GPU] -> Lambda GPU -> Grapher ()
graphWithAcc [Binding]
bs [WithAccInput GPU]
inputs Lambda GPU
f = do
  -- Graph the body, capturing 'UpdateAcc' statements for delayed graphing.
  Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)

  -- Graph each accumulator monoid and its associated 'UpdateAcc' statements.
  ((Type, WithAccInput GPU) -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Type, WithAccInput GPU) -> Grapher ()
forall {shape} {u} {a} {b}.
(TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph ([(Type, WithAccInput GPU)] -> Grapher ())
-> [(Type, WithAccInput GPU)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Type] -> [WithAccInput GPU] -> [(Type, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f) [WithAccInput GPU]
inputs

  -- Record aliases for the backing memory of each returned array.
  -- 'WithAcc' statements are never migrated as a whole and always returns
  -- arrays backed by memory allocated elsewhere.
  let arrs :: [SubExp]
arrs = (WithAccInput GPU -> [SubExp]) -> [WithAccInput GPU] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(ShapeBase SubExp
_, [VName]
as, Maybe (Lambda GPU, [SubExp])
_) -> (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
as) [WithAccInput GPU]
inputs
  let res :: [SubExpRes]
res = Id -> [SubExpRes] -> [SubExpRes]
forall a. Id -> [a] -> [a]
drop ([WithAccInput GPU] -> Id
forall a. [a] -> Id
forall (t :: * -> *) a. Foldable t => t a -> Id
length [WithAccInput GPU]
inputs) (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult (Body GPU -> [SubExpRes]) -> Body GPU -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
  Bool
_ <- [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs ([SubExp]
arrs [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExpRes -> SubExp) -> [SubExpRes] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp [SubExpRes]
res)

  -- Connect return variables to bound values. No outgoing edge exists
  -- from an accumulator vertex so skip those. Note that accumulators do
  -- not map to returned arrays one-to-one but one-to-many.
  [Operands]
ret <- (SubExpRes -> Grapher Operands)
-> [SubExpRes] -> StateT State (Reader Env) [Operands]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (SubExp -> Grapher Operands
onlyGraphedScalarSubExp (SubExp -> Grapher Operands)
-> (SubExpRes -> SubExp) -> SubExpRes -> Grapher Operands
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) [SubExpRes]
res
  ((Binding, Operands) -> Grapher ())
-> [(Binding, Operands)] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((Binding -> Operands -> Grapher ())
-> (Binding, Operands) -> Grapher ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Binding -> Operands -> Grapher ()
createNode) ([(Binding, Operands)] -> Grapher ())
-> [(Binding, Operands)] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [Binding] -> [Operands] -> [(Binding, Operands)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Id -> [Binding] -> [Binding]
forall a. Id -> [a] -> [a]
drop ([SubExp] -> Id
forall a. [a] -> Id
forall (t :: * -> *) a. Foldable t => t a -> Id
length [SubExp]
arrs) [Binding]
bs) [Operands]
ret
  where
    graph :: (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
-> Grapher ()
graph (Acc VName
a ShapeBase SubExp
_ [Type]
types u
_, (a
_, b
_, Maybe (Lambda GPU, [SubExp])
comb)) = do
      let i :: Id
i = VName -> Id
nameToId VName
a

      [Delayed]
delayed <- [Delayed] -> Maybe [Delayed] -> [Delayed]
forall a. a -> Maybe a -> a
fromMaybe [] (Maybe [Delayed] -> [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
-> StateT State (Reader Env) [Delayed]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (State -> Maybe [Delayed])
-> StateT State (Reader Env) (Maybe [Delayed])
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Id -> IntMap [Delayed] -> Maybe [Delayed]
forall a. Id -> IntMap a -> Maybe a
IM.lookup Id
i (IntMap [Delayed] -> Maybe [Delayed])
-> (State -> IntMap [Delayed]) -> State -> Maybe [Delayed]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. State -> IntMap [Delayed]
stateUpdateAccs)
      (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = Id -> IntMap [Delayed] -> IntMap [Delayed]
forall a. Id -> IntMap a -> IntMap a
IM.delete Id
i (State -> IntMap [Delayed]
stateUpdateAccs State
st)}

      Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Id
i [Type]
types ((Lambda GPU, [SubExp]) -> Lambda GPU
forall a b. (a, b) -> a
fst ((Lambda GPU, [SubExp]) -> Lambda GPU)
-> Maybe (Lambda GPU, [SubExp]) -> Maybe (Lambda GPU)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (Lambda GPU, [SubExp])
comb) [Delayed]
delayed

      -- Neutral elements must always be made available on host for 'WithAcc'
      -- to type check.
      (SubExp -> Grapher ()) -> [SubExp] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> Grapher ()
connectSubExpToSink ([SubExp] -> Grapher ()) -> [SubExp] -> Grapher ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
-> ((Lambda GPU, [SubExp]) -> [SubExp])
-> Maybe (Lambda GPU, [SubExp])
-> [SubExp]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] (Lambda GPU, [SubExp]) -> [SubExp]
forall a b. (a, b) -> b
snd Maybe (Lambda GPU, [SubExp])
comb
    graph (TypeBase shape u, (a, b, Maybe (Lambda GPU, [SubExp])))
_ =
      String -> Grapher ()
forall a. String -> a
compilerBugS String
"Type error: WithAcc expression did not return accumulator."

-- Graph the operator and all 'UpdateAcc' statements associated with an
-- accumulator.
--
-- The arguments are the 'Id' for the accumulator token, the element types of
-- the accumulator/operator, its combining function if any, and all associated
-- 'UpdateAcc' statements outside kernels.
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc :: Id -> [Type] -> Maybe (Lambda GPU) -> [Delayed] -> Grapher ()
graphAcc Id
i [Type]
_ Maybe (Lambda GPU)
_ [] = Binding -> Grapher ()
addSource (Id
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) -- Only used on device.
graphAcc Id
i [Type]
types Maybe (Lambda GPU)
op [Delayed]
delayed = do
  -- Accumulators are intended for use within SegOps but in principle the AST
  -- allows their 'UpdateAcc's to be used outside a kernel. This case handles
  -- that unlikely situation.

  Env
env <- Grapher Env
ask
  State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get

  -- Collect statistics about the operator statements.
  let lambda :: Lambda GPU
lambda = Lambda GPU -> Maybe (Lambda GPU) -> Lambda GPU
forall a. a -> Maybe a -> a
fromMaybe ([LParam GPU] -> [Type] -> Body GPU -> Lambda GPU
forall rep. [LParam rep] -> [Type] -> Body rep -> Lambda rep
Lambda [] [] (BodyDec GPU -> Stms GPU -> [SubExpRes] -> Body GPU
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body () Stms GPU
forall a. Seq a
SQ.empty [])) Maybe (Lambda GPU)
op
  let m :: Grapher ()
m = Body GPU -> Grapher ()
graphBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
lambda)
  let stats :: BodyStats
stats = Reader Env BodyStats -> Env -> BodyStats
forall r a. Reader r a -> r -> a
R.runReader (Grapher BodyStats -> State -> Reader Env BodyStats
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT (Grapher () -> Grapher BodyStats
forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher ()
m) State
st) Env
env
  -- We treat GPUBody kernels as host-only to not bother rewriting them inside
  -- operators and to simplify the analysis. They are unlikely to occur anyway.
  --
  -- NOTE: Performance may degrade if a GPUBody is replaced with its contents
  --       but the containing operator is used on host.
  let host_only :: Bool
host_only = BodyStats -> Bool
bodyHostOnly BodyStats
stats Bool -> Bool -> Bool
|| BodyStats -> Bool
bodyHasGPUBody BodyStats
stats

  -- op operands are read from arrays and written back so if any of the operands
  -- are scalar then a read can be avoided by moving the UpdateAcc usages to
  -- device. If the op itself performs scalar reads its UpdateAcc usages should
  -- also be moved.
  let does_read :: Bool
does_read = BodyStats -> Bool
bodyReads BodyStats
stats Bool -> Bool -> Bool
|| (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Type -> Bool
forall t. Typed t => t -> Bool
isScalar [Type]
types

  -- Determine which external variables the operator depends upon.
  -- 'bodyOperands' cannot be used as it might exclude operands that were
  -- connected to sinks within the body, so instead we create an artifical
  -- expression to capture graphed operands from.
  Operands
ops <- Exp GPU -> Grapher Operands
graphedScalarOperands ([WithAccInput GPU] -> Lambda GPU -> Exp GPU
forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [] Lambda GPU
lambda)

  case (Bool
host_only, Bool
does_read) of
    (Bool
True, Bool
_) -> do
      -- If the operator cannot run well in a GPUBody then all non-kernel
      -- UpdateAcc statements are host-only. The current analysis is ignorant
      -- of what happens inside kernels so we must assume that the operator
      -- is used within a kernel, meaning that we cannot migrate its statements.
      --
      -- TODO: Improve analysis if UpdateAcc ever is used outside kernels.
      (Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Exp GPU -> Grapher ()
graphHostOnly (Exp GPU -> Grapher ())
-> (Delayed -> Exp GPU) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Exp GPU
forall a b. (a, b) -> b
snd) [Delayed]
delayed
      Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
ops
    (Bool
_, Bool
True) -> do
      -- Migrate all accumulator usage to device to avoid reads and writes.
      (Delayed -> Grapher ()) -> [Delayed] -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Binding -> Grapher ()
graphAutoMove (Binding -> Grapher ())
-> (Delayed -> Binding) -> Delayed -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Delayed -> Binding
forall a b. (a, b) -> a
fst) [Delayed]
delayed
      Binding -> Grapher ()
addSource (Id
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit)
    (Bool, Bool)
_ -> do
      -- Only migrate operator and UpdateAcc statements if it can allow their
      -- operands to be migrated.
      Binding -> Operands -> Grapher ()
createNode (Id
i, PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) Operands
ops
      [Delayed] -> (Delayed -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Delayed]
delayed ((Delayed -> Grapher ()) -> Grapher ())
-> (Delayed -> Grapher ()) -> Grapher ()
forall a b. (a -> b) -> a -> b
$
        \(Binding
b, Exp GPU
e) -> Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e Grapher Operands -> (Operands -> Grapher ()) -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> (a -> StateT State (Reader Env) b)
-> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Binding -> Operands -> Grapher ()
createNode Binding
b (Operands -> Grapher ())
-> (Operands -> Operands) -> Operands -> Grapher ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Operands -> Operands
IS.insert Id
i

-- Returns for an expression all scalar operands that must be made available
-- on host to execute the expression there.
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands :: Exp GPU -> Grapher Operands
graphedScalarOperands Exp GPU
e =
  let is :: Operands
is = (Operands, Set VName) -> Operands
forall a b. (a, b) -> a
fst ((Operands, Set VName) -> Operands)
-> (Operands, Set VName) -> Operands
forall a b. (a -> b) -> a -> b
$ State (Operands, Set VName) ()
-> (Operands, Set VName) -> (Operands, Set VName)
forall s a. State s a -> s -> s
execState (Exp GPU -> State (Operands, Set VName) ()
collect Exp GPU
e) (Operands, Set VName)
forall {a}. (Operands, Set a)
initial
   in Operands -> Operands -> Operands
IS.intersection Operands
is (Operands -> Operands) -> Grapher Operands -> Grapher Operands
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Operands
getGraphedScalars
  where
    initial :: (Operands, Set a)
initial = (Operands
IS.empty, Set a
forall a. Set a
S.empty) -- scalar operands, accumulator tokens
    captureName :: VName -> StateT (p Operands c) m ()
captureName VName
n = (p Operands c -> p Operands c) -> StateT (p Operands c) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p Operands c -> p Operands c) -> StateT (p Operands c) m ())
-> (p Operands c -> p Operands c) -> StateT (p Operands c) m ()
forall a b. (a -> b) -> a -> b
$ (Operands -> Operands) -> p Operands c -> p Operands c
forall a b c. (a -> b) -> p a c -> p b c
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first ((Operands -> Operands) -> p Operands c -> p Operands c)
-> (Operands -> Operands) -> p Operands c -> p Operands c
forall a b. (a -> b) -> a -> b
$ Id -> Operands -> Operands
IS.insert (VName -> Id
nameToId VName
n)
    captureAcc :: a -> StateT (p a (Set a)) m ()
captureAcc a
a = (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ())
-> (p a (Set a) -> p a (Set a)) -> StateT (p a (Set a)) m ()
forall a b. (a -> b) -> a -> b
$ (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall b c a. (b -> c) -> p a b -> p a c
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Set a -> Set a) -> p a (Set a) -> p a (Set a))
-> (Set a -> Set a) -> p a (Set a) -> p a (Set a)
forall a b. (a -> b) -> a -> b
$ a -> Set a -> Set a
forall a. Ord a => a -> Set a -> Set a
S.insert a
a
    collectFree :: a -> StateT (p Operands c) m ()
collectFree a
x = (VName -> StateT (p Operands c) m ())
-> [VName] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ a -> Names
forall a. FreeIn a => a -> Names
freeIn a
x)

    collect :: Exp GPU -> State (Operands, Set VName) ()
collect b :: Exp GPU
b@BasicOp {} =
      Exp GPU -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p Operands c) m ()
collectBasic Exp GPU
b
    collect (Apply Name
_ [(SubExp, Diet)]
params [(RetType GPU, RetAls)]
_ (Safety, SrcLoc, [SrcLoc])
_) =
      ((SubExp, Diet) -> State (Operands, Set VName) ())
-> [(SubExp, Diet)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SubExp -> State (Operands, Set VName) ())
-> ((SubExp, Diet) -> SubExp)
-> (SubExp, Diet)
-> State (Operands, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp, Diet) -> SubExp
forall a b. (a, b) -> a
fst) [(SubExp, Diet)]
params
    collect (Match [SubExp]
ses [Case (Body GPU)]
cases Body GPU
defbody MatchDec (BranchType GPU)
_) = do
      (SubExp -> State (Operands, Set VName) ())
-> [SubExp] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
ses
      (Case (Body GPU) -> State (Operands, Set VName) ())
-> [Case (Body GPU)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Body GPU -> State (Operands, Set VName) ()
collectBody (Body GPU -> State (Operands, Set VName) ())
-> (Case (Body GPU) -> Body GPU)
-> Case (Body GPU)
-> State (Operands, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Case (Body GPU) -> Body GPU
forall body. Case body -> body
caseBody) [Case (Body GPU)]
cases
      Body GPU -> State (Operands, Set VName) ()
collectBody Body GPU
defbody
    collect (Loop [(FParam GPU, SubExp)]
params LoopForm
lform Body GPU
body) = do
      ((FParam GPU, SubExp) -> State (Operands, Set VName) ())
-> [(FParam GPU, SubExp)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SubExp -> State (Operands, Set VName) ())
-> ((FParam GPU, SubExp) -> SubExp)
-> (FParam GPU, SubExp)
-> State (Operands, Set VName) ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FParam GPU, SubExp) -> SubExp
forall a b. (a, b) -> b
snd) [(FParam GPU, SubExp)]
params
      LoopForm -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
LoopForm -> StateT (p Operands c) m ()
collectLForm LoopForm
lform
      Body GPU -> State (Operands, Set VName) ()
collectBody Body GPU
body
    collect (WithAcc [WithAccInput GPU]
accs Lambda GPU
f) =
      [WithAccInput GPU] -> Lambda GPU -> State (Operands, Set VName) ()
collectWithAcc [WithAccInput GPU]
accs Lambda GPU
f
    collect (Op Op GPU
op) =
      HostOp SOAC GPU -> State (Operands, Set VName) ()
forall {op :: * -> *} {rep} {c}.
FreeIn (op rep) =>
HostOp op rep -> StateT (Operands, c) Identity ()
collectHostOp Op GPU
HostOp SOAC GPU
op

    collectBasic :: Exp rep -> StateT (p Operands c) m ()
collectBasic (BasicOp (Update Safety
_ VName
_ Slice SubExp
slice SubExp
_)) =
      -- Writing a scalar to an array can be replaced with copying a single-
      -- element slice. If the scalar originates from device memory its read
      -- can thus be prevented without requiring the 'Update' to be migrated.
      Slice SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree Slice SubExp
slice
    collectBasic (BasicOp (Replicate ShapeBase SubExp
shape SubExp
_)) =
      -- The replicate of a scalar can be rewritten as a replicate of a single
      -- element array followed by a slice index.
      ShapeBase SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree ShapeBase SubExp
shape
    collectBasic Exp rep
e' =
      -- Note: Plain VName values only refer to arrays.
      Walker rep (StateT (p Operands c) m)
-> Exp rep -> StateT (p Operands c) m ()
forall (m :: * -> *) rep.
Monad m =>
Walker rep m -> Exp rep -> m ()
walkExpM (Walker rep (StateT (p Operands c) m)
forall rep (m :: * -> *). Monad m => Walker rep m
identityWalker {walkOnSubExp :: SubExp -> StateT (p Operands c) m ()
walkOnSubExp = SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp}) Exp rep
e'

    collectSubExp :: SubExp -> StateT (p Operands c) m ()
collectSubExp (Var VName
n) = VName -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName VName
n
    collectSubExp SubExp
_ = () -> StateT (p Operands c) m ()
forall a. a -> StateT (p Operands c) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    collectBody :: Body GPU -> State (Operands, Set VName) ()
collectBody Body GPU
body = do
      Stms GPU -> State (Operands, Set VName) ()
collectStms (Body GPU -> Stms GPU
forall rep. Body rep -> Stms rep
bodyStms Body GPU
body)
      [SubExpRes] -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree (Body GPU -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body GPU
body)
    collectStms :: Stms GPU -> State (Operands, Set VName) ()
collectStms = (Stm GPU -> State (Operands, Set VName) ())
-> Stms GPU -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Stm GPU -> State (Operands, Set VName) ()
collectStm

    collectStm :: Stm GPU -> State (Operands, Set VName) ()
collectStm (Let Pat (LetDec GPU)
pat StmAux (ExpDec GPU)
_ Exp GPU
ua)
      | BasicOp UpdateAcc {} <- Exp GPU
ua,
        Pat [PatElem (LetDec GPU)
pe] <- Pat (LetDec GPU)
pat,
        Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_ <- PatElem (LetDec GPU) -> Type
forall t. Typed t => t -> Type
typeOf PatElem (LetDec GPU)
pe =
          -- Capture the tokens of accumulators used on host.
          VName -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {a}.
(Monad m, Bifunctor p, Ord a) =>
a -> StateT (p a (Set a)) m ()
captureAcc VName
a State (Operands, Set VName) ()
-> State (Operands, Set VName) () -> State (Operands, Set VName) ()
forall a b.
StateT (Operands, Set VName) Identity a
-> StateT (Operands, Set VName) Identity b
-> StateT (Operands, Set VName) Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Exp GPU -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
Exp rep -> StateT (p Operands c) m ()
collectBasic Exp GPU
ua
    collectStm Stm GPU
stm = Exp GPU -> State (Operands, Set VName) ()
collect (Stm GPU -> Exp GPU
forall rep. Stm rep -> Exp rep
stmExp Stm GPU
stm)

    collectLForm :: LoopForm -> StateT (p Operands c) m ()
collectLForm (ForLoop VName
_ IntType
_ SubExp
b) = SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp SubExp
b
    -- WhileLoop condition is declared as a loop parameter.
    collectLForm (WhileLoop VName
_) = () -> StateT (p Operands c) m ()
forall a. a -> StateT (p Operands c) m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    -- The collective operands of an operator lambda body are only used on host
    -- if the associated accumulator is used in an UpdateAcc statement outside a
    -- kernel.
    collectWithAcc :: [WithAccInput GPU] -> Lambda GPU -> State (Operands, Set VName) ()
collectWithAcc [WithAccInput GPU]
inputs Lambda GPU
f = do
      Body GPU -> State (Operands, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
f)
      Set VName
used_accs <- ((Operands, Set VName) -> Set VName)
-> StateT (Operands, Set VName) Identity (Set VName)
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets (Operands, Set VName) -> Set VName
forall a b. (a, b) -> b
snd
      let accs :: [Type]
accs = Id -> [Type] -> [Type]
forall a. Id -> [a] -> [a]
take ([WithAccInput GPU] -> Id
forall a. [a] -> Id
forall (t :: * -> *) a. Foldable t => t a -> Id
length [WithAccInput GPU]
inputs) (Lambda GPU -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda GPU
f)
      let used :: [Bool]
used = (Type -> Bool) -> [Type] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\(Acc VName
a ShapeBase SubExp
_ [Type]
_ NoUniqueness
_) -> VName -> Set VName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member VName
a Set VName
used_accs) [Type]
accs
      ((Bool, WithAccInput GPU) -> State (Operands, Set VName) ())
-> [(Bool, WithAccInput GPU)] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Bool, WithAccInput GPU) -> State (Operands, Set VName) ()
collectAcc ([Bool] -> [WithAccInput GPU] -> [(Bool, WithAccInput GPU)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
used [WithAccInput GPU]
inputs)

    collectAcc :: (Bool, WithAccInput GPU) -> State (Operands, Set VName) ()
collectAcc (Bool
_, (ShapeBase SubExp
_, [VName]
_, Maybe (Lambda GPU, [SubExp])
Nothing)) = () -> State (Operands, Set VName) ()
forall a. a -> StateT (Operands, Set VName) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    collectAcc (Bool
used, (ShapeBase SubExp
_, [VName]
_, Just (Lambda GPU
op, [SubExp]
nes))) = do
      (SubExp -> State (Operands, Set VName) ())
-> [SubExp] -> State (Operands, Set VName) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> State (Operands, Set VName) ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes
      Bool
-> State (Operands, Set VName) () -> State (Operands, Set VName) ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
used (State (Operands, Set VName) () -> State (Operands, Set VName) ())
-> State (Operands, Set VName) () -> State (Operands, Set VName) ()
forall a b. (a -> b) -> a -> b
$ Body GPU -> State (Operands, Set VName) ()
collectBody (Lambda GPU -> Body GPU
forall rep. Lambda rep -> Body rep
lambdaBody Lambda GPU
op)

    -- Does not collect named operands in
    --
    --   * types and shapes; size variables are assumed available to the host.
    --
    --   * use by a kernel body.
    --
    -- All other operands are conservatively collected even if they generally
    -- appear to be size variables or results computed by a SizeOp.
    collectHostOp :: HostOp op rep -> StateT (Operands, c) Identity ()
collectHostOp (SegOp (SegMap SegLevel
lvl SegSpace
sp [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
    collectHostOp (SegOp (SegRed SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
      (SegBinOp rep -> StateT (Operands, c) Identity ())
-> [SegBinOp rep] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp [SegBinOp rep]
ops
    collectHostOp (SegOp (SegScan SegLevel
lvl SegSpace
sp [SegBinOp rep]
ops [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
      (SegBinOp rep -> StateT (Operands, c) Identity ())
-> [SegBinOp rep] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SegBinOp rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp [SegBinOp rep]
ops
    collectHostOp (SegOp (SegHist SegLevel
lvl SegSpace
sp [HistOp rep]
ops [Type]
_ KernelBody rep
_)) = do
      SegLevel -> StateT (Operands, c) Identity ()
forall {c}. SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel SegLevel
lvl
      SegSpace -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
sp
      (HistOp rep -> StateT (Operands, c) Identity ())
-> [HistOp rep] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ HistOp rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {rep} {c}.
(Monad m, Bifunctor p) =>
HistOp rep -> StateT (p Operands c) m ()
collectHistOp [HistOp rep]
ops
    collectHostOp (SizeOp SizeOp
op) = SizeOp -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree SizeOp
op
    collectHostOp (OtherOp op rep
op) = op rep -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {a} {c}.
(Monad m, Bifunctor p, FreeIn a) =>
a -> StateT (p Operands c) m ()
collectFree op rep
op
    collectHostOp GPUBody {} = () -> StateT (Operands, c) Identity ()
forall a. a -> StateT (Operands, c) Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

    collectSegLevel :: SegLevel -> StateT (Operands, c) Identity ()
collectSegLevel = (VName -> StateT (Operands, c) Identity ())
-> [VName] -> StateT (Operands, c) Identity ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> StateT (Operands, c) Identity ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
VName -> StateT (p Operands c) m ()
captureName ([VName] -> StateT (Operands, c) Identity ())
-> (SegLevel -> [VName])
-> SegLevel
-> StateT (Operands, c) Identity ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList (Names -> [VName]) -> (SegLevel -> Names) -> SegLevel -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SegLevel -> Names
forall a. FreeIn a => a -> Names
freeIn

    collectSegSpace :: SegSpace -> StateT (p Operands c) m ()
collectSegSpace SegSpace
space =
      (SubExp -> StateT (p Operands c) m ())
-> [SubExp] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp (SegSpace -> [SubExp]
segSpaceDims SegSpace
space)

    collectSegBinOp :: SegBinOp rep -> StateT (p Operands c) m ()
collectSegBinOp (SegBinOp Commutativity
_ Lambda rep
_ [SubExp]
nes ShapeBase SubExp
_) =
      (SubExp -> StateT (p Operands c) m ())
-> [SubExp] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes

    collectHistOp :: HistOp rep -> StateT (p Operands c) m ()
collectHistOp (HistOp ShapeBase SubExp
_ SubExp
rf [VName]
_ [SubExp]
nes ShapeBase SubExp
_ Lambda rep
_) = do
      SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp SubExp
rf
      (SubExp -> StateT (p Operands c) m ())
-> [SubExp] -> StateT (p Operands c) m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ SubExp -> StateT (p Operands c) m ()
forall {m :: * -> *} {p :: * -> * -> *} {c}.
(Monad m, Bifunctor p) =>
SubExp -> StateT (p Operands c) m ()
collectSubExp [SubExp]
nes

--------------------------------------------------------------------------------
--                        GRAPH BUILDING - PRIMITIVES                         --
--------------------------------------------------------------------------------

-- | Creates a vertex for the given binding, provided that the set of operands
-- is not empty.
createNode :: Binding -> Operands -> Grapher ()
createNode :: Binding -> Operands -> Grapher ()
createNode Binding
b Operands
ops =
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Operands -> Bool
IS.null Operands
ops) (Binding -> Grapher ()
addVertex Binding
b Grapher () -> Grapher () -> Grapher ()
forall a b.
StateT State (Reader Env) a
-> StateT State (Reader Env) b -> StateT State (Reader Env) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Edges -> Operands -> Grapher ()
addEdges (Id -> Edges
MG.oneEdge (Id -> Edges) -> Id -> Edges
forall a b. (a -> b) -> a -> b
$ Binding -> Id
forall a b. (a, b) -> a
fst Binding
b) Operands
ops)

-- | Adds a vertex to the graph for the given binding.
addVertex :: Binding -> Grapher ()
addVertex :: Binding -> Grapher ()
addVertex (Id
i, Type
t) = do
  Meta
meta <- Grapher Meta
getMeta
  let v :: Vertex Meta
v = Id -> Meta -> Vertex Meta
forall m. Id -> m -> Vertex m
MG.vertex Id
i Meta
meta
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isScalar Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ (Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Id -> Operands -> Operands
IS.insert Id
i)
  Bool -> Grapher () -> Grapher ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t) (Grapher () -> Grapher ()) -> Grapher () -> Grapher ()
forall a b. (a -> b) -> a -> b
$ Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Meta -> Id
metaBodyDepth Meta
meta)
  (Graph -> Graph) -> Grapher ()
modifyGraph (Vertex Meta -> Graph -> Graph
forall m. Vertex m -> Graph m -> Graph m
MG.insert Vertex Meta
v)

-- | Adds a source connected vertex to the graph for the given binding.
addSource :: Binding -> Grapher ()
addSource :: Binding -> Grapher ()
addSource Binding
b = do
  Binding -> Grapher ()
addVertex Binding
b
  (Sources -> Sources) -> Grapher ()
modifySources ((Sources -> Sources) -> Grapher ())
-> (Sources -> Sources) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ ([Id] -> [Id]) -> Sources -> Sources
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Binding -> Id
forall a b. (a, b) -> a
fst Binding
b :)

-- | Adds the given edges to each vertex identified by the 'IdSet'. It is
-- assumed that all vertices reside within the body that currently is being
-- graphed.
addEdges :: Edges -> IdSet -> Grapher ()
addEdges :: Edges -> Operands -> Grapher ()
addEdges Edges
ToSink Operands
is = do
  (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Id -> Graph) -> Graph -> Operands -> Graph
forall a. (a -> Id -> a) -> a -> Operands -> a
IS.foldl' ((Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip Id -> Graph -> Graph
forall m. Id -> Graph m -> Graph m
MG.connectToSink) Graph
g Operands
is
  (Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Operands -> Operands -> Operands
`IS.difference` Operands
is)
addEdges Edges
es Operands
is = do
  (Graph -> Graph) -> Grapher ()
modifyGraph ((Graph -> Graph) -> Grapher ()) -> (Graph -> Graph) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \Graph
g -> (Graph -> Id -> Graph) -> Graph -> Operands -> Graph
forall a. (a -> Id -> a) -> a -> Operands -> a
IS.foldl' ((Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Id -> Graph -> Graph) -> Graph -> Id -> Graph)
-> (Id -> Graph -> Graph) -> Graph -> Id -> Graph
forall a b. (a -> b) -> a -> b
$ Edges -> Id -> Graph -> Graph
forall m. Edges -> Id -> Graph m -> Graph m
MG.addEdges Edges
es) Graph
g Operands
is
  Operands -> Grapher ()
tellOperands Operands
is

-- | Ensure that a variable (which is in scope) will be made available on host
-- before its first use.
requiredOnHost :: Id -> Grapher ()
requiredOnHost :: Id -> Grapher ()
requiredOnHost Id
i = do
  Maybe (Vertex Meta)
mv <- Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i (Graph -> Maybe (Vertex Meta))
-> Grapher Graph -> StateT State (Reader Env) (Maybe (Vertex Meta))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Graph
getGraph
  case Maybe (Vertex Meta)
mv of
    Maybe (Vertex Meta)
Nothing -> () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just Vertex Meta
v -> do
      Id -> Grapher ()
connectToSink Id
i
      Id -> Grapher ()
tellHostOnlyParent (Meta -> Id
metaBodyDepth (Meta -> Id) -> Meta -> Id
forall a b. (a -> b) -> a -> b
$ Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v)

-- | Connects the vertex of the given id to a sink.
connectToSink :: Id -> Grapher ()
connectToSink :: Id -> Grapher ()
connectToSink Id
i = do
  (Graph -> Graph) -> Grapher ()
modifyGraph (Id -> Graph -> Graph
forall m. Id -> Graph m -> Graph m
MG.connectToSink Id
i)
  (Operands -> Operands) -> Grapher ()
modifyGraphedScalars (Id -> Operands -> Operands
IS.delete Id
i)

-- | Like 'connectToSink' but vertex is given by a t'SubExp'. This is a no-op if
-- the t'SubExp' is a constant.
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink :: SubExp -> Grapher ()
connectSubExpToSink (Var VName
n) = Id -> Grapher ()
connectToSink (VName -> Id
nameToId VName
n)
connectSubExpToSink SubExp
_ = () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Routes all possible routes within the subgraph identified by this id.
-- Returns the ids of the source connected vertices that were attempted routed.
--
-- Assumption: The subgraph with the given id has just been created and no path
-- exists from it to an external sink.
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph :: Id -> Grapher [Id]
routeSubgraph Id
si = do
  State
st <- StateT State (Reader Env) State
forall (m :: * -> *) s. Monad m => StateT s m s
get
  let g :: Graph
g = State -> Graph
stateGraph State
st
  let ([Id]
routed, [Id]
unrouted) = State -> Sources
stateSources State
st
  let ([Id]
gsrcs, [Id]
unrouted') = (Id -> Bool) -> [Id] -> Sources
forall a. (a -> Bool) -> [a] -> ([a], [a])
span (Id -> Graph -> Id -> Bool
inSubGraph Id
si Graph
g) [Id]
unrouted
  let ([Id]
sinks, Graph
g') = [Id] -> Graph -> ([Id], Graph)
forall m. [Id] -> Graph m -> ([Id], Graph m)
MG.routeMany [Id]
gsrcs Graph
g
  State -> Grapher ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put (State -> Grapher ()) -> State -> Grapher ()
forall a b. (a -> b) -> a -> b
$
    State
st
      { stateGraph :: Graph
stateGraph = Graph
g',
        stateSources :: Sources
stateSources = ([Id]
gsrcs [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ [Id]
routed, [Id]
unrouted'),
        stateSinks :: [Id]
stateSinks = [Id]
sinks [Id] -> [Id] -> [Id]
forall a. [a] -> [a] -> [a]
++ State -> [Id]
stateSinks State
st
      }
  [Id] -> Grapher [Id]
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Id]
gsrcs

-- | @inSubGraph si g i@ returns whether @g@ contains a vertex with id @i@ that
-- is declared within the subgraph with id @si@.
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph :: Id -> Graph -> Id -> Bool
inSubGraph Id
si Graph
g Id
i
  | Just Vertex Meta
v <- Id -> Graph -> Maybe (Vertex Meta)
forall m. Id -> Graph m -> Maybe (Vertex m)
MG.lookup Id
i Graph
g,
    Just Id
mgi <- Meta -> Maybe Id
metaGraphId (Vertex Meta -> Meta
forall m. Vertex m -> m
vertexMeta Vertex Meta
v) =
      Id
si Id -> Id -> Bool
forall a. Eq a => a -> a -> Bool
== Id
mgi
inSubGraph Id
_ Graph
_ Id
_ = Bool
False

-- | @b `reuses` n@ records that @b@ binds an array backed by the same memory
-- as @n@. If @b@ is not array typed or the backing memory is not copyable then
-- this does nothing.
reuses :: Binding -> VName -> Grapher ()
reuses :: Binding -> VName -> Grapher ()
reuses (Id
i, Type
t) VName
n
  | Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t =
      do
        Maybe Id
body_depth <- VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n
        Maybe Id -> (Id -> Grapher ()) -> Grapher ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ Maybe Id
body_depth (Id -> Id -> Grapher ()
recordCopyableMemory Id
i)
  | Bool
otherwise =
      () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp :: Binding -> SubExp -> Grapher ()
reusesSubExp Binding
b (Var VName
n) = Binding
b Binding -> VName -> Grapher ()
`reuses` VName
n
reusesSubExp Binding
_ SubExp
_ = () -> Grapher ()
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- @reusesReturn bs res@ records each array binding in @bs@ as reusing copyable
-- memory if the corresponding return value in @res@ is backed by copyable
-- memory.
--
-- If every array binding is registered as being backed by copyable memory then
-- the function returns @True@, otherwise it returns @False@.
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn :: [Binding] -> [SubExp] -> Grapher Bool
reusesReturn [Binding]
bs [SubExp]
res = do
  Id
body_depth <- Meta -> Id
metaBodyDepth (Meta -> Id) -> Grapher Meta -> StateT State (Reader Env) Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
  (Bool -> (Binding, SubExp) -> Grapher Bool)
-> Bool -> [(Binding, SubExp)] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Id -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Id
body_depth) Bool
True ([Binding] -> [SubExp] -> [(Binding, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs [SubExp]
res)
  where
    reuse :: Int -> Bool -> (Binding, SubExp) -> Grapher Bool
    reuse :: Id -> Bool -> (Binding, SubExp) -> Grapher Bool
reuse Id
body_depth Bool
onlyCopyable (Binding
b, SubExp
se)
      | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
          -- Single element arrays are immediately recognizable as copyable so
          -- don't bother recording those. Note that this case also matches
          -- primitive return values.
          Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
      | (Id
i, Type
t) <- Binding
b,
        Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
        Var VName
n <- SubExp
se =
          do
            Maybe Id
res_body_depth <- VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n
            case Maybe Id
res_body_depth of
              Just Id
inner -> do
                Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Id -> Id -> Id
forall a. Ord a => a -> a -> a
min Id
body_depth Id
inner)
                let returns_free_var :: Bool
returns_free_var = Id
inner Id -> Id -> Bool
forall a. Ord a => a -> a -> Bool
<= Id
body_depth
                Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
              Maybe Id
_ ->
                Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
      | Bool
otherwise =
          Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable

-- @reusesBranches bs seses@ records each array binding in @bs@ as
-- reusing copyable memory if each corresponding return value in the
-- lists in @ses@ are backed by copyable memory.  Each list is the
-- result of a branch body (i.e. for 'if' the list has two elements).
--
-- If every array binding is registered as being backed by copyable
-- memory then the function returns @True@, otherwise it returns
-- @False@.
reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches :: [Binding] -> [[SubExp]] -> Grapher Bool
reusesBranches [Binding]
bs [[SubExp]]
seses = do
  Id
body_depth <- Meta -> Id
metaBodyDepth (Meta -> Id) -> Grapher Meta -> StateT State (Reader Env) Id
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Meta
getMeta
  (Bool -> (Binding, [SubExp]) -> Grapher Bool)
-> Bool -> [(Binding, [SubExp])] -> Grapher Bool
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Id -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse Id
body_depth) Bool
True ([(Binding, [SubExp])] -> Grapher Bool)
-> [(Binding, [SubExp])] -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ [Binding] -> [[SubExp]] -> [(Binding, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip [Binding]
bs ([[SubExp]] -> [(Binding, [SubExp])])
-> [[SubExp]] -> [(Binding, [SubExp])]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [[a]] -> [[a]]
L.transpose [[SubExp]]
seses
  where
    reuse :: Int -> Bool -> (Binding, [SubExp]) -> Grapher Bool
    reuse :: Id -> Bool -> (Binding, [SubExp]) -> Grapher Bool
reuse Id
body_depth Bool
onlyCopyable (Binding
b, [SubExp]
ses)
      | (SubExp -> Bool) -> [SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) (Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Type -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Binding -> Type
forall a b. (a, b) -> b
snd Binding
b) =
          -- Single element arrays are immediately recognizable as copyable so
          -- don't bother recording those. Note that this case also matches
          -- primitive return values.
          Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable
      | (Id
i, Type
t) <- Binding
b,
        Type -> Bool
forall t. Typed t => t -> Bool
isArray Type
t,
        Just [VName]
ns <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM SubExp -> Maybe VName
subExpVar [SubExp]
ses = do
          [Maybe Id]
body_depths <- (VName -> Grapher (Maybe Id))
-> [VName] -> StateT State (Reader Env) [Maybe Id]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> Grapher (Maybe Id)
outermostCopyableArray [VName]
ns
          case [Maybe Id] -> Maybe [Id]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence [Maybe Id]
body_depths of
            Just [Id]
bds -> do
              let inner :: Id
inner = [Id] -> Id
forall a. Ord a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
minimum [Id]
bds
              Id -> Id -> Grapher ()
recordCopyableMemory Id
i (Id -> Id -> Id
forall a. Ord a => a -> a -> a
min Id
body_depth Id
inner)
              let returns_free_var :: Bool
returns_free_var = Id
inner Id -> Id -> Bool
forall a. Ord a => a -> a -> Bool
<= Id
body_depth
              Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
onlyCopyable Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
returns_free_var)
            Maybe [Id]
_ ->
              Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
False
      | Bool
otherwise =
          Bool -> Grapher Bool
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
onlyCopyable

--------------------------------------------------------------------------------
--                           GRAPH BUILDING - TYPES                           --
--------------------------------------------------------------------------------

type Grapher = StateT State (R.Reader Env)

data Env = Env
  { -- | See 'HostOnlyFuns'.
    Env -> HostOnlyFuns
envHostOnlyFuns :: HostOnlyFuns,
    -- | Metadata for the current body being graphed.
    Env -> Meta
envMeta :: Meta
  }

-- | A measurement of how many bodies something is nested within.
type BodyDepth = Int

-- | Metadata on the environment that a variable is declared within.
data Meta = Meta
  { -- | How many if statement branch bodies the variable binding is nested
    -- within. If a route passes through the edge u->v and the fork depth
    --
    --   1) increases from u to v, then u is within a conditional branch.
    --
    --   2) decreases from u to v, then v binds the result of two or more
    --      branches.
    --
    -- After the graph has been built and routed, this can be used to delay
    -- reads into deeper branches to reduce their likelihood of manifesting.
    Meta -> Id
metaForkDepth :: Int,
    -- | How many bodies the variable is nested within.
    Meta -> Id
metaBodyDepth :: BodyDepth,
    -- | An id for the subgraph within which the variable exists, defined at
    -- the body level. A read may only be delayed to a point within its own
    -- subgraph.
    Meta -> Maybe Id
metaGraphId :: Maybe Id
  }

-- | Ids for all variables used as an operand.
type Operands = IdSet

-- | Statistics on the statements within a body and their dependencies.
data BodyStats = BodyStats
  { -- | Whether the body contained any host-only statements.
    BodyStats -> Bool
bodyHostOnly :: Bool,
    -- | Whether the body contained any GPUBody kernels.
    BodyStats -> Bool
bodyHasGPUBody :: Bool,
    -- | Whether the body performed any reads.
    BodyStats -> Bool
bodyReads :: Bool,
    -- | All scalar variables represented in the graph that have been used
    -- as return values of the body or as operands within it, including those
    -- that are defined within the body itself. Variables with vertices
    -- connected to sinks may be excluded.
    BodyStats -> Operands
bodyOperands :: Operands,
    -- | Depth of parent bodies with variables that are required on host. Since
    -- the variables are required on host, the parent statements of these bodies
    -- cannot be moved to device as a whole. They are host-only.
    BodyStats -> Operands
bodyHostOnlyParents :: IS.IntSet
  }

instance Semigroup BodyStats where
  (BodyStats Bool
ho1 Bool
gb1 Bool
r1 Operands
o1 Operands
hop1) <> :: BodyStats -> BodyStats -> BodyStats
<> (BodyStats Bool
ho2 Bool
gb2 Bool
r2 Operands
o2 Operands
hop2) =
    BodyStats
      { bodyHostOnly :: Bool
bodyHostOnly = Bool
ho1 Bool -> Bool -> Bool
|| Bool
ho2,
        bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
gb1 Bool -> Bool -> Bool
|| Bool
gb2,
        bodyReads :: Bool
bodyReads = Bool
r1 Bool -> Bool -> Bool
|| Bool
r2,
        bodyOperands :: Operands
bodyOperands = Operands -> Operands -> Operands
IS.union Operands
o1 Operands
o2,
        bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands -> Operands -> Operands
IS.union Operands
hop1 Operands
hop2
      }

instance Monoid BodyStats where
  mempty :: BodyStats
mempty =
    BodyStats
      { bodyHostOnly :: Bool
bodyHostOnly = Bool
False,
        bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
False,
        bodyReads :: Bool
bodyReads = Bool
False,
        bodyOperands :: Operands
bodyOperands = Operands
IS.empty,
        bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
IS.empty
      }

type Graph = MG.Graph Meta

-- | All vertices connected from a source, partitioned into those that have
-- been attempted routed and those which have not.
type Sources = ([Id], [Id])

-- | All terminal vertices of routes.
type Sinks = [Id]

-- | A captured statement for which graphing has been delayed.
type Delayed = (Binding, Exp GPU)

-- | The vertex handle for a variable and its type.
type Binding = (Id, Type)

-- | Array variables backed by memory segments that may be copied, mapped to the
-- outermost known body depths that declares arrays backed by a superset of
-- those segments.
type CopyableMemoryMap = IM.IntMap BodyDepth

data State = State
  { -- | The graph being built.
    State -> Graph
stateGraph :: Graph,
    -- | All known scalars that have been graphed.
    State -> Operands
stateGraphedScalars :: IdSet,
    -- | All variables that directly bind scalars read from device memory.
    State -> Sources
stateSources :: Sources,
    -- | Graphed scalars that are used as operands by statements that cannot be
    -- migrated. A read cannot be delayed beyond these, so if the statements
    -- that bind these variables are moved to device, the variables must be read
    -- from device memory.
    State -> [Id]
stateSinks :: Sinks,
    -- | Observed 'UpdateAcc' host statements to be graphed later.
    State -> IntMap [Delayed]
stateUpdateAccs :: IM.IntMap [Delayed],
    -- | A map of encountered arrays that are backed by copyable memory.
    -- Trivial instances such as single element arrays are excluded.
    State -> CopyableMemoryMap
stateCopyableMemory :: CopyableMemoryMap,
    -- | Information about the current body being graphed.
    State -> BodyStats
stateStats :: BodyStats
  }

--------------------------------------------------------------------------------
--                             GRAPHER OPERATIONS                             --
--------------------------------------------------------------------------------

execGrapher :: HostOnlyFuns -> Grapher a -> (Graph, Sources, Sinks)
execGrapher :: forall a. HostOnlyFuns -> Grapher a -> (Graph, Sources, [Id])
execGrapher HostOnlyFuns
hof Grapher a
m =
  let s :: State
s = Reader Env State -> Env -> State
forall r a. Reader r a -> r -> a
R.runReader (Grapher a -> State -> Reader Env State
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m s
execStateT Grapher a
m State
st) Env
env
   in (State -> Graph
stateGraph State
s, State -> Sources
stateSources State
s, State -> [Id]
stateSinks State
s)
  where
    env :: Env
env =
      Env
        { envHostOnlyFuns :: HostOnlyFuns
envHostOnlyFuns = HostOnlyFuns
hof,
          envMeta :: Meta
envMeta =
            Meta
              { metaForkDepth :: Id
metaForkDepth = Id
0,
                metaBodyDepth :: Id
metaBodyDepth = Id
0,
                metaGraphId :: Maybe Id
metaGraphId = Maybe Id
forall a. Maybe a
Nothing
              }
        }
    st :: State
st =
      State
        { stateGraph :: Graph
stateGraph = Graph
forall m. Graph m
MG.empty,
          stateGraphedScalars :: Operands
stateGraphedScalars = Operands
IS.empty,
          stateSources :: Sources
stateSources = ([], []),
          stateSinks :: [Id]
stateSinks = [],
          stateUpdateAccs :: IntMap [Delayed]
stateUpdateAccs = IntMap [Delayed]
forall a. IntMap a
IM.empty,
          stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap
forall a. IntMap a
IM.empty,
          stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty
        }

-- | Execute a computation in a modified environment.
local :: (Env -> Env) -> Grapher a -> Grapher a
local :: forall a. (Env -> Env) -> Grapher a -> Grapher a
local Env -> Env
f = (Reader Env (a, State) -> Reader Env (a, State))
-> StateT State (Reader Env) a -> StateT State (Reader Env) a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT ((Env -> Env) -> Reader Env (a, State) -> Reader Env (a, State)
forall r (m :: * -> *) a.
(r -> r) -> ReaderT r m a -> ReaderT r m a
R.local Env -> Env
f)

-- | Fetch the value of the environment.
ask :: Grapher Env
ask :: Grapher Env
ask = Reader Env Env -> Grapher Env
forall (m :: * -> *) a. Monad m => m a -> StateT State m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift Reader Env Env
forall (m :: * -> *) r. Monad m => ReaderT r m r
R.ask

-- | Retrieve a function of the current environment.
asks :: (Env -> a) -> Grapher a
asks :: forall a. (Env -> a) -> Grapher a
asks = Reader Env a -> StateT State (Reader Env) a
forall (m :: * -> *) a. Monad m => m a -> StateT State m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (Reader Env a -> StateT State (Reader Env) a)
-> ((Env -> a) -> Reader Env a)
-> (Env -> a)
-> StateT State (Reader Env) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Env -> a) -> Reader Env a
forall (m :: * -> *) r a. Monad m => (r -> a) -> ReaderT r m a
R.asks

-- | Register that the body contains a host-only statement. This means its
-- parent statement and any parent bodies themselves are host-only. A host-only
-- statement should not be migrated, either because it cannot run on device or
-- because it would be inefficient to do so.
tellHostOnly :: Grapher ()
tellHostOnly :: Grapher ()
tellHostOnly =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHostOnly :: Bool
bodyHostOnly = Bool
True}}

-- | Register that the body contains a GPUBody kernel.
tellGPUBody :: Grapher ()
tellGPUBody :: Grapher ()
tellGPUBody =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyHasGPUBody :: Bool
bodyHasGPUBody = Bool
True}}

-- | Register that the current body contains a statement that reads device
-- memory.
tellRead :: Grapher ()
tellRead :: Grapher ()
tellRead =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = (State -> BodyStats
stateStats State
st) {bodyReads :: Bool
bodyReads = Bool
True}}

-- | Register that these variables are used as operands within the current body.
tellOperands :: IdSet -> Grapher ()
tellOperands :: Operands -> Grapher ()
tellOperands Operands
is =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
        operands :: Operands
operands = BodyStats -> Operands
bodyOperands BodyStats
stats
     in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyOperands :: Operands
bodyOperands = Operands
operands Operands -> Operands -> Operands
forall a. Semigroup a => a -> a -> a
<> Operands
is}}

-- | Register that the current statement with a body at the given body depth is
-- host-only.
tellHostOnlyParent :: BodyDepth -> Grapher ()
tellHostOnlyParent :: Id -> Grapher ()
tellHostOnlyParent Id
body_depth =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st ->
    let stats :: BodyStats
stats = State -> BodyStats
stateStats State
st
        parents :: Operands
parents = BodyStats -> Operands
bodyHostOnlyParents BodyStats
stats
        parents' :: Operands
parents' = Id -> Operands -> Operands
IS.insert Id
body_depth Operands
parents
     in State
st {stateStats :: BodyStats
stateStats = BodyStats
stats {bodyHostOnlyParents :: Operands
bodyHostOnlyParents = Operands
parents'}}

-- | Get the graph under construction.
getGraph :: Grapher Graph
getGraph :: Grapher Graph
getGraph = (State -> Graph) -> Grapher Graph
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Graph
stateGraph

-- | All scalar variables with a vertex representation in the graph.
getGraphedScalars :: Grapher IdSet
getGraphedScalars :: Grapher Operands
getGraphedScalars = (State -> Operands) -> Grapher Operands
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> Operands
stateGraphedScalars

-- | Every known array that is backed by a memory segment that may be copied,
-- mapped to the outermost known body depth where an array is backed by a
-- superset of that segment.
--
-- A body where all returned arrays are backed by such memory and are written by
-- its own statements will retain its asymptotic cost if migrated as a whole.
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory :: Grapher CopyableMemoryMap
getCopyableMemory = (State -> CopyableMemoryMap) -> Grapher CopyableMemoryMap
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> CopyableMemoryMap
stateCopyableMemory

-- | The outermost known body depth for an array backed by the same copyable
-- memory as the array with this name.
outermostCopyableArray :: VName -> Grapher (Maybe BodyDepth)
outermostCopyableArray :: VName -> Grapher (Maybe Id)
outermostCopyableArray VName
n = Id -> CopyableMemoryMap -> Maybe Id
forall a. Id -> IntMap a -> Maybe a
IM.lookup (VName -> Id
nameToId VName
n) (CopyableMemoryMap -> Maybe Id)
-> Grapher CopyableMemoryMap -> Grapher (Maybe Id)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher CopyableMemoryMap
getCopyableMemory

-- | Reduces the variables to just the 'Id's of those that are scalars and which
-- have a vertex representation in the graph, excluding those that have been
-- connected to sinks.
onlyGraphedScalars :: (Foldable t) => t VName -> Grapher IdSet
onlyGraphedScalars :: forall (t :: * -> *). Foldable t => t VName -> Grapher Operands
onlyGraphedScalars t VName
vs = do
  let is :: Operands
is = (Operands -> VName -> Operands) -> Operands -> t VName -> Operands
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\Operands
s VName
n -> Id -> Operands -> Operands
IS.insert (VName -> Id
nameToId VName
n) Operands
s) Operands
IS.empty t VName
vs
  Operands -> Operands -> Operands
IS.intersection Operands
is (Operands -> Operands) -> Grapher Operands -> Grapher Operands
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Grapher Operands
getGraphedScalars

-- | Like 'onlyGraphedScalars' but for a single 'VName'.
onlyGraphedScalar :: VName -> Grapher IdSet
onlyGraphedScalar :: VName -> Grapher Operands
onlyGraphedScalar VName
n = do
  let i :: Id
i = VName -> Id
nameToId VName
n
  Operands
gss <- Grapher Operands
getGraphedScalars
  if Id -> Operands -> Bool
IS.member Id
i Operands
gss
    then Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Id -> Operands
IS.singleton Id
i)
    else Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty

-- | Like 'onlyGraphedScalars' but for a single t'SubExp'.
onlyGraphedScalarSubExp :: SubExp -> Grapher IdSet
onlyGraphedScalarSubExp :: SubExp -> Grapher Operands
onlyGraphedScalarSubExp (Constant PrimValue
_) = Operands -> Grapher Operands
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Operands
IS.empty
onlyGraphedScalarSubExp (Var VName
n) = VName -> Grapher Operands
onlyGraphedScalar VName
n

-- | Update the graph under construction.
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph :: (Graph -> Graph) -> Grapher ()
modifyGraph Graph -> Graph
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraph :: Graph
stateGraph = Graph -> Graph
f (State -> Graph
stateGraph State
st)}

-- | Update the contents of the graphed scalar set.
modifyGraphedScalars :: (IdSet -> IdSet) -> Grapher ()
modifyGraphedScalars :: (Operands -> Operands) -> Grapher ()
modifyGraphedScalars Operands -> Operands
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateGraphedScalars :: Operands
stateGraphedScalars = Operands -> Operands
f (State -> Operands
stateGraphedScalars State
st)}

-- | Update the contents of the copyable memory map.
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory :: (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory CopyableMemoryMap -> CopyableMemoryMap
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateCopyableMemory :: CopyableMemoryMap
stateCopyableMemory = CopyableMemoryMap -> CopyableMemoryMap
f (State -> CopyableMemoryMap
stateCopyableMemory State
st)}

-- | Update the set of source connected vertices.
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources :: (Sources -> Sources) -> Grapher ()
modifySources Sources -> Sources
f =
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateSources :: Sources
stateSources = Sources -> Sources
f (State -> Sources
stateSources State
st)}

-- | Record that this variable binds an array that is backed by copyable
-- memory shared by an array at this outermost body depth.
recordCopyableMemory :: Id -> BodyDepth -> Grapher ()
recordCopyableMemory :: Id -> Id -> Grapher ()
recordCopyableMemory Id
i Id
bd =
  (CopyableMemoryMap -> CopyableMemoryMap) -> Grapher ()
modifyCopyableMemory (Id -> Id -> CopyableMemoryMap -> CopyableMemoryMap
forall a. Id -> a -> IntMap a -> IntMap a
IM.insert Id
i Id
bd)

-- | Increment the fork depth for variables graphed by this action.
incForkDepthFor :: Grapher a -> Grapher a
incForkDepthFor :: forall a. Grapher a -> Grapher a
incForkDepthFor =
  (Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
    let meta :: Meta
meta = Env -> Meta
envMeta Env
env
        fork_depth :: Id
fork_depth = Meta -> Id
metaForkDepth Meta
meta
     in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaForkDepth :: Id
metaForkDepth = Id
fork_depth Id -> Id -> Id
forall a. Num a => a -> a -> a
+ Id
1}}

-- | Increment the body depth for variables graphed by this action.
incBodyDepthFor :: Grapher a -> Grapher a
incBodyDepthFor :: forall a. Grapher a -> Grapher a
incBodyDepthFor =
  (Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
    let meta :: Meta
meta = Env -> Meta
envMeta Env
env
        body_depth :: Id
body_depth = Meta -> Id
metaBodyDepth Meta
meta
     in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaBodyDepth :: Id
metaBodyDepth = Id
body_depth Id -> Id -> Id
forall a. Num a => a -> a -> a
+ Id
1}}

-- | Change the graph id for variables graphed by this action.
graphIdFor :: Id -> Grapher a -> Grapher a
graphIdFor :: forall a. Id -> Grapher a -> Grapher a
graphIdFor Id
i =
  (Env -> Env) -> Grapher a -> Grapher a
forall a. (Env -> Env) -> Grapher a -> Grapher a
local ((Env -> Env) -> Grapher a -> Grapher a)
-> (Env -> Env) -> Grapher a -> Grapher a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
    let meta :: Meta
meta = Env -> Meta
envMeta Env
env
     in Env
env {envMeta :: Meta
envMeta = Meta
meta {metaGraphId :: Maybe Id
metaGraphId = Id -> Maybe Id
forall a. a -> Maybe a
Just Id
i}}

-- | Capture body stats produced by the given action.
captureBodyStats :: Grapher a -> Grapher BodyStats
captureBodyStats :: forall a. Grapher a -> Grapher BodyStats
captureBodyStats Grapher a
m = do
  BodyStats
stats <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
forall a. Monoid a => a
mempty}

  a
_ <- Grapher a
m

  BodyStats
stats' <- (State -> BodyStats) -> Grapher BodyStats
forall (m :: * -> *) s a. Monad m => (s -> a) -> StateT s m a
gets State -> BodyStats
stateStats
  (State -> State) -> Grapher ()
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((State -> State) -> Grapher ()) -> (State -> State) -> Grapher ()
forall a b. (a -> b) -> a -> b
$ \State
st -> State
st {stateStats :: BodyStats
stateStats = BodyStats
stats BodyStats -> BodyStats -> BodyStats
forall a. Semigroup a => a -> a -> a
<> BodyStats
stats'}

  BodyStats -> Grapher BodyStats
forall a. a -> StateT State (Reader Env) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure BodyStats
stats'

-- | Can applications of this function be moved to device?
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun :: Name -> Grapher Bool
isHostOnlyFun Name
fn = (Env -> Bool) -> Grapher Bool
forall a. (Env -> a) -> Grapher a
asks ((Env -> Bool) -> Grapher Bool) -> (Env -> Bool) -> Grapher Bool
forall a b. (a -> b) -> a -> b
$ Name -> HostOnlyFuns -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Name
fn (HostOnlyFuns -> Bool) -> (Env -> HostOnlyFuns) -> Env -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> HostOnlyFuns
envHostOnlyFuns

-- | Get the 'Meta' corresponding to the current body.
getMeta :: Grapher Meta
getMeta :: Grapher Meta
getMeta = (Env -> Meta) -> Grapher Meta
forall a. (Env -> a) -> Grapher a
asks Env -> Meta
envMeta

-- | Get the body depth of the current body (its nesting level).
getBodyDepth :: Grapher BodyDepth
getBodyDepth :: StateT State (Reader Env) Id
getBodyDepth = (Env -> Id) -> StateT State (Reader Env) Id
forall a. (Env -> a) -> Grapher a
asks (Meta -> Id
metaBodyDepth (Meta -> Id) -> (Env -> Meta) -> Env -> Id
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env -> Meta
envMeta)