{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion (fuseSOACs) where
import Control.Monad.Except
import Control.Monad.Reader
import Control.Monad.State
import qualified Data.List as L
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import qualified Futhark.IR.Aliases as Aliases
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.IR.SOACS.Simplify
import Futhark.Optimise.Fusion.LoopKernel
import Futhark.Pass
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (maxinum)
data VarEntry
= IsArray VName (NameInfo SOACS) Names SOAC.Input
| IsNotArray (NameInfo SOACS)
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType :: VarEntry -> NameInfo SOACS
varEntryType (IsArray VName
_ NameInfo SOACS
dec Names
_ Input
_) =
NameInfo SOACS
dec
varEntryType (IsNotArray NameInfo SOACS
dec) =
NameInfo SOACS
dec
varEntryAliases :: VarEntry -> Names
varEntryAliases :: VarEntry -> Names
varEntryAliases (IsArray VName
_ NameInfo SOACS
_ Names
x Input
_) = Names
x
varEntryAliases VarEntry
_ = Names
forall a. Monoid a => a
mempty
data FusionGEnv = FusionGEnv
{
FusionGEnv -> Map VName [VName]
soacs :: M.Map VName [VName],
FusionGEnv -> Map VName VarEntry
varsInScope :: M.Map VName VarEntry,
FusionGEnv -> FusedRes
fusedRes :: FusedRes
}
lookupArr :: VName -> FusionGEnv -> Maybe SOAC.Input
lookupArr :: VName -> FusionGEnv -> Maybe Input
lookupArr VName
v FusionGEnv
env = VarEntry -> Maybe Input
asArray (VarEntry -> Maybe Input) -> Maybe VarEntry -> Maybe Input
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
where
asArray :: VarEntry -> Maybe Input
asArray (IsArray VName
_ NameInfo SOACS
_ Names
_ Input
input) = Input -> Maybe Input
forall a. a -> Maybe a
Just Input
input
asArray IsNotArray {} = Maybe Input
forall a. Maybe a
Nothing
newtype Error = Error String
instance Show Error where
show :: Error -> String
show (Error String
msg) = String
"Fusion error:\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg
newtype FusionGM a = FusionGM (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a)
deriving
( Applicative FusionGM
Applicative FusionGM
-> (forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a. a -> FusionGM a)
-> Monad FusionGM
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> FusionGM a
$creturn :: forall a. a -> FusionGM a
>> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
$c>> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
>>= :: forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
$c>>= :: forall a b. FusionGM a -> (a -> FusionGM b) -> FusionGM b
Monad,
Functor FusionGM
Functor FusionGM
-> (forall a. a -> FusionGM a)
-> (forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM b)
-> (forall a b. FusionGM a -> FusionGM b -> FusionGM a)
-> Applicative FusionGM
forall a. a -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM a
forall a b. FusionGM a -> FusionGM b -> FusionGM b
forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. FusionGM a -> FusionGM b -> FusionGM a
$c<* :: forall a b. FusionGM a -> FusionGM b -> FusionGM a
*> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
$c*> :: forall a b. FusionGM a -> FusionGM b -> FusionGM b
liftA2 :: forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
$cliftA2 :: forall a b c.
(a -> b -> c) -> FusionGM a -> FusionGM b -> FusionGM c
<*> :: forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
$c<*> :: forall a b. FusionGM (a -> b) -> FusionGM a -> FusionGM b
pure :: forall a. a -> FusionGM a
$cpure :: forall a. a -> FusionGM a
Applicative,
(forall a b. (a -> b) -> FusionGM a -> FusionGM b)
-> (forall a b. a -> FusionGM b -> FusionGM a) -> Functor FusionGM
forall a b. a -> FusionGM b -> FusionGM a
forall a b. (a -> b) -> FusionGM a -> FusionGM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> FusionGM b -> FusionGM a
$c<$ :: forall a b. a -> FusionGM b -> FusionGM a
fmap :: forall a b. (a -> b) -> FusionGM a -> FusionGM b
$cfmap :: forall a b. (a -> b) -> FusionGM a -> FusionGM b
Functor,
MonadError Error,
MonadState VNameSource,
MonadReader FusionGEnv
)
instance MonadFreshNames FusionGM where
getNameSource :: FusionGM VNameSource
getNameSource = FusionGM VNameSource
forall s (m :: * -> *). MonadState s m => m s
get
putNameSource :: VNameSource -> FusionGM ()
putNameSource = VNameSource -> FusionGM ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put
instance HasScope SOACS FusionGM where
askScope :: FusionGM (Scope SOACS)
askScope = (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS))
-> (FusionGEnv -> Scope SOACS) -> FusionGM (Scope SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> Scope SOACS
forall {k}. Map k VarEntry -> Map k (NameInfo SOACS)
toScope (Map VName VarEntry -> Scope SOACS)
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> Scope SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
where
toScope :: Map k VarEntry -> Map k (NameInfo SOACS)
toScope = (VarEntry -> NameInfo SOACS)
-> Map k VarEntry -> Map k (NameInfo SOACS)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map VarEntry -> NameInfo SOACS
varEntryType
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar :: FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (Ident VName
name Type
t, Names
aliases) =
FusionGEnv
env {varsInScope :: Map VName VarEntry
varsInScope = VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
name VarEntry
entry (Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env}
where
entry :: VarEntry
entry = case Type
t of
Array {} -> VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
name (LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
t) Names
aliases' (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$ Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
name Type
t
Type
_ -> NameInfo SOACS -> VarEntry
IsNotArray (NameInfo SOACS -> VarEntry) -> NameInfo SOACS -> VarEntry
forall a b. (a -> b) -> a -> b
$ LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
t
expand :: VName -> Names
expand = Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases (Maybe VarEntry -> Names)
-> (VName -> Maybe VarEntry) -> VName -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> VName -> Maybe VarEntry
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env)
aliases' :: Names
aliases' = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
expand ([VName] -> [Names]) -> [VName] -> [Names]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
aliases)
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars :: FusionGEnv -> [(Ident, Names)] -> FusionGEnv
bindVars = (FusionGEnv -> (Ident, Names) -> FusionGEnv)
-> FusionGEnv -> [(Ident, Names)] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar
binding :: [(Ident, Names)] -> FusionGM a -> FusionGM a
binding :: forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
vs = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (FusionGEnv -> [(Ident, Names)] -> FusionGEnv
`bindVars` [(Ident, Names)]
vs)
gatherStmPattern :: Pattern -> Exp -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern :: Pattern -> Exp SOACS -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern Pattern
pat Exp SOACS
e = [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes)
-> [(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
idents [Names]
aliases
where
idents :: [Ident]
idents = PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat
aliases :: [Names]
aliases =
Int -> Names -> [Names]
forall a. Int -> a -> [a]
replicate ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternContextNames PatternT Type
Pattern
pat)) Names
forall a. Monoid a => a
mempty
[Names] -> [Names] -> [Names]
forall a. [a] -> [a] -> [a]
++ Exp (Aliases SOACS) -> [Names]
forall lore. Aliased lore => Exp lore -> [Names]
expAliases (AliasTable -> Exp SOACS -> Exp (Aliases SOACS)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Exp lore -> Exp (Aliases lore)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp SOACS
e)
bindingPat :: Pattern -> FusionGM a -> FusionGM a
bindingPat :: forall a. Pattern -> FusionGM a -> FusionGM a
bindingPat = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> (PatternT Type -> [(Ident, Names)])
-> PatternT Type
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> (PatternT Type -> [Ident]) -> PatternT Type -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents
bindingParams :: Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams :: forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams = [(Ident, Names)] -> FusionGM a -> FusionGM a
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([(Ident, Names)] -> FusionGM a -> FusionGM a)
-> ([Param t] -> [(Ident, Names)])
-> [Param t]
-> FusionGM a
-> FusionGM a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
`zip` Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) ([Ident] -> [(Ident, Names)])
-> ([Param t] -> [Ident]) -> [Param t] -> [(Ident, Names)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param t -> Ident) -> [Param t] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map Param t -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar :: [VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
faml FusionGEnv
env (Ident VName
nm Type
t) =
FusionGEnv
env
{ soacs :: Map VName [VName]
soacs = VName -> [VName] -> Map VName [VName] -> Map VName [VName]
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm [VName]
faml (Map VName [VName] -> Map VName [VName])
-> Map VName [VName] -> Map VName [VName]
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName [VName]
soacs FusionGEnv
env,
varsInScope :: Map VName VarEntry
varsInScope =
VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
VName
nm
( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
nm (LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
t) Names
forall a. Monoid a => a
mempty (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
Ident -> Input
SOAC.identInput (Ident -> Input) -> Ident -> Input
forall a b. (a -> b) -> a -> b
$ VName -> Type -> Ident
Ident VName
nm Type
t
)
(Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
}
varAliases :: VName -> FusionGM Names
varAliases :: VName -> FusionGM Names
varAliases VName
v =
(FusionGEnv -> Names) -> FusionGM Names
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Names) -> FusionGM Names)
-> (FusionGEnv -> Names) -> FusionGM Names
forall a b. (a -> b) -> a -> b
$
(VName -> Names
oneName VName
v Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<>) (Names -> Names) -> (FusionGEnv -> Names) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> (VarEntry -> Names) -> Maybe VarEntry -> Names
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Names
forall a. Monoid a => a
mempty VarEntry -> Names
varEntryAliases
(Maybe VarEntry -> Names)
-> (FusionGEnv -> Maybe VarEntry) -> FusionGEnv -> Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v
(Map VName VarEntry -> Maybe VarEntry)
-> (FusionGEnv -> Map VName VarEntry)
-> FusionGEnv
-> Maybe VarEntry
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
varsAliases :: Names -> FusionGM Names
varsAliases :: Names -> FusionGM Names
varsAliases = ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat (FusionGM [Names] -> FusionGM Names)
-> (Names -> FusionGM [Names]) -> Names -> FusionGM Names
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases ([VName] -> FusionGM [Names])
-> (Names -> [VName]) -> Names -> FusionGM [Names]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces :: FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
ip_vs, [VName]
other_infuse_vs) = do
FusedRes
res' <- (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
res ([VName]
ip_vs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
other_infuse_vs)
Names
aliases <- [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> FusionGM [Names] -> FusionGM Names
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> FusionGM Names) -> [VName] -> FusionGM [Names]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> FusionGM Names
varAliases [VName]
ip_vs
let inspectKer :: FusedKer -> FusedKer
inspectKer FusedKer
k = FusedKer
k {inplace :: Names
inplace = Names
aliases Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedKer -> Names
inplace FusedKer
k}
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res' {kernels :: Map KernName FusedKer
kernels = (FusedKer -> FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b k. (a -> b) -> Map k a -> Map k b
M.map FusedKer -> FusedKer
inspectKer (Map KernName FusedKer -> Map KernName FusedKer)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
res'}
checkForUpdates :: FusedRes -> Exp -> FusionGM FusedRes
checkForUpdates :: FusedRes -> Exp SOACS -> FusionGM FusedRes
checkForUpdates FusedRes
res (BasicOp (Update VName
src Slice SubExp
is SubExp
_)) = do
let ifvs :: [VName]
ifvs = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ([Names] -> Names) -> [Names] -> Names
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Names) -> Slice SubExp -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map DimIndex SubExp -> Names
forall a. FreeIn a => a -> Names
freeIn Slice SubExp
is
FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName
src], [VName]
ifvs)
checkForUpdates FusedRes
res (Op (Futhark.Scatter SubExp
_ Lambda SOACS
_ [VName]
_ [(ShapeBase SubExp, Int, VName)]
written_info)) = do
let updt_arrs :: [VName]
updt_arrs = ((ShapeBase SubExp, Int, VName) -> VName)
-> [(ShapeBase SubExp, Int, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (\(ShapeBase SubExp
_, Int
_, VName
x) -> VName
x) [(ShapeBase SubExp, Int, VName)]
written_info
FusedRes -> ([VName], [VName]) -> FusionGM FusedRes
updateKerInPlaces FusedRes
res ([VName]
updt_arrs, [])
checkForUpdates FusedRes
res Exp SOACS
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
res
bindingFamily :: Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily :: Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pattern
pat = (FusionGEnv -> FusionGEnv)
-> FusionGM FusedRes -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local FusionGEnv -> FusionGEnv
bind
where
idents :: [Ident]
idents = PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat
family :: [VName]
family = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
bind :: FusionGEnv -> FusionGEnv
bind FusionGEnv
env = (FusionGEnv -> Ident -> FusionGEnv)
-> FusionGEnv -> [Ident] -> FusionGEnv
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName] -> FusionGEnv -> Ident -> FusionGEnv
bindingFamilyVar [VName]
family) FusionGEnv
env [Ident]
idents
bindingTransform :: PatElem -> VName -> SOAC.ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform :: forall a.
PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElem
pe VName
srcname ArrayTransform
trns = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a)
-> (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall a b. (a -> b) -> a -> b
$ \FusionGEnv
env ->
case VName -> Map VName VarEntry -> Maybe VarEntry
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
srcname (Map VName VarEntry -> Maybe VarEntry)
-> Map VName VarEntry -> Maybe VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env of
Just (IsArray VName
src' NameInfo SOACS
_ Names
aliases Input
input) ->
FusionGEnv
env
{ varsInScope :: Map VName VarEntry
varsInScope =
VName -> VarEntry -> Map VName VarEntry -> Map VName VarEntry
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert
VName
vname
( VName -> NameInfo SOACS -> Names -> Input -> VarEntry
IsArray VName
src' (LetDec SOACS -> NameInfo SOACS
forall lore. LetDec lore -> NameInfo lore
LetName Type
LetDec SOACS
dec) (VName -> Names
oneName VName
srcname Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
aliases) (Input -> VarEntry) -> Input -> VarEntry
forall a b. (a -> b) -> a -> b
$
ArrayTransform
trns ArrayTransform -> Input -> Input
`SOAC.addTransform` Input
input
)
(Map VName VarEntry -> Map VName VarEntry)
-> Map VName VarEntry -> Map VName VarEntry
forall a b. (a -> b) -> a -> b
$ FusionGEnv -> Map VName VarEntry
varsInScope FusionGEnv
env
}
Maybe VarEntry
_ -> FusionGEnv -> (Ident, Names) -> FusionGEnv
bindVar FusionGEnv
env (PatElemT Type -> Ident
forall dec. Typed dec => PatElemT dec -> Ident
patElemIdent PatElemT Type
PatElem
pe, VName -> Names
oneName VName
vname)
where
vname :: VName
vname = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
pe
dec :: Type
dec = PatElemT Type -> Type
forall dec. PatElemT dec -> dec
patElemDec PatElemT Type
PatElem
pe
bindRes :: FusedRes -> FusionGM a -> FusionGM a
bindRes :: forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
rrr = (FusionGEnv -> FusionGEnv) -> FusionGM a -> FusionGM a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (\FusionGEnv
x -> FusionGEnv
x {fusedRes :: FusedRes
fusedRes = FusedRes
rrr})
runFusionGatherM ::
MonadFreshNames m =>
FusionGM a ->
FusionGEnv ->
m (Either Error a)
runFusionGatherM :: forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM (FusionGM ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) FusionGEnv
env =
(VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a))
-> (VNameSource -> (Either Error a, VNameSource))
-> m (Either Error a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src -> Reader FusionGEnv (Either Error a, VNameSource)
-> FusionGEnv -> (Either Error a, VNameSource)
forall r a. Reader r a -> r -> a
runReader (StateT VNameSource (Reader FusionGEnv) (Either Error a)
-> VNameSource -> Reader FusionGEnv (Either Error a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
-> StateT VNameSource (Reader FusionGEnv) (Either Error a)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT ExceptT Error (StateT VNameSource (Reader FusionGEnv)) a
a) VNameSource
src) FusionGEnv
env
fuseSOACs :: Pass SOACS SOACS
fuseSOACs :: Pass SOACS SOACS
fuseSOACs =
Pass :: forall fromlore tolore.
String
-> String
-> (Prog fromlore -> PassM (Prog tolore))
-> Pass fromlore tolore
Pass
{ passName :: String
passName = String
"Fuse SOACs",
passDescription :: String
passDescription = String
"Perform higher-order optimisation, i.e., fusion.",
passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction = \Prog SOACS
prog ->
Prog SOACS -> PassM (Prog SOACS)
simplifySOACS (Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Prog SOACS -> PassM (Prog SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Prog lore -> m (Prog lore)
renameProg
(Prog SOACS -> PassM (Prog SOACS))
-> PassM (Prog SOACS) -> PassM (Prog SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromlore tolore.
(Stms fromlore -> PassM (Stms tolore))
-> (Stms tolore -> FunDef fromlore -> PassM (FunDef tolore))
-> Prog fromlore
-> PassM (Prog tolore)
intraproceduralTransformationWithConsts
(Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts ([FunDef SOACS] -> Names
forall a. FreeIn a => a -> Names
freeIn (Prog SOACS -> [FunDef SOACS]
forall lore. Prog lore -> [FunDef lore]
progFuns Prog SOACS
prog)))
Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun
Prog SOACS
prog
}
fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts :: Names -> Stms SOACS -> PassM (Stms SOACS)
fuseConsts Names
used_consts Stms SOACS
consts =
Scope SOACS -> Stms SOACS -> [SubExp] -> PassM (Stms SOACS)
fuseStms Scope SOACS
forall a. Monoid a => a
mempty Stms SOACS
consts ([SubExp] -> PassM (Stms SOACS)) -> [SubExp] -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
used_consts
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun :: Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
fuseFun Stms SOACS
consts FunDef SOACS
fun = do
Stms SOACS
stms <-
Scope SOACS -> Stms SOACS -> [SubExp] -> PassM (Stms SOACS)
fuseStms
(Stms SOACS -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf Stms SOACS
consts Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> Scope SOACS
forall lore dec.
(FParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfFParams (FunDef SOACS -> [FParam SOACS]
forall lore. FunDef lore -> [FParam lore]
funDefParams FunDef SOACS
fun))
(BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms (BodyT SOACS -> Stms SOACS) -> BodyT SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun)
(BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult (BodyT SOACS -> [SubExp]) -> BodyT SOACS -> [SubExp]
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun)
let body :: BodyT SOACS
body = (FunDef SOACS -> BodyT SOACS
forall lore. FunDef lore -> BodyT lore
funDefBody FunDef SOACS
fun) {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms}
FunDef SOACS -> PassM (FunDef SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return FunDef SOACS
fun {funDefBody :: BodyT SOACS
funDefBody = BodyT SOACS
body}
fuseStms :: Scope SOACS -> Stms SOACS -> Result -> PassM (Stms SOACS)
fuseStms :: Scope SOACS -> Stms SOACS -> [SubExp] -> PassM (Stms SOACS)
fuseStms Scope SOACS
scope Stms SOACS
stms [SubExp]
res = do
let env :: FusionGEnv
env =
FusionGEnv :: Map VName [VName] -> Map VName VarEntry -> FusedRes -> FusionGEnv
FusionGEnv
{ soacs :: Map VName [VName]
soacs = Map VName [VName]
forall k a. Map k a
M.empty,
varsInScope :: Map VName VarEntry
varsInScope = Map VName VarEntry
forall a. Monoid a => a
mempty,
fusedRes :: FusedRes
fusedRes = FusedRes
forall a. Monoid a => a
mempty
}
FusedRes
k <-
FusedRes -> FusedRes
cleanFusionResult
(FusedRes -> FusedRes) -> PassM FusedRes -> PassM FusedRes
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PassM (Either Error FusedRes) -> PassM FusedRes
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM
( FusionGM FusedRes -> FusionGEnv -> PassM (Either Error FusedRes)
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM
([(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
forall a. Monoid a => a
mempty (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms) [SubExp]
res)
FusionGEnv
env
)
if Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ FusedRes -> Bool
rsucc FusedRes
k
then Stms SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
stms
else PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall err a. Show err => PassM (Either err a) -> PassM a
liftEitherM (PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS))
-> PassM (Either Error (Stms SOACS)) -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusionGM (Stms SOACS)
-> FusionGEnv -> PassM (Either Error (Stms SOACS))
forall (m :: * -> *) a.
MonadFreshNames m =>
FusionGM a -> FusionGEnv -> m (Either Error a)
runFusionGatherM ([(Ident, Names)] -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding [(Ident, Names)]
scope' (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ FusedRes -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
k (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms) FusionGEnv
env
where
scope' :: [(Ident, Names)]
scope' = ((VName, NameInfo SOACS) -> (Ident, Names))
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> [a] -> [b]
map (VName, NameInfo SOACS) -> (Ident, Names)
forall {t} {b}. (Typed t, Monoid b) => (VName, t) -> (Ident, b)
toBind ([(VName, NameInfo SOACS)] -> [(Ident, Names)])
-> [(VName, NameInfo SOACS)] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Scope SOACS -> [(VName, NameInfo SOACS)]
forall k a. Map k a -> [(k, a)]
M.toList Scope SOACS
scope
toBind :: (VName, t) -> (Ident, b)
toBind (VName
k, t
t) = (VName -> Type -> Ident
Ident VName
k (Type -> Ident) -> Type -> Ident
forall a b. (a -> b) -> a -> b
$ t -> Type
forall t. Typed t => t -> Type
typeOf t
t, b
forall a. Monoid a => a
mempty)
newtype KernName = KernName {KernName -> VName
unKernName :: VName}
deriving (KernName -> KernName -> Bool
(KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool) -> Eq KernName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: KernName -> KernName -> Bool
$c/= :: KernName -> KernName -> Bool
== :: KernName -> KernName -> Bool
$c== :: KernName -> KernName -> Bool
Eq, Eq KernName
Eq KernName
-> (KernName -> KernName -> Ordering)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> Bool)
-> (KernName -> KernName -> KernName)
-> (KernName -> KernName -> KernName)
-> Ord KernName
KernName -> KernName -> Bool
KernName -> KernName -> Ordering
KernName -> KernName -> KernName
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: KernName -> KernName -> KernName
$cmin :: KernName -> KernName -> KernName
max :: KernName -> KernName -> KernName
$cmax :: KernName -> KernName -> KernName
>= :: KernName -> KernName -> Bool
$c>= :: KernName -> KernName -> Bool
> :: KernName -> KernName -> Bool
$c> :: KernName -> KernName -> Bool
<= :: KernName -> KernName -> Bool
$c<= :: KernName -> KernName -> Bool
< :: KernName -> KernName -> Bool
$c< :: KernName -> KernName -> Bool
compare :: KernName -> KernName -> Ordering
$ccompare :: KernName -> KernName -> Ordering
Ord, Int -> KernName -> ShowS
[KernName] -> ShowS
KernName -> String
(Int -> KernName -> ShowS)
-> (KernName -> String) -> ([KernName] -> ShowS) -> Show KernName
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [KernName] -> ShowS
$cshowList :: [KernName] -> ShowS
show :: KernName -> String
$cshow :: KernName -> String
showsPrec :: Int -> KernName -> ShowS
$cshowsPrec :: Int -> KernName -> ShowS
Show)
data FusedRes = FusedRes
{
FusedRes -> Bool
rsucc :: Bool,
FusedRes -> Map VName KernName
outArr :: M.Map VName KernName,
FusedRes -> Map VName (Set KernName)
inpArr :: M.Map VName (S.Set KernName),
FusedRes -> Names
infusible :: Names,
FusedRes -> Map KernName FusedKer
kernels :: M.Map KernName FusedKer
}
instance Semigroup FusedRes where
FusedRes
res1 <> :: FusedRes -> FusedRes -> FusedRes
<> FusedRes
res2 =
Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
(FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
(FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
(FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2)
(FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)
instance Monoid FusedRes where
mempty :: FusedRes
mempty =
FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
{ rsucc :: Bool
rsucc = Bool
False,
outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
}
isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers :: FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
ress Set KernName
kers VName
nm =
case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress) of
Maybe (Set KernName)
Nothing -> Bool
False
Just Set KernName
s -> Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName -> Bool
forall a. Set a -> Bool
S.null (Set KernName -> Bool) -> Set KernName -> Bool
forall a b. (a -> b) -> a -> b
$ Set KernName
s Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
`S.difference` Set KernName
kers
getKersWithInpArrs :: FusedRes -> [VName] -> S.Set KernName
getKersWithInpArrs :: FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
ress =
[Set KernName] -> Set KernName
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions ([Set KernName] -> Set KernName)
-> ([VName] -> [Set KernName]) -> [VName] -> Set KernName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> Maybe (Set KernName)) -> [VName] -> [Set KernName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
ress)
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr :: [VName] -> FusionGM [VName]
expandSoacInpArr =
([VName] -> VName -> FusionGM [VName])
-> [VName] -> [VName] -> FusionGM [VName]
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \[VName]
y VName
nm -> do
Maybe [VName]
bnd <- (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName]))
-> (FusionGEnv -> Maybe [VName]) -> FusionGM (Maybe [VName])
forall a b. (a -> b) -> a -> b
$ VName -> Map VName [VName] -> Maybe [VName]
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm (Map VName [VName] -> Maybe [VName])
-> (FusionGEnv -> Map VName [VName]) -> FusionGEnv -> Maybe [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName [VName]
soacs
case Maybe [VName]
bnd of
Maybe [VName]
Nothing -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
nm])
Just [VName]
nns -> [VName] -> FusionGM [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
y [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
nns)
)
[]
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs :: SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac = do
let ([VName]
inp_idds, [VName]
other_idds) = [Input] -> ([VName], [VName])
getIdentArr ([Input] -> ([VName], [VName])) -> [Input] -> ([VName], [VName])
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
([VName]
inp_nms0, [VName]
other_nms0) = ([VName]
inp_idds, [VName]
other_idds)
[VName]
inp_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
inp_nms0
[VName]
other_nms <- [VName] -> FusionGM [VName]
expandSoacInpArr [VName]
other_nms0
([VName], [VName]) -> FusionGM ([VName], [VName])
forall (m :: * -> *) a. Monad m => a -> m a
return ([VName]
inp_nms, [VName]
other_nms)
addNewKerWithInfusible :: FusedRes -> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible :: FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res ([Ident]
idd, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs = do
KernName
nm_ker <- VName -> KernName
KernName (VName -> KernName) -> FusionGM VName -> FusionGM KernName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"ker"
Scope SOACS
scope <- FusionGM (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let out_nms :: [VName]
out_nms = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
idd
new_ker :: FusedKer
new_ker = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope
comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union
os' :: Map VName KernName
os' =
[(VName, KernName)] -> Map VName KernName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
arr, KernName
nm_ker) | VName
arr <- [VName]
out_nms]
Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res
is' :: Map VName (Set KernName)
is' =
[(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (VName
arr, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
nm_ker)
| VName
arr <- (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
]
Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
(FusedRes -> Bool
rsucc FusedRes
res)
Map VName KernName
os'
Map VName (Set KernName)
is'
Names
ufs
(KernName
-> FusedKer -> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert KernName
nm_ker FusedKer
new_ker (FusedRes -> Map KernName FusedKer
kernels FusedRes
res))
lookupInput :: VName -> FusionGM (Maybe SOAC.Input)
lookupInput :: VName -> FusionGM (Maybe Input)
lookupInput VName
name = (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name
inlineSOACInput :: SOAC.Input -> FusionGM SOAC.Input
inlineSOACInput :: Input -> FusionGM Input
inlineSOACInput (SOAC.Input ArrayTransforms
ts VName
v Type
t) = do
Maybe Input
maybe_inp <- VName -> FusionGM (Maybe Input)
lookupInput VName
v
case Maybe Input
maybe_inp of
Maybe Input
Nothing ->
Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v Type
t
Just (SOAC.Input ArrayTransforms
ts2 VName
v2 Type
t2) ->
Input -> FusionGM Input
forall (m :: * -> *) a. Monad m => a -> m a
return (Input -> FusionGM Input) -> Input -> FusionGM Input
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> VName -> Type -> Input
SOAC.Input (ArrayTransforms
ts2 ArrayTransforms -> ArrayTransforms -> ArrayTransforms
forall a. Semigroup a => a -> a -> a
<> ArrayTransforms
ts) VName
v2 Type
t2
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs :: SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
soac = do
[Input]
inputs' <- (Input -> FusionGM Input) -> [Input] -> FusionGM [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Input -> FusionGM Input
inlineSOACInput ([Input] -> FusionGM [Input]) -> [Input] -> FusionGM [Input]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ [Input]
inputs' [Input] -> SOAC -> SOAC
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC
soac
greedyFuse ::
[Stm] ->
Names ->
FusedRes ->
(Pattern, StmAux (), SOAC, Names) ->
FusionGM FusedRes
greedyFuse :: [Stm]
-> Names
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_bnds Names
lam_used_nms FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
orig_soac, Names
consumed) = do
SOAC
soac <- SOAC -> FusionGM SOAC
inlineSOACInputs SOAC
orig_soac
([VName]
inp_nms, [VName]
other_nms) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
let out_nms :: [VName]
out_nms = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds
isInfusible :: VName -> Bool
isInfusible = (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res)
is_screma :: Bool
is_screma = case SOAC
orig_soac of
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_ ->
(Maybe ([Reduce SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Reduce SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe ([Scan SOACS], Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe ([Scan SOACS], Lambda SOACS)
forall lore. ScremaForm lore -> Maybe ([Scan lore], Lambda lore)
isScanomapSOAC ScremaForm SOACS
form))
Bool -> Bool -> Bool
&& Bool -> Bool
not (Maybe [Reduce SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Reduce SOACS]
forall lore. ScremaForm lore -> Maybe [Reduce lore]
isReduceSOAC ScremaForm SOACS
form) Bool -> Bool -> Bool
|| Maybe [Scan SOACS] -> Bool
forall a. Maybe a -> Bool
isJust (ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
isScanSOAC ScremaForm SOACS
form))
SOAC
_ -> Bool
False
(Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
fused_nms, [FusedKer]
old_kers, [KernName]
oldker_nms) <-
if Bool
is_screma Bool -> Bool -> Bool
|| (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isInfusible [VName]
out_nms
then [Stm]
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_bnds FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
else FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed)
let all_used_names :: [VName]
all_used_names = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat [Names
lam_used_nms, [VName] -> Names
namesFromList [VName]
inp_nms, [VName] -> Names
namesFromList [VName]
other_nms]
has_inplace :: FusedKer -> Bool
has_inplace FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` FusedKer -> Names
inplace FusedKer
ker) [VName]
all_used_names
ok_inplace :: Bool
ok_inplace = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (FusedKer -> Bool) -> [FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any FusedKer -> Bool
has_inplace [FusedKer]
old_kers
let fusible_ker :: Bool
fusible_ker = Bool -> Bool
not ([FusedKer] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [FusedKer]
old_kers) Bool -> Bool -> Bool
&& Bool
ok_inplace Bool -> Bool -> Bool
&& Bool
ok_kers_compat
let mod_kerS :: Set KernName
mod_kerS = if Bool
fusible_ker then [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList [KernName]
oldker_nms else Set KernName
forall a. Monoid a => a
mempty
let used_inps :: [VName]
used_inps = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (FusedRes -> Set KernName -> VName -> Bool
isInpArrInResModKers FusedRes
res Set KernName
mod_kerS) [VName]
inp_nms
let ufs :: Names
ufs =
[Names] -> Names
forall a. Monoid a => [a] -> a
mconcat
[ FusedRes -> Names
infusible FusedRes
res,
[VName] -> Names
namesFromList [VName]
used_inps,
[VName] -> Names
namesFromList [VName]
other_nms
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName]) -> [Input] -> [VName]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC
soac)
]
let comb :: Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
comb = (Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union
if Bool -> Bool
not Bool
fusible_ker
then FusedRes
-> ([Ident], StmAux (), SOAC, Names) -> Names -> FusionGM FusedRes
addNewKerWithInfusible FusedRes
res (PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) Names
ufs
else do
let inpArr' :: Map VName (Set KernName)
inpArr' =
(Map VName (Set KernName)
-> (FusedKer, KernName) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(FusedKer, KernName)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \Map VName (Set KernName)
inpa (FusedKer
kold, KernName
knm) ->
(Map VName (Set KernName) -> VName -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> Set VName
-> Map VName (Set KernName)
forall a b. (a -> b -> a) -> a -> Set b -> a
S.foldl'
( \Map VName (Set KernName)
inpp VName
nm ->
case VName -> Map VName (Set KernName) -> Maybe (Set KernName)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
nm Map VName (Set KernName)
inpp of
Maybe (Set KernName)
Nothing -> Map VName (Set KernName)
inpp
Just Set KernName
s ->
let new_set :: Set KernName
new_set = KernName -> Set KernName -> Set KernName
forall a. Ord a => a -> Set a -> Set a
S.delete KernName
knm Set KernName
s
in if Set KernName -> Bool
forall a. Set a -> Bool
S.null Set KernName
new_set
then VName -> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VName
nm Map VName (Set KernName)
inpp
else VName
-> Set KernName
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
nm Set KernName
new_set Map VName (Set KernName)
inpp
)
Map VName (Set KernName)
inpa
(Set VName -> Map VName (Set KernName))
-> Set VName -> Map VName (Set KernName)
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
kold
)
(FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res)
([FusedKer] -> [KernName] -> [(FusedKer, KernName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [FusedKer]
old_kers [KernName]
oldker_nms)
let fused_ker_nms :: [(KernName, FusedKer)]
fused_ker_nms = [KernName] -> [FusedKer] -> [(KernName, FusedKer)]
forall a b. [a] -> [b] -> [(a, b)]
zip [KernName]
fused_nms [FusedKer]
fused_kers
inpArr'' :: Map VName (Set KernName)
inpArr'' =
(Map VName (Set KernName)
-> (KernName, FusedKer) -> Map VName (Set KernName))
-> Map VName (Set KernName)
-> [(KernName, FusedKer)]
-> Map VName (Set KernName)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \Map VName (Set KernName)
inpa' (KernName
knm, FusedKer
knew) ->
[(VName, Set KernName)] -> Map VName (Set KernName)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (VName
k, KernName -> Set KernName
forall a. a -> Set a
S.singleton KernName
knm)
| VName
k <- Set VName -> [VName]
forall a. Set a -> [a]
S.toList (Set VName -> [VName]) -> Set VName -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedKer -> Set VName
arrInputs FusedKer
knew
]
Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
`comb` Map VName (Set KernName)
inpa'
)
Map VName (Set KernName)
inpArr'
[(KernName, FusedKer)]
fused_ker_nms
let kernels' :: Map KernName FusedKer
kernels' = [(KernName, FusedKer)] -> Map KernName FusedKer
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(KernName, FusedKer)]
fused_ker_nms Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes Bool
True (FusedRes -> Map VName KernName
outArr FusedRes
res) Map VName (Set KernName)
inpArr'' Names
ufs Map KernName FusedKer
kernels'
prodconsGreedyFuse ::
FusedRes ->
(Pattern, StmAux (), SOAC, Names) ->
FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse :: FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
prodconsGreedyFuse FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
let out_nms :: [VName]
out_nms = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds
to_fuse_knmSet :: Set KernName
to_fuse_knmSet = FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res [VName]
out_nms
to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList Set KernName
to_fuse_knmSet
lookup_kern :: KernName -> FusionGM FusedKer
lookup_kern KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
Maybe FusedKer
Nothing ->
Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
)
Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker
[FusedKer]
to_fuse_kers <- (KernName -> FusionGM FusedKer)
-> [KernName] -> FusionGM [FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM KernName -> FusionGM FusedKer
lookup_kern [KernName]
to_fuse_knms
(Bool
ok_kers_compat, [FusedKer]
fused_kers) <- do
[Maybe FusedKer]
kers <-
[FusedKer]
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [FusedKer]
to_fuse_kers ((FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer])
-> (FusedKer -> FusionGM (Maybe FusedKer))
-> FusionGM [Maybe FusedKer]
forall a b. (a -> b) -> a -> b
$
Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
forall a. Monoid a => a
mempty (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds) SOAC
soac Names
consumed
case [Maybe FusedKer] -> Maybe [FusedKer]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [Maybe FusedKer]
kers of
Maybe [FusedKer]
Nothing -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, [])
Just [FusedKer]
kers' -> (Bool, [FusedKer]) -> FusionGM (Bool, [FusedKer])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, (FusedKer -> FusedKer) -> [FusedKer] -> [FusedKer]
forall a b. (a -> b) -> [a] -> [b]
map FusedKer -> FusedKer
certifyKer [FusedKer]
kers')
(Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
ok_kers_compat, [FusedKer]
fused_kers, [KernName]
to_fuse_knms, [FusedKer]
to_fuse_kers, [KernName]
to_fuse_knms)
where
certifyKer :: FusedKer -> FusedKer
certifyKer FusedKer
k = FusedKer
k {kerAux :: StmAux ()
kerAux = FusedKer -> StmAux ()
kerAux FusedKer
k StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux}
horizontGreedyFuse ::
[Stm] ->
FusedRes ->
(Pattern, StmAux (), SOAC, Names) ->
FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse :: [Stm]
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
horizontGreedyFuse [Stm]
rem_bnds FusedRes
res (Pattern
out_idds, StmAux ()
aux, SOAC
soac, Names
consumed) = do
([VName]
inp_nms, [VName]
_) <- SOAC -> FusionGM ([VName], [VName])
soacInputs SOAC
soac
let out_nms :: [VName]
out_nms = PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
out_idds
infusible_nms :: Names
infusible_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` FusedRes -> Names
infusible FusedRes
res) [VName]
out_nms
out_arr_nms :: [VName]
out_arr_nms = case SOAC
soac of
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
_) [Input]
_ ->
Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall lore. [Scan lore] -> Int
scanResults [Scan SOACS]
scans Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall lore. [Reduce lore] -> Int
redResults [Reduce SOACS]
reds) [VName]
out_nms
SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_ -> Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
out_nms
SOAC
_ -> [VName]
out_nms
to_fuse_knms1 :: [KernName]
to_fuse_knms1 = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> [VName] -> Set KernName
getKersWithInpArrs FusedRes
res ([VName]
out_arr_nms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
inp_nms)
to_fuse_knms2 :: [KernName]
to_fuse_knms2 = SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize (SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC
soac) FusedRes
res
to_fuse_knms :: [KernName]
to_fuse_knms = Set KernName -> [KernName]
forall a. Set a -> [a]
S.toList (Set KernName -> [KernName]) -> Set KernName -> [KernName]
forall a b. (a -> b) -> a -> b
$ [KernName] -> Set KernName
forall a. Ord a => [a] -> Set a
S.fromList ([KernName] -> Set KernName) -> [KernName] -> Set KernName
forall a b. (a -> b) -> a -> b
$ [KernName]
to_fuse_knms1 [KernName] -> [KernName] -> [KernName]
forall a. [a] -> [a] -> [a]
++ [KernName]
to_fuse_knms2
lookupKernel :: KernName -> FusionGM FusedKer
lookupKernel KernName
k = case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
k (FusedRes -> Map KernName FusedKer
kernels FusedRes
res) of
Maybe FusedKer
Nothing ->
Error -> FusionGM FusedKer
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedKer) -> Error -> FusionGM FusedKer
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, greedyFuse, comp of to_fuse_kers: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"kernel name not found in kernels field!"
)
Just FusedKer
ker -> FusedKer -> FusionGM FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return FusedKer
ker
let bnd_nms :: [[VName]]
bnd_nms = (Stm -> [VName]) -> [Stm] -> [[VName]]
forall a b. (a -> b) -> [a] -> [b]
map (PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (PatternT Type -> [VName])
-> (Stm -> PatternT Type) -> Stm -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stm -> PatternT Type
forall lore. Stm lore -> Pattern lore
stmPattern) [Stm]
rem_bnds
[Maybe (FusedKer, KernName, Int)]
kernminds <- [KernName]
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [KernName]
to_fuse_knms ((KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)])
-> (KernName -> FusionGM (Maybe (FusedKer, KernName, Int)))
-> FusionGM [Maybe (FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ \KernName
ker_nm -> do
FusedKer
ker <- KernName -> FusionGM FusedKer
lookupKernel KernName
ker_nm
case (VName -> Maybe Int) -> [VName] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\VName
out_nm -> ([VName] -> Bool) -> [[VName]] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
L.findIndex (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem VName
out_nm) [[VName]]
bnd_nms) (FusedKer -> [VName]
outNames FusedKer
ker) of
[] -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (FusedKer, KernName, Int)
forall a. Maybe a
Nothing
[Int]
is -> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int)))
-> Maybe (FusedKer, KernName, Int)
-> FusionGM (Maybe (FusedKer, KernName, Int))
forall a b. (a -> b) -> a -> b
$ (FusedKer, KernName, Int) -> Maybe (FusedKer, KernName, Int)
forall a. a -> Maybe a
Just (FusedKer
ker, KernName
ker_nm, [Int] -> Int
forall a (f :: * -> *). (Num a, Ord a, Foldable f) => f a -> a
maxinum [Int]
is)
Scope SOACS
scope <- FusionGM (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
let kernminds' :: [(FusedKer, KernName, Int)]
kernminds' = ((FusedKer, KernName, Int)
-> (FusedKer, KernName, Int) -> Ordering)
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
L.sortBy (\(FusedKer
_, KernName
_, Int
i1) (FusedKer
_, KernName
_, Int
i2) -> Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare Int
i1 Int
i2) ([(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)])
-> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a b. (a -> b) -> a -> b
$ [Maybe (FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (FusedKer, KernName, Int)]
kernminds
soac_kernel :: FusedKer
soac_kernel = StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope
Scope SOACS
use_scope <- (Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> [Stm] -> Scope SOACS
forall lore a. Scoped lore a => a -> Scope lore
scopeOf [Stm]
rem_bnds) (Scope SOACS -> Scope SOACS)
-> FusionGM (Scope SOACS) -> FusionGM (Scope SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> FusionGM (Scope SOACS)
forall lore (m :: * -> *). HasScope lore m => m (Scope lore)
askScope
(Bool
_, Int
ok_ind, Int
_, FusedKer
fused_ker, Names
_) <-
((Bool, Int, Int, FusedKer, Names)
-> (FusedKer, KernName, Int)
-> FusionGM (Bool, Int, Int, FusedKer, Names))
-> (Bool, Int, Int, FusedKer, Names)
-> [(FusedKer, KernName, Int)]
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \(Bool
cur_ok, Int
n, Int
prev_ind, FusedKer
cur_ker, Names
ufus_nms) (FusedKer
ker, KernName
_ker_nm, Int
bnd_ind) -> do
let curker_outnms :: [VName]
curker_outnms = FusedKer -> [VName]
outNames FusedKer
cur_ker
curker_outset :: Names
curker_outset = [VName] -> Names
namesFromList [VName]
curker_outnms
new_ufus_nms :: Names
new_ufus_nms = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ Names -> [VName]
namesToList Names
ufus_nms
out_transf_ok :: Bool
out_transf_ok =
let ker_inp :: [Input]
ker_inp = SOAC -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs (SOAC -> [Input]) -> SOAC -> [Input]
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
unfuse1 :: Names
unfuse1 =
[VName] -> Names
namesFromList ((Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray [Input]
ker_inp)
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarInput [Input]
ker_inp)
unfuse2 :: Names
unfuse2 = Names -> Names -> Names
namesIntersection Names
curker_outset Names
ufus_nms
in Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Names
unfuse1 Names -> Names -> Bool
`namesIntersect` Names
unfuse2
cons_no_out_transf :: Bool
cons_no_out_transf = ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
Bool
consumer_ok <- do
let consumer_bnd :: Stm
consumer_bnd = [Stm]
rem_bnds [Stm] -> Int -> Stm
forall a. [a] -> Int -> a
!! Int
bnd_ind
Either NotSOAC SOAC
maybesoac <- ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC)
-> Scope SOACS -> FusionGM (Either NotSOAC SOAC)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Exp SOACS -> ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC)
forall lore (m :: * -> *).
(Op lore ~ SOAC lore, HasScope lore m) =>
Exp lore -> m (Either NotSOAC (SOAC lore))
SOAC.fromExp (Exp SOACS -> ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC))
-> Exp SOACS
-> ReaderT (Scope SOACS) FusionGM (Either NotSOAC SOAC)
forall a b. (a -> b) -> a -> b
$ Stm -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm
consumer_bnd) Scope SOACS
use_scope
case Either NotSOAC SOAC
maybesoac of
Right SOAC
conssoac ->
Bool -> FusionGM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool -> FusionGM Bool) -> Bool -> FusionGM Bool
forall a b. (a -> b) -> a -> b
$
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
Names
curker_outset
Names -> Names -> Bool
`namesIntersect` BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody (Lambda SOACS -> BodyT SOACS) -> Lambda SOACS -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC
conssoac)
Left NotSOAC
_ -> Bool -> FusionGM Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
let interm_bnds_ok :: Bool
interm_bnds_ok =
Bool
cur_ok Bool -> Bool -> Bool
&& Bool
consumer_ok Bool -> Bool -> Bool
&& Bool
out_transf_ok Bool -> Bool -> Bool
&& Bool
cons_no_out_transf
Bool -> Bool -> Bool
&& (Bool -> Stm -> Bool) -> Bool -> [Stm] -> Bool
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \Bool
ok Stm
bnd ->
Bool
ok
Bool -> Bool -> Bool
&& Bool -> Bool
not (Names
curker_outset Names -> Names -> Bool
`namesIntersect` Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (Stm -> Exp SOACS
forall lore. Stm lore -> Exp lore
stmExp Stm
bnd))
Bool -> Bool -> Bool
||
Bool -> Bool
not
( [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$
[VName]
curker_outnms
[VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
`L.intersect` PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames (Stm -> Pattern
forall lore. Stm lore -> Pattern lore
stmPattern Stm
bnd)
)
)
Bool
True
(Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
drop (Int
prev_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Stm] -> [Stm]) -> [Stm] -> [Stm]
forall a b. (a -> b) -> a -> b
$ Int -> [Stm] -> [Stm]
forall a. Int -> [a] -> [a]
take Int
bnd_ind [Stm]
rem_bnds)
if Bool -> Bool
not Bool
interm_bnds_ok
then (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
bnd_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
else do
Maybe FusedKer
new_ker <-
Names
-> [VName]
-> SOAC
-> Names
-> FusedKer
-> FusionGM (Maybe FusedKer)
forall (m :: * -> *).
MonadFreshNames m =>
Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion
Names
ufus_nms
(FusedKer -> [VName]
outNames FusedKer
cur_ker)
(FusedKer -> SOAC
fsoac FusedKer
cur_ker)
(FusedKer -> Names
fusedConsumed FusedKer
cur_ker)
FusedKer
ker
case Maybe FusedKer
new_ker of
Maybe FusedKer
Nothing -> (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
False, Int
n, Int
bnd_ind, FusedKer
cur_ker, Names
forall a. Monoid a => a
mempty)
Just FusedKer
krn ->
let krn' :: FusedKer
krn' = FusedKer
krn {kerAux :: StmAux ()
kerAux = StmAux ()
aux StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> FusedKer -> StmAux ()
kerAux FusedKer
krn}
in (Bool, Int, Int, FusedKer, Names)
-> FusionGM (Bool, Int, Int, FusedKer, Names)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Int
bnd_ind, FusedKer
krn', Names
new_ufus_nms)
)
(Bool
True, Int
0, Int
0, FusedKer
soac_kernel, Names
infusible_nms)
[(FusedKer, KernName, Int)]
kernminds'
let ([FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms', [Int]
_) = [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int]))
-> [(FusedKer, KernName, Int)] -> ([FusedKer], [KernName], [Int])
forall a b. (a -> b) -> a -> b
$ Int -> [(FusedKer, KernName, Int)] -> [(FusedKer, KernName, Int)]
forall a. Int -> [a] -> [a]
take Int
ok_ind [(FusedKer, KernName, Int)]
kernminds'
new_kernms :: [KernName]
new_kernms = Int -> [KernName] -> [KernName]
forall a. Int -> [a] -> [a]
drop (Int
ok_ind Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) [KernName]
to_fuse_knms'
(Bool, [FusedKer], [KernName], [FusedKer], [KernName])
-> FusionGM (Bool, [FusedKer], [KernName], [FusedKer], [KernName])
forall (m :: * -> *) a. Monad m => a -> m a
return (Int
ok_ind Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0, [FusedKer
fused_ker], [KernName]
new_kernms, [FusedKer]
to_fuse_kers', [KernName]
to_fuse_knms')
where
getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize :: SubExp -> FusedRes -> [KernName]
getKersWithSameInpSize SubExp
sz FusedRes
ress =
((KernName, FusedKer) -> KernName)
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> [a] -> [b]
map (KernName, FusedKer) -> KernName
forall a b. (a, b) -> a
fst ([(KernName, FusedKer)] -> [KernName])
-> [(KernName, FusedKer)] -> [KernName]
forall a b. (a -> b) -> a -> b
$ ((KernName, FusedKer) -> Bool)
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a. (a -> Bool) -> [a] -> [a]
filter (\(KernName
_, FusedKer
ker) -> SubExp
sz SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width (FusedKer -> SOAC
fsoac FusedKer
ker)) ([(KernName, FusedKer)] -> [(KernName, FusedKer)])
-> [(KernName, FusedKer)] -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ Map KernName FusedKer -> [(KernName, FusedKer)]
forall k a. Map k a -> [(k, a)]
M.toList (Map KernName FusedKer -> [(KernName, FusedKer)])
-> Map KernName FusedKer -> [(KernName, FusedKer)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map KernName FusedKer
kernels FusedRes
ress
fusionGatherBody :: FusedRes -> Body -> FusionGM FusedRes
fusionGatherBody :: FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
fres (Body BodyDec SOACS
_ Stms SOACS
stms [SubExp]
res) =
FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres (Stms SOACS -> [Stm]
forall lore. Stms lore -> [Stm lore]
stmsToList Stms SOACS
stms) [SubExp]
res
fusionGatherStms :: FusedRes -> [Stm] -> Result -> FusionGM FusedRes
fusionGatherStms :: FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms
FusedRes
fres
( Let
(Pattern [] [PatElem]
pes)
StmAux (ExpDec SOACS)
bndtp
(DoLoop [] [(FParam SOACS, SubExp)]
merge (ForLoop VName
i IntType
it SubExp
w [(LParam SOACS, VName)]
loop_vars) BodyT SOACS
body)
: [Stm]
bnds
)
[SubExp]
res
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(Param Type, VName)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars = do
let ([Param DeclType]
merge_params, [SubExp]
merge_init) = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge
([Param Type]
loop_params, [VName]
loop_arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars
VName
chunk_size <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"chunk_size"
VName
offset <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"offset"
let chunk_param :: Param Type
chunk_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
chunk_size (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
offset_param :: Param Type
offset_param = VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
offset (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it
[Param Type]
acc_params <- [Param DeclType]
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
merge_params ((Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type])
-> (Param DeclType -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \Param DeclType
p ->
VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param (VName -> Type -> Param Type)
-> FusionGM VName -> FusionGM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_outer")
FusionGM (Type -> Param Type)
-> FusionGM Type -> FusionGM (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> FusionGM Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param DeclType
p)
[Param Type]
chunked_params <- [(Param Type, VName)]
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars (((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type])
-> ((Param Type, VName) -> FusionGM (Param Type))
-> FusionGM [Param Type]
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
arr) ->
VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param (VName -> Type -> Param Type)
-> FusionGM VName -> FusionGM (Type -> Param Type)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (VName -> String
baseString VName
arr String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_chunk")
FusionGM (Type -> Param Type)
-> FusionGM Type -> FusionGM (Param Type)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Type -> FusionGM Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
p Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` VName -> SubExp
Futhark.Var VName
chunk_size)
let lam_params :: [Param Type]
lam_params = Param Type
chunk_param Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param] [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
chunked_params
BodyT SOACS
lam_body <- Binder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
Scope SOACS
-> Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
lam_params) (Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
let merge' :: [(Param DeclType, SubExp)]
merge' = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
merge_params ([SubExp] -> [(Param DeclType, SubExp)])
-> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Futhark.Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
acc_params
VName
j <- String -> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"j"
BodyT SOACS
loop_body <- Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> Binder SOACS (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ do
[(Param Type, Param Type)]
-> ((Param Type, Param Type)
-> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [Param Type] -> [(Param Type, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
loop_params [Param Type]
chunked_params) (((Param Type, Param Type) -> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ())
-> ((Param Type, Param Type)
-> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
p, Param Type
a_p) ->
[VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p] (Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
a_p) (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> Slice SubExp -> Slice SubExp
fullSlice (Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
a_p) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Futhark.Var VName
j]
[VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
i] (Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
j)
BodyT SOACS -> Binder SOACS (BodyT SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return BodyT SOACS
body
[BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))]
-> BinderT
SOACS
(State VNameSource)
(Body (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody
[ Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS))
-> Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
[(FParam SOACS, SubExp)]
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> BodyT SOACS
-> Exp SOACS
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
merge' (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall lore.
VName
-> IntType -> SubExp -> [(LParam lore, VName)] -> LoopForm lore
ForLoop VName
j IntType
it (VName -> SubExp
Futhark.Var VName
chunk_size) []) BodyT SOACS
loop_body,
Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS))
-> Exp SOACS -> BinderT SOACS (State VNameSource) (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Futhark.Var VName
offset) (VName -> SubExp
Futhark.Var VName
chunk_size)
]
let lam :: Lambda SOACS
lam =
Lambda :: forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
lam_params,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
lam_body,
lambdaReturnType :: [Type]
lambdaReturnType = (Param Type -> Type) -> [Param Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType ([Param Type] -> [Type]) -> [Param Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ [Param Type]
acc_params [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
offset_param]
}
stream :: SOAC SOACS
stream = SubExp
-> [VName]
-> StreamForm SOACS
-> [SubExp]
-> Lambda SOACS
-> SOAC SOACS
forall lore.
SubExp
-> [VName]
-> StreamForm lore
-> [SubExp]
-> Lambda lore
-> SOAC lore
Futhark.Stream SubExp
w [VName]
loop_arrs StreamForm SOACS
forall lore. StreamForm lore
Sequential ([SubExp]
merge_init [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [IntType -> Integer -> SubExp
intConst IntType
it Integer
0]) Lambda SOACS
lam
VName
discard <- String -> FusionGM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"discard"
let discard_pe :: PatElemT Type
discard_pe = VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
discard (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms
FusedRes
fres
(Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ([PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type]
[PatElem]
pes [PatElemT Type] -> [PatElemT Type] -> [PatElemT Type]
forall a. Semigroup a => a -> a -> a
<> [PatElemT Type
discard_pe])) StmAux (ExpDec SOACS)
bndtp (Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op Op SOACS
SOAC SOACS
stream) Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
bnds)
[SubExp]
res
fusionGatherStms FusedRes
fres (bnd :: Stm
bnd@(Let Pattern
pat StmAux (ExpDec SOACS)
_ Exp SOACS
e) : [Stm]
bnds) [SubExp]
res = do
Either NotSOAC SOAC
maybesoac <- Exp SOACS -> FusionGM (Either NotSOAC SOAC)
forall lore (m :: * -> *).
(Op lore ~ SOAC lore, HasScope lore m) =>
Exp lore -> m (Either NotSOAC (SOAC lore))
SOAC.fromExp Exp SOACS
e
case Either NotSOAC SOAC
maybesoac of
Right soac :: SOAC
soac@(SOAC.Scatter SubExp
_len Lambda SOACS
lam [Input]
_ivs [(ShapeBase SubExp, Int, VName)]
_as) -> do
FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
FusedRes
fres'' <- FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
FusedRes -> Exp SOACS -> FusionGM FusedRes
checkForUpdates FusedRes
fres'' Exp SOACS
e
Right soac :: SOAC
soac@(SOAC.Hist SubExp
_ [HistOp SOACS]
_ Lambda SOACS
lam [Input]
_) -> do
FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lam
Right soac :: SOAC
soac@(SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
_) ->
SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac ((Scan SOACS -> Lambda SOACS) -> [Scan SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Scan SOACS -> Lambda SOACS
forall lore. Scan lore -> Lambda lore
scanLambda [Scan SOACS]
scans [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> Lambda SOACS) -> [Reduce SOACS] -> [Lambda SOACS]
forall a b. (a -> b) -> [a] -> [b]
map Reduce SOACS -> Lambda SOACS
forall lore. Reduce lore -> Lambda lore
redLambda [Reduce SOACS]
reds [Lambda SOACS] -> [Lambda SOACS] -> [Lambda SOACS]
forall a. Semigroup a => a -> a -> a
<> [Lambda SOACS
map_lam]) ([SubExp] -> FusionGM FusedRes) -> [SubExp] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
(Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans [SubExp] -> [SubExp] -> [SubExp]
forall a. Semigroup a => a -> a -> a
<> (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds
Right soac :: SOAC
soac@(SOAC.Stream SubExp
_ StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
_) -> do
let lambdas :: [Lambda SOACS]
lambdas = case StreamForm SOACS
form of
Parallel StreamOrd
_ Commutativity
_ Lambda SOACS
lout -> [Lambda SOACS
lout, Lambda SOACS
lam]
StreamForm SOACS
Sequential -> [Lambda SOACS
lam]
SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes
Either NotSOAC SOAC
_
| [PatElemT Type
pe] <- PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
Pattern
pat,
Just (VName
src, ArrayTransform
trns) <- Certificates -> Exp SOACS -> Maybe (VName, ArrayTransform)
forall lore.
Certificates -> Exp lore -> Maybe (VName, ArrayTransform)
SOAC.transformFromExp (Stm -> Certificates
forall lore. Stm lore -> Certificates
stmCerts Stm
bnd) Exp SOACS
e ->
PatElem
-> VName
-> ArrayTransform
-> FusionGM FusedRes
-> FusionGM FusedRes
forall a.
PatElem -> VName -> ArrayTransform -> FusionGM a -> FusionGM a
bindingTransform PatElemT Type
PatElem
pe VName
src ArrayTransform
trns (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
bnds [SubExp]
res
| Bool
otherwise -> do
let pat_vars :: [Exp SOACS]
pat_vars = (VName -> Exp SOACS) -> [VName] -> [Exp SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> (VName -> BasicOp) -> VName -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) ([VName] -> [Exp SOACS]) -> [VName] -> [Exp SOACS]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
Pattern
pat
FusedRes
bres <- Pattern -> Exp SOACS -> FusionGM FusedRes -> FusionGM FusedRes
gatherStmPattern Pattern
pat Exp SOACS
e (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres [Stm]
bnds [SubExp]
res
FusedRes
bres' <- FusedRes -> Exp SOACS -> FusionGM FusedRes
checkForUpdates FusedRes
bres Exp SOACS
e
(FusedRes -> Exp SOACS -> FusionGM FusedRes)
-> FusedRes -> [Exp SOACS] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp SOACS -> FusionGM FusedRes
fusionGatherExp FusedRes
bres' (Exp SOACS
e Exp SOACS -> [Exp SOACS] -> [Exp SOACS]
forall a. a -> [a] -> [a]
: [Exp SOACS]
pat_vars)
where
aux :: StmAux (ExpDec SOACS)
aux = Stm -> StmAux (ExpDec SOACS)
forall lore. Stm lore -> StmAux (ExpDec lore)
stmAux Stm
bnd
rem_bnds :: [Stm]
rem_bnds = Stm
bnd Stm -> [Stm] -> [Stm]
forall a. a -> [a] -> [a]
: [Stm]
bnds
consumed :: Names
consumed = Exp (Aliases SOACS) -> Names
forall lore. Aliased lore => Exp lore -> Names
consumedInExp (Exp (Aliases SOACS) -> Names) -> Exp (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ AliasTable -> Exp SOACS -> Exp (Aliases SOACS)
forall lore.
(ASTLore lore, CanBeAliased (Op lore)) =>
AliasTable -> Exp lore -> Exp (Aliases lore)
Alias.analyseExp AliasTable
forall a. Monoid a => a
mempty Exp SOACS
e
reduceLike :: SOAC -> [Lambda SOACS] -> [SubExp] -> FusionGM FusedRes
reduceLike SOAC
soac [Lambda SOACS]
lambdas [SubExp]
nes = do
(Names
used_lam, FusedRes
lres) <- ((Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes))
-> (Names, FusedRes)
-> [Lambda SOACS]
-> FusionGM (Names, FusedRes)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
fres) [Lambda SOACS]
lambdas
FusedRes
bres <- Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pattern
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
lres [Stm]
bnds [SubExp]
res
FusedRes
bres' <- (FusedRes -> SubExp -> FusionGM FusedRes)
-> FusedRes -> [SubExp] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
bres [SubExp]
nes
Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
[Stm]
-> Names
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_bnds Names
used_lam FusedRes
bres' (Pattern
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')
mapLike :: FusedRes -> SOAC -> Lambda SOACS -> FusionGM FusedRes
mapLike FusedRes
fres' SOAC
soac Lambda SOACS
lambda = do
FusedRes
bres <- Pattern -> FusionGM FusedRes -> FusionGM FusedRes
bindingFamily Pattern
pat (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> [Stm] -> [SubExp] -> FusionGM FusedRes
fusionGatherStms FusedRes
fres' [Stm]
bnds [SubExp]
res
(Names
used_lam, FusedRes
blres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
bres) Lambda SOACS
lambda
Names
consumed' <- Names -> FusionGM Names
varsAliases Names
consumed
[Stm]
-> Names
-> FusedRes
-> (Pattern, StmAux (), SOAC, Names)
-> FusionGM FusedRes
greedyFuse [Stm]
rem_bnds Names
used_lam FusedRes
blres (Pattern
pat, StmAux ()
StmAux (ExpDec SOACS)
aux, SOAC
soac, Names
consumed')
fusionGatherStms FusedRes
fres [] [SubExp]
res =
(FusedRes -> Exp SOACS -> FusionGM FusedRes)
-> FusedRes -> [Exp SOACS] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> Exp SOACS -> FusionGM FusedRes
fusionGatherExp FusedRes
fres ([Exp SOACS] -> FusionGM FusedRes)
-> [Exp SOACS] -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ (SubExp -> Exp SOACS) -> [SubExp] -> [Exp SOACS]
forall a b. (a -> b) -> [a] -> [b]
map (BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS)
-> (SubExp -> BasicOp) -> SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp) [SubExp]
res
fusionGatherExp :: FusedRes -> Exp -> FusionGM FusedRes
fusionGatherExp :: FusedRes -> Exp SOACS -> FusionGM FusedRes
fusionGatherExp FusedRes
fres (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loop_body) = do
FusedRes
fres' <- FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ LoopForm SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn LoopForm SOACS
form Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)] -> Names
forall a. FreeIn a => a -> Names
freeIn [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val
let form_idents :: [Ident]
form_idents =
case LoopForm SOACS
form of
ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (IntType -> PrimType
IntType IntType
it)) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
: ((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
WhileLoop {} -> []
FusedRes
new_res <-
[(Ident, Names)] -> FusionGM FusedRes -> FusionGM FusedRes
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding
( [Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Ident]
form_idents [Ident] -> [Ident] -> [Ident]
forall a. [a] -> [a] -> [a]
++ ((Param DeclType, SubExp) -> Ident)
-> [(Param DeclType, SubExp)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param DeclType -> Ident)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) ([(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. Semigroup a => a -> a -> a
<> [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val)) ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$
Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty
)
(FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
loop_body
let ([VName]
inp_arrs, [Set KernName]
_) = [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Set KernName)] -> ([VName], [Set KernName]))
-> [(VName, Set KernName)] -> ([VName], [Set KernName])
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [(VName, Set KernName)]
forall k a. Map k a -> [(k, a)]
M.toList (Map VName (Set KernName) -> [(VName, Set KernName)])
-> Map VName (Set KernName) -> [(VName, Set KernName)]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_arrs)}
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres'
fusionGatherExp FusedRes
fres (If SubExp
cond BodyT SOACS
e_then BodyT SOACS
e_else IfDec (BranchType SOACS)
_) = do
FusedRes
then_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_then
FusedRes
else_res <- FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
e_else
let both_res :: FusedRes
both_res = FusedRes
then_res FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
else_res
FusedRes
fres' <- FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres SubExp
cond
FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
fres' FusedRes
both_res
fusionGatherExp FusedRes
_ (Op Futhark.Screma {}) = String -> FusionGM FusedRes
errorIllegal String
"screma"
fusionGatherExp FusedRes
_ (Op Futhark.Scatter {}) = String -> FusionGM FusedRes
errorIllegal String
"write"
fusionGatherExp FusedRes
fres Exp SOACS
e = FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres (Names -> FusionGM FusedRes) -> Names -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
e
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp :: FusedRes -> SubExp -> FusionGM FusedRes
fusionGatherSubExp FusedRes
fres (Var VName
idd) = FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
idd
fusionGatherSubExp FusedRes
fres SubExp
_ = FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible :: FusedRes -> Names -> FusionGM FusedRes
addNamesToInfusible FusedRes
fres = (FusedRes -> VName -> FusionGM FusedRes)
-> FusedRes -> [VName] -> FusionGM FusedRes
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres ([VName] -> FusionGM FusedRes)
-> (Names -> [VName]) -> Names -> FusionGM FusedRes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Names -> [VName]
namesToList
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible :: FusedRes -> VName -> FusionGM FusedRes
addVarToInfusible FusedRes
fres VName
name = do
Maybe Input
trns <- (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input))
-> (FusionGEnv -> Maybe Input) -> FusionGM (Maybe Input)
forall a b. (a -> b) -> a -> b
$ VName -> FusionGEnv -> Maybe Input
lookupArr VName
name
let name' :: VName
name' = case Maybe Input
trns of
Maybe Input
Nothing -> VName
name
Just (SOAC.Input ArrayTransforms
_ VName
orig Type
_) -> VName
orig
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return FusedRes
fres {infusible :: Names
infusible = VName -> Names
oneName VName
name' Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
fres}
fusionGatherLam :: (Names, FusedRes) -> Lambda -> FusionGM (Names, FusedRes)
fusionGatherLam :: (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
u_set, FusedRes
fres) (Lambda [LParam SOACS]
idds BodyT SOACS
body [Type]
_) = do
FusedRes
new_res <- [Param Type] -> FusionGM FusedRes -> FusionGM FusedRes
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
idds (FusionGM FusedRes -> FusionGM FusedRes)
-> FusionGM FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$ FusedRes -> BodyT SOACS -> FusionGM FusedRes
fusionGatherBody FusedRes
forall a. Monoid a => a
mempty BodyT SOACS
body
let inp_arrs :: Names
inp_arrs = [VName] -> Names
namesFromList ([VName] -> Names) -> [VName] -> Names
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
new_res
let unfus :: Names
unfus = FusedRes -> Names
infusible FusedRes
new_res Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
inp_arrs
[VName]
bnds <- (FusionGEnv -> [VName]) -> FusionGM [VName]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((FusionGEnv -> [VName]) -> FusionGM [VName])
-> (FusionGEnv -> [VName]) -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName VarEntry -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName VarEntry -> [VName])
-> (FusionGEnv -> Map VName VarEntry) -> FusionGEnv -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusionGEnv -> Map VName VarEntry
varsInScope
let unfus' :: Names
unfus' = Names
unfus Names -> Names -> Names
`namesIntersection` [VName] -> Names
namesFromList [VName]
bnds
let new_res' :: FusedRes
new_res' = FusedRes
new_res {infusible :: Names
infusible = Names
unfus'}
(Names, FusedRes) -> FusionGM (Names, FusedRes)
forall (m :: * -> *) a. Monad m => a -> m a
return (Names
u_set Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
unfus', FusedRes
new_res' FusedRes -> FusedRes -> FusedRes
forall a. Semigroup a => a -> a -> a
<> FusedRes
fres)
fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms :: Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms
| Just (Let Pattern
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e, Stms SOACS
stms') <- Stms SOACS -> Maybe (Stm, Stms SOACS)
forall lore. Stms lore -> Maybe (Stm lore, Stms lore)
stmsHead Stms SOACS
stms = do
Stms SOACS
stms'' <- Pattern -> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a. Pattern -> FusionGM a -> FusionGM a
bindingPat Pattern
pat (FusionGM (Stms SOACS) -> FusionGM (Stms SOACS))
-> FusionGM (Stms SOACS) -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms'
Stms SOACS
soac_bnds <- Pattern -> StmAux () -> Exp SOACS -> FusionGM (Stms SOACS)
replaceSOAC Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux Exp SOACS
e
Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> FusionGM (Stms SOACS))
-> Stms SOACS -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
soac_bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stms SOACS
stms''
| Bool
otherwise =
Stms SOACS -> FusionGM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms SOACS
forall a. Monoid a => a
mempty
fuseInBody :: Body -> FusionGM Body
fuseInBody :: BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody (Body BodyDec SOACS
_ Stms SOACS
stms [SubExp]
res) =
BodyDec SOACS -> Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body () (Stms SOACS -> [SubExp] -> BodyT SOACS)
-> FusionGM (Stms SOACS) -> FusionGM ([SubExp] -> BodyT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Stms SOACS -> FusionGM (Stms SOACS)
fuseInStms Stms SOACS
stms FusionGM ([SubExp] -> BodyT SOACS)
-> FusionGM [SubExp] -> FusionGM (BodyT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> FusionGM [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
res
fuseInExp :: Exp -> FusionGM Exp
fuseInExp :: Exp SOACS -> FusionGM (Exp SOACS)
fuseInExp (DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form BodyT SOACS
loopbody) =
[(Ident, Names)] -> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall a. [(Ident, Names)] -> FusionGM a -> FusionGM a
binding ([Ident] -> [Names] -> [(Ident, Names)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Ident]
form_idents ([Names] -> [(Ident, Names)]) -> [Names] -> [(Ident, Names)]
forall a b. (a -> b) -> a -> b
$ Names -> [Names]
forall a. a -> [a]
repeat Names
forall a. Monoid a => a
mempty) (FusionGM (Exp SOACS) -> FusionGM (Exp SOACS))
-> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
[Param DeclType] -> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst ([(Param DeclType, SubExp)] -> [Param DeclType])
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> a -> b
$ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
ctx [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. [a] -> [a] -> [a]
++ [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val) (FusionGM (Exp SOACS) -> FusionGM (Exp SOACS))
-> FusionGM (Exp SOACS) -> FusionGM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
[(FParam SOACS, SubExp)]
-> [(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> BodyT SOACS
-> Exp SOACS
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [(FParam SOACS, SubExp)]
ctx [(FParam SOACS, SubExp)]
val LoopForm SOACS
form (BodyT SOACS -> Exp SOACS)
-> FusionGM (BodyT SOACS) -> FusionGM (Exp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
loopbody
where
form_idents :: [Ident]
form_idents = case LoopForm SOACS
form of
WhileLoop {} -> []
ForLoop VName
i IntType
it SubExp
_ [(LParam SOACS, VName)]
loopvars ->
VName -> Type -> Ident
Ident VName
i (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
it) Ident -> [Ident] -> [Ident]
forall a. a -> [a] -> [a]
:
((Param Type, VName) -> Ident) -> [(Param Type, VName)] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent (Param Type -> Ident)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> Ident
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loopvars
fuseInExp Exp SOACS
e = Mapper SOACS SOACS FusionGM -> Exp SOACS -> FusionGM (Exp SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS SOACS FusionGM
fuseIn Exp SOACS
e
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn :: Mapper SOACS SOACS FusionGM
fuseIn =
Mapper SOACS SOACS FusionGM
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
{ mapOnBody :: Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
mapOnBody = (BodyT SOACS -> FusionGM (BodyT SOACS))
-> Scope SOACS -> BodyT SOACS -> FusionGM (BodyT SOACS)
forall a b. a -> b -> a
const BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody,
mapOnOp :: Op SOACS -> FusionGM (Op SOACS)
mapOnOp = SOACMapper SOACS SOACS FusionGM
-> SOAC SOACS -> FusionGM (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper Any Any FusionGM
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
mapOnSOACLambda = Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda}
}
fuseInLambda :: Lambda -> FusionGM Lambda
fuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda (Lambda [LParam SOACS]
params BodyT SOACS
body [Type]
rtp) = do
BodyT SOACS
body' <- [Param Type] -> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall t a. Typed t => [Param t] -> FusionGM a -> FusionGM a
bindingParams [Param Type]
[LParam SOACS]
params (FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS))
-> FusionGM (BodyT SOACS) -> FusionGM (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$ BodyT SOACS -> FusionGM (BodyT SOACS)
fuseInBody BodyT SOACS
body
Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda SOACS -> FusionGM (Lambda SOACS))
-> Lambda SOACS -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [LParam SOACS]
params BodyT SOACS
body' [Type]
rtp
replaceSOAC :: Pattern -> StmAux () -> Exp -> FusionGM (Stms SOACS)
replaceSOAC :: Pattern -> StmAux () -> Exp SOACS -> FusionGM (Stms SOACS)
replaceSOAC (Pattern [PatElem]
_ []) StmAux ()
_ Exp SOACS
_ = Stms SOACS -> FusionGM (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return Stms SOACS
forall a. Monoid a => a
mempty
replaceSOAC pat :: Pattern
pat@(Pattern [PatElem]
_ (PatElem
patElem : [PatElem]
_)) StmAux ()
aux Exp SOACS
e = do
FusedRes
fres <- (FusionGEnv -> FusedRes) -> FusionGM FusedRes
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks FusionGEnv -> FusedRes
fusedRes
let pat_nm :: VName
pat_nm = PatElemT Type -> VName
forall dec. PatElemT dec -> VName
patElemName PatElemT Type
PatElem
patElem
names :: [Ident]
names = PatternT Type -> [Ident]
forall dec. Typed dec => PatternT dec -> [Ident]
patternIdents PatternT Type
Pattern
pat
case VName -> Map VName KernName -> Maybe KernName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
pat_nm (FusedRes -> Map VName KernName
outArr FusedRes
fres) of
Maybe KernName
Nothing ->
Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm -> Stms SOACS)
-> (Exp SOACS -> Stm) -> Exp SOACS -> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stms SOACS)
-> FusionGM (Exp SOACS) -> FusionGM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp SOACS -> FusionGM (Exp SOACS)
fuseInExp Exp SOACS
e
Just KernName
knm ->
case KernName -> Map KernName FusedKer -> Maybe FusedKer
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup KernName
knm (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres) of
Maybe FusedKer
Nothing ->
Error -> FusionGM (Stms SOACS)
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM (Stms SOACS)) -> Error -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, replaceSOAC, outArr in ker_name "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"which is not in Res: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ VName -> String
forall a. Pretty a => a -> String
pretty (KernName -> VName
unKernName KernName
knm)
)
Just FusedKer
ker -> do
Bool -> FusionGM () -> FusionGM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
fusedVars FusedKer
ker) (FusionGM () -> FusionGM ()) -> FusionGM () -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
Error -> FusionGM ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM ()) -> Error -> FusionGM ()
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
( String
"In Fusion.hs, replaceSOAC, unfused kernel "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"still in result: "
String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Ident] -> String
forall a. Pretty a => a -> String
pretty [Ident]
names
)
StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux (FusedKer -> [VName]
outNames FusedKer
ker) FusedKer
ker
insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC :: StmAux () -> [VName] -> FusedKer -> FusionGM (Stms SOACS)
insertKerSOAC StmAux ()
aux [VName]
names FusedKer
ker = do
SOAC
new_soac' <- SOAC -> FusionGM SOAC
finaliseSOAC (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker
BinderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore a -> m (Stms lore)
runBinder_ (BinderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS))
-> BinderT SOACS (State VNameSource) () -> FusionGM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ do
SOAC SOACS
f_soac <- SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(SOAC (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
SOAC (Lore m) -> m (SOAC (Lore m))
SOAC.toSOAC SOAC (Lore (BinderT SOACS (State VNameSource)))
SOAC
new_soac'
SOAC SOACS
f_soac' <- Names
-> SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed (FusedKer -> Names
fusedConsumed FusedKer
ker) (SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable -> SOAC SOACS -> OpWithAliases (SOAC SOACS)
forall op. CanBeAliased op => AliasTable -> op -> OpWithAliases op
addOpAliases AliasTable
forall a. Monoid a => a
mempty SOAC SOACS
f_soac
[Ident]
validents <- (String -> Type -> BinderT SOACS (State VNameSource) Ident)
-> [String] -> [Type] -> BinderT SOACS (State VNameSource) [Ident]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM String -> Type -> BinderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent ((VName -> String) -> [VName] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map VName -> String
baseString [VName]
names) ([Type] -> BinderT SOACS (State VNameSource) [Ident])
-> [Type] -> BinderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall lore. SOAC lore -> [Type]
SOAC.typeOf SOAC
new_soac'
StmAux ()
-> BinderT SOACS (State VNameSource) ()
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing (FusedKer -> StmAux ()
kerAux FusedKer
ker StmAux () -> StmAux () -> StmAux ()
forall a. Semigroup a => a -> a -> a
<> StmAux ()
aux) (BinderT SOACS (State VNameSource) ()
-> BinderT SOACS (State VNameSource) ())
-> BinderT SOACS (State VNameSource) ()
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS (State VNameSource)))
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind ([Ident] -> [Ident] -> PatternT Type
basicPattern [] [Ident]
validents) (Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) ()
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op Op SOACS
SOAC SOACS
f_soac'
ArrayTransforms
-> [VName] -> [Ident] -> BinderT SOACS (State VNameSource) ()
transformOutput (FusedKer -> ArrayTransforms
outputTransform FusedKer
ker) [VName]
names [Ident]
validents
finaliseSOAC :: SOAC.SOAC SOACS -> FusionGM (SOAC.SOAC SOACS)
finaliseSOAC :: SOAC -> FusionGM SOAC
finaliseSOAC SOAC
new_soac =
case SOAC
new_soac of
SOAC.Screma SubExp
w (ScremaForm [Scan SOACS]
scans [Reduce SOACS]
reds Lambda SOACS
map_lam) [Input]
arrs -> do
[Scan SOACS]
scans' <- [Scan SOACS]
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Scan SOACS]
scans ((Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS])
-> (Scan SOACS -> FusionGM (Scan SOACS)) -> FusionGM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ \(Scan Lambda SOACS
scan_lam [SubExp]
scan_nes) -> do
Lambda SOACS
scan_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
scan_lam
Scan SOACS -> FusionGM (Scan SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Scan SOACS -> FusionGM (Scan SOACS))
-> Scan SOACS -> FusionGM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Scan Lambda SOACS
scan_lam' [SubExp]
scan_nes
[Reduce SOACS]
reds' <- [Reduce SOACS]
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Reduce SOACS]
reds ((Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS])
-> (Reduce SOACS -> FusionGM (Reduce SOACS))
-> FusionGM [Reduce SOACS]
forall a b. (a -> b) -> a -> b
$ \(Reduce Commutativity
comm Lambda SOACS
red_lam [SubExp]
red_nes) -> do
Lambda SOACS
red_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
red_lam
Reduce SOACS -> FusionGM (Reduce SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Reduce SOACS -> FusionGM (Reduce SOACS))
-> Reduce SOACS -> FusionGM (Reduce SOACS)
forall a b. (a -> b) -> a -> b
$ Commutativity -> Lambda SOACS -> [SubExp] -> Reduce SOACS
forall lore.
Commutativity -> Lambda lore -> [SubExp] -> Reduce lore
Reduce Commutativity
comm Lambda SOACS
red_lam' [SubExp]
red_nes
Lambda SOACS
map_lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
map_lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam') [Input]
arrs
SOAC.Scatter SubExp
w Lambda SOACS
lam [Input]
inps [(ShapeBase SubExp, Int, VName)]
dests -> do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> Lambda SOACS
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC
forall lore.
SubExp
-> Lambda lore
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
lam' [Input]
inps [(ShapeBase SubExp, Int, VName)]
dests
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam [Input]
arrs -> do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
lam' [Input]
arrs
SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam [SubExp]
nes [Input]
inps -> do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam
SOAC -> FusionGM SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> FusionGM SOAC) -> SOAC -> FusionGM SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w StreamForm SOACS
form Lambda SOACS
lam' [SubExp]
nes [Input]
inps
simplifyAndFuseInLambda :: Lambda -> FusionGM Lambda
simplifyAndFuseInLambda :: Lambda SOACS -> FusionGM (Lambda SOACS)
simplifyAndFuseInLambda Lambda SOACS
lam = do
Lambda SOACS
lam' <- Lambda SOACS -> FusionGM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
lam
(Names
_, FusedRes
nfres) <- (Names, FusedRes) -> Lambda SOACS -> FusionGM (Names, FusedRes)
fusionGatherLam (Names
forall a. Monoid a => a
mempty, FusedRes
mkFreshFusionRes) Lambda SOACS
lam'
let nfres' :: FusedRes
nfres' = FusedRes -> FusedRes
cleanFusionResult FusedRes
nfres
FusedRes -> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a. FusedRes -> FusionGM a -> FusionGM a
bindRes FusedRes
nfres' (FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS))
-> FusionGM (Lambda SOACS) -> FusionGM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> FusionGM (Lambda SOACS)
fuseInLambda Lambda SOACS
lam'
copyNewlyConsumed ::
Names ->
Futhark.SOAC (Aliases.Aliases SOACS) ->
Binder SOACS (Futhark.SOAC SOACS)
copyNewlyConsumed :: Names
-> SOAC (Aliases SOACS)
-> BinderT SOACS (State VNameSource) (SOAC SOACS)
copyNewlyConsumed Names
was_consumed SOAC (Aliases SOACS)
soac =
case SOAC (Aliases SOACS)
soac of
Futhark.Screma SubExp
w [VName]
arrs (Futhark.ScremaForm [Scan (Aliases SOACS)]
scans [Reduce (Aliases SOACS)]
reds Lambda (Aliases SOACS)
map_lam) -> do
[VName]
arrs' <- (VName -> BinderT SOACS (State VNameSource) VName)
-> [VName] -> BinderT SOACS (State VNameSource) [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> BinderT SOACS (State VNameSource) VName
copyConsumedArr [VName]
arrs
Lambda SOACS
map_lam' <- Lambda (Aliases (Lore (BinderT SOACS (State VNameSource))))
-> BinderT
SOACS
(State VNameSource)
(Lambda (Lore (BinderT SOACS (State VNameSource))))
forall {m :: * -> *}.
(CanBeAliased (Op (Lore m)), MonadBinder m, Bindable (Lore m)) =>
Lambda (Aliases (Lore m)) -> m (Lambda (Lore m))
copyFreeInLambda Lambda (Aliases (Lore (BinderT SOACS (State VNameSource))))
Lambda (Aliases SOACS)
map_lam
let scans' :: [Scan SOACS]
scans' =
(Scan (Aliases SOACS) -> Scan SOACS)
-> [Scan (Aliases SOACS)] -> [Scan SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
( \Scan (Aliases SOACS)
scan ->
Scan (Aliases SOACS)
scan
{ scanLambda :: Lambda SOACS
scanLambda =
Lambda (Aliases SOACS) -> Lambda SOACS
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
Aliases.removeLambdaAliases
(Scan (Aliases SOACS) -> Lambda (Aliases SOACS)
forall lore. Scan lore -> Lambda lore
scanLambda Scan (Aliases SOACS)
scan)
}
)
[Scan (Aliases SOACS)]
scans
let reds' :: [Reduce SOACS]
reds' =
(Reduce (Aliases SOACS) -> Reduce SOACS)
-> [Reduce (Aliases SOACS)] -> [Reduce SOACS]
forall a b. (a -> b) -> [a] -> [b]
map
( \Reduce (Aliases SOACS)
red ->
Reduce (Aliases SOACS)
red
{ redLambda :: Lambda SOACS
redLambda =
Lambda (Aliases SOACS) -> Lambda SOACS
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
Aliases.removeLambdaAliases
(Reduce (Aliases SOACS) -> Lambda (Aliases SOACS)
forall lore. Reduce lore -> Lambda lore
redLambda Reduce (Aliases SOACS)
red)
}
)
[Reduce (Aliases SOACS)]
reds
SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w [VName]
arrs' (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
Futhark.ScremaForm [Scan SOACS]
scans' [Reduce SOACS]
reds' Lambda SOACS
map_lam'
SOAC (Aliases SOACS)
_ -> SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS))
-> SOAC SOACS -> BinderT SOACS (State VNameSource) (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ OpWithAliases (SOAC SOACS) -> SOAC SOACS
forall op. CanBeAliased op => OpWithAliases op -> op
removeOpAliases OpWithAliases (SOAC SOACS)
SOAC (Aliases SOACS)
soac
where
consumed :: Names
consumed = SOAC (Aliases SOACS) -> Names
forall op. AliasedOp op => op -> Names
consumedInOp SOAC (Aliases SOACS)
soac
newly_consumed :: Names
newly_consumed = Names
consumed Names -> Names -> Names
`namesSubtract` Names
was_consumed
copyConsumedArr :: VName -> BinderT SOACS (State VNameSource) VName
copyConsumedArr VName
a
| VName
a VName -> Names -> Bool
`nameIn` Names
newly_consumed =
String
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp (VName -> String
baseString VName
a String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName)
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> BinderT SOACS (State VNameSource) VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
a
| Bool
otherwise = VName -> BinderT SOACS (State VNameSource) VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
a
copyFreeInLambda :: Lambda (Aliases (Lore m)) -> m (Lambda (Lore m))
copyFreeInLambda Lambda (Aliases (Lore m))
lam = do
let free_consumed :: Names
free_consumed =
Lambda (Aliases (Lore m)) -> Names
forall lore. Aliased lore => Lambda lore -> Names
consumedByLambda Lambda (Aliases (Lore m))
lam
Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList ((Param (LParamInfo (Lore m)) -> VName)
-> [Param (LParamInfo (Lore m))] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo (Lore m)) -> VName
forall dec. Param dec -> VName
paramName ([Param (LParamInfo (Lore m))] -> [VName])
-> [Param (LParamInfo (Lore m))] -> [VName]
forall a b. (a -> b) -> a -> b
$ Lambda (Aliases (Lore m)) -> [LParam (Aliases (Lore m))]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda (Aliases (Lore m))
lam)
(Seq (Stm (Lore m))
bnds, Map VName VName
subst) <-
((Seq (Stm (Lore m)), Map VName VName)
-> VName -> m (Seq (Stm (Lore m)), Map VName VName))
-> (Seq (Stm (Lore m)), Map VName VName)
-> [VName]
-> m (Seq (Stm (Lore m)), Map VName VName)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Seq (Stm (Lore m)), Map VName VName)
-> VName -> m (Seq (Stm (Lore m)), Map VName VName)
forall {m :: * -> *}.
MonadBinder m =>
(Stms (Lore m), Map VName VName)
-> VName -> m (Stms (Lore m), Map VName VName)
copyFree (Seq (Stm (Lore m))
forall a. Monoid a => a
mempty, Map VName VName
forall a. Monoid a => a
mempty) ([VName] -> m (Seq (Stm (Lore m)), Map VName VName))
-> [VName] -> m (Seq (Stm (Lore m)), Map VName VName)
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList Names
free_consumed
let lam' :: Lambda (Lore m)
lam' = Lambda (Aliases (Lore m)) -> Lambda (Lore m)
forall lore.
CanBeAliased (Op lore) =>
Lambda (Aliases lore) -> Lambda lore
Aliases.removeLambdaAliases Lambda (Aliases (Lore m))
lam
Lambda (Lore m) -> m (Lambda (Lore m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Lambda (Lore m) -> m (Lambda (Lore m)))
-> Lambda (Lore m) -> m (Lambda (Lore m))
forall a b. (a -> b) -> a -> b
$
if Seq (Stm (Lore m)) -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null Seq (Stm (Lore m))
bnds
then Lambda (Lore m)
lam'
else
Lambda (Lore m)
lam'
{ lambdaBody :: BodyT (Lore m)
lambdaBody =
Seq (Stm (Lore m)) -> BodyT (Lore m) -> BodyT (Lore m)
forall lore. Bindable lore => Stms lore -> Body lore -> Body lore
insertStms Seq (Stm (Lore m))
bnds (BodyT (Lore m) -> BodyT (Lore m))
-> BodyT (Lore m) -> BodyT (Lore m)
forall a b. (a -> b) -> a -> b
$
Map VName VName -> BodyT (Lore m) -> BodyT (Lore m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
subst (BodyT (Lore m) -> BodyT (Lore m))
-> BodyT (Lore m) -> BodyT (Lore m)
forall a b. (a -> b) -> a -> b
$ Lambda (Lore m) -> BodyT (Lore m)
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda (Lore m)
lam'
}
copyFree :: (Stms (Lore m), Map VName VName)
-> VName -> m (Stms (Lore m), Map VName VName)
copyFree (Stms (Lore m)
bnds, Map VName VName
subst) VName
v = do
VName
v_copy <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
"_copy"
Stm (Lore m)
copy <- [VName] -> Exp (Lore m) -> m (Stm (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m (Stm (Lore m))
mkLetNamesM [VName
v_copy] (Exp (Lore m) -> m (Stm (Lore m)))
-> Exp (Lore m) -> m (Stm (Lore m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v
(Stms (Lore m), Map VName VName)
-> m (Stms (Lore m), Map VName VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm (Lore m) -> Stms (Lore m)
forall lore. Stm lore -> Stms lore
oneStm Stm (Lore m)
copy Stms (Lore m) -> Stms (Lore m) -> Stms (Lore m)
forall a. Semigroup a => a -> a -> a
<> Stms (Lore m)
bnds, VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v_copy Map VName VName
subst)
mkFreshFusionRes :: FusedRes
mkFreshFusionRes :: FusedRes
mkFreshFusionRes =
FusedRes :: Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
{ rsucc :: Bool
rsucc = Bool
False,
outArr :: Map VName KernName
outArr = Map VName KernName
forall k a. Map k a
M.empty,
inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
forall k a. Map k a
M.empty,
infusible :: Names
infusible = Names
forall a. Monoid a => a
mempty,
kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
forall k a. Map k a
M.empty
}
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes :: FusedRes -> FusedRes -> FusionGM FusedRes
mergeFusionRes FusedRes
res1 FusedRes
res2 = do
let ufus_mres :: Names
ufus_mres = FusedRes -> Names
infusible FusedRes
res1 Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> FusedRes -> Names
infusible FusedRes
res2
[VName]
inp_both <- [VName] -> FusionGM [VName]
expandSoacInpArr ([VName] -> FusionGM [VName]) -> [VName] -> FusionGM [VName]
forall a b. (a -> b) -> a -> b
$ Map VName (Set KernName) -> [VName]
forall k a. Map k a -> [k]
M.keys (Map VName (Set KernName) -> [VName])
-> Map VName (Set KernName) -> [VName]
forall a b. (a -> b) -> a -> b
$ FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1 Map VName (Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall k a b. Ord k => Map k a -> Map k b -> Map k a
`M.intersection` FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2
let m_unfus :: Names
m_unfus = Names
ufus_mres Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [Names] -> Names
forall a. Monoid a => [a] -> a
mconcat ((VName -> Names) -> [VName] -> [Names]
forall a b. (a -> b) -> [a] -> [b]
map VName -> Names
oneName [VName]
inp_both)
FusedRes -> FusionGM FusedRes
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedRes -> FusionGM FusedRes) -> FusedRes -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
Bool
-> Map VName KernName
-> Map VName (Set KernName)
-> Names
-> Map KernName FusedKer
-> FusedRes
FusedRes
(FusedRes -> Bool
rsucc FusedRes
res1 Bool -> Bool -> Bool
|| FusedRes -> Bool
rsucc FusedRes
res2)
(FusedRes -> Map VName KernName
outArr FusedRes
res1 Map VName KernName -> Map VName KernName -> Map VName KernName
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map VName KernName
outArr FusedRes
res2)
((Set KernName -> Set KernName -> Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
-> Map VName (Set KernName)
forall k a. Ord k => (a -> a -> a) -> Map k a -> Map k a -> Map k a
M.unionWith Set KernName -> Set KernName -> Set KernName
forall a. Ord a => Set a -> Set a -> Set a
S.union (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res1) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
res2))
Names
m_unfus
(FusedRes -> Map KernName FusedKer
kernels FusedRes
res1 Map KernName FusedKer
-> Map KernName FusedKer -> Map KernName FusedKer
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` FusedRes -> Map KernName FusedKer
kernels FusedRes
res2)
getIdentArr :: [SOAC.Input] -> ([VName], [VName])
getIdentArr :: [Input] -> ([VName], [VName])
getIdentArr = (([VName], [VName]) -> Input -> ([VName], [VName]))
-> ([VName], [VName]) -> [Input] -> ([VName], [VName])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([], [])
where
comb :: ([VName], [VName]) -> Input -> ([VName], [VName])
comb ([VName]
vs, [VName]
os) (SOAC.Input ArrayTransforms
ts VName
idd Type
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = (VName
idd VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs, [VName]
os)
comb ([VName]
vs, [VName]
os) Input
inp =
([VName]
vs, Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
os)
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult :: FusedRes -> FusedRes
cleanFusionResult FusedRes
fres =
let newks :: Map KernName FusedKer
newks = (FusedKer -> Bool)
-> Map KernName FusedKer -> Map KernName FusedKer
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (Bool -> Bool
not (Bool -> Bool) -> (FusedKer -> Bool) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [VName] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([VName] -> Bool) -> (FusedKer -> [VName]) -> FusedKer -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [VName]
fusedVars) (FusedRes -> Map KernName FusedKer
kernels FusedRes
fres)
newoa :: Map VName KernName
newoa = (KernName -> Bool) -> Map VName KernName -> Map VName KernName
forall a k. (a -> Bool) -> Map k a -> Map k a
M.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks) (FusedRes -> Map VName KernName
outArr FusedRes
fres)
newia :: Map VName (Set KernName)
newia = (Set KernName -> Set KernName)
-> Map VName (Set KernName) -> Map VName (Set KernName)
forall a b k. (a -> b) -> Map k a -> Map k b
M.map ((KernName -> Bool) -> Set KernName -> Set KernName
forall a. (a -> Bool) -> Set a -> Set a
S.filter (KernName -> Map KernName FusedKer -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map KernName FusedKer
newks)) (FusedRes -> Map VName (Set KernName)
inpArr FusedRes
fres)
in FusedRes
fres {outArr :: Map VName KernName
outArr = Map VName KernName
newoa, inpArr :: Map VName (Set KernName)
inpArr = Map VName (Set KernName)
newia, kernels :: Map KernName FusedKer
kernels = Map KernName FusedKer
newks}
errorIllegal :: String -> FusionGM FusedRes
errorIllegal :: String -> FusionGM FusedRes
errorIllegal String
soac_name =
Error -> FusionGM FusedRes
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (Error -> FusionGM FusedRes) -> Error -> FusionGM FusedRes
forall a b. (a -> b) -> a -> b
$
String -> Error
Error
(String
"In Fusion.hs, soac " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
soac_name String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" appears illegally in pgm!")