{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion.LoopKernel
( FusedKer (..),
newKernel,
inputs,
setInputs,
arrInputs,
transformOutput,
attemptFusion,
SOAC,
MapNest,
)
where
import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.HORep.MapNest as MapNest
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)
newtype TryFusion a
= TryFusion
( ReaderT
(Scope SOACS)
(StateT VNameSource Maybe)
a
)
deriving
( (forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
Functor TryFusion
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: forall a. a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
Applicative,
Applicative TryFusion
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: forall a. TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: forall a. TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: forall a. TryFusion a
$cempty :: forall a. TryFusion a
Alternative,
Applicative TryFusion
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
Monad,
Monad TryFusion
Monad TryFusion
-> (forall a. String -> TryFusion a) -> MonadFail TryFusion
forall a. String -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: forall a. String -> TryFusion a
$cfail :: forall a. String -> TryFusion a
MonadFail,
Monad TryFusion
Applicative TryFusion
TryFusion VNameSource
Applicative TryFusion
-> Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
tryFusion ::
MonadFreshNames m =>
TryFusion a ->
Scope SOACS ->
m (Maybe a)
tryFusion :: forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)
liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: forall a. Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = String -> TryFusion a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
type SOAC = SOAC.SOAC SOACS
type MapNest = MapNest.MapNest SOACS
transformOutput ::
SOAC.ArrayTransforms ->
[VName] ->
[Ident] ->
Binder SOACS ()
transformOutput :: ArrayTransforms -> [VName] -> [Ident] -> Binder SOACS ()
transformOutput ArrayTransforms
ts [VName]
names = ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts
where
descend :: ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts' [Ident]
validents =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts' of
ViewF
SOAC.EmptyF ->
[(VName, Ident)]
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Ident]
validents) (((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ())
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
k, Ident
valident) ->
[VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
k] (Exp (Lore (BinderT SOACS (State VNameSource))) -> Binder SOACS ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
valident
ArrayTransform
t SOAC.:< ArrayTransforms
ts'' -> do
let ([BasicOp]
es, [Certificates]
css) = [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BasicOp, Certificates)] -> ([BasicOp], [Certificates]))
-> [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. (a -> b) -> a -> b
$ (Ident -> (BasicOp, Certificates))
-> [Ident] -> [(BasicOp, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform ArrayTransform
t) [Ident]
validents
mkPat :: Ident -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
mkPat (Ident VName
nm TypeBase (ShapeBase SubExp) NoUniqueness
tp) = [PatElemT (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElemT (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElemT (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
nm TypeBase (ShapeBase SubExp) NoUniqueness
tp]
[TypeBase (ShapeBase SubExp) NoUniqueness]
opts <- [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> BinderT
SOACS
(State VNameSource)
[[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> BinderT
SOACS
(State VNameSource)
[TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BasicOp
-> BinderT
SOACS
(State VNameSource)
[TypeBase (ShapeBase SubExp) NoUniqueness])
-> [BasicOp]
-> BinderT
SOACS
(State VNameSource)
[[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BasicOp
-> BinderT
SOACS
(State VNameSource)
[TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore (m :: * -> *).
HasScope lore m =>
BasicOp -> m [TypeBase (ShapeBase SubExp) NoUniqueness]
primOpType [BasicOp]
es
[Ident]
newIds <- [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [TypeBase (ShapeBase SubExp) NoUniqueness]
opts) (((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident])
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ \(VName
k, TypeBase (ShapeBase SubExp) NoUniqueness
opt) ->
String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent (VName -> String
baseString VName
k) TypeBase (ShapeBase SubExp) NoUniqueness
opt
[(Certificates, Ident, BasicOp)]
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Certificates]
-> [Ident] -> [BasicOp] -> [(Certificates, Ident, BasicOp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Certificates]
css [Ident]
newIds [BasicOp]
es) (((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ())
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(Certificates
cs, Ident
ids, BasicOp
e) ->
Certificates -> Binder SOACS () -> Binder SOACS ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder SOACS () -> Binder SOACS ())
-> Binder SOACS () -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS (State VNameSource)))
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind (Ident -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
mkPat Ident
ids) (BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
e)
ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts'' [Ident]
newIds
applyTransform :: SOAC.ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform :: ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform (SOAC.Rearrange Certificates
cs [Int]
perm) Ident
v =
([Int] -> VName -> BasicOp
Rearrange [Int]
perm' (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
where
perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm) [Int
0 .. TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
applyTransform (SOAC.Reshape Certificates
cs ShapeChange SubExp
shape) Ident
v =
(ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) Ident
v =
let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> ShapeBase SubExp -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
shape Int
1 (ShapeBase SubExp -> ShapeChange SubExp)
-> ShapeBase SubExp -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v
in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeInner Certificates
cs ShapeChange SubExp
shape) Ident
v =
let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> ShapeBase SubExp -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
shape Int
1 (ShapeBase SubExp -> ShapeChange SubExp)
-> ShapeBase SubExp -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v
in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.Replicate Certificates
cs ShapeBase SubExp
n) Ident
v =
(ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
n (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia TypeBase (ShapeBase SubExp) NoUniqueness
iat) =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms
-> VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts' VName
ia TypeBase (ShapeBase SubExp) NoUniqueness
iat)
ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing
data FusedKer = FusedKer
{
FusedKer -> SOAC SOACS
fsoac :: SOAC,
FusedKer -> Names
inplace :: Names,
FusedKer -> [VName]
fusedVars :: [VName],
FusedKer -> Names
fusedConsumed :: Names,
FusedKer -> Scope SOACS
kernelScope :: Scope SOACS,
FusedKer -> ArrayTransforms
outputTransform :: SOAC.ArrayTransforms,
FusedKer -> [VName]
outNames :: [VName],
FusedKer -> StmAux ()
kerAux :: StmAux ()
}
deriving (Int -> FusedKer -> ShowS
[FusedKer] -> ShowS
FusedKer -> String
(Int -> FusedKer -> ShowS)
-> (FusedKer -> String) -> ([FusedKer] -> ShowS) -> Show FusedKer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FusedKer] -> ShowS
$cshowList :: [FusedKer] -> ShowS
show :: FusedKer -> String
$cshow :: FusedKer -> String
showsPrec :: Int -> FusedKer -> ShowS
$cshowsPrec :: Int -> FusedKer -> ShowS
Show)
newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel :: StmAux ()
-> SOAC SOACS -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC SOACS
soac Names
consumed [VName]
out_nms Scope SOACS
scope =
FusedKer :: SOAC SOACS
-> Names
-> [VName]
-> Names
-> Scope SOACS
-> ArrayTransforms
-> [VName]
-> StmAux ()
-> FusedKer
FusedKer
{ fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac,
inplace :: Names
inplace = Names
consumed,
fusedVars :: [VName]
fusedVars = [],
fusedConsumed :: Names
fusedConsumed = Names
consumed,
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms,
outNames :: [VName]
outNames = [VName]
out_nms,
kernelScope :: Scope SOACS
kernelScope = Scope SOACS
scope,
kerAux :: StmAux ()
kerAux = StmAux ()
aux
}
arrInputs :: FusedKer -> S.Set VName
arrInputs :: FusedKer -> Set VName
arrInputs = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (FusedKer -> [VName]) -> FusedKer -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName])
-> (FusedKer -> [Input]) -> FusedKer -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [Input]
inputs
inputs :: FusedKer -> [SOAC.Input]
inputs :: FusedKer -> [Input]
inputs = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs (SOAC SOACS -> [Input])
-> (FusedKer -> SOAC SOACS) -> FusedKer -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> SOAC SOACS
fsoac
setInputs :: [SOAC.Input] -> FusedKer -> FusedKer
setInputs :: [Input] -> FusedKer -> FusedKer
setInputs [Input]
inps FusedKer
ker = FusedKer
ker {fsoac :: SOAC SOACS
fsoac = [Input]
inps [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC SOACS
fsoac FusedKer
ker}
tryOptimizeSOAC ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryOptimizeSOAC :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker = do
(SOAC SOACS
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC SOACS
soac ArrayTransforms
forall a. Monoid a => a
mempty
let ker' :: FusedKer
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedKer -> [Input]
inputs FusedKer
ker) [Input] -> FusedKer -> FusedKer
`setInputs` FusedKer
ker
outIdents :: [Ident]
outIdents = (VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident)
-> [VName] -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac' Names
consumed FusedKer
ker''
where
addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
| Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
| Bool
otherwise =
Input
inp
tryOptimizeKernel ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryOptimizeKernel :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker = do
FusedKer
ker' <- Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedKer
ker
Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker'
tryExposeInputs ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryExposeInputs :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker = do
(FusedKer
ker', ArrayTransforms
ots) <- [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
outVars FusedKer
ker
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
then Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker'
else do
(SOAC SOACS
soac', ArrayTransforms
ots') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac ArrayTransforms
ots
let outIdents :: [Ident]
outIdents = (VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident)
-> [VName] -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac' Names
consumed FusedKer
ker''
else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"tryExposeInputs could not pull SOAC transforms"
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker =
FusedKer
ker {fsoac :: SOAC SOACS
fsoac = SOAC SOACS -> SOAC SOACS
fixInputTypes' (SOAC SOACS -> SOAC SOACS) -> SOAC SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC SOACS
fsoac FusedKer
ker}
where
fixInputTypes' :: SOAC SOACS -> SOAC SOACS
fixInputTypes' SOAC SOACS
soac =
(Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac) [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC SOACS
soac
fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v TypeBase (ShapeBase SubExp) NoUniqueness
_)
| Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
ArrayTransforms
-> VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts VName
v (TypeBase (ShapeBase SubExp) NoUniqueness -> Input)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v'
fixInputType Input
inp = Input
inp
applyFusionRules ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
applyFusionRules :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker =
Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
attemptFusion ::
MonadFreshNames m =>
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
m (Maybe FusedKer)
attemptFusion :: forall (m :: * -> *).
MonadFreshNames m =>
Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker =
(FusedKer -> FusedKer) -> Maybe FusedKer -> Maybe FusedKer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FusedKer -> FusedKer
removeUnusedParamsFromKer
(Maybe FusedKer -> Maybe FusedKer)
-> m (Maybe FusedKer) -> m (Maybe FusedKer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TryFusion FusedKer -> Scope SOACS -> m (Maybe FusedKer)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion
(Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker)
(FusedKer -> Scope SOACS
kernelScope FusedKer
ker)
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer FusedKer
ker =
case SOAC SOACS
soac of
SOAC.Screma {} -> FusedKer
ker {fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac'}
SOAC SOACS
_ -> FusedKer
ker
where
soac :: SOAC SOACS
soac = FusedKer -> SOAC SOACS
fsoac FusedKer
ker
l :: Lambda SOACS
l = SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
soac
inps :: [Input]
inps = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac
(Lambda SOACS
l', [Input]
inps') = Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps
soac' :: SOAC SOACS
soac' =
Lambda SOACS
l'
Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` ([Input]
inps' [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC SOACS
soac)
removeUnusedParams :: Lambda -> [SOAC.Input] -> (Lambda, [SOAC.Input])
removeUnusedParams :: Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps =
(Lambda SOACS
l {lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
[LParam SOACS]
ps'}, [Input]
inps')
where
pInps :: [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
pInps = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Input]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
l) [Input]
inps
([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps', [Input]
inps') = case ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)], [Input]))
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)], [Input])
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input) -> Bool)
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool
used (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> (Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
pInps, [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
pInps) of
(([], []), (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p, Input
inp) : [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
_) -> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p], [Input
inp])
(([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps_, [Input]
inps_), [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
_) -> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps_, [Input]
inps_)
used :: Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool
used Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p = Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p VName -> Names -> Bool
`nameIn` Names
freeVars
freeVars :: Names
freeVars = BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT SOACS -> Names) -> BodyT SOACS -> Names
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
l
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)
fuseSOACwithKer ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
fuseSOACwithKer :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC SOACS
soac_p Names
soac_p_consumed FusedKer
ker = do
let soac_c :: SOAC SOACS
soac_c = FusedKer -> SOAC SOACS
fsoac FusedKer
ker
inp_p_arr :: [Input]
inp_p_arr = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac_p
horizFuse :: Bool
horizFuse =
Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty
Bool -> Bool -> Bool
&& SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_c
inp_c_arr :: [Input]
inp_c_arr = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac_c
lam_p :: Lambda SOACS
lam_p = SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
soac_p
lam_c :: Lambda SOACS
lam_c = SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
soac_c
w :: SubExp
w = SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_p
returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
success :: [VName] -> SOAC SOACS -> TryFusion FusedKer
success [VName]
res_outnms SOAC SOACS
res_soac = do
let fusedVars_new :: [VName]
fusedVars_new = FusedKer -> [VName]
fusedVars FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars
Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
res_soac
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker
{ fsoac :: SOAC SOACS
fsoac = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` SOAC SOACS
res_soac,
fusedVars :: [VName]
fusedVars = [VName]
fusedVars_new,
inplace :: Names
inplace = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
fusedConsumed :: Names
fusedConsumed = FusedKer -> Names
fusedConsumed FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
outNames :: [VName]
outNames = [VName]
res_outnms
}
[(VName, Ident)]
outPairs <- [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p) (((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)])
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, TypeBase (ShapeBase SubExp) NoUniqueness
t) -> do
VName
outVar' <- String -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> TryFusion VName) -> String -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
outVar String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
(VName, Ident) -> TryFusion (VName, Ident)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
outVar, VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident VName
outVar' TypeBase (ShapeBase SubExp) NoUniqueness
t)
let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
([VName]
extra_nms, [TypeBase (ShapeBase SubExp) NoUniqueness]
extra_rtps) =
[(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([VName], [TypeBase (ShapeBase SubExp) NoUniqueness])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([VName], [TypeBase (ShapeBase SubExp) NoUniqueness]))
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([VName], [TypeBase (ShapeBase SubExp) NoUniqueness])
forall a b. (a -> b) -> a -> b
$
((VName, TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool)
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> (VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall a b. (a, b) -> a
fst) ([(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
[VName]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1) ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p
res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
res_lam [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase (ShapeBase SubExp) NoUniqueness]
extra_rtps}
in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
horizFuse Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker)) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Horizontal fusion is invalid in the presence of output transforms."
case (SOAC SOACS
soac_c, SOAC SOACS
soac_p) of
(SOAC SOACS, SOAC SOACS)
_ | SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_c -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC widths must match."
( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_
)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall lore. [Scan lore] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall lore. [Reduce lore] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedKer
ker
Bool -> Bool -> Bool
|| Bool
horizFuse -> do
let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam_p
[SubExp]
scan_nes_p
[SubExp]
red_nes_p
[Input]
inp_p_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam_c
[SubExp]
scan_nes_c
[SubExp]
red_nes_c
[Input]
inp_c_arr
([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker
unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
[VName] -> SOAC SOACS -> TryFusion FusedKer
success
( [VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
)
(SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma
SubExp
w
([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
[Input]
new_inp
( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(ShapeBase SubExp, Int, VName)]
dests,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
[VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp
-> Lambda SOACS
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall lore.
SubExp
-> Lambda lore
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(ShapeBase SubExp, Int, VName)]
dests
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
[VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_
)
| Bool
horizFuse -> do
let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
(BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
body' :: BodyT SOACS
body' =
Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
{ bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
bodyResult :: [SubExp]
bodyResult =
Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
[SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
}
lam' :: Lambda SOACS
lam' =
Lambda :: forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType =
Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c)
[TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)
( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(ShapeBase SubExp, Int, VName)]
as_c,
SOAC.Scatter SubExp
_len_p Lambda SOACS
_lam_p [Input]
ivs_p [(ShapeBase SubExp, Int, VName)]
as_p
)
| Bool
horizFuse -> do
let zipW :: [(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, Int, array)] -> [a] -> [a]
zipW [(ShapeBase SubExp, Int, array)]
as_xs [a]
xs [(ShapeBase SubExp, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
where
([a]
xs_indices, [a]
xs_vals) = [(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
forall array a.
[(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(ShapeBase SubExp, Int, array)]
as_xs [a]
xs
([a]
ys_indices, [a]
ys_vals) = [(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
forall array a.
[(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(ShapeBase SubExp, Int, array)]
as_ys [a]
ys
let (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
let body' :: BodyT SOACS
body' =
Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
{ bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
bodyResult :: [SubExp]
bodyResult = [(ShapeBase SubExp, Int, VName)]
-> [SubExp]
-> [(ShapeBase SubExp, Int, VName)]
-> [SubExp]
-> [SubExp]
forall {array} {a} {array}.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, Int, array)] -> [a] -> [a]
zipW [(ShapeBase SubExp, Int, VName)]
as_c (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c) [(ShapeBase SubExp, Int, VName)]
as_p (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
}
let lam' :: Lambda SOACS
lam' =
Lambda :: forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = [(ShapeBase SubExp, Int, VName)]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(ShapeBase SubExp, Int, VName)]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall {array} {a} {array}.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, Int, array)] -> [a] -> [a]
zipW [(ShapeBase SubExp, Int, VName)]
as_c (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c) [(ShapeBase SubExp, Int, VName)]
as_p (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp
-> Lambda SOACS
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall lore.
SubExp
-> Lambda lore
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(ShapeBase SubExp, Int, VName)]
as_c [(ShapeBase SubExp, Int, VName)]
-> [(ShapeBase SubExp, Int, VName)]
-> [(ShapeBase SubExp, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(ShapeBase SubExp, Int, VName)]
as_p)
(SOAC.Scatter {}, SOAC SOACS
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
(SOAC SOACS
_, SOAC.Scatter {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
nes [Input]
_)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
[VName] -> SOAC SOACS -> TryFusion FusedKer
success [VName]
res_nms SOAC SOACS
res_stream
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two SEQ streams!"
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
(SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
(SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
[VName] -> SOAC SOACS -> TryFusion FusedKer
success [VName]
res_nms SOAC SOACS
res_stream
(SOAC.Stream {}, SOAC.Stream {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two PAR streams!"
(SOAC.Stream SubExp
_ StreamForm SOACS
form2 Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC SOACS
_) -> do
(SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
SOAC SOACS
soac_p'' <- case StreamForm SOACS
form2 of
Sequential {} -> SOAC SOACS -> TryFusion (SOAC SOACS)
toSeqStream SOAC SOACS
soac_p'
StreamForm SOACS
_ -> SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC SOACS
soac_p'
if SOAC SOACS
soac_p' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS
soac_p
then String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
else Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC SOACS
soac_p'' Names
soac_p_consumed FusedKer
ker
(SOAC SOACS
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) | Just [Scan SOACS]
_ <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form -> do
(SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
if SOAC SOACS
soac_p' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_p
then Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC SOACS
soac_p' Names
soac_p_consumed FusedKer
ker
else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
(SOAC SOACS
_, SOAC.Stream SubExp
_ StreamForm SOACS
form_p Lambda SOACS
_ [SubExp]
_ [Input]
_) -> do
(SOAC SOACS
soac_c', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC SOACS
soac_c
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SOAC SOACS
soac_c' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS
soac_c) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
SOAC SOACS
soac_c'' <- case StreamForm SOACS
form_p of
StreamForm SOACS
Sequential -> SOAC SOACS -> TryFusion (SOAC SOACS)
toSeqStream SOAC SOACS
soac_c'
StreamForm SOACS
_ -> SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC SOACS
soac_c'
Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC SOACS
soac_p Names
soac_p_consumed (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker {fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac_c'', outNames :: [VName]
outNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedKer -> [VName]
outNames FusedKer
ker}
(SOAC SOACS, SOAC SOACS)
_ -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse"
getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder :: forall lore. StreamForm lore -> StreamOrd
getStreamOrder (Parallel StreamOrd
o Commutativity
_ Lambda lore
_) = StreamOrd
o
getStreamOrder StreamForm lore
Sequential = StreamOrd
InOrder
fuseStreamHelper ::
[VName] ->
Names ->
[VName] ->
[(VName, Ident)] ->
SOAC ->
SOAC ->
TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper
[VName]
out_kernms
Names
unfus_set
[VName]
outVars
[(VName, Ident)]
outPairs
(SOAC.Stream SubExp
w2 StreamForm SOACS
form2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
(SOAC.Stream SubExp
_ StreamForm SOACS
form1 Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) =
if StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form2 StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form1
then String -> TryFusion ([VName], SOAC SOACS)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"fusion conditions not met!"
else do
let chunk1 :: Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk1 = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. [a] -> a
head ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1
chunk2 :: Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk2 = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. [a] -> a
head ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam2
hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk2, Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk1)]
lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a]
tail ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1}
lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a]
tail ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam20}
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam1'
[]
[SubExp]
nes1
[Input]
inp1_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam2'
[]
[SubExp]
nes2
[Input]
inp2_arr
res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk1 Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
res_lam'}
unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
StreamForm SOACS
res_form <- StreamForm SOACS
-> StreamForm SOACS -> TryFusion (StreamForm SOACS)
forall {m :: * -> *} {lore}.
MonadFail m =>
StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm SOACS
form2 StreamForm SOACS
form1
([VName], SOAC SOACS) -> TryFusion ([VName], SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return
( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
SubExp
-> StreamForm SOACS
-> Lambda SOACS
-> [SubExp]
-> [Input]
-> SOAC SOACS
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w2 StreamForm SOACS
res_form Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
)
where
mergeForms :: StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm lore
Sequential StreamForm lore
Sequential = StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return StreamForm lore
forall lore. StreamForm lore
Sequential
mergeForms (Parallel StreamOrd
_ Commutativity
comm2 Lambda lore
lam2r) (Parallel StreamOrd
o1 Commutativity
comm1 Lambda lore
lam1r) =
StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm lore -> m (StreamForm lore))
-> StreamForm lore -> m (StreamForm lore)
forall a b. (a -> b) -> a -> b
$ StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o1 (Commutativity
comm1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> Commutativity
comm2) (Lambda lore -> Lambda lore -> Lambda lore
forall lore. Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps Lambda lore
lam1r Lambda lore
lam2r)
mergeForms StreamForm lore
_ StreamForm lore
_ = String -> m (StreamForm lore)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusing sequential to parallel stream disallowed!"
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC SOACS
_ SOAC SOACS
_ = String -> TryFusion ([VName], SOAC SOACS)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot Fuse Streams!"
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream :: SOAC SOACS -> TryFusion (SOAC SOACS)
toSeqStream s :: SOAC SOACS
s@(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) = SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC SOACS
s
toSeqStream (SOAC.Stream SubExp
w Parallel {} Lambda SOACS
l [SubExp]
acc [Input]
inps) =
SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> TryFusion (SOAC SOACS))
-> SOAC SOACS -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS
-> Lambda SOACS
-> [SubExp]
-> [Input]
-> SOAC SOACS
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w StreamForm SOACS
forall lore. StreamForm lore
Sequential Lambda SOACS
l [SubExp]
acc [Input]
inps
toSeqStream SOAC SOACS
_ = String -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"toSeqStream expects a stream, but given a SOAC."
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel Maybe [VName]
inp FusedKer
ker = do
(SOAC SOACS
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedKer -> SOAC SOACS
fsoac FusedKer
ker) ArrayTransforms
startTrans
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker
{ fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac,
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
resTrans
}
where
startTrans :: ArrayTransforms
startTrans = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
optimizeSOAC ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC SOACS
soac ArrayTransforms
os = do
(Bool, SOAC SOACS, ArrayTransforms)
res <- ((Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms))
-> (Bool, SOAC SOACS, ArrayTransforms)
-> [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
False, SOAC SOACS
soac, ArrayTransforms
os) [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations
case (Bool, SOAC SOACS, ArrayTransforms)
res of
(Bool
False, SOAC SOACS
_, ArrayTransforms
_) -> String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No optimisation applied"
(Bool
True, SOAC SOACS
soac', ArrayTransforms
os') -> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
os')
where
comb :: (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f =
do
(SOAC SOACS
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f Maybe [VName]
inp SOAC SOACS
soac' ArrayTransforms
os
(Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, SOAC SOACS
soac'', ArrayTransforms
os'')
TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os')
type Optimization =
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizations :: [Optimization]
optimizations :: [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim]
iswim ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
| Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form,
Just (Pattern
map_pat, Certificates
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pattern, Certificates, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
let nes_idents :: [Ident]
nes_idents = (VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident)
-> [VName] -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident [VName]
nes_names ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_acc_params, [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_elem_params) =
Int
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)],
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)],
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)],
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
scan_fun
map_params :: [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
map_params =
(Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_acc_params
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_elem_params
map_rettype :: [TypeBase (ShapeBase SubExp) NoUniqueness]
map_rettype = (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
map_fun
scan_body :: BodyT SOACS
scan_body = Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
map_fun
scan_rettype :: [TypeBase (ShapeBase SubExp) NoUniqueness]
scan_rettype = Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun
scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS]
-> BodyT SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda SOACS
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda [LParam SOACS]
scan_params BodyT SOACS
scan_body [TypeBase (ShapeBase SubExp) NoUniqueness]
scan_rettype
nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
map_params
arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
map_params
ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']
let map_body :: BodyT SOACS
map_body =
Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody
( Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w Pattern
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
)
([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
Pattern
map_pat
map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS]
-> BodyT SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda SOACS
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
[LParam SOACS]
map_params BodyT SOACS
map_body [TypeBase (ShapeBase SubExp) NoUniqueness]
map_rettype
perm :: [Int]
perm = case Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun of
[] -> []
TypeBase (ShapeBase SubExp) NoUniqueness
t : [TypeBase (ShapeBase SubExp) NoUniqueness]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
t]
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return
( SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
map_cs [Int]
perm
)
iswim Maybe [VName]
_ SOAC SOACS
_ ArrayTransforms
_ =
String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ISWIM does not apply."
removeParamOuterDim :: LParam -> LParam
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
let t :: TypeBase (ShapeBase SubExp) NoUniqueness
t = TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param
in Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase (ShapeBase SubExp) NoUniqueness
paramDec = TypeBase (ShapeBase SubExp) NoUniqueness
t}
setParamOuterDimTo :: SubExp -> LParam -> LParam
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
let t :: TypeBase (ShapeBase SubExp) NoUniqueness
t = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
in Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase (ShapeBase SubExp) NoUniqueness
paramDec = TypeBase (ShapeBase SubExp) NoUniqueness
t}
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w = (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
-> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)
commonTransforms ::
[VName] ->
[SOAC.Input] ->
(SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
where
inps' :: [(Bool, Input)]
inps' =
[ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
| Input
inp <- [Input]
inps
]
commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
case ((Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
where
inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
(Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
| ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
where
resDims :: Int
resDims = [TypeBase (ShapeBase SubExp) NoUniqueness] -> Int
forall {shape} {u}. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Int)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
[] -> Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam
Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
Nesting lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
MapNest.nestingReturnType Nesting SOACS
nest
minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts
pullRearrange ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange SOAC SOACS
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC SOACS
soac
SOAC.Rearrange Certificates
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall (m :: * -> *) a. Monad m => a -> m a
return (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let
perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
SOAC SOACS
soac' <-
MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
[Input]
inputs' [Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
ots')
else String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull transpose"
pushRearrange ::
[VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC SOACS
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC SOACS
soac
([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let invertRearrange :: ArrayTransform
invertRearrange = Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
SOAC SOACS
soac' <-
MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
[Input]
inputs'
[Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
else String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot push transpose"
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall lore.
SubExp -> Lambda lore -> [Nesting lore] -> [Input] -> MapNest lore
MapNest.MapNest
SubExp
w
Lambda SOACS
body
( (Nesting SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting SOACS)
-> [Nesting SOACS]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
Nesting SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting SOACS
forall {lore} {lore}.
Nesting lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting lore
setReturnType
[Nesting SOACS]
nestings
([[TypeBase (ShapeBase SubExp) NoUniqueness]] -> [Nesting SOACS])
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]])
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall a b. (a -> b) -> a -> b
$ ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall a. (a -> a) -> a -> [a]
iterate ((TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType) [TypeBase (ShapeBase SubExp) NoUniqueness]
ts
)
[Input]
inps
where
origts :: [TypeBase (ShapeBase SubExp) NoUniqueness]
origts = MapNest -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
MapNest lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
MapNest.typeOf MapNest
nest
rearrangeType' :: TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
rearrangeType' TypeBase (ShapeBase SubExp) NoUniqueness
t = [Int]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
t) [Int]
perm) TypeBase (ShapeBase SubExp) NoUniqueness
t
ts :: [TypeBase (ShapeBase SubExp) NoUniqueness]
ts = (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
rearrangeType' [TypeBase (ShapeBase SubExp) NoUniqueness]
origts
setReturnType :: Nesting lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting lore
setReturnType Nesting lore
nesting [TypeBase (ShapeBase SubExp) NoUniqueness]
t' =
Nesting lore
nesting {nestingReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
MapNest.nestingReturnType = [TypeBase (ShapeBase SubExp) NoUniqueness]
t'}
fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
[Int]
perm : [[Int]]
_ -> do
[Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
([Int], [Input]) -> Maybe ([Int], [Input])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [Input]
inps')
[[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
where
exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray
inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ TypeBase (ShapeBase SubExp) NoUniqueness
_)
| ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certificates
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing
fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
| Int
r <- Input -> Int
SOAC.inputRank Input
inp,
Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
| Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing
pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
| Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
Futhark.isMapSOAC ScremaForm SOACS
form,
SOAC.Reshape Certificates
cs ShapeChange SubExp
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
(TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
maplam = do
let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape of
[] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
SubExp
d : [SubExp]
_ -> SubExp
d
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Input -> Input
SOAC.addTransform (ArrayTransform -> Input -> Input)
-> ArrayTransform -> Input -> Input
forall a b. (a -> b) -> a -> b
$ Certificates -> ShapeChange SubExp -> ArrayTransform
SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) [Input]
inps
inputTypes :: [TypeBase (ShapeBase SubExp) NoUniqueness]
inputTypes = (Input -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [Input] -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase (ShapeBase SubExp) NoUniqueness
SOAC.inputType [Input]
inputs'
let outersoac ::
([SOAC.Input] -> SOAC) ->
(SubExp, [SubExp]) ->
TryFusion ([SOAC.Input] -> SOAC)
outersoac :: ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac [Input] -> SOAC SOACS
inner (SubExp
w, [SubExp]
outershape) = do
let addDims :: TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
addDims TypeBase (ShapeBase SubExp) NoUniqueness
t = TypeBase (ShapeBase SubExp) NoUniqueness
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase (ShapeBase SubExp) NoUniqueness
t ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
retTypes :: [TypeBase (ShapeBase SubExp) NoUniqueness]
retTypes = (TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
addDims ([TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
maplam
[Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps <- [TypeBase (ShapeBase SubExp) NoUniqueness]
-> (TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TryFusion [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase (ShapeBase SubExp) NoUniqueness]
inputTypes ((TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TryFusion [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> (TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TryFusion [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase (ShapeBase SubExp) NoUniqueness
inpt ->
String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"pullReshape_param" (TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall a b. (a -> b) -> a -> b
$
Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) TypeBase (ShapeBase SubExp) NoUniqueness
inpt
BodyT SOACS
inner_body <-
Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
[BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))]
-> BinderT
SOACS
(State VNameSource)
(Body (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m)) =>
SOAC (Lore m) -> m (Exp (Lore m))
SOAC.toExp (SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource)))))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
SOACS
(State VNameSource)
(Exp (Lore (BinderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC SOACS
inner ([Input] -> SOAC SOACS) -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Input)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input)
-> (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps]
let inner_fun :: Lambda SOACS
inner_fun =
Lambda :: forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
[LParam SOACS]
ps,
lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = [TypeBase (ShapeBase SubExp) NoUniqueness]
retTypes,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
inner_body
}
([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
inner_fun
[Input] -> SOAC SOACS
op' <-
(([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC SOACS)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape
(SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Input] -> SOAC SOACS
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC SOACS
_ ArrayTransforms
_ = String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull reshape"
exposeInputs ::
[VName] ->
FusedKer ->
TryFusion (FusedKer, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
inpIds FusedKer
ker =
(FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pushRearrange')
TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pullRearrange')
TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker
where
ot :: ArrayTransforms
ot = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
pushRearrange' :: TryFusion FusedKer
pushRearrange' = do
(SOAC SOACS
soac', ArrayTransforms
ot') <- [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedKer -> SOAC SOACS
fsoac FusedKer
ker) ArrayTransforms
ot
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
FusedKer
ker
{ fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac',
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
ot'
}
pullRearrange' :: TryFusion FusedKer
pullRearrange' = do
(SOAC SOACS
soac', ArrayTransforms
ot') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange (FusedKer -> SOAC SOACS
fsoac FusedKer
ker) ArrayTransforms
ot
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"pullRearrange was not enough"
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
FusedKer
ker
{ fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac',
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms
}
exposeInputs' :: FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker' =
case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [Input]
inputs FusedKer
ker' of
(ArrayTransforms
ot', [Input]
inps')
| (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
(FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer
ker' {fsoac :: SOAC SOACS
fsoac = [Input]
inps' [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC SOACS
fsoac FusedKer
ker'}, ArrayTransforms
ot')
(ArrayTransforms, [Input])
_ -> String -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot expose"
exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ TypeBase (ShapeBase SubExp) NoUniqueness
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds
outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange, SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape]
pullOutputTransforms ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall {t} {t}.
[t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers
where
attempt :: [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [] t
_ t
_ = String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull anything"
attempt (t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps) t
soac t
ots =
do
(SOAC SOACS
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p t
soac t
ots
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
SOAC.noTransforms)
else SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac' ArrayTransforms
ots' TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
ots')
TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps t
soac t
ots