{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Optimise.Fusion.LoopKernel
  ( FusedKer (..),
    newKernel,
    inputs,
    setInputs,
    arrInputs,
    transformOutput,
    attemptFusion,
    SOAC,
    MapNest,
  )
where

import Control.Applicative
import Control.Arrow (first)
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Data.List (find, tails, (\\))
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import qualified Futhark.Analysis.HORep.MapNest as MapNest
import qualified Futhark.Analysis.HORep.SOAC as SOAC
import Futhark.Construct
import Futhark.IR.SOACS hiding (SOAC (..))
import qualified Futhark.IR.SOACS as Futhark
import Futhark.Optimise.Fusion.Composing
import Futhark.Pass.ExtractKernels.ISRWIM (rwimPossible)
import Futhark.Transform.Rename (renameLambda)
import Futhark.Transform.Substitute
import Futhark.Util (splitAt3)

newtype TryFusion a
  = TryFusion
      ( ReaderT
          (Scope SOACS)
          (StateT VNameSource Maybe)
          a
      )
  deriving
    ( (forall a b. (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b. a -> TryFusion b -> TryFusion a)
-> Functor TryFusion
forall a b. a -> TryFusion b -> TryFusion a
forall a b. (a -> b) -> TryFusion a -> TryFusion b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> TryFusion b -> TryFusion a
$c<$ :: forall a b. a -> TryFusion b -> TryFusion a
fmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
$cfmap :: forall a b. (a -> b) -> TryFusion a -> TryFusion b
Functor,
      Functor TryFusion
Functor TryFusion
-> (forall a. a -> TryFusion a)
-> (forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b)
-> (forall a b c.
    (a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion a)
-> Applicative TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
$c<* :: forall a b. TryFusion a -> TryFusion b -> TryFusion a
*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c*> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
liftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
$cliftA2 :: forall a b c.
(a -> b -> c) -> TryFusion a -> TryFusion b -> TryFusion c
<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
$c<*> :: forall a b. TryFusion (a -> b) -> TryFusion a -> TryFusion b
pure :: forall a. a -> TryFusion a
$cpure :: forall a. a -> TryFusion a
Applicative,
      Applicative TryFusion
Applicative TryFusion
-> (forall a. TryFusion a)
-> (forall a. TryFusion a -> TryFusion a -> TryFusion a)
-> (forall a. TryFusion a -> TryFusion [a])
-> (forall a. TryFusion a -> TryFusion [a])
-> Alternative TryFusion
forall a. TryFusion a
forall a. TryFusion a -> TryFusion [a]
forall a. TryFusion a -> TryFusion a -> TryFusion a
forall (f :: * -> *).
Applicative f
-> (forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
many :: forall a. TryFusion a -> TryFusion [a]
$cmany :: forall a. TryFusion a -> TryFusion [a]
some :: forall a. TryFusion a -> TryFusion [a]
$csome :: forall a. TryFusion a -> TryFusion [a]
<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
$c<|> :: forall a. TryFusion a -> TryFusion a -> TryFusion a
empty :: forall a. TryFusion a
$cempty :: forall a. TryFusion a
Alternative,
      Applicative TryFusion
Applicative TryFusion
-> (forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b)
-> (forall a b. TryFusion a -> TryFusion b -> TryFusion b)
-> (forall a. a -> TryFusion a)
-> Monad TryFusion
forall a. a -> TryFusion a
forall a b. TryFusion a -> TryFusion b -> TryFusion b
forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> TryFusion a
$creturn :: forall a. a -> TryFusion a
>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
$c>> :: forall a b. TryFusion a -> TryFusion b -> TryFusion b
>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
$c>>= :: forall a b. TryFusion a -> (a -> TryFusion b) -> TryFusion b
Monad,
      Monad TryFusion
Monad TryFusion
-> (forall a. String -> TryFusion a) -> MonadFail TryFusion
forall a. String -> TryFusion a
forall (m :: * -> *).
Monad m -> (forall a. String -> m a) -> MonadFail m
fail :: forall a. String -> TryFusion a
$cfail :: forall a. String -> TryFusion a
MonadFail,
      Monad TryFusion
Applicative TryFusion
TryFusion VNameSource
Applicative TryFusion
-> Monad TryFusion
-> TryFusion VNameSource
-> (VNameSource -> TryFusion ())
-> MonadFreshNames TryFusion
VNameSource -> TryFusion ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> TryFusion ()
$cputNameSource :: VNameSource -> TryFusion ()
getNameSource :: TryFusion VNameSource
$cgetNameSource :: TryFusion VNameSource
MonadFreshNames,
      HasScope SOACS,
      LocalScope SOACS
    )

tryFusion ::
  MonadFreshNames m =>
  TryFusion a ->
  Scope SOACS ->
  m (Maybe a)
tryFusion :: forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion (TryFusion ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m) Scope SOACS
types = (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a))
-> (VNameSource -> (Maybe a, VNameSource)) -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ \VNameSource
src ->
  case StateT VNameSource Maybe a -> VNameSource -> Maybe (a, VNameSource)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT (ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
-> Scope SOACS -> StateT VNameSource Maybe a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT (Scope SOACS) (StateT VNameSource Maybe) a
m Scope SOACS
types) VNameSource
src of
    Just (a
x, VNameSource
src') -> (a -> Maybe a
forall a. a -> Maybe a
Just a
x, VNameSource
src')
    Maybe (a, VNameSource)
Nothing -> (Maybe a
forall a. Maybe a
Nothing, VNameSource
src)

liftMaybe :: Maybe a -> TryFusion a
liftMaybe :: forall a. Maybe a -> TryFusion a
liftMaybe Maybe a
Nothing = String -> TryFusion a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Nothing"
liftMaybe (Just a
x) = a -> TryFusion a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x

type SOAC = SOAC.SOAC SOACS

type MapNest = MapNest.MapNest SOACS

-- XXX: This function is very gross.
transformOutput ::
  SOAC.ArrayTransforms ->
  [VName] ->
  [Ident] ->
  Binder SOACS ()
transformOutput :: ArrayTransforms -> [VName] -> [Ident] -> Binder SOACS ()
transformOutput ArrayTransforms
ts [VName]
names = ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts
  where
    descend :: ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts' [Ident]
validents =
      case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts' of
        ViewF
SOAC.EmptyF ->
          [(VName, Ident)]
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> [Ident] -> [(VName, Ident)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Ident]
validents) (((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ())
-> ((VName, Ident) -> Binder SOACS ()) -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(VName
k, Ident
valident) ->
            [VName]
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
k] (Exp (Lore (BinderT SOACS (State VNameSource))) -> Binder SOACS ())
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT SOACS) -> BasicOp -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
valident
        ArrayTransform
t SOAC.:< ArrayTransforms
ts'' -> do
          let ([BasicOp]
es, [Certificates]
css) = [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(BasicOp, Certificates)] -> ([BasicOp], [Certificates]))
-> [(BasicOp, Certificates)] -> ([BasicOp], [Certificates])
forall a b. (a -> b) -> a -> b
$ (Ident -> (BasicOp, Certificates))
-> [Ident] -> [(BasicOp, Certificates)]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform ArrayTransform
t) [Ident]
validents
              mkPat :: Ident -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
mkPat (Ident VName
nm TypeBase (ShapeBase SubExp) NoUniqueness
tp) = [PatElemT (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [PatElemT (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] [VName
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> PatElemT (TypeBase (ShapeBase SubExp) NoUniqueness)
forall dec. VName -> dec -> PatElemT dec
PatElem VName
nm TypeBase (ShapeBase SubExp) NoUniqueness
tp]
          [TypeBase (ShapeBase SubExp) NoUniqueness]
opts <- [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[TypeBase (ShapeBase SubExp) NoUniqueness]]
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> BinderT
     SOACS
     (State VNameSource)
     [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> BinderT
     SOACS
     (State VNameSource)
     [TypeBase (ShapeBase SubExp) NoUniqueness]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BasicOp
 -> BinderT
      SOACS
      (State VNameSource)
      [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [BasicOp]
-> BinderT
     SOACS
     (State VNameSource)
     [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM BasicOp
-> BinderT
     SOACS
     (State VNameSource)
     [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore (m :: * -> *).
HasScope lore m =>
BasicOp -> m [TypeBase (ShapeBase SubExp) NoUniqueness]
primOpType [BasicOp]
es
          [Ident]
newIds <- [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
    -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [TypeBase (ShapeBase SubExp) NoUniqueness]
opts) (((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
  -> BinderT SOACS (State VNameSource) Ident)
 -> BinderT SOACS (State VNameSource) [Ident])
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
    -> BinderT SOACS (State VNameSource) Ident)
-> BinderT SOACS (State VNameSource) [Ident]
forall a b. (a -> b) -> a -> b
$ \(VName
k, TypeBase (ShapeBase SubExp) NoUniqueness
opt) ->
            String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> BinderT SOACS (State VNameSource) Ident
forall (m :: * -> *).
MonadFreshNames m =>
String -> TypeBase (ShapeBase SubExp) NoUniqueness -> m Ident
newIdent (VName -> String
baseString VName
k) TypeBase (ShapeBase SubExp) NoUniqueness
opt
          [(Certificates, Ident, BasicOp)]
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Certificates]
-> [Ident] -> [BasicOp] -> [(Certificates, Ident, BasicOp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Certificates]
css [Ident]
newIds [BasicOp]
es) (((Certificates, Ident, BasicOp) -> Binder SOACS ())
 -> Binder SOACS ())
-> ((Certificates, Ident, BasicOp) -> Binder SOACS ())
-> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ \(Certificates
cs, Ident
ids, BasicOp
e) ->
            Certificates -> Binder SOACS () -> Binder SOACS ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying Certificates
cs (Binder SOACS () -> Binder SOACS ())
-> Binder SOACS () -> Binder SOACS ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (BinderT SOACS (State VNameSource)))
-> Exp (Lore (BinderT SOACS (State VNameSource)))
-> Binder SOACS ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind (Ident -> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
mkPat Ident
ids) (BasicOp -> ExpT SOACS
forall lore. BasicOp -> ExpT lore
BasicOp BasicOp
e)
          ArrayTransforms -> [Ident] -> Binder SOACS ()
descend ArrayTransforms
ts'' [Ident]
newIds

applyTransform :: SOAC.ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform :: ArrayTransform -> Ident -> (BasicOp, Certificates)
applyTransform (SOAC.Rearrange Certificates
cs [Int]
perm) Ident
v =
  ([Int] -> VName -> BasicOp
Rearrange [Int]
perm' (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
  where
    perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm) [Int
0 .. TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v) Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
applyTransform (SOAC.Reshape Certificates
cs ShapeChange SubExp
shape) Ident
v =
  (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shape (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) Ident
v =
  let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> ShapeBase SubExp -> ShapeChange SubExp
reshapeOuter ShapeChange SubExp
shape Int
1 (ShapeBase SubExp -> ShapeChange SubExp)
-> ShapeBase SubExp -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v
   in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.ReshapeInner Certificates
cs ShapeChange SubExp
shape) Ident
v =
  let shapes :: ShapeChange SubExp
shapes = ShapeChange SubExp -> Int -> ShapeBase SubExp -> ShapeChange SubExp
reshapeInner ShapeChange SubExp
shape Int
1 (ShapeBase SubExp -> ShapeChange SubExp)
-> ShapeBase SubExp -> ShapeChange SubExp
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> ShapeBase SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v
   in (ShapeChange SubExp -> VName -> BasicOp
Reshape ShapeChange SubExp
shapes (VName -> BasicOp) -> VName -> BasicOp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)
applyTransform (SOAC.Replicate Certificates
cs ShapeBase SubExp
n) Ident
v =
  (ShapeBase SubExp -> SubExp -> BasicOp
Replicate ShapeBase SubExp
n (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Ident -> VName
identName Ident
v, Certificates
cs)

inputToOutput :: SOAC.Input -> Maybe (SOAC.ArrayTransform, SOAC.Input)
inputToOutput :: Input -> Maybe (ArrayTransform, Input)
inputToOutput (SOAC.Input ArrayTransforms
ts VName
ia TypeBase (ShapeBase SubExp) NoUniqueness
iat) =
  case ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ts of
    ArrayTransform
t SOAC.:< ArrayTransforms
ts' -> (ArrayTransform, Input) -> Maybe (ArrayTransform, Input)
forall a. a -> Maybe a
Just (ArrayTransform
t, ArrayTransforms
-> VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts' VName
ia TypeBase (ShapeBase SubExp) NoUniqueness
iat)
    ViewF
SOAC.EmptyF -> Maybe (ArrayTransform, Input)
forall a. Maybe a
Nothing

data FusedKer = FusedKer
  { -- | the SOAC expression, e.g., mapT( f(a,b), x, y )
    FusedKer -> SOAC SOACS
fsoac :: SOAC,
    -- | Variables used in in-place updates in the kernel itself, as
    -- well as on the path to the kernel from the current position.
    -- This is used to avoid fusion that would violate in-place
    -- restrictions.
    FusedKer -> Names
inplace :: Names,
    -- | whether at least a fusion has been performed.
    FusedKer -> [VName]
fusedVars :: [VName],
    -- | The set of variables that were consumed by the SOACs
    -- contributing to this kernel.  Note that, by the type rules, the
    -- final SOAC may actually consume _more_ than its original
    -- contributors, which implies the need for 'Copy' expressions.
    FusedKer -> Names
fusedConsumed :: Names,
    -- | The names in scope at the kernel.
    FusedKer -> Scope SOACS
kernelScope :: Scope SOACS,
    FusedKer -> ArrayTransforms
outputTransform :: SOAC.ArrayTransforms,
    FusedKer -> [VName]
outNames :: [VName],
    FusedKer -> StmAux ()
kerAux :: StmAux ()
  }
  deriving (Int -> FusedKer -> ShowS
[FusedKer] -> ShowS
FusedKer -> String
(Int -> FusedKer -> ShowS)
-> (FusedKer -> String) -> ([FusedKer] -> ShowS) -> Show FusedKer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FusedKer] -> ShowS
$cshowList :: [FusedKer] -> ShowS
show :: FusedKer -> String
$cshow :: FusedKer -> String
showsPrec :: Int -> FusedKer -> ShowS
$cshowsPrec :: Int -> FusedKer -> ShowS
Show)

newKernel :: StmAux () -> SOAC -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel :: StmAux ()
-> SOAC SOACS -> Names -> [VName] -> Scope SOACS -> FusedKer
newKernel StmAux ()
aux SOAC SOACS
soac Names
consumed [VName]
out_nms Scope SOACS
scope =
  FusedKer :: SOAC SOACS
-> Names
-> [VName]
-> Names
-> Scope SOACS
-> ArrayTransforms
-> [VName]
-> StmAux ()
-> FusedKer
FusedKer
    { fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac,
      inplace :: Names
inplace = Names
consumed,
      fusedVars :: [VName]
fusedVars = [],
      fusedConsumed :: Names
fusedConsumed = Names
consumed,
      outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms,
      outNames :: [VName]
outNames = [VName]
out_nms,
      kernelScope :: Scope SOACS
kernelScope = Scope SOACS
scope,
      kerAux :: StmAux ()
kerAux = StmAux ()
aux
    }

arrInputs :: FusedKer -> S.Set VName
arrInputs :: FusedKer -> Set VName
arrInputs = [VName] -> Set VName
forall a. Ord a => [a] -> Set a
S.fromList ([VName] -> Set VName)
-> (FusedKer -> [VName]) -> FusedKer -> Set VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Input -> VName) -> [Input] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Input -> VName
SOAC.inputArray ([Input] -> [VName])
-> (FusedKer -> [Input]) -> FusedKer -> [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> [Input]
inputs

inputs :: FusedKer -> [SOAC.Input]
inputs :: FusedKer -> [Input]
inputs = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs (SOAC SOACS -> [Input])
-> (FusedKer -> SOAC SOACS) -> FusedKer -> [Input]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. FusedKer -> SOAC SOACS
fsoac

setInputs :: [SOAC.Input] -> FusedKer -> FusedKer
setInputs :: [Input] -> FusedKer -> FusedKer
setInputs [Input]
inps FusedKer
ker = FusedKer
ker {fsoac :: SOAC SOACS
fsoac = [Input]
inps [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC SOACS
fsoac FusedKer
ker}

tryOptimizeSOAC ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
tryOptimizeSOAC :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker = do
  (SOAC SOACS
soac', ArrayTransforms
ots) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
forall a. Maybe a
Nothing SOAC SOACS
soac ArrayTransforms
forall a. Monoid a => a
mempty
  let ker' :: FusedKer
ker' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots) (FusedKer -> [Input]
inputs FusedKer
ker) [Input] -> FusedKer -> FusedKer
`setInputs` FusedKer
ker
      outIdents :: [Ident]
outIdents = (VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident)
-> [VName] -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
      ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
  Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac' Names
consumed FusedKer
ker''
  where
    addInitialTransformIfRelevant :: ArrayTransforms -> Input -> Input
addInitialTransformIfRelevant ArrayTransforms
ots Input
inp
      | Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
outVars =
        ArrayTransforms -> Input -> Input
SOAC.addInitialTransforms ArrayTransforms
ots Input
inp
      | Bool
otherwise =
        Input
inp

tryOptimizeKernel ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
tryOptimizeKernel :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker = do
  FusedKer
ker' <- Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel ([VName] -> Maybe [VName]
forall a. a -> Maybe a
Just [VName]
outVars) FusedKer
ker
  Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker'

tryExposeInputs ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
tryExposeInputs :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker = do
  (FusedKer
ker', ArrayTransforms
ots) <- [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
outVars FusedKer
ker
  if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots
    then Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker'
    else do
      (SOAC SOACS
soac', ArrayTransforms
ots') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac ArrayTransforms
ots
      let outIdents :: [Ident]
outIdents = (VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident)
-> [VName] -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac'
          ker'' :: FusedKer
ker'' = [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker'
      if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
        then Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac' Names
consumed FusedKer
ker''
        else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"tryExposeInputs could not pull SOAC transforms"

fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes :: [Ident] -> FusedKer -> FusedKer
fixInputTypes [Ident]
outIdents FusedKer
ker =
  FusedKer
ker {fsoac :: SOAC SOACS
fsoac = SOAC SOACS -> SOAC SOACS
fixInputTypes' (SOAC SOACS -> SOAC SOACS) -> SOAC SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ FusedKer -> SOAC SOACS
fsoac FusedKer
ker}
  where
    fixInputTypes' :: SOAC SOACS -> SOAC SOACS
fixInputTypes' SOAC SOACS
soac =
      (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
fixInputType (SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac) [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC SOACS
soac
    fixInputType :: Input -> Input
fixInputType (SOAC.Input ArrayTransforms
ts VName
v TypeBase (ShapeBase SubExp) NoUniqueness
_)
      | Just Ident
v' <- (Ident -> Bool) -> [Ident] -> Maybe Ident
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool) -> (Ident -> VName) -> Ident -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ident -> VName
identName) [Ident]
outIdents =
        ArrayTransforms
-> VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Input
SOAC.Input ArrayTransforms
ts VName
v (TypeBase (ShapeBase SubExp) NoUniqueness -> Input)
-> TypeBase (ShapeBase SubExp) NoUniqueness -> Input
forall a b. (a -> b) -> a -> b
$ Ident -> TypeBase (ShapeBase SubExp) NoUniqueness
identType Ident
v'
    fixInputType Input
inp = Input
inp

applyFusionRules ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
applyFusionRules :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker =
  Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeSOAC Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
    TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryOptimizeKernel Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
    TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker
    TryFusion FusedKer -> TryFusion FusedKer -> TryFusion FusedKer
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
tryExposeInputs Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker

attemptFusion ::
  MonadFreshNames m =>
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  m (Maybe FusedKer)
attemptFusion :: forall (m :: * -> *).
MonadFreshNames m =>
Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> m (Maybe FusedKer)
attemptFusion Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker =
  (FusedKer -> FusedKer) -> Maybe FusedKer -> Maybe FusedKer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap FusedKer -> FusedKer
removeUnusedParamsFromKer
    (Maybe FusedKer -> Maybe FusedKer)
-> m (Maybe FusedKer) -> m (Maybe FusedKer)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TryFusion FusedKer -> Scope SOACS -> m (Maybe FusedKer)
forall (m :: * -> *) a.
MonadFreshNames m =>
TryFusion a -> Scope SOACS -> m (Maybe a)
tryFusion
      (Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
applyFusionRules Names
unfus_nms [VName]
outVars SOAC SOACS
soac Names
consumed FusedKer
ker)
      (FusedKer -> Scope SOACS
kernelScope FusedKer
ker)

removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer :: FusedKer -> FusedKer
removeUnusedParamsFromKer FusedKer
ker =
  case SOAC SOACS
soac of
    SOAC.Screma {} -> FusedKer
ker {fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac'}
    SOAC SOACS
_ -> FusedKer
ker
  where
    soac :: SOAC SOACS
soac = FusedKer -> SOAC SOACS
fsoac FusedKer
ker
    l :: Lambda SOACS
l = SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
soac
    inps :: [Input]
inps = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac
    (Lambda SOACS
l', [Input]
inps') = Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps
    soac' :: SOAC SOACS
soac' =
      Lambda SOACS
l'
        Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` ([Input]
inps' [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` SOAC SOACS
soac)

removeUnusedParams :: Lambda -> [SOAC.Input] -> (Lambda, [SOAC.Input])
removeUnusedParams :: Lambda SOACS -> [Input] -> (Lambda SOACS, [Input])
removeUnusedParams Lambda SOACS
l [Input]
inps =
  (Lambda SOACS
l {lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
[LParam SOACS]
ps'}, [Input]
inps')
  where
    pInps :: [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
pInps = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Input]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
l) [Input]
inps
    ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps', [Input]
inps') = case ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)], [Input])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
 -> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)], [Input]))
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)], [Input])
forall a b. (a -> b) -> a -> b
$ ((Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input) -> Bool)
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
-> [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool
used (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool)
-> ((Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)
    -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> (Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a, b) -> a
fst) [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
pInps, [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
pInps) of
      (([], []), (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p, Input
inp) : [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
_) -> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p], [Input
inp])
      (([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps_, [Input]
inps_), [(Param (TypeBase (ShapeBase SubExp) NoUniqueness), Input)]
_) -> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps_, [Input]
inps_)
    used :: Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool
used Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p = Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
p VName -> Names -> Bool
`nameIn` Names
freeVars
    freeVars :: Names
freeVars = BodyT SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn (BodyT SOACS -> Names) -> BodyT SOACS -> Names
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
l

-- | Check that the consumer uses at least one output of the producer
-- unmodified.
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK :: [VName] -> FusedKer -> Bool
mapFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
  where
    inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)

-- | Check that the consumer uses all the outputs of the producer unmodified.
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK :: [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker = (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) [VName]
outVars
  where
    inpIds :: [VName]
inpIds = (Input -> Maybe VName) -> [Input] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe VName
SOAC.isVarishInput (FusedKer -> [Input]
inputs FusedKer
ker)

-- | The brain of this module: Fusing a SOAC with a Kernel.
fuseSOACwithKer ::
  Names ->
  [VName] ->
  SOAC ->
  Names ->
  FusedKer ->
  TryFusion FusedKer
fuseSOACwithKer :: Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC SOACS
soac_p Names
soac_p_consumed FusedKer
ker = do
  -- We are fusing soac_p into soac_c, i.e, the output of soac_p is going
  -- into soac_c.
  let soac_c :: SOAC SOACS
soac_c = FusedKer -> SOAC SOACS
fsoac FusedKer
ker
      inp_p_arr :: [Input]
inp_p_arr = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac_p
      horizFuse :: Bool
horizFuse =
        Names
unfus_set Names -> Names -> Bool
forall a. Eq a => a -> a -> Bool
/= Names
forall a. Monoid a => a
mempty
          Bool -> Bool -> Bool
&& SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_c
      inp_c_arr :: [Input]
inp_c_arr = SOAC SOACS -> [Input]
forall lore. SOAC lore -> [Input]
SOAC.inputs SOAC SOACS
soac_c
      lam_p :: Lambda SOACS
lam_p = SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
soac_p
      lam_c :: Lambda SOACS
lam_c = SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
soac_c
      w :: SubExp
w = SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_p
      returned_outvars :: [VName]
returned_outvars = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
      success :: [VName] -> SOAC SOACS -> TryFusion FusedKer
success [VName]
res_outnms SOAC SOACS
res_soac = do
        let fusedVars_new :: [VName]
fusedVars_new = FusedKer -> [VName]
fusedVars FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars
        -- Avoid name duplication, because the producer lambda is not
        -- removed from the program until much later.
        Lambda SOACS
uniq_lam <- Lambda SOACS -> TryFusion (Lambda SOACS)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Lambda lore -> m (Lambda lore)
renameLambda (Lambda SOACS -> TryFusion (Lambda SOACS))
-> Lambda SOACS -> TryFusion (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> Lambda SOACS
forall lore. SOAC lore -> Lambda lore
SOAC.lambda SOAC SOACS
res_soac
        FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
          FusedKer
ker
            { fsoac :: SOAC SOACS
fsoac = Lambda SOACS
uniq_lam Lambda SOACS -> SOAC SOACS -> SOAC SOACS
forall lore. Lambda lore -> SOAC lore -> SOAC lore
`SOAC.setLambda` SOAC SOACS
res_soac,
              fusedVars :: [VName]
fusedVars = [VName]
fusedVars_new,
              inplace :: Names
inplace = FusedKer -> Names
inplace FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
              fusedConsumed :: Names
fusedConsumed = FusedKer -> Names
fusedConsumed FusedKer
ker Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
soac_p_consumed,
              outNames :: [VName]
outNames = [VName]
res_outnms
            }

  [(VName, Ident)]
outPairs <- [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
    -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p) (((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
  -> TryFusion (VName, Ident))
 -> TryFusion [(VName, Ident)])
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness)
    -> TryFusion (VName, Ident))
-> TryFusion [(VName, Ident)]
forall a b. (a -> b) -> a -> b
$ \(VName
outVar, TypeBase (ShapeBase SubExp) NoUniqueness
t) -> do
    VName
outVar' <- String -> TryFusion VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> TryFusion VName) -> String -> TryFusion VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
outVar String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"_elem"
    (VName, Ident) -> TryFusion (VName, Ident)
forall (m :: * -> *) a. Monad m => a -> m a
return (VName
outVar, VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident VName
outVar' TypeBase (ShapeBase SubExp) NoUniqueness
t)

  let mapLikeFusionCheck :: ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck =
        let (Lambda SOACS
res_lam, [Input]
new_inp) = Names
-> Lambda SOACS
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> Lambda lore
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [Input]
-> (Lambda lore, [Input])
fuseMaps Names
unfus_set Lambda SOACS
lam_p [Input]
inp_p_arr [(VName, Ident)]
outPairs Lambda SOACS
lam_c [Input]
inp_c_arr
            ([VName]
extra_nms, [TypeBase (ShapeBase SubExp) NoUniqueness]
extra_rtps) =
              [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([VName], [TypeBase (ShapeBase SubExp) NoUniqueness])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> ([VName], [TypeBase (ShapeBase SubExp) NoUniqueness]))
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([VName], [TypeBase (ShapeBase SubExp) NoUniqueness])
forall a b. (a -> b) -> a -> b
$
                ((VName, TypeBase (ShapeBase SubExp) NoUniqueness) -> Bool)
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((VName -> Names -> Bool
`nameIn` Names
unfus_set) (VName -> Bool)
-> ((VName, TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> (VName, TypeBase (ShapeBase SubExp) NoUniqueness)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall a b. (a, b) -> a
fst) ([(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$
                  [VName]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
outVars ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(VName, TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1) ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ SOAC SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
SOAC lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
SOAC.typeOf SOAC SOACS
soac_p
            res_lam' :: Lambda SOACS
res_lam' = Lambda SOACS
res_lam {lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
res_lam [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ [TypeBase (ShapeBase SubExp) NoUniqueness]
extra_rtps}
         in ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp)

  Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool
horizFuse Bool -> Bool -> Bool
&& Bool -> Bool
not (ArrayTransforms -> Bool
SOAC.nullTransforms (ArrayTransforms -> Bool) -> ArrayTransforms -> Bool
forall a b. (a -> b) -> a -> b
$ FusedKer -> ArrayTransforms
outputTransform FusedKer
ker)) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
    String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Horizontal fusion is invalid in the presence of output transforms."

  case (SOAC SOACS
soac_c, SOAC SOACS
soac_p) of
    (SOAC SOACS, SOAC SOACS)
_ | SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_p SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS -> SubExp
forall lore. SOAC lore -> SubExp
SOAC.width SOAC SOACS
soac_c -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC widths must match."
    ( SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_c [Reduce SOACS]
reds_c Lambda SOACS
_) [Input]
_,
      SOAC.Screma SubExp
_ (ScremaForm [Scan SOACS]
scans_p [Reduce SOACS]
reds_p Lambda SOACS
_) [Input]
_
      )
        | [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Scan SOACS] -> Int
forall lore. [Scan lore] -> Int
Futhark.scanResults [Scan SOACS]
scans_p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Reduce SOACS] -> Int
forall lore. [Reduce lore] -> Int
Futhark.redResults [Reduce SOACS]
reds_p) [VName]
outVars) FusedKer
ker
            Bool -> Bool -> Bool
|| Bool
horizFuse -> do
          let red_nes_p :: [SubExp]
red_nes_p = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_p
              red_nes_c :: [SubExp]
red_nes_c = (Reduce SOACS -> [SubExp]) -> [Reduce SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Reduce SOACS -> [SubExp]
forall lore. Reduce lore -> [SubExp]
redNeutral [Reduce SOACS]
reds_c
              scan_nes_p :: [SubExp]
scan_nes_p = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_p
              scan_nes_c :: [SubExp]
scan_nes_c = (Scan SOACS -> [SubExp]) -> [Scan SOACS] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Scan SOACS -> [SubExp]
forall lore. Scan lore -> [SubExp]
scanNeutral [Scan SOACS]
scans_c
              (Lambda SOACS
res_lam', [Input]
new_inp) =
                Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
                  Names
unfus_set
                  [VName]
outVars
                  Lambda SOACS
lam_p
                  [SubExp]
scan_nes_p
                  [SubExp]
red_nes_p
                  [Input]
inp_p_arr
                  [(VName, Ident)]
outPairs
                  Lambda SOACS
lam_c
                  [SubExp]
scan_nes_c
                  [SubExp]
red_nes_c
                  [Input]
inp_c_arr
              ([VName]
soac_p_scanout, [VName]
soac_p_redout, [VName]
_soac_p_mapout) =
                Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_p) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_p) [VName]
outVars
              ([VName]
soac_c_scanout, [VName]
soac_c_redout, [VName]
soac_c_mapout) =
                Int -> Int -> [VName] -> ([VName], [VName], [VName])
forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
scan_nes_c) ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
red_nes_c) ([VName] -> ([VName], [VName], [VName]))
-> [VName] -> ([VName], [VName], [VName])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [VName]
outNames FusedKer
ker
              unfus_arrs :: [VName]
unfus_arrs = [VName]
returned_outvars [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ ([VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout)
          [VName] -> SOAC SOACS -> TryFusion FusedKer
success
            ( [VName]
soac_p_scanout [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_scanout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_p_redout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_redout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
soac_c_mapout
                [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs
            )
            (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma
              SubExp
w
              ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm ([Scan SOACS]
scans_p [Scan SOACS] -> [Scan SOACS] -> [Scan SOACS]
forall a. [a] -> [a] -> [a]
++ [Scan SOACS]
scans_c) ([Reduce SOACS]
reds_p [Reduce SOACS] -> [Reduce SOACS] -> [Reduce SOACS]
forall a. [a] -> [a] -> [a]
++ [Reduce SOACS]
reds_c) Lambda SOACS
res_lam')
              [Input]
new_inp

    ------------------
    -- Scatter fusion --
    ------------------

    -- Map-Scatter fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Scatter is not writing to any array used in
    -- the Map.
    ( SOAC.Scatter SubExp
_len Lambda SOACS
_lam [Input]
_ivs [(ShapeBase SubExp, Int, VName)]
dests,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
      )
        | Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the scatter, i.e., not used elsewhere.
          Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
          let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
          [VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp
-> Lambda SOACS
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall lore.
SubExp
-> Lambda lore
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
res_lam' [Input]
new_inp [(ShapeBase SubExp, Int, VName)]
dests

    -- Map-Hist fusion.
    --
    -- The 'inplace' mechanism for kernels already takes care of
    -- checking that the Hist is not writing to any array used in
    -- the Map.
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops Lambda SOACS
_ [Input]
_,
      SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_
      )
        | Maybe (Lambda SOACS) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Lambda SOACS) -> Bool) -> Maybe (Lambda SOACS) -> Bool
forall a b. (a -> b) -> a -> b
$ ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
isMapSOAC ScremaForm SOACS
form,
          -- 1. all arrays produced by the map are ONLY used (consumed)
          --    by the hist, i.e., not used elsewhere.
          Bool -> Bool
not ((VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars),
          -- 2. all arrays produced by the map are input to the scatter.
          [VName] -> FusedKer -> Bool
mapWriteFusionOK [VName]
outVars FusedKer
ker -> do
          let ([VName]
extra_nms, Lambda SOACS
res_lam', [Input]
new_inp) = ([VName], Lambda SOACS, [Input])
mapLikeFusionCheck
          [VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
extra_nms) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w [HistOp SOACS]
ops Lambda SOACS
res_lam' [Input]
new_inp

    -- Hist-Hist fusion
    ( SOAC.Hist SubExp
_ [HistOp SOACS]
ops_c Lambda SOACS
_ [Input]
_,
      SOAC.Hist SubExp
_ [HistOp SOACS]
ops_p Lambda SOACS
_ [Input]
_
      )
        | Bool
horizFuse -> do
          let p_num_buckets :: Int
p_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_p
              c_num_buckets :: Int
c_num_buckets = [HistOp SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [HistOp SOACS]
ops_c
              (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
              body' :: BodyT SOACS
body' =
                Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
                  { bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p, -- body_p and body_c have the same lores
                    bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
                    bodyResult :: [SubExp]
bodyResult =
                      Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
                        [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
                        [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c)
                        [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
                  }
              lam' :: Lambda SOACS
lam' =
                Lambda :: forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda
                  { lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
                    lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
                    lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType =
                      Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> a -> [a]
replicate (Int
c_num_buckets Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
p_num_buckets) (PrimType -> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
                        [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
c_num_buckets (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c)
                        [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. [a] -> [a] -> [a]
++ Int
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a. Int -> [a] -> [a]
drop Int
p_num_buckets (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
                  }
          [VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp -> [HistOp SOACS] -> Lambda SOACS -> [Input] -> SOAC SOACS
forall lore.
SubExp -> [HistOp lore] -> Lambda lore -> [Input] -> SOAC lore
SOAC.Hist SubExp
w ([HistOp SOACS]
ops_c [HistOp SOACS] -> [HistOp SOACS] -> [HistOp SOACS]
forall a. Semigroup a => a -> a -> a
<> [HistOp SOACS]
ops_p) Lambda SOACS
lam' ([Input]
inp_c_arr [Input] -> [Input] -> [Input]
forall a. Semigroup a => a -> a -> a
<> [Input]
inp_p_arr)

    -- Scatter-write fusion.
    ( SOAC.Scatter SubExp
_len_c Lambda SOACS
_lam_c [Input]
ivs_c [(ShapeBase SubExp, Int, VName)]
as_c,
      SOAC.Scatter SubExp
_len_p Lambda SOACS
_lam_p [Input]
ivs_p [(ShapeBase SubExp, Int, VName)]
as_p
      )
        | Bool
horizFuse -> do
          let zipW :: [(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, Int, array)] -> [a] -> [a]
zipW [(ShapeBase SubExp, Int, array)]
as_xs [a]
xs [(ShapeBase SubExp, Int, array)]
as_ys [a]
ys = [a]
xs_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_indices [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
xs_vals [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys_vals
                where
                  ([a]
xs_indices, [a]
xs_vals) = [(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
forall array a.
[(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(ShapeBase SubExp, Int, array)]
as_xs [a]
xs
                  ([a]
ys_indices, [a]
ys_vals) = [(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
forall array a.
[(ShapeBase SubExp, Int, array)] -> [a] -> ([a], [a])
splitScatterResults [(ShapeBase SubExp, Int, array)]
as_ys [a]
ys
          let (BodyT SOACS
body_p, BodyT SOACS
body_c) = (Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_p, Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
lam_c)
          let body' :: BodyT SOACS
body' =
                Body :: forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body
                  { bodyDec :: BodyDec SOACS
bodyDec = BodyT SOACS -> BodyDec SOACS
forall lore. BodyT lore -> BodyDec lore
bodyDec BodyT SOACS
body_p, -- body_p and body_c have the same lores
                    bodyStms :: Stms SOACS
bodyStms = BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_p Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> BodyT SOACS -> Stms SOACS
forall lore. BodyT lore -> Stms lore
bodyStms BodyT SOACS
body_c,
                    bodyResult :: [SubExp]
bodyResult = [(ShapeBase SubExp, Int, VName)]
-> [SubExp]
-> [(ShapeBase SubExp, Int, VName)]
-> [SubExp]
-> [SubExp]
forall {array} {a} {array}.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, Int, array)] -> [a] -> [a]
zipW [(ShapeBase SubExp, Int, VName)]
as_c (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_c) [(ShapeBase SubExp, Int, VName)]
as_p (BodyT SOACS -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult BodyT SOACS
body_p)
                  }
          let lam' :: Lambda SOACS
lam' =
                Lambda :: forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda
                  { lambdaParams :: [LParam SOACS]
lambdaParams = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_c [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam_p,
                    lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
body',
                    lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = [(ShapeBase SubExp, Int, VName)]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [(ShapeBase SubExp, Int, VName)]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall {array} {a} {array}.
[(ShapeBase SubExp, Int, array)]
-> [a] -> [(ShapeBase SubExp, Int, array)] -> [a] -> [a]
zipW [(ShapeBase SubExp, Int, VName)]
as_c (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_c) [(ShapeBase SubExp, Int, VName)]
as_p (Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam_p)
                  }
          [VName] -> SOAC SOACS -> TryFusion FusedKer
success (FusedKer -> [VName]
outNames FusedKer
ker [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
returned_outvars) (SOAC SOACS -> TryFusion FusedKer)
-> SOAC SOACS -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
            SubExp
-> Lambda SOACS
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall lore.
SubExp
-> Lambda lore
-> [Input]
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC lore
SOAC.Scatter SubExp
w Lambda SOACS
lam' ([Input]
ivs_c [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ [Input]
ivs_p) ([(ShapeBase SubExp, Int, VName)]
as_c [(ShapeBase SubExp, Int, VName)]
-> [(ShapeBase SubExp, Int, VName)]
-> [(ShapeBase SubExp, Int, VName)]
forall a. [a] -> [a] -> [a]
++ [(ShapeBase SubExp, Int, VName)]
as_p)
    (SOAC.Scatter {}, SOAC SOACS
_) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
    (SOAC SOACS
_, SOAC.Scatter {}) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a write with anything else than a write or a map"
    ----------------------------
    -- Stream-Stream Fusions: --
    ----------------------------
    (SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
nes [Input]
_)
      | [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
        -- fuse two SEQUENTIAL streams
        ([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
        [VName] -> SOAC SOACS -> TryFusion FusedKer
success [VName]
res_nms SOAC SOACS
res_stream
    (SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two SEQ streams!"
    (SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC.Stream {}) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
    (SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse a parallel with a sequential Stream!"
    (SOAC.Stream {}, SOAC.Stream SubExp
_ StreamForm SOACS
_ Lambda SOACS
_ [SubExp]
nes [Input]
_)
      | [VName] -> FusedKer -> Bool
mapFusionOK (Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes) [VName]
outVars) FusedKer
ker Bool -> Bool -> Bool
|| Bool
horizFuse -> do
        -- fuse two PARALLEL streams
        ([VName]
res_nms, SOAC SOACS
res_stream) <- [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper (FusedKer -> [VName]
outNames FusedKer
ker) Names
unfus_set [VName]
outVars [(VName, Ident)]
outPairs SOAC SOACS
soac_c SOAC SOACS
soac_p
        [VName] -> SOAC SOACS -> TryFusion FusedKer
success [VName]
res_nms SOAC SOACS
res_stream
    (SOAC.Stream {}, SOAC.Stream {}) ->
      String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusion conditions not met for two PAR streams!"
    -------------------------------------------------------------------
    --- If one is a stream, translate the other to a stream as well.---
    --- This does not get in trouble (infinite computation) because ---
    ---   scan's translation to Stream introduces a hindrance to    ---
    ---   (horizontal fusion), hence repeated application is for the---
    ---   moment impossible. However, if with a dependence-graph rep---
    ---   we could run in an infinite recursion, i.e., repeatedly   ---
    ---   fusing map o scan into an infinity of Stream levels!      ---
    -------------------------------------------------------------------
    (SOAC.Stream SubExp
_ StreamForm SOACS
form2 Lambda SOACS
_ [SubExp]
_ [Input]
_, SOAC SOACS
_) -> do
      -- If this rule is matched then soac_p is NOT a stream.
      -- To fuse a stream kernel, we transform soac_p to a stream, which
      -- borrows the sequential/parallel property of the soac_c Stream,
      -- and recursively perform stream-stream fusion.
      (SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
      SOAC SOACS
soac_p'' <- case StreamForm SOACS
form2 of
        Sequential {} -> SOAC SOACS -> TryFusion (SOAC SOACS)
toSeqStream SOAC SOACS
soac_p'
        StreamForm SOACS
_ -> SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC SOACS
soac_p'
      if SOAC SOACS
soac_p' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS
soac_p
        then String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
        else Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC SOACS
soac_p'' Names
soac_p_consumed FusedKer
ker
    (SOAC SOACS
_, SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
_) | Just [Scan SOACS]
_ <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form -> do
      -- A Scan soac can be currently only fused as a (sequential) stream,
      -- hence it is first translated to a (sequential) Stream and then
      -- fusion with a kernel is attempted.
      (SOAC SOACS
soac_p', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC SOACS
soac_p
      if SOAC SOACS
soac_p' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
/= SOAC SOACS
soac_p
        then Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
outVars) SOAC SOACS
soac_p' Names
soac_p_consumed FusedKer
ker
        else String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
    (SOAC SOACS
_, SOAC.Stream SubExp
_ StreamForm SOACS
form_p Lambda SOACS
_ [SubExp]
_ [Input]
_) -> do
      -- If it reached this case then soac_c is NOT a Stream kernel,
      -- hence transform the kernel's soac to a stream and attempt
      -- stream-stream fusion recursivelly.
      -- The newly created stream corresponding to soac_c borrows the
      -- sequential/parallel property of the soac_p stream.
      (SOAC SOACS
soac_c', [Ident]
newacc_ids) <- SOAC SOACS -> TryFusion (SOAC SOACS, [Ident])
forall (m :: * -> *) lore.
(MonadFreshNames m, Bindable lore, Op lore ~ SOAC lore) =>
SOAC lore -> m (SOAC lore, [Ident])
SOAC.soacToStream SOAC SOACS
soac_c
      Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (SOAC SOACS
soac_c' SOAC SOACS -> SOAC SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== SOAC SOACS
soac_c) (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$ String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"SOAC could not be turned into stream."
      SOAC SOACS
soac_c'' <- case StreamForm SOACS
form_p of
        StreamForm SOACS
Sequential -> SOAC SOACS -> TryFusion (SOAC SOACS)
toSeqStream SOAC SOACS
soac_c'
        StreamForm SOACS
_ -> SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC SOACS
soac_c'

      Names
-> [VName] -> SOAC SOACS -> Names -> FusedKer -> TryFusion FusedKer
fuseSOACwithKer Names
unfus_set [VName]
outVars SOAC SOACS
soac_p Names
soac_p_consumed (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
        FusedKer
ker {fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac_c'', outNames :: [VName]
outNames = (Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
newacc_ids [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ FusedKer -> [VName]
outNames FusedKer
ker}

    ---------------------------------
    --- DEFAULT, CANNOT FUSE CASE ---
    ---------------------------------
    (SOAC SOACS, SOAC SOACS)
_ -> String -> TryFusion FusedKer
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot fuse"

getStreamOrder :: StreamForm lore -> StreamOrd
getStreamOrder :: forall lore. StreamForm lore -> StreamOrd
getStreamOrder (Parallel StreamOrd
o Commutativity
_ Lambda lore
_) = StreamOrd
o
getStreamOrder StreamForm lore
Sequential = StreamOrd
InOrder

fuseStreamHelper ::
  [VName] ->
  Names ->
  [VName] ->
  [(VName, Ident)] ->
  SOAC ->
  SOAC ->
  TryFusion ([VName], SOAC)
fuseStreamHelper :: [VName]
-> Names
-> [VName]
-> [(VName, Ident)]
-> SOAC SOACS
-> SOAC SOACS
-> TryFusion ([VName], SOAC SOACS)
fuseStreamHelper
  [VName]
out_kernms
  Names
unfus_set
  [VName]
outVars
  [(VName, Ident)]
outPairs
  (SOAC.Stream SubExp
w2 StreamForm SOACS
form2 Lambda SOACS
lam2 [SubExp]
nes2 [Input]
inp2_arr)
  (SOAC.Stream SubExp
_ StreamForm SOACS
form1 Lambda SOACS
lam1 [SubExp]
nes1 [Input]
inp1_arr) =
    if StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form2 StreamOrd -> StreamOrd -> Bool
forall a. Eq a => a -> a -> Bool
/= StreamForm SOACS -> StreamOrd
forall lore. StreamForm lore -> StreamOrd
getStreamOrder StreamForm SOACS
form1
      then String -> TryFusion ([VName], SOAC SOACS)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"fusion conditions not met!"
      else do
        -- very similar to redomap o redomap composition, but need
        -- to remove first the `chunk' parameters of streams'
        -- lambdas and put them in the resulting stream lambda.
        let chunk1 :: Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk1 = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. [a] -> a
head ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1
            chunk2 :: Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk2 = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a. [a] -> a
head ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam2
            hmnms :: Map VName VName
hmnms = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk2, Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk1)]
            lam20 :: Lambda SOACS
lam20 = Map VName VName -> Lambda SOACS -> Lambda SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
hmnms Lambda SOACS
lam2
            lam1' :: Lambda SOACS
lam1' = Lambda SOACS
lam1 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a]
tail ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam1}
            lam2' :: Lambda SOACS
lam2' = Lambda SOACS
lam20 {lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a]
tail ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
lam20}
            (Lambda SOACS
res_lam', [Input]
new_inp) =
              Names
-> [VName]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda SOACS
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda SOACS, [Input])
forall lore.
Bindable lore =>
Names
-> [VName]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> [(VName, Ident)]
-> Lambda lore
-> [SubExp]
-> [SubExp]
-> [Input]
-> (Lambda lore, [Input])
fuseRedomap
                Names
unfus_set
                [VName]
outVars
                Lambda SOACS
lam1'
                []
                [SubExp]
nes1
                [Input]
inp1_arr
                [(VName, Ident)]
outPairs
                Lambda SOACS
lam2'
                []
                [SubExp]
nes2
                [Input]
inp2_arr
            res_lam'' :: Lambda SOACS
res_lam'' = Lambda SOACS
res_lam' {lambdaParams :: [LParam SOACS]
lambdaParams = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
chunk1 Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. a -> [a] -> [a]
: Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
res_lam'}
            unfus_accs :: [VName]
unfus_accs = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
nes1) [VName]
outVars
            unfus_arrs :: [VName]
unfus_arrs = (VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> Names -> Bool
`nameIn` Names
unfus_set) [VName]
outVars
        StreamForm SOACS
res_form <- StreamForm SOACS
-> StreamForm SOACS -> TryFusion (StreamForm SOACS)
forall {m :: * -> *} {lore}.
MonadFail m =>
StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm SOACS
form2 StreamForm SOACS
form1
        ([VName], SOAC SOACS) -> TryFusion ([VName], SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return
          ( [VName]
unfus_accs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
out_kernms [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
unfus_arrs,
            SubExp
-> StreamForm SOACS
-> Lambda SOACS
-> [SubExp]
-> [Input]
-> SOAC SOACS
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w2 StreamForm SOACS
res_form Lambda SOACS
res_lam'' ([SubExp]
nes1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
nes2) [Input]
new_inp
          )
    where
      mergeForms :: StreamForm lore -> StreamForm lore -> m (StreamForm lore)
mergeForms StreamForm lore
Sequential StreamForm lore
Sequential = StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return StreamForm lore
forall lore. StreamForm lore
Sequential
      mergeForms (Parallel StreamOrd
_ Commutativity
comm2 Lambda lore
lam2r) (Parallel StreamOrd
o1 Commutativity
comm1 Lambda lore
lam1r) =
        StreamForm lore -> m (StreamForm lore)
forall (m :: * -> *) a. Monad m => a -> m a
return (StreamForm lore -> m (StreamForm lore))
-> StreamForm lore -> m (StreamForm lore)
forall a b. (a -> b) -> a -> b
$ StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
forall lore.
StreamOrd -> Commutativity -> Lambda lore -> StreamForm lore
Parallel StreamOrd
o1 (Commutativity
comm1 Commutativity -> Commutativity -> Commutativity
forall a. Semigroup a => a -> a -> a
<> Commutativity
comm2) (Lambda lore -> Lambda lore -> Lambda lore
forall lore. Lambda lore -> Lambda lore -> Lambda lore
mergeReduceOps Lambda lore
lam1r Lambda lore
lam2r)
      mergeForms StreamForm lore
_ StreamForm lore
_ = String -> m (StreamForm lore)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Fusing sequential to parallel stream disallowed!"
fuseStreamHelper [VName]
_ Names
_ [VName]
_ [(VName, Ident)]
_ SOAC SOACS
_ SOAC SOACS
_ = String -> TryFusion ([VName], SOAC SOACS)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot Fuse Streams!"

-- | If a Stream is passed as argument then it converts it to a
--   Sequential Stream; Otherwise it FAILS!
toSeqStream :: SOAC -> TryFusion SOAC
toSeqStream :: SOAC SOACS -> TryFusion (SOAC SOACS)
toSeqStream s :: SOAC SOACS
s@(SOAC.Stream SubExp
_ StreamForm SOACS
Sequential Lambda SOACS
_ [SubExp]
_ [Input]
_) = SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return SOAC SOACS
s
toSeqStream (SOAC.Stream SubExp
w Parallel {} Lambda SOACS
l [SubExp]
acc [Input]
inps) =
  SOAC SOACS -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS -> TryFusion (SOAC SOACS))
-> SOAC SOACS -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp
-> StreamForm SOACS
-> Lambda SOACS
-> [SubExp]
-> [Input]
-> SOAC SOACS
forall lore.
SubExp
-> StreamForm lore
-> Lambda lore
-> [SubExp]
-> [Input]
-> SOAC lore
SOAC.Stream SubExp
w StreamForm SOACS
forall lore. StreamForm lore
Sequential Lambda SOACS
l [SubExp]
acc [Input]
inps
toSeqStream SOAC SOACS
_ = String -> TryFusion (SOAC SOACS)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"toSeqStream expects a stream, but given a SOAC."

-- Here follows optimizations and transforms to expose fusability.

optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel :: Maybe [VName] -> FusedKer -> TryFusion FusedKer
optimizeKernel Maybe [VName]
inp FusedKer
ker = do
  (SOAC SOACS
soac, ArrayTransforms
resTrans) <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp (FusedKer -> SOAC SOACS
fsoac FusedKer
ker) ArrayTransforms
startTrans
  FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer -> TryFusion FusedKer) -> FusedKer -> TryFusion FusedKer
forall a b. (a -> b) -> a -> b
$
    FusedKer
ker
      { fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac,
        outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
resTrans
      }
  where
    startTrans :: ArrayTransforms
startTrans = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker

optimizeSOAC ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
optimizeSOAC :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
optimizeSOAC Maybe [VName]
inp SOAC SOACS
soac ArrayTransforms
os = do
  (Bool, SOAC SOACS, ArrayTransforms)
res <- ((Bool, SOAC SOACS, ArrayTransforms)
 -> (Maybe [VName]
     -> SOAC SOACS
     -> ArrayTransforms
     -> TryFusion (SOAC SOACS, ArrayTransforms))
 -> TryFusion (Bool, SOAC SOACS, ArrayTransforms))
-> (Bool, SOAC SOACS, ArrayTransforms)
-> [Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
False, SOAC SOACS
soac, ArrayTransforms
os) [Maybe [VName]
 -> SOAC SOACS
 -> ArrayTransforms
 -> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations
  case (Bool, SOAC SOACS, ArrayTransforms)
res of
    (Bool
False, SOAC SOACS
_, ArrayTransforms
_) -> String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"No optimisation applied"
    (Bool
True, SOAC SOACS
soac', ArrayTransforms
os') -> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
os')
  where
    comb :: (Bool, SOAC SOACS, ArrayTransforms)
-> (Maybe [VName]
    -> SOAC SOACS
    -> ArrayTransforms
    -> TryFusion (SOAC SOACS, ArrayTransforms))
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
comb (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os') Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f =
      do
        (SOAC SOACS
soac'', ArrayTransforms
os'') <- Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
f Maybe [VName]
inp SOAC SOACS
soac' ArrayTransforms
os
        (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
True, SOAC SOACS
soac'', ArrayTransforms
os'')
        TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (Bool, SOAC SOACS, ArrayTransforms)
-> TryFusion (Bool, SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (Bool
changed, SOAC SOACS
soac', ArrayTransforms
os')

type Optimization =
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)

optimizations :: [Optimization]
optimizations :: [Maybe [VName]
 -> SOAC SOACS
 -> ArrayTransforms
 -> TryFusion (SOAC SOACS, ArrayTransforms)]
optimizations = [Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim]

iswim ::
  Maybe [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
iswim :: Maybe [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
iswim Maybe [VName]
_ (SOAC.Screma SubExp
w ScremaForm SOACS
form [Input]
arrs) ArrayTransforms
ots
  | Just [Futhark.Scan Lambda SOACS
scan_fun [SubExp]
nes] <- ScremaForm SOACS -> Maybe [Scan SOACS]
forall lore. ScremaForm lore -> Maybe [Scan lore]
Futhark.isScanSOAC ScremaForm SOACS
form,
    Just (Pattern
map_pat, Certificates
map_cs, SubExp
map_w, Lambda SOACS
map_fun) <- Lambda SOACS -> Maybe (Pattern, Certificates, SubExp, Lambda SOACS)
rwimPossible Lambda SOACS
scan_fun,
    Just [VName]
nes_names <- (SubExp -> Maybe VName) -> [SubExp] -> Maybe [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> Maybe VName
subExpVar [SubExp]
nes = do
    let nes_idents :: [Ident]
nes_idents = (VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident)
-> [VName] -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> TypeBase (ShapeBase SubExp) NoUniqueness -> Ident
Ident [VName]
nes_names ([TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident])
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> [Ident]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun
        map_nes :: [Input]
map_nes = (Ident -> Input) -> [Ident] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> Input
SOAC.identInput [Ident]
nes_idents
        map_arrs' :: [Input]
map_arrs' = [Input]
map_nes [Input] -> [Input] -> [Input]
forall a. [a] -> [a] -> [a]
++ (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Input -> Input
SOAC.transposeInput Int
0 Int
1) [Input]
arrs
        ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_acc_params, [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_elem_params) =
          Int
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)],
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
arrs) ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
 -> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)],
     [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> ([Param (TypeBase (ShapeBase SubExp) NoUniqueness)],
    [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
scan_fun
        map_params :: [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
map_params =
          (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS -> LParam SOACS
removeParamOuterDim [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_acc_params
            [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a. [a] -> [a] -> [a]
++ (Param (TypeBase (ShapeBase SubExp) NoUniqueness)
 -> Param (TypeBase (ShapeBase SubExp) NoUniqueness))
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w) [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
scan_elem_params
        map_rettype :: [TypeBase (ShapeBase SubExp) NoUniqueness]
map_rettype = (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map (TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w) ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
scan_fun

        scan_params :: [LParam SOACS]
scan_params = Lambda SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda SOACS
map_fun
        scan_body :: BodyT SOACS
scan_body = Lambda SOACS -> BodyT SOACS
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda SOACS
map_fun
        scan_rettype :: [TypeBase (ShapeBase SubExp) NoUniqueness]
scan_rettype = Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun
        scan_fun' :: Lambda SOACS
scan_fun' = [LParam SOACS]
-> BodyT SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda SOACS
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda [LParam SOACS]
scan_params BodyT SOACS
scan_body [TypeBase (ShapeBase SubExp) NoUniqueness]
scan_rettype
        nes' :: [SubExp]
nes' = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
take ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
map_params
        arrs' :: [VName]
arrs' = Int -> [VName] -> [VName]
forall a. Int -> [a] -> [a]
drop ([Input] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Input]
map_nes) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> VName
forall dec. Param dec -> VName
paramName [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
map_params

    ScremaForm SOACS
scan_form <- [Scan SOACS] -> TryFusion (ScremaForm SOACS)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m) =>
[Scan lore] -> m (ScremaForm lore)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall lore. Lambda lore -> [SubExp] -> Scan lore
Futhark.Scan Lambda SOACS
scan_fun' [SubExp]
nes']

    let map_body :: BodyT SOACS
map_body =
          Stms SOACS -> [SubExp] -> BodyT SOACS
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody
            ( Stm SOACS -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$
                Pattern -> StmAux (ExpDec SOACS) -> ExpT SOACS -> Stm SOACS
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let (SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w Pattern
map_pat) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (ExpT SOACS -> Stm SOACS) -> ExpT SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                  Op SOACS -> ExpT SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> ExpT SOACS) -> Op SOACS -> ExpT SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Futhark.Screma SubExp
w [VName]
arrs' ScremaForm SOACS
scan_form
            )
            ([SubExp] -> BodyT SOACS) -> [SubExp] -> BodyT SOACS
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT (TypeBase (ShapeBase SubExp) NoUniqueness) -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
Pattern
map_pat
        map_fun' :: Lambda SOACS
map_fun' = [LParam SOACS]
-> BodyT SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> Lambda SOACS
forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
[LParam SOACS]
map_params BodyT SOACS
map_body [TypeBase (ShapeBase SubExp) NoUniqueness]
map_rettype
        perm :: [Int]
perm = case Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
map_fun of
          [] -> []
          TypeBase (ShapeBase SubExp) NoUniqueness
t : [TypeBase (ShapeBase SubExp) NoUniqueness]
_ -> Int
1 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: Int
0 Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
2 .. TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
t]

    (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return
      ( SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
map_w ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall lore.
[Scan lore] -> [Reduce lore] -> Lambda lore -> ScremaForm lore
ScremaForm [] [] Lambda SOACS
map_fun') [Input]
map_arrs',
        ArrayTransforms
ots ArrayTransforms -> ArrayTransform -> ArrayTransforms
SOAC.|> Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
map_cs [Int]
perm
      )
iswim Maybe [VName]
_ SOAC SOACS
_ ArrayTransforms
_ =
  String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"ISWIM does not apply."

removeParamOuterDim :: LParam -> LParam
removeParamOuterDim :: LParam SOACS -> LParam SOACS
removeParamOuterDim LParam SOACS
param =
  let t :: TypeBase (ShapeBase SubExp) NoUniqueness
t = TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall a b. (a -> b) -> a -> b
$ Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param
   in Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase (ShapeBase SubExp) NoUniqueness
paramDec = TypeBase (ShapeBase SubExp) NoUniqueness
t}

setParamOuterDimTo :: SubExp -> LParam -> LParam
setParamOuterDimTo :: SubExp -> LParam SOACS -> LParam SOACS
setParamOuterDimTo SubExp
w LParam SOACS
param =
  let t :: TypeBase (ShapeBase SubExp) NoUniqueness
t = Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall dec.
Typed dec =>
Param dec -> TypeBase (ShapeBase SubExp) NoUniqueness
paramType Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
   in Param (TypeBase (ShapeBase SubExp) NoUniqueness)
LParam SOACS
param {paramDec :: TypeBase (ShapeBase SubExp) NoUniqueness
paramDec = TypeBase (ShapeBase SubExp) NoUniqueness
t}

setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo :: SubExp -> Pattern -> Pattern
setPatternOuterDimTo SubExp
w = (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
-> PatternT (TypeBase (ShapeBase SubExp) NoUniqueness)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (TypeBase (ShapeBase SubExp) NoUniqueness
-> SubExp -> TypeBase (ShapeBase SubExp) NoUniqueness
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w)

-- Now for fiddling with transpositions...

commonTransforms ::
  [VName] ->
  [SOAC.Input] ->
  (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms :: [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
interesting [Input]
inps = [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps'
  where
    inps' :: [(Bool, Input)]
inps' =
      [ (Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
interesting, Input
inp)
        | Input
inp <- [Input]
inps
      ]

commonTransforms' :: [(Bool, SOAC.Input)] -> (SOAC.ArrayTransforms, [SOAC.Input])
commonTransforms' :: [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' [(Bool, Input)]
inps =
  case ((Maybe ArrayTransform, [(Bool, Input)])
 -> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)]))
-> (Maybe ArrayTransform, [(Bool, Input)])
-> [(Bool, Input)]
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
forall a. Maybe a
Nothing, []) [(Bool, Input)]
inps of
    Just (Just ArrayTransform
mot, [(Bool, Input)]
inps') -> (ArrayTransforms -> ArrayTransforms)
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first (ArrayTransform
mot ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<|) ((ArrayTransforms, [Input]) -> (ArrayTransforms, [Input]))
-> (ArrayTransforms, [Input]) -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> (ArrayTransforms, [Input])
commonTransforms' ([(Bool, Input)] -> (ArrayTransforms, [Input]))
-> [(Bool, Input)] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ [(Bool, Input)] -> [(Bool, Input)]
forall a. [a] -> [a]
reverse [(Bool, Input)]
inps'
    Maybe (Maybe ArrayTransform, [(Bool, Input)])
_ -> (ArrayTransforms
SOAC.noTransforms, ((Bool, Input) -> Input) -> [(Bool, Input)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Bool, Input) -> Input
forall a b. (a, b) -> b
snd [(Bool, Input)]
inps)
  where
    inspect :: (Maybe ArrayTransform, [(Bool, Input)])
-> (Bool, Input) -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool
True, Input
inp) =
      case (Maybe ArrayTransform
mot, Input -> Maybe (ArrayTransform, Input)
inputToOutput Input
inp) of
        (Maybe ArrayTransform
Nothing, Just (ArrayTransform
ot, Input
inp')) -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Just ArrayTransform
ot1, Just (ArrayTransform
ot2, Input
inp'))
          | ArrayTransform
ot1 ArrayTransform -> ArrayTransform -> Bool
forall a. Eq a => a -> a -> Bool
== ArrayTransform
ot2 -> (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (ArrayTransform -> Maybe ArrayTransform
forall a. a -> Maybe a
Just ArrayTransform
ot2, (Bool
True, Input
inp') (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)
        (Maybe ArrayTransform, Maybe (ArrayTransform, Input))
_ -> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. Maybe a
Nothing
    inspect (Maybe ArrayTransform
mot, [(Bool, Input)]
prev) (Bool, Input)
inp = (Maybe ArrayTransform, [(Bool, Input)])
-> Maybe (Maybe ArrayTransform, [(Bool, Input)])
forall a. a -> Maybe a
Just (Maybe ArrayTransform
mot, (Bool, Input)
inp (Bool, Input) -> [(Bool, Input)] -> [(Bool, Input)]
forall a. a -> [a] -> [a]
: [(Bool, Input)]
prev)

mapDepth :: MapNest -> Int
mapDepth :: MapNest -> Int
mapDepth (MapNest.MapNest SubExp
_ Lambda SOACS
lam [Nesting SOACS]
levels [Input]
_) =
  Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
resDims ([Nesting SOACS] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Nesting SOACS]
levels) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
  where
    resDims :: Int
resDims = [TypeBase (ShapeBase SubExp) NoUniqueness] -> Int
forall {shape} {u}. ArrayShape shape => [TypeBase shape u] -> Int
minDim ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Int)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Int
forall a b. (a -> b) -> a -> b
$ case [Nesting SOACS]
levels of
      [] -> Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
lam
      Nesting SOACS
nest : [Nesting SOACS]
_ -> Nesting SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
Nesting lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
MapNest.nestingReturnType Nesting SOACS
nest
    minDim :: [TypeBase shape u] -> Int
minDim [] = Int
0
    minDim (TypeBase shape u
t : [TypeBase shape u]
ts) = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Int -> Int -> Int
forall a. Ord a => a -> a -> a
min (TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase shape u
t) ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (TypeBase shape u -> Int) -> [TypeBase shape u] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase shape u -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank [TypeBase shape u]
ts

pullRearrange ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullRearrange :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange SOAC SOACS
soac ArrayTransforms
ots = do
  MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
 Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC SOACS
soac
  SOAC.Rearrange Certificates
cs [Int]
perm SOAC.:< ArrayTransforms
ots' <- ViewF -> TryFusion ViewF
forall (m :: * -> *) a. Monad m => a -> m a
return (ViewF -> TryFusion ViewF) -> ViewF -> TryFusion ViewF
forall a b. (a -> b) -> a -> b
$ ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots
  if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let -- Expand perm to cover the full extent of the input dimensionality
          perm' :: Input -> [Int]
perm' Input
inp = Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
            where
              r :: Int
r = Input -> Int
SOAC.inputRank Input
inp
          addPerm :: Input -> Input
addPerm Input
inp = ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
cs ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Input -> [Int]
perm' Input
inp) Input
inp
          inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map Input -> Input
addPerm ([Input] -> [Input]) -> [Input] -> [Input]
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
      SOAC SOACS
soac' <-
        MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
 Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
          [Input]
inputs' [Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
ots')
    else String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull transpose"

pushRearrange ::
  [VName] ->
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pushRearrange :: [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds SOAC SOACS
soac ArrayTransforms
ots = do
  MapNest
nest <- Maybe MapNest -> TryFusion MapNest
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe MapNest -> TryFusion MapNest)
-> TryFusion (Maybe MapNest) -> TryFusion MapNest
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SOAC SOACS -> TryFusion (Maybe MapNest)
forall lore (m :: * -> *).
(Bindable lore, MonadFreshNames m, LocalScope lore m,
 Op lore ~ SOAC lore) =>
SOAC lore -> m (Maybe (MapNest lore))
MapNest.fromSOAC SOAC SOACS
soac
  ([Int]
perm, [Input]
inputs') <- Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a. Maybe a -> TryFusion a
liftMaybe (Maybe ([Int], [Input]) -> TryFusion ([Int], [Input]))
-> Maybe ([Int], [Input]) -> TryFusion ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds ([Input] -> Maybe ([Int], [Input]))
-> [Input] -> Maybe ([Int], [Input])
forall a b. (a -> b) -> a -> b
$ MapNest -> [Input]
forall lore. MapNest lore -> [Input]
MapNest.inputs MapNest
nest
  if [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= MapNest -> Int
mapDepth MapNest
nest
    then do
      let invertRearrange :: ArrayTransform
invertRearrange = Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int]
rearrangeInverse [Int]
perm
      SOAC SOACS
soac' <-
        MapNest -> TryFusion (SOAC SOACS)
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, Bindable lore, BinderOps lore,
 Op lore ~ SOAC lore) =>
MapNest lore -> m (SOAC lore)
MapNest.toSOAC (MapNest -> TryFusion (SOAC SOACS))
-> MapNest -> TryFusion (SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
          [Input]
inputs'
            [Input] -> MapNest -> MapNest
forall lore. [Input] -> MapNest lore -> MapNest lore
`MapNest.setInputs` MapNest -> [Int] -> MapNest
rearrangeReturnTypes MapNest
nest [Int]
perm
      (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransform
invertRearrange ArrayTransform -> ArrayTransforms -> ArrayTransforms
SOAC.<| ArrayTransforms
ots)
    else String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot push transpose"

-- | Actually also rearranges indices.
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes :: MapNest -> [Int] -> MapNest
rearrangeReturnTypes nest :: MapNest
nest@(MapNest.MapNest SubExp
w Lambda SOACS
body [Nesting SOACS]
nestings [Input]
inps) [Int]
perm =
  SubExp -> Lambda SOACS -> [Nesting SOACS] -> [Input] -> MapNest
forall lore.
SubExp -> Lambda lore -> [Nesting lore] -> [Input] -> MapNest lore
MapNest.MapNest
    SubExp
w
    Lambda SOACS
body
    ( (Nesting SOACS
 -> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting SOACS)
-> [Nesting SOACS]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [Nesting SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
        Nesting SOACS
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting SOACS
forall {lore} {lore}.
Nesting lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting lore
setReturnType
        [Nesting SOACS]
nestings
        ([[TypeBase (ShapeBase SubExp) NoUniqueness]] -> [Nesting SOACS])
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]] -> [Nesting SOACS]
forall a b. (a -> b) -> a -> b
$ Int
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[TypeBase (ShapeBase SubExp) NoUniqueness]]
 -> [[TypeBase (ShapeBase SubExp) NoUniqueness]])
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall a b. (a -> b) -> a -> b
$ ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [[TypeBase (ShapeBase SubExp) NoUniqueness]]
forall a. (a -> a) -> a -> [a]
iterate ((TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType) [TypeBase (ShapeBase SubExp) NoUniqueness]
ts
    )
    [Input]
inps
  where
    origts :: [TypeBase (ShapeBase SubExp) NoUniqueness]
origts = MapNest -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
MapNest lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
MapNest.typeOf MapNest
nest
    -- The permutation may be deeper than the rank of the type,
    -- but it is required that it is an identity permutation
    -- beyond that.  This is supposed to be checked as an
    -- invariant by whoever calls rearrangeReturnTypes.
    rearrangeType' :: TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
rearrangeType' TypeBase (ShapeBase SubExp) NoUniqueness
t = [Int]
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
rearrangeType (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take (TypeBase (ShapeBase SubExp) NoUniqueness -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank TypeBase (ShapeBase SubExp) NoUniqueness
t) [Int]
perm) TypeBase (ShapeBase SubExp) NoUniqueness
t
    ts :: [TypeBase (ShapeBase SubExp) NoUniqueness]
ts = (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
rearrangeType' [TypeBase (ShapeBase SubExp) NoUniqueness]
origts

    setReturnType :: Nesting lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Nesting lore
setReturnType Nesting lore
nesting [TypeBase (ShapeBase SubExp) NoUniqueness]
t' =
      Nesting lore
nesting {nestingReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
MapNest.nestingReturnType = [TypeBase (ShapeBase SubExp) NoUniqueness]
t'}

fixupInputs :: [VName] -> [SOAC.Input] -> Maybe ([Int], [SOAC.Input])
fixupInputs :: [VName] -> [Input] -> Maybe ([Int], [Input])
fixupInputs [VName]
inpIds [Input]
inps =
  case (Input -> Maybe [Int]) -> [Input] -> [[Int]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Input -> Maybe [Int]
inputRearrange ([Input] -> [[Int]]) -> [Input] -> [[Int]]
forall a b. (a -> b) -> a -> b
$ (Input -> Bool) -> [Input] -> [Input]
forall a. (a -> Bool) -> [a] -> [a]
filter Input -> Bool
exposable [Input]
inps of
    [Int]
perm : [[Int]]
_ -> do
      [Input]
inps' <- (Input -> Maybe Input) -> [Input] -> Maybe [Input]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Int -> [Int] -> Input -> Maybe Input
fixupInput ([Int] -> Int
rearrangeReach [Int]
perm) [Int]
perm) [Input]
inps
      ([Int], [Input]) -> Maybe ([Int], [Input])
forall (m :: * -> *) a. Monad m => a -> m a
return ([Int]
perm, [Input]
inps')
    [[Int]]
_ -> Maybe ([Int], [Input])
forall a. Maybe a
Nothing
  where
    exposable :: Input -> Bool
exposable = (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
inpIds) (VName -> Bool) -> (Input -> VName) -> Input -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Input -> VName
SOAC.inputArray

    inputRearrange :: Input -> Maybe [Int]
inputRearrange (SOAC.Input ArrayTransforms
ts VName
_ TypeBase (ShapeBase SubExp) NoUniqueness
_)
      | ArrayTransforms
_ SOAC.:> SOAC.Rearrange Certificates
_ [Int]
perm <- ArrayTransforms -> ViewL
SOAC.viewl ArrayTransforms
ts = [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
    inputRearrange Input
_ = Maybe [Int]
forall a. Maybe a
Nothing

    fixupInput :: Int -> [Int] -> Input -> Maybe Input
fixupInput Int
d [Int]
perm Input
inp
      | Int
r <- Input -> Int
SOAC.inputRank Input
inp,
        Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
d =
        Input -> Maybe Input
forall a. a -> Maybe a
Just (Input -> Maybe Input) -> Input -> Maybe Input
forall a b. (a -> b) -> a -> b
$ ArrayTransform -> Input -> Input
SOAC.addTransform (Certificates -> [Int] -> ArrayTransform
SOAC.Rearrange Certificates
forall a. Monoid a => a
mempty ([Int] -> ArrayTransform) -> [Int] -> ArrayTransform
forall a b. (a -> b) -> a -> b
$ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
r [Int]
perm) Input
inp
      | Bool
otherwise = Maybe Input
forall a. Maybe a
Nothing

pullReshape :: SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)
pullReshape :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape (SOAC.Screma SubExp
_ ScremaForm SOACS
form [Input]
inps) ArrayTransforms
ots
  | Just Lambda SOACS
maplam <- ScremaForm SOACS -> Maybe (Lambda SOACS)
forall lore. ScremaForm lore -> Maybe (Lambda lore)
Futhark.isMapSOAC ScremaForm SOACS
form,
    SOAC.Reshape Certificates
cs ShapeChange SubExp
shape SOAC.:< ArrayTransforms
ots' <- ArrayTransforms -> ViewF
SOAC.viewf ArrayTransforms
ots,
    (TypeBase (ShapeBase SubExp) NoUniqueness -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all TypeBase (ShapeBase SubExp) NoUniqueness -> Bool
forall shape u. TypeBase shape u -> Bool
primType ([TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool)
-> [TypeBase (ShapeBase SubExp) NoUniqueness] -> Bool
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
maplam = do
    let mapw' :: SubExp
mapw' = case [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape of
          [] -> IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0
          SubExp
d : [SubExp]
_ -> SubExp
d
        inputs' :: [Input]
inputs' = (Input -> Input) -> [Input] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (ArrayTransform -> Input -> Input
SOAC.addTransform (ArrayTransform -> Input -> Input)
-> ArrayTransform -> Input -> Input
forall a b. (a -> b) -> a -> b
$ Certificates -> ShapeChange SubExp -> ArrayTransform
SOAC.ReshapeOuter Certificates
cs ShapeChange SubExp
shape) [Input]
inps
        inputTypes :: [TypeBase (ShapeBase SubExp) NoUniqueness]
inputTypes = (Input -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [Input] -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map Input -> TypeBase (ShapeBase SubExp) NoUniqueness
SOAC.inputType [Input]
inputs'

    let outersoac ::
          ([SOAC.Input] -> SOAC) ->
          (SubExp, [SubExp]) ->
          TryFusion ([SOAC.Input] -> SOAC)
        outersoac :: ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac [Input] -> SOAC SOACS
inner (SubExp
w, [SubExp]
outershape) = do
          let addDims :: TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
addDims TypeBase (ShapeBase SubExp) NoUniqueness
t = TypeBase (ShapeBase SubExp) NoUniqueness
-> ShapeBase SubExp
-> NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase (ShapeBase SubExp) NoUniqueness
t ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
outershape) NoUniqueness
NoUniqueness
              retTypes :: [TypeBase (ShapeBase SubExp) NoUniqueness]
retTypes = (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TypeBase (ShapeBase SubExp) NoUniqueness)
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
addDims ([TypeBase (ShapeBase SubExp) NoUniqueness]
 -> [TypeBase (ShapeBase SubExp) NoUniqueness])
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [TypeBase (ShapeBase SubExp) NoUniqueness]
forall lore.
LambdaT lore -> [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType Lambda SOACS
maplam

          [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps <- [TypeBase (ShapeBase SubExp) NoUniqueness]
-> (TypeBase (ShapeBase SubExp) NoUniqueness
    -> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TryFusion [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [TypeBase (ShapeBase SubExp) NoUniqueness]
inputTypes ((TypeBase (ShapeBase SubExp) NoUniqueness
  -> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
 -> TryFusion [Param (TypeBase (ShapeBase SubExp) NoUniqueness)])
-> (TypeBase (ShapeBase SubExp) NoUniqueness
    -> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TryFusion [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
forall a b. (a -> b) -> a -> b
$ \TypeBase (ShapeBase SubExp) NoUniqueness
inpt ->
            String
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"pullReshape_param" (TypeBase (ShapeBase SubExp) NoUniqueness
 -> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness)))
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TryFusion (Param (TypeBase (ShapeBase SubExp) NoUniqueness))
forall a b. (a -> b) -> a -> b
$
              Int
-> TypeBase (ShapeBase SubExp) NoUniqueness
-> TypeBase (ShapeBase SubExp) NoUniqueness
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray (ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
outershape) TypeBase (ShapeBase SubExp) NoUniqueness
inpt

          BodyT SOACS
inner_body <-
            Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall lore (m :: * -> *) somelore.
(Bindable lore, MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore (Body lore) -> m (Body lore)
runBodyBinder (Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS))
-> Binder SOACS (BodyT SOACS) -> TryFusion (BodyT SOACS)
forall a b. (a -> b) -> a -> b
$
              [BinderT
   SOACS
   (State VNameSource)
   (Exp (Lore (BinderT SOACS (State VNameSource))))]
-> BinderT
     SOACS
     (State VNameSource)
     (Body (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
MonadBinder m =>
[m (Exp (Lore m))] -> m (Body (Lore m))
eBody [SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
     SOACS
     (State VNameSource)
     (Exp (Lore (BinderT SOACS (State VNameSource))))
forall (m :: * -> *).
(MonadBinder m, Op (Lore m) ~ SOAC (Lore m)) =>
SOAC (Lore m) -> m (Exp (Lore m))
SOAC.toExp (SOAC (Lore (BinderT SOACS (State VNameSource)))
 -> BinderT
      SOACS
      (State VNameSource)
      (Exp (Lore (BinderT SOACS (State VNameSource)))))
-> SOAC (Lore (BinderT SOACS (State VNameSource)))
-> BinderT
     SOACS
     (State VNameSource)
     (Exp (Lore (BinderT SOACS (State VNameSource))))
forall a b. (a -> b) -> a -> b
$ [Input] -> SOAC SOACS
inner ([Input] -> SOAC SOACS) -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Input)
-> [Param (TypeBase (ShapeBase SubExp) NoUniqueness)] -> [Input]
forall a b. (a -> b) -> [a] -> [b]
map (Ident -> Input
SOAC.identInput (Ident -> Input)
-> (Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident)
-> Param (TypeBase (ShapeBase SubExp) NoUniqueness)
-> Input
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param (TypeBase (ShapeBase SubExp) NoUniqueness) -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent) [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
ps]
          let inner_fun :: Lambda SOACS
inner_fun =
                Lambda :: forall lore.
[LParam lore]
-> BodyT lore
-> [TypeBase (ShapeBase SubExp) NoUniqueness]
-> LambdaT lore
Lambda
                  { lambdaParams :: [LParam SOACS]
lambdaParams = [Param (TypeBase (ShapeBase SubExp) NoUniqueness)]
[LParam SOACS]
ps,
                    lambdaReturnType :: [TypeBase (ShapeBase SubExp) NoUniqueness]
lambdaReturnType = [TypeBase (ShapeBase SubExp) NoUniqueness]
retTypes,
                    lambdaBody :: BodyT SOACS
lambdaBody = BodyT SOACS
inner_body
                  }
          ([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS) -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$ SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
w (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
inner_fun

    [Input] -> SOAC SOACS
op' <-
      (([Input] -> SOAC SOACS)
 -> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS))
-> ([Input] -> SOAC SOACS)
-> [(SubExp, [SubExp])]
-> TryFusion ([Input] -> SOAC SOACS)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ([Input] -> SOAC SOACS)
-> (SubExp, [SubExp]) -> TryFusion ([Input] -> SOAC SOACS)
outersoac (SubExp -> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall lore. SubExp -> ScremaForm lore -> [Input] -> SOAC lore
SOAC.Screma SubExp
mapw' (ScremaForm SOACS -> [Input] -> SOAC SOACS)
-> ScremaForm SOACS -> [Input] -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
Futhark.mapSOAC Lambda SOACS
maplam) ([(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS))
-> [(SubExp, [SubExp])] -> TryFusion ([Input] -> SOAC SOACS)
forall a b. (a -> b) -> a -> b
$
        [SubExp] -> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. [a] -> [b] -> [(a, b)]
zip (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
1 ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp]
forall a. [a] -> [a]
reverse ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape) ([[SubExp]] -> [(SubExp, [SubExp])])
-> [[SubExp]] -> [(SubExp, [SubExp])]
forall a b. (a -> b) -> a -> b
$
          Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [[SubExp]]
forall a. [a] -> [a]
reverse ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Int -> [[SubExp]] -> [[SubExp]]
forall a. Int -> [a] -> [a]
drop Int
1 ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [[SubExp]]
forall a. [a] -> [[a]]
tails ([SubExp] -> [[SubExp]]) -> [SubExp] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ShapeChange SubExp -> [SubExp]
forall d. ShapeChange d -> [d]
newDims ShapeChange SubExp
shape
    (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return ([Input] -> SOAC SOACS
op' [Input]
inputs', ArrayTransforms
ots')
pullReshape SOAC SOACS
_ ArrayTransforms
_ = String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull reshape"

-- Tie it all together in exposeInputs (for making inputs to a
-- consumer available) and pullOutputTransforms (for moving
-- output-transforms of a producer to its inputs instead).

exposeInputs ::
  [VName] ->
  FusedKer ->
  TryFusion (FusedKer, SOAC.ArrayTransforms)
exposeInputs :: [VName] -> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs [VName]
inpIds FusedKer
ker =
  (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pushRearrange')
    TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' (FusedKer -> TryFusion (FusedKer, ArrayTransforms))
-> TryFusion FusedKer -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TryFusion FusedKer
pullRearrange')
    TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker
  where
    ot :: ArrayTransforms
ot = FusedKer -> ArrayTransforms
outputTransform FusedKer
ker

    pushRearrange' :: TryFusion FusedKer
pushRearrange' = do
      (SOAC SOACS
soac', ArrayTransforms
ot') <- [VName]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
pushRearrange [VName]
inpIds (FusedKer -> SOAC SOACS
fsoac FusedKer
ker) ArrayTransforms
ot
      FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
        FusedKer
ker
          { fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac',
            outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
ot'
          }

    pullRearrange' :: TryFusion FusedKer
pullRearrange' = do
      (SOAC SOACS
soac', ArrayTransforms
ot') <- SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange (FusedKer -> SOAC SOACS
fsoac FusedKer
ker) ArrayTransforms
ot
      Bool -> TryFusion () -> TryFusion ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ot') (TryFusion () -> TryFusion ()) -> TryFusion () -> TryFusion ()
forall a b. (a -> b) -> a -> b
$
        String -> TryFusion ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"pullRearrange was not enough"
      FusedKer -> TryFusion FusedKer
forall (m :: * -> *) a. Monad m => a -> m a
return
        FusedKer
ker
          { fsoac :: SOAC SOACS
fsoac = SOAC SOACS
soac',
            outputTransform :: ArrayTransforms
outputTransform = ArrayTransforms
SOAC.noTransforms
          }

    exposeInputs' :: FusedKer -> TryFusion (FusedKer, ArrayTransforms)
exposeInputs' FusedKer
ker' =
      case [VName] -> [Input] -> (ArrayTransforms, [Input])
commonTransforms [VName]
inpIds ([Input] -> (ArrayTransforms, [Input]))
-> [Input] -> (ArrayTransforms, [Input])
forall a b. (a -> b) -> a -> b
$ FusedKer -> [Input]
inputs FusedKer
ker' of
        (ArrayTransforms
ot', [Input]
inps')
          | (Input -> Bool) -> [Input] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Input -> Bool
exposed [Input]
inps' ->
            (FusedKer, ArrayTransforms)
-> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (FusedKer
ker' {fsoac :: SOAC SOACS
fsoac = [Input]
inps' [Input] -> SOAC SOACS -> SOAC SOACS
forall lore. [Input] -> SOAC lore -> SOAC lore
`SOAC.setInputs` FusedKer -> SOAC SOACS
fsoac FusedKer
ker'}, ArrayTransforms
ot')
        (ArrayTransforms, [Input])
_ -> String -> TryFusion (FusedKer, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot expose"

    exposed :: Input -> Bool
exposed (SOAC.Input ArrayTransforms
ts VName
_ TypeBase (ShapeBase SubExp) NoUniqueness
_)
      | ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ts = Bool
True
    exposed Input
inp = Input -> VName
SOAC.inputArray Input
inp VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
inpIds

outputTransformPullers :: [SOAC -> SOAC.ArrayTransforms -> TryFusion (SOAC, SOAC.ArrayTransforms)]
outputTransformPullers :: [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers = [SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullRearrange, SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullReshape]

pullOutputTransforms ::
  SOAC ->
  SOAC.ArrayTransforms ->
  TryFusion (SOAC, SOAC.ArrayTransforms)
pullOutputTransforms :: SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms = [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> SOAC SOACS
-> ArrayTransforms
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall {t} {t}.
[t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [SOAC SOACS
 -> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)]
outputTransformPullers
  where
    attempt :: [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [] t
_ t
_ = String -> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot pull anything"
    attempt (t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p : [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps) t
soac t
ots =
      do
        (SOAC SOACS
soac', ArrayTransforms
ots') <- t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
p t
soac t
ots
        if ArrayTransforms -> Bool
SOAC.nullTransforms ArrayTransforms
ots'
          then (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
SOAC.noTransforms)
          else SOAC SOACS
-> ArrayTransforms -> TryFusion (SOAC SOACS, ArrayTransforms)
pullOutputTransforms SOAC SOACS
soac' ArrayTransforms
ots' TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (m :: * -> *) a. Monad m => a -> m a
return (SOAC SOACS
soac', ArrayTransforms
ots')
        TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
-> TryFusion (SOAC SOACS, ArrayTransforms)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
-> t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)
attempt [t -> t -> TryFusion (SOAC SOACS, ArrayTransforms)]
ps t
soac t
ots