{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}

-- | Perform horizontal and vertical fusion of SOACs.  See the paper
-- /A T2 Graph-Reduction Approach To Fusion/ for the basic idea (some
-- extensions discussed in /Design and GPGPU Performance of Futhark’s
-- Redomap Construct/).
module Futhark.Optimise.Fusion (fuseSOACs) where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import qualified Futhark.IR.Aliases as Aliases
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.IR.SOACS.Simplify
import Futhark.Optimise.Fusion.LoopKernel
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (maxinum)

data VarEntry
  = IsArray VName (NameInfo SOACS) Names SOAC.Input
  | IsNotArray (NameInfo SOACS)

varEntryType :: VarEntry -> NameInfo SOACS
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType (IsArray VName
_ NameInfo SOACS
dec Names
_ Input
_) =
  NameInfo SOACS
dec
varEntryType (IsNotArray NameInfo SOACS
dec) =
  NameInfo SOACS
dec

varEntryAliases :: VarEntry -> Names
varEntryAliases :: VarEntry -> Names
varEntryAliases (IsArray VName
_ NameInfo SOACS
_ Names
x Input
_) = Names
x
varEntryAliases VarEntry
_ = Names
forall a. Monoid a => a
mempty

data FusionGEnv = FusionGEnv
  { -- | Mapping from variable name to its entire family.
    FusionGEnv -> Map VName [VName]
soacs :: M.Map VName [VName],
    FusionGEnv -> Map VName VarEntry
varsInScope :: M.Map VName VarEntry,
    FusionGEnv -> FusedRes
fusedRes :: FusedRes
  }

lookupArr :: VName -> FusionGEnv -> Maybe SOAC.Input
lookupArr :: VName -> FusionGEnv -> Maybe Input
lookupArr VName
v FusionGEnv
env = VarEntry -> Maybe Input
asArray (VarEntry -> Maybe Input) -> Maybe VarEntry -> Maybe Input
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
  where
    asArray :: VarEntry -> Maybe Input
asArray (IsArray VName
_ NameInfo SOACS
_ Names
_ Input
input) = Input -> Maybe Input
forall a. a -> Maybe a
Just Input
input
    asArray IsNotArray {} = Maybe Input
forall a. Maybe a
Nothing

newtype Error = Error String

instance Show Error where
  show :: Error -> String
show (Error String
msg) = String
"Fusion error:\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg

newtype FusionGM a = FusionGM (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a)
  deriving
    ( Applicative FusionGM
Applicative FusionGM
-> (forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a. a -> FusionGM a)
-> Monad FusionGM
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> FusionGM a
$creturn :: forall a. a -> FusionGM a
>> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
$c>> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
>>= :: forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
$c>>= :: forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
Monad,
      Functor FusionGM
Functor FusionGM
-> (forall a. a -> FusionGM a)
-> (forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b c.
    (a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM a)
-> Applicative FusionGM
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. FusionGM a -> FusionGM b -> FusionGM a
$c<* :: forall a b. FusionGM a -> FusionGM b -> FusionGM a
*> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
$c*> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
liftA2 :: forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
<*> :: forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
$c<*> :: forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
pure :: forall a. a -> FusionGM a
$cpure :: forall a. a -> FusionGM a
Applicative,
      (forall a b. (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b. a -> FusionGM b -> FusionGM a) -> Functor FusionGM
forall a b. a -> FusionGM b -> FusionGM a
forall a b. (a -> b) -> FusionGM a -> FusionGM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> FusionGM b -> FusionGM a
$c<$ :: forall a b. a -> FusionGM b -> FusionGM a
fmap :: forall a b. (a -> b) -> FusionGM a -> FusionGM b
$cfmap :: forall a b. (a -> b) -> FusionGM a -> FusionGM b
Functor,
      MonadError Error,
      MonadState VNameSource,
      MonadReader FusionGEnv
    )

instance MonadFreshNames FusionGM where
  getNameSource :: FusionGM VNameSource
getNameSource = FusionGM VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
  putNameSource :: VNameSource -> FusionGM ()
putNameSource = VNameSource -> FusionGM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put

instance HasScope SOACS FusionGM where
  askScope :: FusionGM (Scope SOACS)
askScope = (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS))
-> (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> Scope SOACS
forall {k}. Map k VarEntry -> Map k (NameInfo SOACS)
toScope (Map VName VarEntry -> Scope SOACS)
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
    where
      toScope :: Map k VarEntry -> Map k (NameInfo SOACS)
toScope = (VarEntry -> NameInfo SOACS)
-> Map k VarEntry -> Map k (NameInfo SOACS)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry -> NameInfo SOACS
varEntryType

------------------------------------------------------------------------
--- Monadic Helpers: bind/new/runFusionGatherM, etc
------------------------------------------------------------------------

-- | Binds an array name to the set of used-array vars
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (Ident VName
name Type
t, Names
aliases) =
  FusionGEnv
env {varsInScope :: Map VName VarEntry
varsInScope = VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry
entry (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env}
  where
    entry :: VarEntry
entry = case Type
t of
      Array {} -> VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
name (LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
t) Names
aliases' (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$ Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
      Type
_ -> NameInfo SOACS -> VarEntry
IsNotArray (NameInfo SOACS -> VarEntry) -> NameInfo SOACS -> VarEntry
forall a b. (a -> b) -> a -> b
$ LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
t
    expand :: VName -> Names
expand = Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases (Maybe VarEntry -> Names)
-> (VName -> Maybe VarEntry) -> VName -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> VName -> Maybe VarEntry
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
    aliases' :: Names
aliases' = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expand ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases)

bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars = (FusionGEnv -> (Ident, Names) -> FusionGEnv)
-> FusionGEnv -> [(Ident, Names)] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar

binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding :: forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
vs = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (FusionGEnv -> [(Ident, Names)] -> FusionGEnv
`bindVars` [(Ident, Names)]
vs)

gatherStmPattern :: Pattern -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern :: Pattern -> Exp SOACS -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern Pattern
pat Exp SOACS
e = [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes)
-> [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
idents [Names]
aliases
  where
    idents :: [Ident]
idents = PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat
    aliases :: [Names]
aliases =
      Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames PatternT Type
Pattern
pat)) Names
forall a. Monoid a => a
mempty
        [Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ Exp (Aliases SOACS) -> [Names]
forall lore. Aliased lore => Exp lore -> [Names]
expAliases (AliasTable -> Exp SOACS -> Exp (Aliases SOACS)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Exp lore -> Exp (Aliases lore)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp SOACS
e)

bindingPat :: Pattern -> FusionGM a -> FusionGM a
bindingPat :: forall a. Pattern -> FusionGM a -> FusionGM a
bindingPat = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> (PatternT Type -> [(Ident, Names)])
-> PatternT Type
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> (PatternT Type -> [Ident]) -> PatternT Type -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents

bindingParams :: Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams :: forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> ([Param t] -> [(Ident, Names)])
-> [Param t]
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> ([Param t] -> [Ident]) -> [Param t] -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param t -> Ident) -> [Param t] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param t -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent

-- | Binds an array name to the set of soac-produced vars
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
faml FusionGEnv
env (Ident VName
nm Type
t) =
  FusionGEnv
env
    { soacs :: Map VName [VName]
soacs = VName -> [VName] -> Map VName [VName] -> Map VName [VName]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm [VName]
faml (Map VName [VName] -> Map VName [VName])
-> Map VName [VName] -> Map VName [VName]
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName [VName]
soacs FusionGEnv
env,
      varsInScope :: Map VName VarEntry
varsInScope =
        VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
          VName
nm
          ( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
nm (LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
t) Names
forall a. Monoid a => a
mempty (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
              Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm Type
t
          )
          (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
    }

varAliases :: VName -> FusionGM Names
varAliases :: VName -> FusionGM Names
varAliases VName
v =
  (FusionGEnv -> Names) -> FusionGM Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Names) -> FusionGM Names)
-> (FusionGEnv -> Names) -> FusionGM Names
forall a b. (a -> b) -> a -> b
$
    (VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> (FusionGEnv -> Names) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases
      (Maybe VarEntry -> Names)
-> (FusionGEnv -> Maybe VarEntry) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v
      (Map VName VarEntry -> Maybe VarEntry)
-> (FusionGEnv -> Map VName VarEntry)
-> FusionGEnv
-> Maybe VarEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope

varsAliases :: Names -> FusionGM Names
varsAliases :: Names -> FusionGM Names
varsAliases = ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (FusionGM [Names] -> FusionGM Names)
-> (Names -> FusionGM [Names]) -> Names -> FusionGM Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases ([VName] -> FusionGM [Names])
-> (Names -> [VName]) -> Names -> FusionGM [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
ip_vs, [VName]
other_infuse_vs) = do
  FusedRes
res' <- (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
res ([VName]
ip_vs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
other_infuse_vs)
  Names
aliases <- [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases [VName]
ip_vs
  let inspectKer :: FusedKer -> FusedKer
inspectKer FusedKer
k = FusedKer
k {inplace :: Names
inplace = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedKer -> Names
inplace FusedKer
k}
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res' {kernels :: Map KernName FusedKer
kernels = (FusedKer -> FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b k. (a -> b) -> Map k a -> Map k b
M.map FusedKer -> FusedKer
inspectKer (Map KernName FusedKer -> Map KernName FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
res'}

checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates :: FusedRes -> Exp SOACS -> FusionGM FusedRes
checkForUpdates FusedRes
res (BasicOp (Update VName
src Slice SubExp
is SubExp
_)) = do
  let ifvs :: [VName]
ifvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Names) -> Slice SubExp -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map DimIndex SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
is
  FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName
src], [VName]
ifvs)
checkForUpdates FusedRes
res (Op (Futhark.Scatter SubExp
_ Lambda SOACS
_ [VName]
_ [(Shape, Int, VName)]
written_info)) = do
  let updt_arrs :: [VName]
updt_arrs = ((Shape, Int, VName) -> VName) -> [(Shape, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(Shape
_, Int
_, VName
x) -> VName
x) [(Shape, Int, VName)]
written_info
  FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
updt_arrs, [])
checkForUpdates FusedRes
res Exp SOACS
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res

-- | Updates the environment: (i) the @soacs@ (map) by binding each pattern
--   element identifier to all pattern elements (identifiers) and (ii) the
--   variables in scope (map) by inserting each (pattern-array) name.
--   Finally, if the binding is an in-place update, then the @inplace@ field
--   of each (result) kernel is updated with the new in-place updates.
bindingFamily :: Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily :: Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pattern
pat = (FusionGEnv -> FusionGEnv)
-> FusionGM FusedRes -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local FusionGEnv -> FusionGEnv
bind
  where
    idents :: [Ident]
idents = PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat
    family :: [VName]
family = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
    bind :: FusionGEnv -> FusionGEnv
bind FusionGEnv
env = (FusionGEnv -> Ident -> FusionGEnv)
-> FusionGEnv -> [Ident] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
family) FusionGEnv
env [Ident]
idents

bindingTransform :: PatElem -> VName -> SOAC.ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform :: forall a.
PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElem
pe VName
srcname ArrayTransform
trns = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a)
-> (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall a b. (a -> b) -> a -> b
$ \FusionGEnv
env ->
  case VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
srcname (Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> Maybe VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env of
    Just (IsArray VName
src' NameInfo SOACS
_ Names
aliases Input
input) ->
      FusionGEnv
env
        { varsInScope :: Map VName VarEntry
varsInScope =
            VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
              VName
vname
              ( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
src' (LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
dec) (VName -> Names
oneName VName
srcname Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases) (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
                  ArrayTransform
trns ArrayTransform -> Input -> Input
`SOAC.addTransform` Input
input
              )
              (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
        }
    Maybe VarEntry
_ -> FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (PatElemT Type -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT Type
PatElem
pe, VName -> Names
oneName VName
vname)
  where
    vname :: VName
vname = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
pe
    dec :: Type
dec = PatElemT Type -> Type
forall dec. PatElemT dec -> dec
patElemDec PatElemT Type
PatElem
pe

-- | Binds the fusion result to the environment.
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes :: forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
rrr = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\FusionGEnv
x -> FusionGEnv
x {fusedRes :: FusedRes
fusedRes = FusedRes
rrr})

-- | The fusion transformation runs in this monad.  The mutable
-- state refers to the fresh-names engine.
-- The reader hides the vtable that associates ... to ... (fill in, please).
-- The 'Either' monad is used for error handling.
runFusionGatherM ::
  MonadFreshNames m =>
  FusionGM a ->
  FusionGEnv ->
  m (Either Error a)
runFusionGatherM :: forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM (FusionGM ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) FusionGEnv
env =
  (VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Either Error a, VNameSource))
 -> m (Either Error a))
-> (VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src -> Reader FusionGEnv (Either Error a, VNameSource)
-> FusionGEnv -> (Either Error a, VNameSource)
forall r a. Reader r a -> r -> a
runReader (StateT VNameSource (Reader FusionGEnv) (Either Error a)
-> VNameSource -> Reader FusionGEnv (Either Error a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
-> StateT VNameSource (Reader FusionGEnv) (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) VNameSource
src) FusionGEnv
env

------------------------------------------------------------------------
--- Fusion Entry Points: gather the to-be-fused kernels@pgm level    ---
---    and fuse them in a second pass!                               ---
------------------------------------------------------------------------

-- | The pass definition.
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
  Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
    { passName :: String
passName = String
"Fuse SOACs",
      passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
prog ->
        Prog SOACS -> PassM (Prog SOACS)
simplifySOACS (Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Prog SOACS -> PassM (Prog SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Prog lore -> m (Prog lore)
renameProg
          (Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts
            (Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts ([FunDef SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog SOACS -> [FunDef SOACS]
forall lore. Prog lore -> [FunDef lore]
progFuns Prog SOACS
prog)))
            Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
            Prog SOACS
prog
    }

fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts Names
used_consts Stms SOACS
consts =
  Scope SOACS -> Stms SOACS -> [SubExp] -> PassM (Stms SOACS)
fuseStms Scope SOACS
forall a. Monoid a => a
mempty Stms SOACS
consts ([SubExp] -> PassM (Stms SOACS)) -> [SubExp] -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
used_consts

fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
  Stms SOACS
stms <-
    Scope SOACS -> Stms SOACS -> [SubExp] -> PassM (Stms SOACS)
fuseStms
      (Stms SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SOACS
consts Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope SOACS
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (FunDef SOACS -> [FParam SOACS]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SOACS
fun))
      (BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun)
      (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT SOACS -> [SubExp]) -> BodyT SOACS -> [SubExp]
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun)
  let body :: BodyT SOACS
body = (FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms}
  FunDef SOACS -> PassM (FunDef SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SOACS
fun {funDefBody :: BodyT SOACS
funDefBody = BodyT SOACS
body}

fuseStms :: Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms :: Scope SOACS -> Stms SOACS -> [SubExp] -> PassM (Stms SOACS)
fuseStms Scope SOACS
scope Stms SOACS
stms [SubExp]
res = do
  let env :: FusionGEnv
env =
        FusionGEnv :: Map VName [VName] -> Map VName VarEntry -> FusedRes -> FusionGEnv
FusionGEnv
          { soacs :: Map VName [VName]
soacs = Map VName [VName]
forall k a. Map k a
M.empty,
            varsInScope :: Map VName VarEntry
varsInScope = Map VName VarEntry
forall a. Monoid a => a
mempty,
            fusedRes :: FusedRes
fusedRes = FusedRes
forall a. Monoid a => a
mempty
          }
  FusedRes
k <-
    FusedRes -> FusedRes
cleanFusionResult
      (FusedRes -> FusedRes) -> PassM FusedRes -> PassM FusedRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PassM (Either Error FusedRes) -> PassM FusedRes
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM
        ( FusionGM FusedRes -> FusionGEnv -> PassM (Either Error FusedRes)
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM
            ([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
forall a. Monoid a => a
mempty (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms) [SubExp]
res)
            FusionGEnv
env
        )
  if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FusedRes -> Bool
rsucc FusedRes
k
    then Stms SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
stms
    else PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM (PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS))
-> PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusionGM (Stms SOACS)
-> FusionGEnv -> PassM (Either Error (Stms SOACS))
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM ([(Ident, Names)] -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusedRes -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
k (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms) FusionGEnv
env
  where
    scope' :: [(Ident, Names)]
scope' = ((VName, NameInfo SOACS) -> (Ident, Names))
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, NameInfo SOACS) -> (Ident, Names)
forall {t} {b}. (Typed t, Monoid b) => (VName, t) -> (Ident, b)
toBind ([(VName, NameInfo SOACS)] -> [(Ident, Names)])
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> [(VName, NameInfo SOACS)]
forall k a. Map k a -> [(k, a)]
M.toList Scope SOACS
scope
    toBind :: (VName, t) -> (Ident, b)
toBind (VName
k, t
t) = (VName -> Type -> Ident
Ident VName
k (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
t, b
forall a. Monoid a => a
mempty)

---------------------------------------------------
---------------------------------------------------
---- RESULT's Data Structure
---------------------------------------------------
---------------------------------------------------

-- | A type used for (hopefully) uniquely referring a producer SOAC.
-- The uniquely identifying value is the name of the first array
-- returned from the SOAC.
newtype KernName = KernName {KernName -> VName
unKernName :: VName}
  deriving (KernName -> KernName -> Bool
(KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool) -> Eq KernName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernName -> KernName -> Bool
$c/= :: KernName -> KernName -> Bool
== :: KernName -> KernName -> Bool
$c== :: KernName -> KernName -> Bool
Eq, Eq KernName
Eq KernName
-> (KernName -> KernName -> Ordering)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> KernName)
-> (KernName -> KernName -> KernName)
-> Ord KernName
KernName -> KernName -> Bool
KernName -> KernName -> Ordering
KernName -> KernName -> KernName
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernName -> KernName -> KernName
$cmin :: KernName -> KernName -> KernName
max :: KernName -> KernName -> KernName
$cmax :: KernName -> KernName -> KernName
>= :: KernName -> KernName -> Bool
$c>= :: KernName -> KernName -> Bool
> :: KernName -> KernName -> Bool
$c> :: KernName -> KernName -> Bool
<= :: KernName -> KernName -> Bool
$c<= :: KernName -> KernName -> Bool
< :: KernName -> KernName -> Bool
$c< :: KernName -> KernName -> Bool
compare :: KernName -> KernName -> Ordering
$ccompare :: KernName -> KernName -> Ordering
Ord, Int -> KernName -> ShowS
[KernName] -> ShowS
KernName -> String
(Int -> KernName -> ShowS)
-> (KernName -> String) -> ([KernName] -> ShowS) -> Show KernName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernName] -> ShowS
$cshowList :: [KernName] -> ShowS
show :: KernName -> String
$cshow :: KernName -> String
showsPrec :: Int -> KernName -> ShowS
$cshowsPrec :: Int -> KernName -> ShowS
Show)

data FusedRes = FusedRes
  { -- | Whether we have fused something anywhere.
    FusedRes -> Bool
rsucc :: Bool,
    -- | Associates an array to the name of the
    -- SOAC kernel that has produced it.
    FusedRes -> Map VName KernName
outArr :: M.Map VName KernName,
    -- | Associates an array to the names of the
    -- SOAC kernels that uses it. These sets include
    -- only the SOAC input arrays used as full variables, i.e., no `a[i]'.
    FusedRes -> Map VName (Set KernName)
inpArr :: M.Map VName (S.Set KernName),
    -- | the (names of) arrays that are not fusible, i.e.,
    --
    --   1. they are either used other than input to SOAC kernels, or
    --
    --   2. are used as input to at least two different kernels that
    --      are not located on disjoint control-flow branches, or
    --
    --   3. are used in the lambda expression of SOACs
    FusedRes -> Names
infusible :: Names,
    -- | The map recording the uses
    FusedRes -> Map KernName FusedKer
kernels :: M.Map KernName FusedKer
  }

instance Semigroup FusedRes where
  FusedRes
res1 <> :: FusedRes -> FusedRes -> FusedRes
<> FusedRes
res2 =
    Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      (FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
      (FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
      ((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
      (FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2)
      (FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)

instance Monoid FusedRes where
  mempty :: FusedRes
mempty =
    FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      { rsucc :: Bool
rsucc = Bool
False,
        outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
        inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
        infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
        kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
      }

isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers :: FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
ress Set KernName
kers VName
nm =
  case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress) of
    Maybe (Set KernName)
Nothing -> Bool
False
    Just Set KernName
s -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName -> Bool
forall a. Set a -> Bool
S.null (Set KernName -> Bool) -> Set KernName -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName
s Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set KernName
kers

getKersWithInpArrs :: FusedRes -> [VName] -> S.Set KernName
getKersWithInpArrs :: FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
ress =
  [Set KernName] -> Set KernName
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set KernName] -> Set KernName)
-> ([VName] -> [Set KernName]) -> [VName] -> Set KernName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Maybe (Set KernName)) -> [VName] -> [Set KernName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress)

-- | extend the set of names to include all the names
--     produced via SOACs (by querring the vtable's soac)
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr =
  ([VName] -> VName -> FusionGM [VName])
-> [VName] -> [VName] -> FusionGM [VName]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
    ( \[VName]
y VName
nm -> do
        Maybe [VName]
bnd <- (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName]))
-> (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall a b. (a -> b) -> a -> b
$ VName -> Map VName [VName] -> Maybe [VName]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (Map VName [VName] -> Maybe [VName])
-> (FusionGEnv -> Map VName [VName]) -> FusionGEnv -> Maybe [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName [VName]
soacs
        case Maybe [VName]
bnd of
          Maybe [VName]
Nothing -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
nm])
          Just [VName]
nns -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
nns)
    )
    []

----------------------------------------------------------------------
----------------------------------------------------------------------

soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac = do
  let ([VName]
inp_idds, [VName]
other_idds) = [Input] -> ([VName], [VName])
getIdentArr ([Input] -> ([VName], [VName])) -> [Input] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
      ([VName]
inp_nms0, [VName]
other_nms0) = ([VName]
inp_idds, [VName]
other_idds)
  [VName]
inp_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
inp_nms0
  [VName]
other_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
other_nms0
  ([VName], [VName]) -> FusionGM ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
inp_nms, [VName]
other_nms)

addNewKerWithInfusible :: FusedRes -> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible :: FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res ([Ident]
idd, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs = do
  KernName
nm_ker <- VName -> KernName
KernName (VName -> KernName) -> FusionGM VName -> FusionGM KernName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ker"
  Scope SOACS
scope <- FusionGM (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let out_nms :: [VName]
out_nms = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
idd
      new_ker :: FusedKer
new_ker = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope
      comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union
      os' :: Map VName KernName
os' =
        [(VName, KernName)] -> Map VName KernName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
arr, KernName
nm_ker) | VName
arr <- [VName]
out_nms]
          Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res
      is' :: Map VName (Set KernName)
is' =
        [(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
          [ (VName
arr, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
nm_ker)
            | VName
arr <- (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
          ]
          Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
    Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      (FusedRes -> Bool
rsucc FusedRes
res)
      Map VName KernName
os'
      Map VName (Set KernName)
is'
      Names
ufs
      (KernName
-> FusedKer -> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernName
nm_ker FusedKer
new_ker (FusedRes -> Map KernName FusedKer
kernels FusedRes
res))

lookupInput :: VName -> FusionGM (Maybe SOAC.Input)
lookupInput :: VName -> FusionGM (Maybe Input)
lookupInput VName
name = (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name

inlineSOACInput :: SOAC.Input -> FusionGM SOAC.Input
inlineSOACInput :: Input -> FusionGM Input
inlineSOACInput (SOAC.Input ArrayTransforms
ts VName
v Type
t) = do
  Maybe Input
maybe_inp <- VName -> FusionGM (Maybe Input)
lookupInput VName
v
  case Maybe Input
maybe_inp of
    Maybe Input
Nothing ->
      Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v Type
t
    Just (SOAC.Input ArrayTransforms
ts2 VName
v2 Type
t2) ->
      Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input (ArrayTransforms
ts2 ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts) VName
v2 Type
t2

inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
soac = do
  [Input]
inputs' <- (Input -> FusionGM Input) -> [Input] -> FusionGM [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Input -> FusionGM Input
inlineSOACInput ([Input] -> FusionGM [Input]) -> [Input] -> FusionGM [Input]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
  SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ [Input]
inputs' [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC
soac

-- | Attempts to fuse between SOACs. Input:
--   @rem_bnds@ are the bindings remaining in the current body after @orig_soac@.
--   @lam_used_nms@ the infusible names
--   @res@ the fusion result (before processing the current soac)
--   @orig_soac@ and @out_idds@ the current SOAC and its binding pattern
--   @consumed@ is the set of names consumed by the SOAC.
--   Output: a new Fusion Result (after processing the current SOAC binding)
greedyFuse ::
  [Stm] ->
  Names ->
  FusedRes ->
  (Pattern, StmAux (), SOAC, Names) ->
  FusionGM FusedRes
greedyFuse :: [Stm]
-> Names
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_bnds Names
lam_used_nms FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
orig_soac, Names
consumed) = do
  SOAC
soac <- SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
orig_soac
  ([VName]
inp_nms, [VName]
other_nms) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
  -- Assumption: the free vars in lambda are already in @infusible res@.
  let out_nms :: [VName]
out_nms = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds
      isInfusible :: VName -> Bool
isInfusible = (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res)
      is_screma :: Bool
is_screma = case SOAC
orig_soac of
        SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_ ->
          (Maybe ([Reduce SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form))
            Bool -> Bool -> Bool
&& Bool -> Bool
not (Maybe [Reduce SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe [Scan SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
isScanSOAC ScremaForm SOACS
form))
        SOAC
_ -> Bool
False
  --
  -- Conditions for fusion:
  -- If current soac is a replicate OR (current soac a redomap/scanomap AND
  --    (i) none of @out_idds@ belongs to the infusible set)
  -- THEN try applying producer-consumer fusion
  -- ELSE try applying horizontal        fusion
  -- (without duplicating computation in both cases)

  (Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
fused_nms, [FusedKer]
old_kers, [KernName]
oldker_nms) <-
    if Bool
is_screma Bool -> Bool -> Bool
|| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isInfusible [VName]
out_nms
      then [Stm]
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_bnds FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
      else FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
  --
  -- (ii) check whether fusing @soac@ will violate any in-place update
  --      restriction, e.g., would move an input array past its in-place update.
  let all_used_names :: [VName]
all_used_names = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat [Names
lam_used_nms, [VName] -> Names
namesFromList [VName]
inp_nms, [VName] -> Names
namesFromList [VName]
other_nms]
      has_inplace :: FusedKer -> Bool
has_inplace FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` FusedKer -> Names
inplace FusedKer
ker) [VName]
all_used_names
      ok_inplace :: Bool
ok_inplace = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (FusedKer -> Bool) -> [FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any FusedKer -> Bool
has_inplace [FusedKer]
old_kers
  --
  -- (iii)  there are some kernels that use some of `out_idds' as inputs
  -- (iv)   and producer-consumer or horizontal fusion succeeds with those.
  let fusible_ker :: Bool
fusible_ker = Bool -> Bool
not ([FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FusedKer]
old_kers) Bool -> Bool -> Bool
&& Bool
ok_inplace Bool -> Bool -> Bool
&& Bool
ok_kers_compat
  --
  -- Start constructing the fusion's result:
  --  (i) inparr ids other than vars will be added to infusible list,
  -- (ii) will also become part of the infusible set the inparr vars
  --         that also appear as inparr of another kernel,
  --         BUT which said kernel is not the one we are fusing with (now)!
  let mod_kerS :: Set KernName
mod_kerS = if Bool
fusible_ker then [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList [KernName]
oldker_nms else Set KernName
forall a. Monoid a => a
mempty
  let used_inps :: [VName]
used_inps = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
res Set KernName
mod_kerS) [VName]
inp_nms
  let ufs :: Names
ufs =
        [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
          [ FusedRes -> Names
infusible FusedRes
res,
            [VName] -> Names
namesFromList [VName]
used_inps,
            [VName] -> Names
namesFromList [VName]
other_nms
              Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac)
          ]
  let comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union

  if Bool -> Bool
not Bool
fusible_ker
    then FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res (PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs
    else do
      -- Need to suitably update `inpArr':
      --   (i) first remove the inpArr bindings of the old kernel
      let inpArr' :: Map VName (Set KernName)
inpArr' =
            (Map VName (Set KernName)
 -> (FusedKer, KernName) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(FusedKer, KernName)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \Map VName (Set KernName)
inpa (FusedKer
kold, KernName
knm) ->
                  (Map VName (Set KernName) -> VName -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> Set VName
-> Map VName (Set KernName)
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl'
                    ( \Map VName (Set KernName)
inpp VName
nm ->
                        case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm Map VName (Set KernName)
inpp of
                          Maybe (Set KernName)
Nothing -> Map VName (Set KernName)
inpp
                          Just Set KernName
s ->
                            let new_set :: Set KernName
new_set = KernName -> Set KernName -> Set KernName
forall a. Ord a => a -> Set a -> Set a
S.delete KernName
knm Set KernName
s
                             in if Set KernName -> Bool
forall a. Set a -> Bool
S.null Set KernName
new_set
                                  then VName -> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
nm Map VName (Set KernName)
inpp
                                  else VName
-> Set KernName
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm Set KernName
new_set Map VName (Set KernName)
inpp
                    )
                    Map VName (Set KernName)
inpa
                    (Set VName -> Map VName (Set KernName))
-> Set VName -> Map VName (Set KernName)
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
kold
              )
              (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res)
              ([FusedKer] -> [KernName] -> [(FusedKer, KernName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FusedKer]
old_kers [KernName]
oldker_nms)
      --  (ii) then add the inpArr bindings of the new kernel
      let fused_ker_nms :: [(KernName, FusedKer)]
fused_ker_nms = [KernName] -> [FusedKer] -> [(KernName, FusedKer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [KernName]
fused_nms [FusedKer]
fused_kers
          inpArr'' :: Map VName (Set KernName)
inpArr'' =
            (Map VName (Set KernName)
 -> (KernName, FusedKer) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(KernName, FusedKer)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
              ( \Map VName (Set KernName)
inpa' (KernName
knm, FusedKer
knew) ->
                  [(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                    [ (VName
k, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
knm)
                      | VName
k <- Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
knew
                    ]
                    Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` Map VName (Set KernName)
inpa'
              )
              Map VName (Set KernName)
inpArr'
              [(KernName, FusedKer)]
fused_ker_nms
      -- Update the kernels map (why not delete the ones that have been fused?)
      let kernels' :: Map KernName FusedKer
kernels' = [(KernName, FusedKer)] -> Map KernName FusedKer
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(KernName, FusedKer)]
fused_ker_nms Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res
      -- nothing to do for `outArr' (since we have not added a new kernel)
      -- DO IMPROVEMENT: attempt to fuse the resulting kernel AGAIN until it fails,
      --                 but make sure NOT to add a new kernel!
      FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes Bool
True (FusedRes -> Map VName KernName
outArr FusedRes
res) Map VName (Set KernName)
inpArr'' Names
ufs Map KernName FusedKer
kernels'

prodconsGreedyFuse ::
  FusedRes ->
  (Pattern, StmAux (), SOAC, Names) ->
  FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse :: FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
  let out_nms :: [VName]
out_nms = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds -- Extract VNames from output patterns
      to_fuse_knmSet :: Set KernName
to_fuse_knmSet = FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res [VName]
out_nms -- Find kernels which consume outputs
      to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList Set KernName
to_fuse_knmSet
      lookup_kern :: KernName -> FusionGM FusedKer
lookup_kern KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
        Maybe FusedKer
Nothing ->
          Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
            String -> Error
Error
              ( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
              )
        Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker
  [FusedKer]
to_fuse_kers <- (KernName -> FusionGM FusedKer)
-> [KernName] -> FusionGM [FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernName -> FusionGM FusedKer
lookup_kern [KernName]
to_fuse_knms -- Get all consumer kernels
  -- try producer-consumer fusion
  (Bool
ok_kers_compat, [FusedKer]
fused_kers) <- do
    [Maybe FusedKer]
kers <-
      [FusedKer]
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [FusedKer]
to_fuse_kers ((FusedKer -> FusionGM (Maybe FusedKer))
 -> FusionGM [Maybe FusedKer])
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall a b. (a -> b) -> a -> b
$
        Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
forall a. Monoid a => a
mempty (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds) SOAC
soac Names
consumed
    case [Maybe FusedKer] -> Maybe [FusedKer]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe FusedKer]
kers of
      Maybe [FusedKer]
Nothing -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, [])
      Just [FusedKer]
kers' -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, (FusedKer -> FusedKer) -> [FusedKer] -> [FusedKer]
forall a b. (a -> b) -> [a] -> [b]
map FusedKer -> FusedKer
certifyKer [FusedKer]
kers')
  (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
to_fuse_knms, [FusedKer]
to_fuse_kers, [KernName]
to_fuse_knms)
  where
    certifyKer :: FusedKer -> FusedKer
certifyKer FusedKer
k = FusedKer
k {kerAux :: StmAux ()
kerAux = FusedKer -> StmAux ()
kerAux FusedKer
k StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux}

horizontGreedyFuse ::
  [Stm] ->
  FusedRes ->
  (Pattern, StmAux (), SOAC, Names) ->
  FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse :: [Stm]
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_bnds FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
  ([VName]
inp_nms, [VName]
_) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
  let out_nms :: [VName]
out_nms = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds
      infusible_nms :: Names
infusible_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res) [VName]
out_nms
      out_arr_nms :: [VName]
out_arr_nms = case SOAC
soac of
        -- the accumulator result cannot be fused!
        SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
_) [Input]
_ ->
          Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall lore. [Scan lore] -> Int
scanResults [Scan SOACS]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall lore. [Reduce lore] -> Int
redResults [Reduce SOACS]
reds) [VName]
out_nms
        SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_ -> Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
out_nms
        SOAC
_ -> [VName]
out_nms
      to_fuse_knms1 :: [KernName]
to_fuse_knms1 = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res ([VName]
out_arr_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
inp_nms)
      to_fuse_knms2 :: [KernName]
to_fuse_knms2 = SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize (SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac) FusedRes
res
      to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList ([KernName] -> Set KernName) -> [KernName] -> Set KernName
forall a b. (a -> b) -> a -> b
$ [KernName]
to_fuse_knms1 [KernName] -> [KernName] -> [KernName]
forall a. [a] -> [a] -> [a]
++ [KernName]
to_fuse_knms2
      lookupKernel :: KernName -> FusionGM FusedKer
lookupKernel KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
        Maybe FusedKer
Nothing ->
          Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
            String -> Error
Error
              ( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
              )
        Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker

  -- For each kernel get the index in the bindings where the kernel is
  -- located and sort based on the index so that partial fusion may
  -- succeed.  We use the last position where one of the kernel
  -- outputs occur.
  let bnd_nms :: [[VName]]
bnd_nms = (Stm -> [VName]) -> [Stm] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm -> PatternT Type) -> Stm -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) [Stm]
rem_bnds
  [Maybe (FusedKer, KernName, Int)]
kernminds <- [KernName]
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [KernName]
to_fuse_knms ((KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
 -> FusionGM [Maybe (FusedKer, KernName, Int)])
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ \KernName
ker_nm -> do
    FusedKer
ker <- KernName -> FusionGM FusedKer
lookupKernel KernName
ker_nm
    case (VName -> Maybe Int) -> [VName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\VName
out_nm -> ([VName] -> Bool) -> [[VName]] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem VName
out_nm) [[VName]]
bnd_nms) (FusedKer -> [VName]
outNames FusedKer
ker) of
      [] -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (FusedKer, KernName, Int)
forall a. Maybe a
Nothing
      [Int]
is -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (FusedKer, KernName, Int)
 -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall a b. (a -> b) -> a -> b
$ (FusedKer, KernName, Int) -> Maybe (FusedKer, KernName, Int)
forall a. a -> Maybe a
Just (FusedKer
ker, KernName
ker_nm, [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum [Int]
is)

  Scope SOACS
scope <- FusionGM (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  let kernminds' :: [(FusedKer, KernName, Int)]
kernminds' = ((FusedKer, KernName, Int)
 -> (FusedKer, KernName, Int) -> Ordering)
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy (\(FusedKer
_, KernName
_, Int
i1) (FusedKer
_, KernName
_, Int
i2) -> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
i1 Int
i2) ([(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)])
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ [Maybe (FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (FusedKer, KernName, Int)]
kernminds
      soac_kernel :: FusedKer
soac_kernel = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope

  -- now try to fuse kernels one by one (in a fold); @ok_ind@ is the index of the
  -- kernel until which fusion succeded, and @fused_ker@ is the resulting kernel.
  Scope SOACS
use_scope <- (Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Stm] -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf [Stm]
rem_bnds) (Scope SOACS -> Scope SOACS)
-> FusionGM (Scope SOACS) -> FusionGM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FusionGM (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
  (Bool
_, Int
ok_ind, Int
_, FusedKer
fused_ker, Names
_) <-
    ((Bool, Int, Int, FusedKer, Names)
 -> (FusedKer, KernName, Int)
 -> FusionGM (Bool, Int, Int, FusedKer, Names))
-> (Bool, Int, Int, FusedKer, Names)
-> [(FusedKer, KernName, Int)]
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
      ( \(Bool
cur_ok, Int
n, Int
prev_ind, FusedKer
cur_ker, Names
ufus_nms) (FusedKer
ker, KernName
_ker_nm, Int
bnd_ind) -> do
          -- check that we still try fusion and that the intermediate
          -- bindings do not use the results of cur_ker
          let curker_outnms :: [VName]
curker_outnms = FusedKer -> [VName]
outNames FusedKer
cur_ker
              curker_outset :: Names
curker_outset = [VName] -> Names
namesFromList [VName]
curker_outnms
              new_ufus_nms :: Names
new_ufus_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Names -> [VName]
namesToList Names
ufus_nms
              -- disable horizontal fusion in the case when an output array of
              -- producer SOAC is a non-trivially transformed input of the consumer
              out_transf_ok :: Bool
out_transf_ok =
                let ker_inp :: [Input]
ker_inp = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs (SOAC -> [Input]) -> SOAC -> [Input]
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
                    unfuse1 :: Names
unfuse1 =
                      [VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray [Input]
ker_inp)
                        Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarInput [Input]
ker_inp)
                    unfuse2 :: Names
unfuse2 = Names -> Names -> Names
namesIntersection Names
curker_outset Names
ufus_nms
                 in Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
unfuse1 Names -> Names -> Bool
`namesIntersect` Names
unfuse2
              -- Disable horizontal fusion if consumer has any
              -- output transforms.
              cons_no_out_transf :: Bool
cons_no_out_transf = ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker

          Bool
consumer_ok <- do
            let consumer_bnd :: Stm
consumer_bnd = [Stm]
rem_bnds [Stm] -> Int -> Stm
forall a. [a] -> Int -> a
!! Int
bnd_ind
            Either NotSOAC SOAC
maybesoac <- ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC)
-> Scope SOACS -> FusionGM (Either NotSOAC SOAC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Exp SOACS -> ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC)
forall lore (m :: * -> *).
(Op lore ~ SOAC lore, HasScope lore m) =>
Exp lore -> m (Either NotSOAC (SOAC lore))
SOAC.fromExp (Exp SOACS -> ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC))
-> Exp SOACS
-> ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC)
forall a b. (a -> b) -> a -> b
$ Stm -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm
consumer_bnd) Scope SOACS
use_scope
            case Either NotSOAC SOAC
maybesoac of
              -- check that consumer's lambda body does not use
              -- directly the produced arrays (e.g., see noFusion3.fut).
              Right SOAC
conssoac ->
                Bool -> FusionGM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> FusionGM Bool) -> Bool -> FusionGM Bool
forall a b. (a -> b) -> a -> b
$
                  Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
                    Names
curker_outset
                      Names -> Names -> Bool
`namesIntersect` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda SOACS -> BodyT SOACS) -> Lambda SOACS -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
conssoac)
              Left NotSOAC
_ -> Bool -> FusionGM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

          let interm_bnds_ok :: Bool
interm_bnds_ok =
                Bool
cur_ok Bool -> Bool -> Bool
&& Bool
consumer_ok Bool -> Bool -> Bool
&& Bool
out_transf_ok Bool -> Bool -> Bool
&& Bool
cons_no_out_transf
                  Bool -> Bool -> Bool
&& (Bool -> Stm -> Bool) -> Bool -> [Stm] -> Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
                    ( \Bool
ok Stm
bnd ->
                        Bool
ok
                          Bool -> Bool -> Bool
&& Bool -> Bool
not (Names
curker_outset Names -> Names -> Bool
`namesIntersect` Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (Stm -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm
bnd)) -- hardwired to False after first fail
                          -- (i) check that the in-between bindings do
                          --     not use the result of current kernel OR
                          Bool -> Bool -> Bool
||
                          --(ii) that the pattern-binding corresponds to
                          --     the result of the consumer kernel; in the
                          --     latter case it means it corresponds to a
                          --     kernel that has been fused in the consumer,
                          --     hence it should be ignored
                          Bool -> Bool
not
                            ( [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
                                [VName]
curker_outnms
                                  [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm
bnd)
                            )
                    )
                    Bool
True
                    (Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
drop (Int
prev_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Stm] -> [Stm]) -> [Stm] -> [Stm]
forall a b. (a -> b) -> a -> b
$ Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
take Int
bnd_ind [Stm]
rem_bnds)
          if Bool -> Bool
not Bool
interm_bnds_ok
            then (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
bnd_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
            else do
              Maybe FusedKer
new_ker <-
                Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion
                  Names
ufus_nms
                  (FusedKer -> [VName]
outNames FusedKer
cur_ker)
                  (FusedKer -> SOAC
fsoac FusedKer
cur_ker)
                  (FusedKer -> Names
fusedConsumed FusedKer
cur_ker)
                  FusedKer
ker
              case Maybe FusedKer
new_ker of
                Maybe FusedKer
Nothing -> (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
bnd_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
                Just FusedKer
krn ->
                  let krn' :: FusedKer
krn' = FusedKer
krn {kerAux :: StmAux ()
kerAux = StmAux ()
aux StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> FusedKer -> StmAux ()
kerAux FusedKer
krn}
                   in (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
bnd_ind, FusedKer
krn', Names
new_ufus_nms)
      )
      (Bool
True, Int
0, Int
0, FusedKer
soac_kernel, Names
infusible_nms)
      [(FusedKer, KernName, Int)]
kernminds'

  -- Find the kernels we have fused into and the name of the last such
  -- kernel (if any).
  let ([FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms', [Int]
_) = [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int]))
-> [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b. (a -> b) -> a -> b
$ Int -> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. Int -> [a] -> [a]
take Int
ok_ind [(FusedKer, KernName, Int)]
kernminds'
      new_kernms :: [KernName]
new_kernms = Int -> [KernName] -> [KernName]
forall a. Int -> [a] -> [a]
drop (Int
ok_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [KernName]
to_fuse_knms'

  (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
ok_ind Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0, [FusedKer
fused_ker], [KernName]
new_kernms, [FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms')
  where
    getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
    getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize SubExp
sz FusedRes
ress =
      ((KernName, FusedKer) -> KernName)
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> [a] -> [b]
map (KernName, FusedKer) -> KernName
forall a b. (a, b) -> a
fst ([(KernName, FusedKer)] -> [KernName])
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> a -> b
$ ((KernName, FusedKer) -> Bool)
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(KernName
_, FusedKer
ker) -> SubExp
sz SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width (FusedKer -> SOAC
fsoac FusedKer
ker)) ([(KernName, FusedKer)] -> [(KernName, FusedKer)])
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ Map KernName FusedKer -> [(KernName, FusedKer)]
forall k a. Map k a -> [(k, a)]
M.toList (Map KernName FusedKer -> [(KernName, FusedKer)])
-> Map KernName FusedKer -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
ress

------------------------------------------------------------------------
------------------------------------------------------------------------
------------------------------------------------------------------------
--- Fusion Gather for EXPRESSIONS and BODIES,                        ---
--- i.e., where work is being done:                                  ---
---    i) bottom-up AbSyn traversal (backward analysis)              ---
---   ii) soacs are fused greedily iff does not duplicate computation---
--- E.g., (y1, y2, y3) = mapT(f, x1, x2[i])                          ---
---       (z1, z2)     = mapT(g1, y1, y2)                            ---
---       (q1, q2)     = mapT(g2, y3, z1, a, y3)                     ---
---       res          = reduce(op, ne, q1, q2, z2, y1, y3)          ---
--- can be fused if y1,y2,y3, z1,z2, q1,q2 are not used elsewhere:   ---
---       res = redomap(op, \(x1,x2i,a)->                            ---
---                             let (y1,y2,y3) = f (x1, x2i)       in---
---                             let (z1,z2)    = g1(y1, y2)        in---
---                             let (q1,q2)    = g2(y3, z1, a, y3) in---
---                             (q1, q2, z2, y1, y3)                 ---
---                     x1, x2[i], a)                                ---
------------------------------------------------------------------------
------------------------------------------------------------------------
------------------------------------------------------------------------

fusionGatherBody :: FusedRes -> Body -> FusionGM FusedRes
fusionGatherBody :: FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
fres (Body BodyDec SOACS
_ Stms SOACS
stms [SubExp]
res) =
  FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms) [SubExp]
res

fusionGatherStms :: FusedRes -> [Stm] -> Result -> FusionGM FusedRes
-- Some forms of do-loops can profitably be considered streamSeqs.  We
-- are careful to ensure that the generated nested loop cannot itself
-- be considered a stream, to avoid infinite recursion.
fusionGatherStms :: FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms
  FusedRes
fres
  ( Let
      (Pattern [] [PatElem]
pes)
      StmAux (ExpDec SOACS)
bndtp
      (DoLoop [] [(FParam SOACS, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
w [(LParam SOACS, VName)]
loop_vars) BodyT SOACS
body)
      : [Stm]
bnds
    )
  [SubExp]
res
    | Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Param Type, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars = do
      let ([Param DeclType]
merge_params, [SubExp]
merge_init) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
          ([Param Type]
loop_params, [VName]
loop_arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars
      VName
chunk_size <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_size"
      VName
offset <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"offset"
      let chunk_param :: Param Type
chunk_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
chunk_size (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
          offset_param :: Param Type
offset_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
offset (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it

      [Param Type]
acc_params <- [Param DeclType]
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
merge_params ((Param DeclType -> FusionGM (Param Type))
 -> FusionGM [Param Type])
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \Param DeclType
p ->
        VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param (VName -> Type -> Param Type)
-> FusionGM VName -> FusionGM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_outer")
          FusionGM (Type -> Param Type)
-> FusionGM Type -> FusionGM (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> FusionGM Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p)

      [Param Type]
chunked_params <- [(Param Type, VName)]
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars (((Param Type, VName) -> FusionGM (Param Type))
 -> FusionGM [Param Type])
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
        VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param (VName -> Type -> Param Type)
-> FusionGM VName -> FusionGM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_chunk")
          FusionGM (Type -> Param Type)
-> FusionGM Type -> FusionGM (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> FusionGM Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Futhark.Var VName
chunk_size)

      let lam_params :: [Param Type]
lam_params = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param] [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
chunked_params

      BodyT SOACS
lam_body <- Binder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
        Scope SOACS
-> Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
lam_params) (Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
          let merge' :: [(Param DeclType, SubExp)]
merge' = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
merge_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
acc_params
          VName
j <- String -> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"j"
          BodyT SOACS
loop_body <- Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
            [(Param Type, Param Type)]
-> ((Param Type, Param Type)
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
loop_params [Param Type]
chunked_params) (((Param Type, Param Type) -> BinderT SOACS (State VNameSource) ())
 -> BinderT SOACS (State VNameSource) ())
-> ((Param Type, Param Type)
    -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, Param Type
a_p) ->
              [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
a_p) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
                    Type -> Slice SubExp -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
a_p) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Futhark.Var VName
j]
            [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
j)
            BodyT SOACS -> Binder SOACS (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SOACS
body
          [BinderT
   SOACS
   (State VNameSource)
   (Exp (Lore (BinderT SOACS (State VNameSource))))]
-> BinderT
     SOACS
     (State VNameSource)
     (Body (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody
            [ Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS))
-> Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
                [(FParam SOACS, SubExp)]
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> BodyT SOACS
-> Exp SOACS
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge' (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
j IntType
it (VName -> SubExp
Futhark.Var VName
chunk_size) []) BodyT SOACS
loop_body,
              Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS))
-> Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
chunk_size)
            ]
      let lam :: Lambda SOACS
lam =
            Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
              { lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
lam_params,
                lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
lam_body,
                lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param]
              }
          stream :: SOAC SOACS
stream = SubExp
-> [VName]
-> StreamForm SOACS
-> [SubExp]
-> Lambda SOACS
-> SOAC SOACS
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Futhark.Stream SubExp
w [VName]
loop_arrs StreamForm SOACS
forall lore. StreamForm lore
Sequential ([SubExp]
merge_init [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
it Integer
0]) Lambda SOACS
lam

      -- It is important that the (discarded) final-offset is not the
      -- first element in the pattern, as we use the first element to
      -- identify the SOAC in the second phase of fusion.
      VName
discard <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"discard"
      let discard_pe :: PatElemT Type
discard_pe = VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
discard (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64

      FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms
        FusedRes
fres
        (Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type]
[PatElem]
pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type
discard_pe])) StmAux (ExpDec SOACS)
bndtp (Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op Op SOACS
SOAC SOACS
stream) Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
bnds)
        [SubExp]
res
fusionGatherStms FusedRes
fres (bnd :: Stm
bnd@(Let Pattern
pat StmAux (ExpDec SOACS)
_ Exp SOACS
e) : [Stm]
bnds) [SubExp]
res = do
  Either NotSOAC SOAC
maybesoac <- Exp SOACS -> FusionGM (Either NotSOAC SOAC)
forall lore (m :: * -> *).
(Op lore ~ SOAC lore, HasScope lore m) =>
Exp lore -> m (Either NotSOAC (SOAC lore))
SOAC.fromExp Exp SOACS
e
  case Either NotSOAC SOAC
maybesoac of
    Right soac :: SOAC
soac@(SOAC.Scatter SubExp
_len Lambda SOACS
lam [Input]
_ivs [(Shape, Int, VName)]
_as) -> do
      -- We put the variables produced by Scatter into the infusible
      -- set to force horizontal fusion.  It is not possible to
      -- producer/consumer-fuse Scatter anyway.
      FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
      FusedRes
fres'' <- FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
      FusedRes -> Exp SOACS -> FusionGM FusedRes
checkForUpdates FusedRes
fres'' Exp SOACS
e
    Right soac :: SOAC
soac@(SOAC.Hist SubExp
_ [HistOp SOACS]
_ Lambda SOACS
lam [Input]
_) -> do
      -- We put the variables produced by Hist into the infusible
      -- set to force horizontal fusion.  It is not possible to
      -- producer/consumer-fuse Hist anyway.
      FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
      FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
    Right soac :: SOAC
soac@(SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
_) ->
      SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac ((Scan SOACS -> Lambda SOACS) -> [Scan SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Scan SOACS -> Lambda SOACS
forall lore. Scan lore -> Lambda lore
scanLambda [Scan SOACS]
scans [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> Lambda SOACS) -> [Reduce SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> Lambda SOACS
forall lore. Reduce lore -> Lambda lore
redLambda [Reduce SOACS]
reds [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> [Lambda SOACS
map_lam]) ([SubExp] -> FusionGM FusedRes) -> [SubExp] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
        (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds
    Right soac :: SOAC
soac@(SOAC.Stream SubExp
_ StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
_) -> do
      -- a redomap does not neccessarily start a new kernel, e.g.,
      -- @let a= reduce(+,0,A) in ... bnds ... in let B = map(f,A)@
      -- can be fused into a redomap that replaces the @map@, if @a@
      -- and @B@ are defined in the same scope and @bnds@ does not uses @a@.
      -- a redomap always starts a new kernel
      let lambdas :: [Lambda SOACS]
lambdas = case StreamForm SOACS
form of
            Parallel StreamOrd
_ Commutativity
_ Lambda SOACS
lout -> [Lambda SOACS
lout, Lambda SOACS
lam]
            StreamForm SOACS
Sequential -> [Lambda SOACS
lam]
      SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes
    Either NotSOAC SOAC
_
      | [PatElemT Type
pe] <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern
pat,
        Just (VName
src, ArrayTransform
trns) <- Certificates -> Exp SOACS -> Maybe (VName, ArrayTransform)
forall lore.
Certificates -> Exp lore -> Maybe (VName, ArrayTransform)
SOAC.transformFromExp (Stm -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm
bnd) Exp SOACS
e ->
        PatElem
-> VName
-> ArrayTransform
-> FusionGM FusedRes
-> FusionGM FusedRes
forall a.
PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElemT Type
PatElem
pe VName
src ArrayTransform
trns (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
bnds [SubExp]
res
      | Bool
otherwise -> do
        let pat_vars :: [Exp SOACS]
pat_vars = (VName -> Exp SOACS) -> [VName] -> [Exp SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [Exp SOACS]) -> [VName] -> [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
        FusedRes
bres <- Pattern -> Exp SOACS -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern Pattern
pat Exp SOACS
e (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
bnds [SubExp]
res
        FusedRes
bres' <- FusedRes -> Exp SOACS -> FusionGM FusedRes
checkForUpdates FusedRes
bres Exp SOACS
e
        (FusedRes -> Exp SOACS -> FusionGM FusedRes)
-> FusedRes -> [Exp SOACS] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp SOACS -> FusionGM FusedRes
fusionGatherExp FusedRes
bres' (Exp SOACS
e Exp SOACS -> [Exp SOACS] -> [Exp SOACS]
forall a. a -> [a] -> [a]
: [Exp SOACS]
pat_vars)
  where
    aux :: StmAux (ExpDec SOACS)
aux = Stm -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm
bnd
    rem_bnds :: [Stm]
rem_bnds = Stm
bnd Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
bnds
    consumed :: Names
consumed = Exp (Aliases SOACS) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp (Exp (Aliases SOACS) -> Names) -> Exp (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Exp SOACS -> Exp (Aliases SOACS)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Exp lore -> Exp (Aliases lore)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp SOACS
e

    reduceLike :: SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes = do
      (Names
used_lam, FusedRes
lres) <- ((Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes))
-> (Names, FusedRes)
-> [Lambda SOACS]
-> FusionGM (Names, FusedRes)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) [Lambda SOACS]
lambdas
      FusedRes
bres <- Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pattern
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
lres [Stm]
bnds [SubExp]
res
      FusedRes
bres' <- (FusedRes -> SubExp -> FusionGM FusedRes)
-> FusedRes -> [SubExp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
bres [SubExp]
nes
      Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
      [Stm]
-> Names
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_bnds Names
used_lam FusedRes
bres' (Pattern
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')

    mapLike :: FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lambda = do
      FusedRes
bres <- Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pattern
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres' [Stm]
bnds [SubExp]
res
      (Names
used_lam, FusedRes
blres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
bres) Lambda SOACS
lambda
      Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
      [Stm]
-> Names
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_bnds Names
used_lam FusedRes
blres (Pattern
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')
fusionGatherStms FusedRes
fres [] [SubExp]
res =
  (FusedRes -> Exp SOACS -> FusionGM FusedRes)
-> FusedRes -> [Exp SOACS] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp SOACS -> FusionGM FusedRes
fusionGatherExp FusedRes
fres ([Exp SOACS] -> FusionGM FusedRes)
-> [Exp SOACS] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp SOACS) -> [SubExp] -> [Exp SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
res

fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp :: FusedRes -> Exp SOACS -> FusionGM FusedRes
fusionGatherExp FusedRes
fres (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loop_body) = do
  FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val
  let form_idents :: [Ident]
form_idents =
        case LoopForm SOACS
form of
          ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
            VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
it)) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
: ((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
          WhileLoop {} -> []

  FusedRes
new_res <-
    [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding
      ( [Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Ident]
form_idents [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ ((Param DeclType, SubExp) -> Ident)
-> [(Param DeclType, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param DeclType -> Ident)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) ([(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val)) ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$
          Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty
      )
      (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
loop_body
  -- make the inpArr infusible, so that they
  -- cannot be fused from outside the loop:
  let ([VName]
inp_arrs, [Set KernName]
_) = [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Set KernName)] -> ([VName], [Set KernName]))
-> [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [(VName, Set KernName)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName (Set KernName) -> [(VName, Set KernName)])
-> Map VName (Set KernName) -> [(VName, Set KernName)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
  let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_arrs)}
  -- merge new_res with fres'
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres'
fusionGatherExp FusedRes
fres (If SubExp
cond BodyT SOACS
e_then BodyT SOACS
e_else IfDec (BranchType SOACS)
_) = do
  FusedRes
then_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_then
  FusedRes
else_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_else
  let both_res :: FusedRes
both_res = FusedRes
then_res FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
else_res
  FusedRes
fres' <- FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres SubExp
cond
  FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
fres' FusedRes
both_res
fusionGatherExp FusedRes
fres (WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inps Lambda SOACS
lam) = do
  (Names
_, FusedRes
fres') <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) Lambda SOACS
lam
  FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres' (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inps

-----------------------------------------------------------------------------------
--- Errors: all SOACs, (because normalization ensures they appear
--- directly in let exp, i.e., let x = e)
-----------------------------------------------------------------------------------

fusionGatherExp FusedRes
_ (Op Futhark.Screma {}) = String -> FusionGM FusedRes
errorIllegal String
"screma"
fusionGatherExp FusedRes
_ (Op Futhark.Scatter {}) = String -> FusionGM FusedRes
errorIllegal String
"write"
-----------------------------------
---- Generic Traversal         ----
-----------------------------------

fusionGatherExp FusedRes
fres Exp SOACS
e = FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
e

fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres (Var VName
idd) = FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
idd
fusionGatherSubExp FusedRes
fres SubExp
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres

addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres = (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres ([VName] -> FusionGM FusedRes)
-> (Names -> [VName]) -> Names -> FusionGM FusedRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList

addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
name = do
  Maybe Input
trns <- (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name
  let name' :: VName
name' = case Maybe Input
trns of
        Maybe Input
Nothing -> VName
name
        Just (SOAC.Input ArrayTransforms
_ VName
orig Type
_) -> VName
orig
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres {infusible :: Names
infusible = VName -> Names
oneName VName
name' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
fres}

-- Lambdas create a new scope.  Disallow fusing from outside lambda by
-- adding inp_arrs to the infusible set.
fusionGatherLam :: (Names, FusedRes) -> Lambda -> FusionGM (Names, FusedRes)
fusionGatherLam :: (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
u_set, FusedRes
fres) (Lambda [LParam SOACS]
idds BodyT SOACS
body [Type]
_) = do
  FusedRes
new_res <- [Param Type] -> FusionGM FusedRes -> FusionGM FusedRes
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
idds (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
body
  -- make the inpArr infusible, so that they
  -- cannot be fused from outside the lambda:
  let inp_arrs :: Names
inp_arrs = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
  let unfus :: Names
unfus = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
inp_arrs
  [VName]
bnds <- (FusionGEnv -> [VName]) -> FusionGM [VName]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> [VName]) -> FusionGM [VName])
-> (FusionGEnv -> [VName]) -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName VarEntry -> [VName])
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
  let unfus' :: Names
unfus' = Names
unfus Names -> Names -> Names
`namesIntersection` [VName] -> Names
namesFromList [VName]
bnds
  -- merge fres with new_res'
  let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = Names
unfus'}
  -- merge new_res with fres'
  (Names, FusedRes) -> FusionGM (Names, FusedRes)
forall (m :: * -> *) a. Monad m => a -> m a
return (Names
u_set Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus', FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres)

-------------------------------------------------------------
-------------------------------------------------------------
--- FINALLY, Substitute the kernels in function
-------------------------------------------------------------
-------------------------------------------------------------

fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms
  | Just (Let Pattern
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e, Stms SOACS
stms') <- Stms SOACS -> Maybe (Stm, Stms SOACS)
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms SOACS
stms = do
    Stms SOACS
stms'' <- Pattern -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. Pattern -> FusionGM a -> FusionGM a
bindingPat Pattern
pat (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms'
    Stms SOACS
soac_bnds <- Pattern -> StmAux () -> Exp SOACS -> FusionGM (Stms SOACS)
replaceSOAC Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp SOACS
e
    Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> FusionGM (Stms SOACS))
-> Stms SOACS -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
soac_bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms''
  | Bool
otherwise =
    Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty

fuseInBody :: Body -> FusionGM Body
fuseInBody :: BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody (Body BodyDec SOACS
_ Stms SOACS
stms [SubExp]
res) =
  BodyDec SOACS -> Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms SOACS -> [SubExp] -> BodyT SOACS)
-> FusionGM (Stms SOACS) -> FusionGM ([SubExp] -> BodyT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms FusionGM ([SubExp] -> BodyT SOACS)
-> FusionGM [SubExp] -> FusionGM (BodyT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> FusionGM [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
res

fuseInExp :: Exp -> FusionGM Exp
-- Handle loop specially because we need to bind the types of the
-- merge variables.
fuseInExp :: Exp SOACS -> FusionGM (Exp SOACS)
fuseInExp (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loopbody) =
  [(Ident, Names)] -> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
form_idents ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) (FusionGM (Exp SOACS) -> FusionGM (Exp SOACS))
-> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
    [Param DeclType] -> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val) (FusionGM (Exp SOACS) -> FusionGM (Exp SOACS))
-> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
      [(FParam SOACS, SubExp)]
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> BodyT SOACS
-> Exp SOACS
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form (BodyT SOACS -> Exp SOACS)
-> FusionGM (BodyT SOACS) -> FusionGM (Exp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
loopbody
  where
    form_idents :: [Ident]
form_idents = case LoopForm SOACS
form of
      WhileLoop {} -> []
      ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
        VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
:
        ((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
fuseInExp Exp SOACS
e = Mapper SOACS SOACS FusionGM -> Exp SOACS -> FusionGM (Exp SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS SOACS FusionGM
fuseIn Exp SOACS
e

fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn =
  Mapper SOACS SOACS FusionGM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
    { mapOnBody :: Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
mapOnBody = (BodyT SOACS -> FusionGM (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
forall a b. a -> b -> a
const BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody,
      mapOnOp :: Op SOACS -> FusionGM (Op SOACS)
mapOnOp = SOACMapper SOACS SOACS FusionGM
-> SOAC SOACS -> FusionGM (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any FusionGM
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda}
    }

fuseInLambda :: Lambda -> FusionGM Lambda
fuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rtp) = do
  BodyT SOACS
body' <- [Param Type] -> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
params (FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
body
  Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda SOACS -> FusionGM (Lambda SOACS))
-> Lambda SOACS -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
params BodyT SOACS
body' [Type]
rtp

replaceSOAC :: Pattern -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC :: Pattern -> StmAux () -> Exp SOACS -> FusionGM (Stms SOACS)
replaceSOAC (Pattern [PatElem]
_ []) StmAux ()
_ Exp SOACS
_ = Stms SOACS -> FusionGM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
forall a. Monoid a => a
mempty
replaceSOAC pat :: Pattern
pat@(Pattern [PatElem]
_ (PatElem
patElem : [PatElem]
_)) StmAux ()
aux Exp SOACS
e = do
  FusedRes
fres <- (FusionGEnv -> FusedRes) -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks FusionGEnv -> FusedRes
fusedRes
  let pat_nm :: VName
pat_nm = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
patElem
      names :: [Ident]
names = PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat
  case VName -> Map VName KernName -> Maybe KernName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_nm (FusedRes -> Map VName KernName
outArr FusedRes
fres) of
    Maybe KernName
Nothing ->
      Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm -> Stms SOACS)
-> (Exp SOACS -> Stm) -> Exp SOACS -> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stms SOACS)
-> FusionGM (Exp SOACS) -> FusionGM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp SOACS -> FusionGM (Exp SOACS)
fuseInExp Exp SOACS
e
    Just KernName
knm ->
      case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
knm (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres) of
        Maybe FusedKer
Nothing ->
          Error -> FusionGM (Stms SOACS)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM (Stms SOACS)) -> Error -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
            String -> Error
Error
              ( String
"In Fusion.hs, replaceSOAC, outArr in ker_name "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"which is not in Res: "
                  String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (KernName -> VName
unKernName KernName
knm)
              )
        Just FusedKer
ker -> do
          Bool -> FusionGM () -> FusionGM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
fusedVars FusedKer
ker) (FusionGM () -> FusionGM ()) -> FusionGM () -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
            Error -> FusionGM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM ()) -> Error -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
              String -> Error
Error
                ( String
"In Fusion.hs, replaceSOAC, unfused kernel "
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"still in result: "
                    String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Ident] -> String
forall a. Pretty a => a -> String
pretty [Ident]
names
                )
          StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux (FusedKer -> [VName]
outNames FusedKer
ker) FusedKer
ker

insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux [VName]
names FusedKer
ker = do
  SOAC
new_soac' <- SOAC -> FusionGM SOAC
finaliseSOAC (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
  BinderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS))
-> BinderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
    SOAC SOACS
f_soac <- SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
     SOACS
     (State VNameSource)
     (SOAC (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SOAC (Lore m) -> m (SOAC (Lore m))
SOAC.toSOAC SOAC (Lore (BinderT SOACS (State VNameSource)))
SOAC
new_soac'
    -- The fused kernel may consume more than the original SOACs (see
    -- issue #224).  We insert copy expressions to fix it.
    SOAC SOACS
f_soac' <- Names
-> SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed (FusedKer -> Names
fusedConsumed FusedKer
ker) (SOAC (Aliases SOACS)
 -> BinderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable -> SOAC SOACS -> OpWithAliases (SOAC SOACS)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
forall a. Monoid a => a
mempty SOAC SOACS
f_soac
    [Ident]
validents <- (String -> Type -> BinderT SOACS (State VNameSource) Ident)
-> [String] -> [Type] -> BinderT SOACS (State VNameSource) [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> BinderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
baseString [VName]
names) ([Type] -> BinderT SOACS (State VNameSource) [Ident])
-> [Type] -> BinderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
new_soac'
    StmAux ()
-> BinderT SOACS (State VNameSource) ()
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing (FusedKer -> StmAux ()
kerAux FusedKer
ker StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux) (BinderT SOACS (State VNameSource) ()
 -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS (State VNameSource)))
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident]
validents) (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op Op SOACS
SOAC SOACS
f_soac'
    ArrayTransforms
-> [VName] -> [Ident] -> BinderT SOACS (State VNameSource) ()
transformOutput (FusedKer -> ArrayTransforms
outputTransform FusedKer
ker) [VName]
names [Ident]
validents

-- | Perform simplification and fusion inside the lambda(s) of a SOAC.
finaliseSOAC :: SOAC.SOAC SOACS -> FusionGM (SOAC.SOAC SOACS)
finaliseSOAC :: SOAC -> FusionGM SOAC
finaliseSOAC SOAC
new_soac =
  case SOAC
new_soac of
    SOAC.Screma SubExp
w (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
arrs -> do
      [Scan SOACS]
scans' <- [Scan SOACS]
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS])
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
scan_nes) -> do
        Lambda SOACS
scan_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
scan_lam
        Scan SOACS -> FusionGM (Scan SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scan SOACS -> FusionGM (Scan SOACS))
-> Scan SOACS -> FusionGM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda SOACS
scan_lam' [SubExp]
scan_nes

      [Reduce SOACS]
reds' <- [Reduce SOACS]
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> FusionGM (Reduce SOACS))
 -> FusionGM [Reduce SOACS])
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes) -> do
        Lambda SOACS
red_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
red_lam
        Reduce SOACS -> FusionGM (Reduce SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Reduce SOACS -> FusionGM (Reduce SOACS))
-> Reduce SOACS -> FusionGM (Reduce SOACS)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm Lambda SOACS
red_lam' [SubExp]
red_nes

      Lambda SOACS
map_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
map_lam

      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam') [Input]
arrs
    SOAC.Scatter SubExp
w Lambda SOACS
lam [Input]
inps [(Shape, Int, VName)]
dests -> do
      Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall lore.
SubExp
-> Lambda lore -> [Input] -> [(Shape, Int, VName)] -> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
lam' [Input]
inps [(Shape, Int, VName)]
dests
    SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam [Input]
arrs -> do
      Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam' [Input]
arrs
    SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
inps -> do
      Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
      SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam' [SubExp]
nes [Input]
inps

simplifyAndFuseInLambda :: Lambda -> FusionGM Lambda
simplifyAndFuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam = do
  Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
  (Names
_, FusedRes
nfres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
mkFreshFusionRes) Lambda SOACS
lam'
  let nfres' :: FusedRes
nfres' = FusedRes -> FusedRes
cleanFusionResult FusedRes
nfres
  FusedRes -> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
nfres' (FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS))
-> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda Lambda SOACS
lam'

copyNewlyConsumed ::
  Names ->
  Futhark.SOAC (Aliases.Aliases SOACS) ->
  Binder SOACS (Futhark.SOAC SOACS)
copyNewlyConsumed :: Names
-> SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed Names
was_consumed SOAC (Aliases SOACS)
soac =
  case SOAC (Aliases SOACS)
soac of
    Futhark.Screma SubExp
w [VName]
arrs (Futhark.ScremaForm [Scan (Aliases SOACS)]
scans [Reduce (Aliases SOACS)]
reds Lambda (Aliases SOACS)
map_lam) -> do
      -- Copy any arrays that are consumed now, but were not in the
      -- constituents.
      [VName]
arrs' <- (VName -> BinderT SOACS (State VNameSource) VName)
-> [VName] -> BinderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT SOACS (State VNameSource) VName
copyConsumedArr [VName]
arrs
      -- Any consumed free variables will have to be copied inside the
      -- lambda, and we have to substitute the name of the copy for
      -- the original.
      Lambda SOACS
map_lam' <- Lambda (Aliases (Lore (BinderT SOACS (State VNameSource))))
-> BinderT
     SOACS
     (State VNameSource)
     (Lambda (Lore (BinderT SOACS (State VNameSource))))
forall {m :: * -> *}.
(CanBeAliased (Op (Lore m)), MonadBinder m, Bindable (Lore m)) =>
Lambda (Aliases (Lore m)) -> m (Lambda (Lore m))
copyFreeInLambda Lambda (Aliases (Lore (BinderT SOACS (State VNameSource))))
Lambda (Aliases SOACS)
map_lam

      let scans' :: [Scan SOACS]
scans' =
            (Scan (Aliases SOACS) -> Scan SOACS)
-> [Scan (Aliases SOACS)] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
              ( \Scan (Aliases SOACS)
scan ->
                  Scan (Aliases SOACS)
scan
                    { scanLambda :: Lambda SOACS
scanLambda =
                        Lambda (Aliases SOACS) -> Lambda SOACS
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
Aliases.removeLambdaAliases
                          (Scan (Aliases SOACS) -> Lambda (Aliases SOACS)
forall lore. Scan lore -> Lambda lore
scanLambda Scan (Aliases SOACS)
scan)
                    }
              )
              [Scan (Aliases SOACS)]
scans

      let reds' :: [Reduce SOACS]
reds' =
            (Reduce (Aliases SOACS) -> Reduce SOACS)
-> [Reduce (Aliases SOACS)] -> [Reduce SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
              ( \Reduce (Aliases SOACS)
red ->
                  Reduce (Aliases SOACS)
red
                    { redLambda :: Lambda SOACS
redLambda =
                        Lambda (Aliases SOACS) -> Lambda SOACS
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
Aliases.removeLambdaAliases
                          (Reduce (Aliases SOACS) -> Lambda (Aliases SOACS)
forall lore. Reduce lore -> Lambda lore
redLambda Reduce (Aliases SOACS)
red)
                    }
              )
              [Reduce (Aliases SOACS)]
reds

      SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w [VName]
arrs' (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
Futhark.ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam'
    SOAC (Aliases SOACS)
_ -> SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SOAC SOACS) -> SOAC SOACS
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SOAC SOACS)
SOAC (Aliases SOACS)
soac
  where
    consumed :: Names
consumed = SOAC (Aliases SOACS) -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SOAC (Aliases SOACS)
soac
    newly_consumed :: Names
newly_consumed = Names
consumed Names -> Names -> Names
`namesSubtract` Names
was_consumed

    copyConsumedArr :: VName -> BinderT SOACS (State VNameSource) VName
copyConsumedArr VName
a
      | VName
a VName -> Names -> Bool
`nameIn` Names
newly_consumed =
        String
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT SOACS (State VNameSource) VName)
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
a
      | Bool
otherwise = VName -> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
a

    copyFreeInLambda :: Lambda (Aliases (Lore m)) -> m (Lambda (Lore m))
copyFreeInLambda Lambda (Aliases (Lore m))
lam = do
      let free_consumed :: Names
free_consumed =
            Lambda (Aliases (Lore m)) -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda (Aliases (Lore m))
lam
              Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Param (LParamInfo (Lore m)) -> VName)
-> [Param (LParamInfo (Lore m))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo (Lore m)) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo (Lore m))] -> [VName])
-> [Param (LParamInfo (Lore m))] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases (Lore m)) -> [LParam (Aliases (Lore m))]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Aliases (Lore m))
lam)
      (Seq (Stm (Lore m))
bnds, Map VName VName
subst) <-
        ((Seq (Stm (Lore m)), Map VName VName)
 -> VName -> m (Seq (Stm (Lore m)), Map VName VName))
-> (Seq (Stm (Lore m)), Map VName VName)
-> [VName]
-> m (Seq (Stm (Lore m)), Map VName VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Seq (Stm (Lore m)), Map VName VName)
-> VName -> m (Seq (Stm (Lore m)), Map VName VName)
forall {m :: * -> *}.
MonadBinder m =>
(Stms (Lore m), Map VName VName)
-> VName -> m (Stms (Lore m), Map VName VName)
copyFree (Seq (Stm (Lore m))
forall a. Monoid a => a
mempty, Map VName VName
forall a. Monoid a => a
mempty) ([VName] -> m (Seq (Stm (Lore m)), Map VName VName))
-> [VName] -> m (Seq (Stm (Lore m)), Map VName VName)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free_consumed
      let lam' :: Lambda (Lore m)
lam' = Lambda (Aliases (Lore m)) -> Lambda (Lore m)
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
Aliases.removeLambdaAliases Lambda (Aliases (Lore m))
lam
      Lambda (Lore m) -> m (Lambda (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Lore m) -> m (Lambda (Lore m)))
-> Lambda (Lore m) -> m (Lambda (Lore m))
forall a b. (a -> b) -> a -> b
$
        if Seq (Stm (Lore m)) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm (Lore m))
bnds
          then Lambda (Lore m)
lam'
          else
            Lambda (Lore m)
lam'
              { lambdaBody :: BodyT (Lore m)
lambdaBody =
                  Seq (Stm (Lore m)) -> BodyT (Lore m) -> BodyT (Lore m)
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Seq (Stm (Lore m))
bnds (BodyT (Lore m) -> BodyT (Lore m))
-> BodyT (Lore m) -> BodyT (Lore m)
forall a b. (a -> b) -> a -> b
$
                    Map VName VName -> BodyT (Lore m) -> BodyT (Lore m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (BodyT (Lore m) -> BodyT (Lore m))
-> BodyT (Lore m) -> BodyT (Lore m)
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> BodyT (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam'
              }

    copyFree :: (Stms (Lore m), Map VName VName)
-> VName -> m (Stms (Lore m), Map VName VName)
copyFree (Stms (Lore m)
bnds, Map VName VName
subst) VName
v = do
      VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy"
      Stm (Lore m)
copy <- [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [VName
v_copy] (Exp (Lore m) -> m (Stm (Lore m)))
-> Exp (Lore m) -> m (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
      (Stms (Lore m), Map VName VName)
-> m (Stms (Lore m), Map VName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Lore m) -> Stms (Lore m)
forall lore. Stm lore -> Stms lore
oneStm Stm (Lore m)
copy Stms (Lore m) -> Stms (Lore m) -> Stms (Lore m)
forall a. Semigroup a => a -> a -> a
<> Stms (Lore m)
bnds, VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v_copy Map VName VName
subst)

---------------------------------------------------
---------------------------------------------------
---- HELPERS
---------------------------------------------------
---------------------------------------------------

-- | Get a new fusion result, i.e., for when entering a new scope,
--   e.g., a new lambda or a new loop.
mkFreshFusionRes :: FusedRes
mkFreshFusionRes :: FusedRes
mkFreshFusionRes =
  FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
    { rsucc :: Bool
rsucc = Bool
False,
      outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
      inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
      infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
      kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
    }

mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
res1 FusedRes
res2 = do
  let ufus_mres :: Names
ufus_mres = FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2
  [VName]
inp_both <- [VName] -> FusionGM [VName]
expandSoacInpArr ([VName] -> FusionGM [VName]) -> [VName] -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1 Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.intersection` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2
  let m_unfus :: Names
m_unfus = Names
ufus_mres Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_both)
  FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
    Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
      (FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
      (FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
      ((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
      Names
m_unfus
      (FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)

-- | The expression arguments are supposed to be array-type exps.
--   Returns a tuple, in which the arrays that are vars are in the
--   first element of the tuple, and the one which are indexed or
--   transposes (or otherwise transformed) should be in the second.
--
--   E.g., for expression `mapT(f, a, b[i])', the result should be
--   `([a],[b])'
getIdentArr :: [SOAC.Input] -> ([VName], [VName])
getIdentArr :: [Input] -> ([VName], [VName])
getIdentArr = (([VName], [VName]) -> Input -> ([VName], [VName]))
-> ([VName], [VName]) -> [Input] -> ([VName], [VName])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([], [])
  where
    comb :: ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([VName]
vs, [VName]
os) (SOAC.Input ArrayTransforms
ts VName
idd Type
_)
      | ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = (VName
idd VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs, [VName]
os)
    comb ([VName]
vs, [VName]
os) Input
inp =
      ([VName]
vs, Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
os)

cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult FusedRes
fres =
  let newks :: Map KernName FusedKer
newks = (FusedKer -> Bool)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not (Bool -> Bool) -> (FusedKer -> Bool) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> (FusedKer -> [VName]) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [VName]
fusedVars) (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres)
      newoa :: Map VName KernName
newoa = (KernName -> Bool) -> Map VName KernName -> Map VName KernName
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks) (FusedRes -> Map VName KernName
outArr FusedRes
fres)
      newia :: Map VName (Set KernName)
newia = (Set KernName -> Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((KernName -> Bool) -> Set KernName -> Set KernName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks)) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
fres)
   in FusedRes
fres {outArr :: Map VName KernName
outArr = Map VName KernName
newoa, inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
newia, kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
newks}

--------------
--- Errors ---
--------------

errorIllegal :: String -> FusionGM FusedRes
errorIllegal :: String -> FusionGM FusedRes
errorIllegal String
soac_name =
  Error -> FusionGM FusedRes
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedRes) -> Error -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
    String -> Error
Error
      (String
"In Fusion.hs, soac " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
soac_name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" appears illegally in pgm!")