{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Optimise.Fusion.LoopKernel
( FusedKer (..),
newKernel,
inputs,
setInputs,
arrInputs,
transformOutput,
attemptFusion,
SOAC,
MapNest,
)
where
import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.HORep.MapNest as MapNest
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)
newtype TryFusion a
= TryFusion
( ReaderT
(Scope SOACS)
(StateT VNameSource Maybe)
a
)
deriving
( a -> TryFusion b -> TryFusion a
(a -> b) -> TryFusion a -> TryFusion b
(forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
Functor TryFusion
a -> TryFusion a
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
TryFusion a -> TryFusion b -> TryFusion b
TryFusion a -> TryFusion b -> TryFusion a
TryFusion (a -> b) -> TryFusion a -> TryFusion b
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: (a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
$cp1Applicative :: Functor TryFusion
Applicative,
Applicative TryFusion
TryFusion a
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
TryFusion a -> TryFusion a -> TryFusion a
TryFusion a -> TryFusion [a]
TryFusion a -> TryFusion [a]
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: TryFusion a
$cempty :: forall a. TryFusion a
$cp1Alternative :: Applicative TryFusion
Alternative,
Applicative TryFusion
a -> TryFusion a
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
TryFusion a -> (a -> TryFusion b) -> TryFusion b
TryFusion a -> TryFusion b -> TryFusion b
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$cp1Monad :: Applicative TryFusion
Monad,
Monad TryFusion
Monad TryFusion
-> (forall a. String -> TryFusion a) -> MonadFail TryFusion
String -> TryFusion a
forall a. String -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: String -> TryFusion a
$cfail :: forall a. String -> TryFusion a
$cp1MonadFail :: Monad TryFusion
MonadFail,
Monad TryFusion
Applicative TryFusion
TryFusion VNameSource
Applicative TryFusion
-> Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
$cp2MonadFreshNames :: Monad TryFusion
$cp1MonadFreshNames :: Applicative TryFusion
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
tryFusion ::
MonadFreshNames m =>
TryFusion a ->
Scope SOACS ->
m (Maybe a)
tryFusion :: TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)
liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = String -> TryFusion a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
type SOAC = SOAC.SOAC SOACS
type MapNest = MapNest.MapNest SOACS
transformOutput ::
SOAC.ArrayTransforms ->
[VName] ->
[Ident] ->
Builder SOACS ()
transformOutput :: ArrayTransforms -> [VName] -> [Ident] -> Builder SOACS ()
transformOutput ArrayTransforms
ts [VName]
names = ArrayTransforms -> [Ident] -> Builder SOACS ()
descend ArrayTransforms
ts
where
descend :: ArrayTransforms -> [Ident] -> Builder SOACS ()
descend ArrayTransforms
ts' [Ident]
validents =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts' of
ViewF
SOAC.EmptyF ->
[(VName, Ident)]
-> ((VName, Ident) -> Builder SOACS ()) -> Builder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Ident]
validents) (((VName, Ident) -> Builder SOACS ()) -> Builder SOACS ())
-> ((VName, Ident) -> Builder SOACS ()) -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
k, Ident
valident) ->
[VName]
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
k] (Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ())
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
valident
ArrayTransform
t SOAC.:< ArrayTransforms
ts'' -> do
let ([BasicOp]
es, [Certs]
css) = [(BasicOp, Certs)] -> ([BasicOp], [Certs])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BasicOp, Certs)] -> ([BasicOp], [Certs]))
-> [(BasicOp, Certs)] -> ([BasicOp], [Certs])
forall a b. (a -> b) -> a -> b
$ (Ident -> (BasicOp, Certs)) -> [Ident] -> [(BasicOp, Certs)]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Ident -> (BasicOp, Certs)
applyTransform ArrayTransform
t) [Ident]
validents
mkPat :: Ident -> PatT Type
mkPat (Ident VName
nm Type
tp) = [PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat [VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
nm Type
tp]
[Type]
opts <- [[Type]] -> [Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[Type]] -> [Type])
-> BuilderT SOACS (State VNameSource) [[Type]]
-> BuilderT SOACS (State VNameSource) [Type]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BasicOp -> BuilderT SOACS (State VNameSource) [Type])
-> [BasicOp] -> BuilderT SOACS (State VNameSource) [[Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BasicOp -> BuilderT SOACS (State VNameSource) [Type]
forall rep (m :: * -> *). HasScope rep m => BasicOp -> m [Type]
basicOpType [BasicOp]
es
[Ident]
newIds <- [(VName, Type)]
-> ((VName, Type) -> BuilderT SOACS (State VNameSource) Ident)
-> BuilderT SOACS (State VNameSource) [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Type]
opts) (((VName, Type) -> BuilderT SOACS (State VNameSource) Ident)
-> BuilderT SOACS (State VNameSource) [Ident])
-> ((VName, Type) -> BuilderT SOACS (State VNameSource) Ident)
-> BuilderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ \(VName
k, Type
opt) ->
String -> Type -> BuilderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> Type -> m Ident
newIdent (VName -> String
baseString VName
k) Type
opt
[(Certs, Ident, BasicOp)]
-> ((Certs, Ident, BasicOp) -> Builder SOACS ())
-> Builder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Certs] -> [Ident] -> [BasicOp] -> [(Certs, Ident, BasicOp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Certs]
css [Ident]
newIds [BasicOp]
es) (((Certs, Ident, BasicOp) -> Builder SOACS ()) -> Builder SOACS ())
-> ((Certs, Ident, BasicOp) -> Builder SOACS ())
-> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(Certs
cs, Ident
ids, BasicOp
e) ->
Certs -> Builder SOACS () -> Builder SOACS ()
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (Builder SOACS () -> Builder SOACS ())
-> Builder SOACS () -> Builder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pat (Rep (BuilderT SOACS (State VNameSource)))
-> Exp (Rep (BuilderT SOACS (State VNameSource)))
-> Builder SOACS ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (Rep m) -> Exp (Rep m) -> m ()
letBind (Ident -> PatT Type
mkPat Ident
ids) (BasicOp -> ExpT SOACS
forall rep. BasicOp -> ExpT rep
BasicOp BasicOp
e)
ArrayTransforms -> [Ident] -> Builder SOACS ()
descend ArrayTransforms
ts'' [Ident]
newIds
applyTransform :: SOAC.ArrayTransform -> Ident -> (BasicOp, Certs)
applyTransform :: ArrayTransform -> Ident -> (BasicOp, Certs)
applyTransform (SOAC.Rearrange Certs
cs [Int]
perm) Ident
v =
([Int] -> VName -> BasicOp
Rearrange [Int]
perm' (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certs
cs)
where
perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm) [Int
0 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> Type
identType Ident
v) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
applyTransform (SOAC.Reshape Certs
cs ShapeChange SubExp
shape) Ident
v =
(ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certs
cs)
applyTransform (SOAC.ReshapeOuter Certs
cs ShapeChange SubExp
shape) Ident
v =
let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp) -> Shape -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certs
cs)
applyTransform (SOAC.ReshapeInner Certs
cs ShapeChange SubExp
shape) Ident
v =
let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
shape Int
1 (Shape -> ShapeChange SubExp) -> Shape -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Type -> Shape) -> Type -> Shape
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v
in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certs
cs)
applyTransform (SOAC.Replicate Certs
cs Shape
n) Ident
v =
(Shape -> SubExp -> BasicOp
Replicate Shape
n (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certs
cs)
inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia Type
iat) =
case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts' VName
ia Type
iat)
ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing
data FusedKer = FusedKer
{
FusedKer -> SOAC
fsoac :: SOAC,
FusedKer -> Names
inplace :: Names,
FusedKer -> [VName]
fusedVars :: [VName],
FusedKer -> Names
fusedConsumed :: Names,
FusedKer -> Scope SOACS
kernelScope :: Scope SOACS,
FusedKer -> ArrayTransforms
outputTransform :: SOAC.ArrayTransforms,
FusedKer -> [VName]
outNames :: [VName],
FusedKer -> StmAux ()
kerAux :: StmAux ()
}
deriving (Int -> FusedKer -> ShowS
[FusedKer] -> ShowS
FusedKer -> String
(Int -> FusedKer -> ShowS)
-> (FusedKer -> String) -> ([FusedKer] -> ShowS) -> Show FusedKer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FusedKer] -> ShowS
$cshowList :: [FusedKer] -> ShowS
show :: FusedKer -> String
$cshow :: FusedKer -> String
showsPrec :: Int -> FusedKer -> ShowS
$cshowsPrec :: Int -> FusedKer -> ShowS
Show)
newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC
soac Names
consumed [VName]
out_nms Scope SOACS
scope =
FusedKer :: SOAC
-> Names
-> [VName]
-> Names
-> Scope SOACS
-> ArrayTransforms
-> [VName]
-> StmAux ()
-> FusedKer
FusedKer
{ fsoac :: SOAC
fsoac = SOAC
soac,
inplace :: Names
inplace = Names
consumed,
fusedVars :: [VName]
fusedVars = [],
fusedConsumed :: Names
fusedConsumed = Names
consumed,
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms,
outNames :: [VName]
outNames = [VName]
out_nms,
kernelScope :: Scope SOACS
kernelScope = Scope SOACS
scope,
kerAux :: StmAux ()
kerAux = StmAux ()
aux
}
arrInputs :: FusedKer -> S.Set VName
arrInputs :: FusedKer -> Set VName
arrInputs = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (FusedKer -> [VName]) -> FusedKer -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName])
-> (FusedKer -> [Input]) -> FusedKer -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [Input]
inputs
inputs :: FusedKer -> [SOAC.Input]
inputs :: FusedKer -> [Input]
inputs = SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs (SOAC -> [Input]) -> (FusedKer -> SOAC) -> FusedKer -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> SOAC
fsoac
setInputs :: [SOAC.Input] -> FusedKer -> FusedKer
setInputs :: [Input] -> FusedKer -> FusedKer
setInputs [Input]
inps FusedKer
ker = FusedKer
ker {fsoac :: SOAC
fsoac = [Input]
inps [Input] -> SOAC -> SOAC
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedKer -> SOAC
fsoac FusedKer
ker}
tryOptimizeSOAC ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryOptimizeSOAC :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
(SOAC
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC
soac ArrayTransforms
forall a. Monoid a => a
mempty
let ker' :: FusedKer
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedKer -> [Input]
inputs FusedKer
ker) [Input] -> FusedKer -> FusedKer
`setInputs` FusedKer
ker
outIdents :: [Ident]
outIdents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
outVars ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall rep. SOAC rep -> [Type]
SOAC.typeOf SOAC
soac'
ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac' Names
consumed FusedKer
ker''
where
addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
| Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
| Bool
otherwise =
Input
inp
tryOptimizeKernel ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryOptimizeKernel :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
FusedKer
ker' <- Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedKer
ker
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker'
tryExposeInputs ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
tryExposeInputs :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker = do
(FusedKer
ker', ArrayTransforms
ots) <- [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
outVars FusedKer
ker
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker'
else do
(SOAC
soac', ArrayTransforms
ots') <- SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms SOAC
soac ArrayTransforms
ots
let outIdents :: [Ident]
outIdents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
outVars ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall rep. SOAC rep -> [Type]
SOAC.typeOf SOAC
soac'
ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac' Names
consumed FusedKer
ker''
else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"tryExposeInputs could not pull SOAC transforms"
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker =
FusedKer
ker {fsoac :: SOAC
fsoac = SOAC -> SOAC
fixInputTypes' (SOAC -> SOAC) -> SOAC -> SOAC
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC
fsoac FusedKer
ker}
where
fixInputTypes' :: SOAC -> SOAC
fixInputTypes' SOAC
soac =
(Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac) [Input] -> SOAC -> SOAC
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC
soac
fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v Type
_)
| Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
ArrayTransforms -> VName -> Type -> Input
SOAC.Input ArrayTransforms
ts VName
v (Type -> Input) -> Type -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> Type
identType Ident
v'
fixInputType Input
inp = Input
inp
applyFusionRules ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
applyFusionRules :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker =
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker
attemptFusion ::
MonadFreshNames m =>
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
m (Maybe FusedKer)
attemptFusion :: Names -> [VName] -> SOAC -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker =
(FusedKer -> FusedKer) -> Maybe FusedKer -> Maybe FusedKer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FusedKer -> FusedKer
removeUnusedParamsFromKer
(Maybe FusedKer -> Maybe FusedKer)
-> m (Maybe FusedKer) -> m (Maybe FusedKer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TryFusion FusedKer -> Scope SOACS -> m (Maybe FusedKer)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion
(Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC
soac Names
consumed FusedKer
ker)
(FusedKer -> Scope SOACS
kernelScope FusedKer
ker)
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer FusedKer
ker =
case SOAC
soac of
SOAC.Screma {} -> FusedKer
ker {fsoac :: SOAC
fsoac = SOAC
soac'}
SOAC
_ -> FusedKer
ker
where
soac :: SOAC
soac = FusedKer -> SOAC
fsoac FusedKer
ker
l :: Lambda SOACS
l = SOAC -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC
soac
inps :: [Input]
inps = SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac
(Lambda SOACS
l', [Input]
inps') = Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps
soac' :: SOAC
soac' =
Lambda SOACS
l'
Lambda SOACS -> SOAC -> SOAC
forall rep. Lambda rep -> SOAC rep -> SOAC rep
`SOAC.setLambda` ([Input]
inps' [Input] -> SOAC -> SOAC
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` SOAC
soac)
removeUnusedParams :: Lambda -> [SOAC.Input] -> (Lambda, [SOAC.Input])
removeUnusedParams :: Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps =
(Lambda SOACS
l {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
ps'}, [Input]
inps')
where
pInps :: [(Param Type, Input)]
pInps = [Param Type] -> [Input] -> [(Param Type, Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
l) [Input]
inps
([Param Type]
ps', [Input]
inps') = case ([(Param Type, Input)] -> ([Param Type], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, Input)] -> ([Param Type], [Input]))
-> [(Param Type, Input)] -> ([Param Type], [Input])
forall a b. (a -> b) -> a -> b
$ ((Param Type, Input) -> Bool)
-> [(Param Type, Input)] -> [(Param Type, Input)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param Type -> Bool
used (Param Type -> Bool)
-> ((Param Type, Input) -> Param Type)
-> (Param Type, Input)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, Input) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, Input)]
pInps, [(Param Type, Input)]
pInps) of
(([], []), (Param Type
p, Input
inp) : [(Param Type, Input)]
_) -> ([Param Type
p], [Input
inp])
(([Param Type]
ps_, [Input]
inps_), [(Param Type, Input)]
_) -> ([Param Type]
ps_, [Input]
inps_)
used :: Param Type -> Bool
used Param Type
p = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
p VName -> Names -> Bool
`nameIn` Names
freeVars
freeVars :: Names
freeVars = BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT SOACS -> Names) -> BodyT SOACS -> Names
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
l
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
where
inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)
fuseSOACwithKer ::
Names ->
[VName] ->
SOAC ->
Names ->
FusedKer ->
TryFusion FusedKer
fuseSOACwithKer :: Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC
soac_p Names
soac_p_consumed FusedKer
ker = do
let soac_c :: SOAC
soac_c = FusedKer -> SOAC
fsoac FusedKer
ker
inp_p_arr :: [Input]
inp_p_arr = SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac_p
horizFuse :: Bool
horizFuse =
Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty
Bool -> Bool -> Bool
&& SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac_c
inp_c_arr :: [Input]
inp_c_arr = SOAC -> [Input]
forall rep. SOAC rep -> [Input]
SOAC.inputs SOAC
soac_c
lam_p :: Lambda SOACS
lam_p = SOAC -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC
soac_p
lam_c :: Lambda SOACS
lam_c = SOAC -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC
soac_c
w :: SubExp
w = SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac_p
returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
success :: [VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_outnms SOAC
res_soac = do
let fusedVars_new :: [VName]
fusedVars_new = FusedKer -> [VName]
fusedVars FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars
Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC -> Lambda SOACS
forall rep. SOAC rep -> Lambda rep
SOAC.lambda SOAC
res_soac
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker
{ fsoac :: SOAC
fsoac = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC -> SOAC
forall rep. Lambda rep -> SOAC rep -> SOAC rep
`SOAC.setLambda` SOAC
res_soac,
fusedVars :: [VName]
fusedVars = [VName]
fusedVars_new,
inplace :: Names
inplace = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
fusedConsumed :: Names
fusedConsumed = FusedKer -> Names
fusedConsumed FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
outNames :: [VName]
outNames = [VName]
res_outnms
}
[(VName, Ident)]
outPairs <- [(VName, Type)]
-> ((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall rep. SOAC rep -> [Type]
SOAC.typeOf SOAC
soac_p) (((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)])
-> ((VName, Type) -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, Type
t) -> do
VName
outVar' <- String -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> TryFusion VName) -> String -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
outVar String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
(VName, Ident) -> TryFusion (VName, Ident)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
outVar, VName -> Type -> Ident
Ident VName
outVar' Type
t)
let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> Lambda rep
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [Input]
-> (Lambda rep, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
([VName]
extra_nms, [Type]
extra_rtps) =
[(VName, Type)] -> ([VName], [Type])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, Type)] -> ([VName], [Type]))
-> [(VName, Type)] -> ([VName], [Type])
forall a b. (a -> b) -> a -> b
$
((VName, Type) -> Bool) -> [(VName, Type)] -> [(VName, Type)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, Type) -> VName) -> (VName, Type) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Type) -> VName
forall a b. (a, b) -> a
fst) ([(VName, Type)] -> [(VName, Type)])
-> [(VName, Type)] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$
[VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([Type] -> [(VName, Type)]) -> [Type] -> [(VName, Type)]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray Int
1) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ SOAC -> [Type]
forall rep. SOAC rep -> [Type]
SOAC.typeOf SOAC
soac_p
res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [Type]
lambdaReturnType = Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
res_lam [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
extra_rtps}
in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
horizFuse Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker)) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Horizontal fusion is invalid in the presence of output transforms."
case (SOAC
soac_c, SOAC
soac_p) of
(SOAC, SOAC)
_ | SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC -> SubExp
forall rep. SOAC rep -> SubExp
SOAC.width SOAC
soac_c -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC widths must match."
( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_
)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall rep. [Scan rep] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall rep. [Reduce rep] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedKer
ker
Bool -> Bool -> Bool
|| Bool
horizFuse -> do
let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall rep. Reduce rep -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam_p
[SubExp]
scan_nes_p
[SubExp]
red_nes_p
[Input]
inp_p_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam_c
[SubExp]
scan_nes_c
[SubExp]
red_nes_c
[Input]
inp_c_arr
([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker
unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
[VName] -> SOAC -> TryFusion FusedKer
success
( [VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
[VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
)
(SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma
SubExp
w
([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
[Input]
new_inp
( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(Shape, Int, VName)]
dests,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
[VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(Shape, Int, VName)]
dests
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
)
| Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC ScremaForm SOACS
form,
Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
[VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp
( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_
)
| Bool
horizFuse -> do
let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
(BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam_c)
body' :: BodyT SOACS
body' =
Body :: forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body
{ bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall rep. BodyT rep -> BodyDec rep
bodyDec BodyT SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body_c,
bodyResult :: Result
bodyResult =
Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body_c)
Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body_p)
Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body_c)
Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Int -> Result -> Result
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body_p)
}
lam' :: Lambda SOACS
lam' =
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
lambdaReturnType :: [Type]
lambdaReturnType =
Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
lam_c)
[Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> [Type] -> [Type]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC
forall rep.
SubExp -> [HistOp rep] -> Lambda rep -> [Input] -> SOAC rep
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)
( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(Shape, Int, VName)]
as_c,
SOAC.Scatter SubExp
_len_p Lambda SOACS
_lam_p [Input]
ivs_p [(Shape, Int, VName)]
as_p
)
| Bool
horizFuse -> do
let zipW :: [(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, array)]
as_xs [a]
xs [(Shape, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
where
([a]
xs_indices, [a]
xs_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_xs [a]
xs
([a]
ys_indices, [a]
ys_vals) = [(Shape, Int, array)] -> [a] -> ([a], [a])
forall array a. [(Shape, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(Shape, Int, array)]
as_ys [a]
ys
let (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
lam_c)
let body' :: BodyT SOACS
body' =
Body :: forall rep. BodyDec rep -> Stms rep -> Result -> BodyT rep
Body
{ bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall rep. BodyT rep -> BodyDec rep
bodyDec BodyT SOACS
body_p,
bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall rep. BodyT rep -> Stms rep
bodyStms BodyT SOACS
body_c,
bodyResult :: Result
bodyResult = [(Shape, Int, VName)]
-> Result -> [(Shape, Int, VName)] -> Result -> Result
forall array a array.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body_c) [(Shape, Int, VName)]
as_p (BodyT SOACS -> Result
forall rep. BodyT rep -> Result
bodyResult BodyT SOACS
body_p)
}
let lam' :: Lambda SOACS
lam' =
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_c [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam_p,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
lambdaReturnType :: [Type]
lambdaReturnType = [(Shape, Int, VName)]
-> [Type] -> [(Shape, Int, VName)] -> [Type] -> [Type]
forall array a array.
[(Shape, Int, array)] -> [a] -> [(Shape, Int, array)] -> [a] -> [a]
zipW [(Shape, Int, VName)]
as_c (Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
lam_c) [(Shape, Int, VName)]
as_p (Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
lam_p)
}
[VName] -> SOAC -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC -> TryFusion FusedKer) -> SOAC -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
SubExp -> Lambda SOACS -> [Input] -> [(Shape, Int, VName)] -> SOAC
forall rep.
SubExp
-> Lambda rep -> [Input] -> [(Shape, Int, VName)] -> SOAC rep
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(Shape, Int, VName)]
as_c [(Shape, Int, VName)]
-> [(Shape, Int, VName)] -> [(Shape, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(Shape, Int, VName)]
as_p)
(SOAC.Scatter {}, SOAC
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
(SOAC
_, SOAC.Scatter {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
nes [Input]
_)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
([VName]
res_nms, SOAC
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC
soac_c SOAC
soac_p
[VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_nms SOAC
res_stream
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two SEQ streams!"
(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
(SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
(SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_)
| [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
([VName]
res_nms, SOAC
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC
soac_c SOAC
soac_p
[VName] -> SOAC -> TryFusion FusedKer
success [VName]
res_nms SOAC
res_stream
(SOAC.Stream {}, SOAC.Stream {}) ->
String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two PAR streams!"
(SOAC.Stream SubExp
_ StreamForm SOACS
form2 Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC
_) -> do
(SOAC
soac_p', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC
soac_p
SOAC
soac_p'' <- case StreamForm SOACS
form2 of
Sequential {} -> SOAC -> TryFusion SOAC
toSeqStream SOAC
soac_p'
StreamForm SOACS
_ -> SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
soac_p'
if SOAC
soac_p' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC
soac_p
then String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
else Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC
soac_p'' Names
soac_p_consumed FusedKer
ker
(SOAC
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) | Just [Scan SOACS]
_ <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
Futhark.isScanSOAC ScremaForm SOACS
form -> do
(SOAC
soac_p', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC
soac_p
if SOAC
soac_p' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC
soac_p
then Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC
soac_p' Names
soac_p_consumed FusedKer
ker
else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
(SOAC
_, SOAC.Stream SubExp
_ StreamForm SOACS
form_p Lambda SOACS
_ [SubExp]
_ [Input]
_) -> do
(SOAC
soac_c', [Ident]
newacc_ids) <- SOAC -> TryFusion (SOAC, [Ident])
forall (m :: * -> *) rep.
(MonadFreshNames m, Buildable rep, Op rep ~ SOAC rep) =>
SOAC rep -> m (SOAC rep, [Ident])
SOAC.soacToStream SOAC
soac_c
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SOAC
soac_c' SOAC -> SOAC -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC
soac_c) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
SOAC
soac_c'' <- case StreamForm SOACS
form_p of
StreamForm SOACS
Sequential -> SOAC -> TryFusion SOAC
toSeqStream SOAC
soac_c'
StreamForm SOACS
_ -> SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
soac_c'
Names -> [VName] -> SOAC -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC
soac_p Names
soac_p_consumed (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker {fsoac :: SOAC
fsoac = SOAC
soac_c'', outNames :: [VName]
outNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedKer -> [VName]
outNames FusedKer
ker}
(SOAC, SOAC)
_ -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse"
getStreamOrder :: StreamForm rep -> StreamOrd
getStreamOrder :: StreamForm rep -> StreamOrd
getStreamOrder (Parallel StreamOrd
o Commutativity
_ Lambda rep
_) = StreamOrd
o
getStreamOrder StreamForm rep
Sequential = StreamOrd
InOrder
fuseStreamHelper ::
[VName] ->
Names ->
[VName] ->
[(VName, Ident)] ->
SOAC ->
SOAC ->
TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC
-> SOAC
-> TryFusion ([VName], SOAC)
fuseStreamHelper
[VName]
out_kernms
Names
unfus_set
[VName]
outVars
[(VName, Ident)]
outPairs
(SOAC.Stream SubExp
w2 StreamForm SOACS
form2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
(SOAC.Stream SubExp
_ StreamForm SOACS
form1 Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) =
if StreamForm SOACS -> StreamOrd
forall rep. StreamForm rep -> StreamOrd
getStreamOrder StreamForm SOACS
form2 StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamForm SOACS -> StreamOrd
forall rep. StreamForm rep -> StreamOrd
getStreamOrder StreamForm SOACS
form1
then String -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"fusion conditions not met!"
else do
let chunk1 :: Param Type
chunk1 = [Param Type] -> Param Type
forall a. [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1
chunk2 :: Param Type
chunk2 = [Param Type] -> Param Type
forall a. [a] -> a
head ([Param Type] -> Param Type) -> [Param Type] -> Param Type
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam2
hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk2, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
chunk1)]
lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type] -> [Param Type]
forall a. [a] -> [a]
tail ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam1}
lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type] -> [Param Type]
forall a. [a] -> [a]
tail ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
lam20}
(Lambda SOACS
res_lam', [Input]
new_inp) =
Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall rep.
Buildable rep =>
Names
-> [VName]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda rep
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda rep, [Input])
fuseRedomap
Names
unfus_set
[VName]
outVars
Lambda SOACS
lam1'
[]
[SubExp]
nes1
[Input]
inp1_arr
[(VName, Ident)]
outPairs
Lambda SOACS
lam2'
[]
[SubExp]
nes2
[Input]
inp2_arr
res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param Type
chunk1 Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
res_lam'}
unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
unfus_accs) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
StreamForm SOACS
res_form <- StreamForm SOACS
-> StreamForm SOACS -> TryFusion (StreamForm SOACS)
forall (m :: * -> *) rep.
MonadFail m =>
StreamForm rep -> StreamForm rep -> m (StreamForm rep)
mergeForms StreamForm SOACS
form2 StreamForm SOACS
form1
([VName], SOAC) -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. Monad m => a -> m a
return
( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall rep.
SubExp
-> StreamForm rep -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w2 StreamForm SOACS
res_form Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
)
where
mergeForms :: StreamForm rep -> StreamForm rep -> m (StreamForm rep)
mergeForms StreamForm rep
Sequential StreamForm rep
Sequential = StreamForm rep -> m (StreamForm rep)
forall (m :: * -> *) a. Monad m => a -> m a
return StreamForm rep
forall rep. StreamForm rep
Sequential
mergeForms (Parallel StreamOrd
_ Commutativity
comm2 Lambda rep
lam2r) (Parallel StreamOrd
o1 Commutativity
comm1 Lambda rep
lam1r) =
StreamForm rep -> m (StreamForm rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm rep -> m (StreamForm rep))
-> StreamForm rep -> m (StreamForm rep)
forall a b. (a -> b) -> a -> b
$ StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
forall rep.
StreamOrd -> Commutativity -> Lambda rep -> StreamForm rep
Parallel StreamOrd
o1 (Commutativity
comm1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> Commutativity
comm2) (Lambda rep -> Lambda rep -> Lambda rep
forall rep. Lambda rep -> Lambda rep -> Lambda rep
mergeReduceOps Lambda rep
lam1r Lambda rep
lam2r)
mergeForms StreamForm rep
_ StreamForm rep
_ = String -> m (StreamForm rep)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusing sequential to parallel stream disallowed!"
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC
_ SOAC
_ = String -> TryFusion ([VName], SOAC)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot Fuse Streams!"
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream s :: SOAC
s@(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) = SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC
s
toSeqStream (SOAC.Stream SubExp
w Parallel {} Lambda SOACS
l [SubExp]
acc [Input]
inps) =
SOAC -> TryFusion SOAC
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC -> TryFusion SOAC) -> SOAC -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS -> Lambda SOACS -> [SubExp] -> [Input] -> SOAC
forall rep.
SubExp
-> StreamForm rep -> Lambda rep -> [SubExp] -> [Input] -> SOAC rep
SOAC.Stream SubExp
w StreamForm SOACS
forall rep. StreamForm rep
Sequential Lambda SOACS
l [SubExp]
acc [Input]
inps
toSeqStream SOAC
_ = String -> TryFusion SOAC
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"toSeqStream expects a stream, but given a SOAC."
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel Maybe [VName]
inp FusedKer
ker = do
(SOAC
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
startTrans
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
FusedKer
ker
{ fsoac :: SOAC
fsoac = SOAC
soac,
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
resTrans
}
where
startTrans :: ArrayTransforms
startTrans = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
optimizeSOAC ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC
soac ArrayTransforms
os = do
(Bool, SOAC, ArrayTransforms)
res <- ((Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms))
-> (Bool, SOAC, ArrayTransforms)
-> [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms)
comb (Bool
False, SOAC
soac, ArrayTransforms
os) [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
optimizations
case (Bool, SOAC, ArrayTransforms)
res of
(Bool
False, SOAC
_, ArrayTransforms
_) -> String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No optimisation applied"
(Bool
True, SOAC
soac', ArrayTransforms
os') -> (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
os')
where
comb :: (Bool, SOAC, ArrayTransforms)
-> (Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms))
-> TryFusion (Bool, SOAC, ArrayTransforms)
comb (Bool
changed, SOAC
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
f =
do
(SOAC
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
f Maybe [VName]
inp SOAC
soac' ArrayTransforms
os
(Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, SOAC
soac'', ArrayTransforms
os'')
TryFusion (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC, ArrayTransforms)
-> TryFusion (Bool, SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
changed, SOAC
soac', ArrayTransforms
os')
type Optimization =
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
optimizations :: [Optimization]
optimizations :: [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
iswim]
iswim ::
Maybe [VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
| Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall rep. ScremaForm rep -> Maybe [Scan rep]
Futhark.isScanSOAC ScremaForm SOACS
form,
Just (Pat
map_pat, Certs
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pat, Certs, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
let nes_idents :: [Ident]
nes_idents = (VName -> Type -> Ident) -> [VName] -> [Type] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> Ident
Ident [VName]
nes_names ([Type] -> [Ident]) -> [Type] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
scan_fun
map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
([Param Type]
scan_acc_params, [Param Type]
scan_elem_params) =
Int -> [Param Type] -> ([Param Type], [Param Type])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param Type] -> ([Param Type], [Param Type]))
-> [Param Type] -> ([Param Type], [Param Type])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
scan_fun
map_params :: [Param Type]
map_params =
(Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param Type]
scan_acc_params
[Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ (Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param Type]
scan_elem_params
map_rettype :: [Type]
map_rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
scan_fun
scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda SOACS
map_fun
scan_body :: BodyT SOACS
scan_body = Lambda SOACS -> BodyT SOACS
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda SOACS
map_fun
scan_rettype :: [Type]
scan_rettype = Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
map_fun
scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam SOACS]
scan_params BodyT SOACS
scan_body [Type]
scan_rettype
nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
map_params
ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']
let map_body :: BodyT SOACS
map_body =
Stms SOACS -> Result -> BodyT SOACS
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody
( Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
Pat -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let (SubExp -> Pat -> Pat
setPatOuterDimTo SubExp
w Pat
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
Op SOACS -> ExpT SOACS
forall rep. Op rep -> ExpT rep
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
)
(Result -> BodyT SOACS) -> Result -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ PatT Type -> [VName]
forall dec. PatT dec -> [VName]
patNames PatT Type
Pat
map_pat
map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS] -> BodyT SOACS -> [Type] -> Lambda SOACS
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [Param Type]
[LParam SOACS]
map_params BodyT SOACS
map_body [Type]
map_rettype
perm :: [Int]
perm = case Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
map_fun of
[] -> []
Type
t : [Type]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t]
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return
( SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
map_cs [Int]
perm
)
iswim Maybe [VName]
_ SOAC
_ ArrayTransforms
_ =
String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ISWIM does not apply."
removeParamOuterDim :: LParam -> LParam
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
let t :: Type
t = Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param
in Param Type
LParam SOACS
param {paramDec :: Type
paramDec = Type
t}
setParamOuterDimTo :: SubExp -> LParam -> LParam
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
let t :: Type
t = Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
LParam SOACS
param Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
in Param Type
LParam SOACS
param {paramDec :: Type
paramDec = Type
t}
setPatOuterDimTo :: SubExp -> Pat -> Pat
setPatOuterDimTo :: SubExp -> Pat -> Pat
setPatOuterDimTo SubExp
w = (Type -> Type) -> PatT Type -> PatT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)
commonTransforms ::
[VName] ->
[SOAC.Input] ->
(SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
where
inps' :: [(Bool, Input)]
inps' =
[ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
| Input
inp <- [Input]
inps
]
commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
case ((Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
where
inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
(Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
| ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
(Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
where
resDims :: Int
resDims = [Type] -> Int
forall shape u. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
[] -> Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
lam
Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [Type]
forall rep. Nesting rep -> [Type]
MapNest.nestingReturnType Nesting SOACS
nest
minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts
pullRearrange ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange SOAC
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC -> TryFusion (Maybe MapNest)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC
soac
SOAC.Rearrange Certs
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall (m :: * -> *) a. Monad m => a -> m a
return (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let
perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
where
r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
SOAC
soac' <-
MapNest -> TryFusion SOAC
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC (MapNest -> TryFusion SOAC) -> MapNest -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$
[Input]
inputs' [Input] -> MapNest -> MapNest
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
ots')
else String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull transpose"
pushRearrange ::
[VName] ->
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC
soac ArrayTransforms
ots = do
MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC -> TryFusion (Maybe MapNest)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m, LocalScope rep m,
Op rep ~ SOAC rep) =>
SOAC rep -> m (Maybe (MapNest rep))
MapNest.fromSOAC SOAC
soac
([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall rep. MapNest rep -> [Input]
MapNest.inputs MapNest
nest
if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
then do
let invertRearrange :: ArrayTransform
invertRearrange = Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
SOAC
soac' <-
MapNest -> TryFusion SOAC
forall (m :: * -> *) rep.
(MonadFreshNames m, HasScope rep m, Buildable rep, BuilderOps rep,
Op rep ~ SOAC rep) =>
MapNest rep -> m (SOAC rep)
MapNest.toSOAC (MapNest -> TryFusion SOAC) -> MapNest -> TryFusion SOAC
forall a b. (a -> b) -> a -> b
$
[Input]
inputs'
[Input] -> MapNest -> MapNest
forall rep. [Input] -> MapNest rep -> MapNest rep
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
else String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot push transpose"
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall rep.
SubExp -> Lambda rep -> [Nesting rep] -> [Input] -> MapNest rep
MapNest.MapNest
SubExp
w
Lambda SOACS
body
( (Nesting SOACS -> [Type] -> Nesting SOACS)
-> [Nesting SOACS] -> [[Type]] -> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
Nesting SOACS -> [Type] -> Nesting SOACS
forall rep rep. Nesting rep -> [Type] -> Nesting rep
setReturnType
[Nesting SOACS]
nestings
([[Type]] -> [Nesting SOACS]) -> [[Type]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int -> [[Type]] -> [[Type]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[Type]] -> [[Type]]) -> [[Type]] -> [[Type]]
forall a b. (a -> b) -> a -> b
$ ([Type] -> [Type]) -> [Type] -> [[Type]]
forall a. (a -> a) -> a -> [a]
iterate ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ts
)
[Input]
inps
where
origts :: [Type]
origts = MapNest -> [Type]
forall rep. MapNest rep -> [Type]
MapNest.typeOf MapNest
nest
rearrangeType' :: Type -> Type
rearrangeType' Type
t = [Int] -> Type -> Type
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) [Int]
perm) Type
t
ts :: [Type]
ts = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
rearrangeType' [Type]
origts
setReturnType :: Nesting rep -> [Type] -> Nesting rep
setReturnType Nesting rep
nesting [Type]
t' =
Nesting rep
nesting {nestingReturnType :: [Type]
MapNest.nestingReturnType = [Type]
t'}
fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
[Int]
perm : [[Int]]
_ -> do
[Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
([Int], [Input]) -> Maybe ([Int], [Input])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [Input]
inps')
[[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
where
exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray
inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ Type
_)
| ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certs
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing
fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
| Int
r <- Input -> Int
SOAC.inputRank Input
inp,
Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certs -> [Int] -> ArrayTransform
SOAC.Rearrange Certs
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
| Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing
pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
| Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall rep. ScremaForm rep -> Maybe (Lambda rep)
Futhark.isMapSOAC ScremaForm SOACS
form,
SOAC.Reshape Certs
cs ShapeChange SubExp
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
(Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([Type] -> Bool) -> [Type] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
maplam = do
let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape of
[] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
SubExp
d : [SubExp]
_ -> SubExp
d
inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Input -> Input
SOAC.addTransform (ArrayTransform -> Input -> Input)
-> ArrayTransform -> Input -> Input
forall a b. (a -> b) -> a -> b
$ Certs -> ShapeChange SubExp -> ArrayTransform
SOAC.ReshapeOuter Certs
cs ShapeChange SubExp
shape) [Input]
inps
inputTypes :: [Type]
inputTypes = (Input -> Type) -> [Input] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Type
SOAC.inputType [Input]
inputs'
let outersoac ::
([SOAC.Input] -> SOAC) ->
(SubExp, [SubExp]) ->
TryFusion ([SOAC.Input] -> SOAC)
outersoac :: ([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC)
outersoac [Input] -> SOAC
inner (SubExp
w, [SubExp]
outershape) = do
let addDims :: Type -> Type
addDims Type
t = Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
retTypes :: [Type]
retTypes = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
addDims ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. LambdaT rep -> [Type]
lambdaReturnType Lambda SOACS
maplam
[Param Type]
ps <- [Type]
-> (Type -> TryFusion (Param Type)) -> TryFusion [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Type]
inputTypes ((Type -> TryFusion (Param Type)) -> TryFusion [Param Type])
-> (Type -> TryFusion (Param Type)) -> TryFusion [Param Type]
forall a b. (a -> b) -> a -> b
$ \Type
inpt ->
String -> Type -> TryFusion (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"pullReshape_param" (Type -> TryFusion (Param Type)) -> Type -> TryFusion (Param Type)
forall a b. (a -> b) -> a -> b
$
Int -> Type -> Type
forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) Type
inpt
BodyT SOACS
inner_body <-
Builder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall rep (m :: * -> *) somerep.
(Buildable rep, MonadFreshNames m, HasScope somerep m,
SameScope somerep rep) =>
Builder rep (Body rep) -> m (Body rep)
runBodyBuilder (Builder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS))
-> Builder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
[BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))]
-> BuilderT
SOACS
(State VNameSource)
(Body (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody [SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBuilder m, Op (Rep m) ~ SOAC (Rep m)) =>
SOAC (Rep m) -> m (Exp (Rep m))
SOAC.toExp (SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource)))))
-> SOAC (Rep (BuilderT SOACS (State VNameSource)))
-> BuilderT
SOACS
(State VNameSource)
(Exp (Rep (BuilderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC
inner ([Input] -> SOAC) -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ (Param Type -> Input) -> [Param Type] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input) -> (Param Type -> Ident) -> Param Type -> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param Type]
ps]
let inner_fun :: Lambda SOACS
inner_fun =
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type]
[LParam SOACS]
ps,
lambdaReturnType :: [Type]
lambdaReturnType = [Type]
retTypes,
lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
inner_body
}
([Input] -> SOAC) -> TryFusion ([Input] -> SOAC)
forall (m :: * -> *) a. Monad m => a -> m a
return (([Input] -> SOAC) -> TryFusion ([Input] -> SOAC))
-> ([Input] -> SOAC) -> TryFusion ([Input] -> SOAC)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC)
-> ScremaForm SOACS -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
inner_fun
[Input] -> SOAC
op' <-
(([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC))
-> ([Input] -> SOAC)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC
forall rep. SubExp -> ScremaForm rep -> [Input] -> SOAC rep
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC)
-> ScremaForm SOACS -> [Input] -> SOAC
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC)
forall a b. (a -> b) -> a -> b
$
[SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape
(SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Input] -> SOAC
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC
_ ArrayTransforms
_ = String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull reshape"
exposeInputs ::
[VName] ->
FusedKer ->
TryFusion (FusedKer, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
inpIds FusedKer
ker =
(FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pushRearrange')
TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pullRearrange')
TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker
where
ot :: ArrayTransforms
ot = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker
pushRearrange' :: TryFusion FusedKer
pushRearrange' = do
(SOAC
soac', ArrayTransforms
ot') <- [VName]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
ot
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
FusedKer
ker
{ fsoac :: SOAC
fsoac = SOAC
soac',
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
ot'
}
pullRearrange' :: TryFusion FusedKer
pullRearrange' = do
(SOAC
soac', ArrayTransforms
ot') <- SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange (FusedKer -> SOAC
fsoac FusedKer
ker) ArrayTransforms
ot
Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"pullRearrange was not enough"
FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
FusedKer
ker
{ fsoac :: SOAC
fsoac = SOAC
soac',
outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms
}
exposeInputs' :: FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker' =
case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [Input]
inputs FusedKer
ker' of
(ArrayTransforms
ot', [Input]
inps')
| (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
(FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer
ker' {fsoac :: SOAC
fsoac = [Input]
inps' [Input] -> SOAC -> SOAC
forall rep. [Input] -> SOAC rep -> SOAC rep
`SOAC.setInputs` FusedKer -> SOAC
fsoac FusedKer
ker'}, ArrayTransforms
ot')
(ArrayTransforms, [Input])
_ -> String -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot expose"
exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ Type
_)
| ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds
outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
outputTransformPullers = [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullRearrange, SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullReshape]
pullOutputTransforms ::
SOAC ->
SOAC.ArrayTransforms ->
TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms = [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
-> SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
forall t t.
[t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)]
outputTransformPullers
where
attempt :: [t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [] t
_ t
_ = String -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull anything"
attempt (t -> t -> TryFusion (SOAC, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC, ArrayTransforms)]
ps) t
soac t
ots =
do
(SOAC
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC, ArrayTransforms)
p t
soac t
ots
if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
then (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
SOAC.noTransforms)
else SOAC -> ArrayTransforms -> TryFusion (SOAC, ArrayTransforms)
pullOutputTransforms SOAC
soac' ArrayTransforms
ots' TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC, ArrayTransforms) -> TryFusion (SOAC, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC
soac', ArrayTransforms
ots')
TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
-> TryFusion (SOAC, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC, ArrayTransforms)]
ps t
soac t
ots