{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}

-- | It is well known that fully parallel loops can always be
-- interchanged inwards with a sequential loop.  This module
-- implements that transformation.
--
-- This is also where we implement loop-switching (for branches),
-- which is semantically similar to interchange.
module Futhark.Pass.ExtractKernels.Interchange
  ( SeqLoop (..),
    interchangeLoops,
    Branch (..),
    interchangeBranch,
    WithAccStm (..),
    interchangeWithAcc,
  )
where

import Control.Monad.Identity
import Data.List (find)
import Data.Maybe
import Futhark.IR.SOACS
import Futhark.MonadFreshNames
import Futhark.Pass.ExtractKernels.Distribution
  ( KernelNest,
    LoopNesting (..),
    kernelNestLoops,
  )
import Futhark.Tools
import Futhark.Transform.Rename

-- | An encoding of a sequential do-loop with no existential context,
-- alongside its result pattern.
data SeqLoop = SeqLoop [Int] Pattern [(FParam, SubExp)] (LoopForm SOACS) Body

seqLoopStm :: SeqLoop -> Stm
seqLoopStm :: SeqLoop -> Stm
seqLoopStm (SeqLoop [Int]
_ Pattern
pat [(FParam, SubExp)]
merge LoopForm SOACS
form Body
body) =
  Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ [(FParam, SubExp)]
-> [(FParam, SubExp)] -> LoopForm SOACS -> Body -> Exp SOACS
forall lore.
[(FParam lore, SubExp)]
-> [(FParam lore, SubExp)]
-> LoopForm lore
-> BodyT lore
-> ExpT lore
DoLoop [] [(FParam, SubExp)]
merge LoopForm SOACS
form Body
body

interchangeLoop ::
  (MonadBinder m, LocalScope SOACS m) =>
  (VName -> Maybe VName) ->
  SeqLoop ->
  LoopNesting ->
  m SeqLoop
interchangeLoop :: forall (m :: * -> *).
(MonadBinder m, LocalScope SOACS m) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop
  VName -> Maybe VName
isMapParameter
  (SeqLoop [Int]
perm Pattern
loop_pat [(FParam, SubExp)]
merge LoopForm SOACS
form Body
body)
  (MapNesting PatternT Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    [(Param DeclType, SubExp)]
merge_expanded <-
      Scope SOACS
-> m [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams ([Param Type] -> Scope SOACS) -> [Param Type] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((Param Type, VName) -> Param Type)
-> [(Param Type, VName)] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst [(Param Type, VName)]
params_and_arrs) (m [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)])
-> m [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$
        ((Param DeclType, SubExp) -> m (Param DeclType, SubExp))
-> [(Param DeclType, SubExp)] -> m [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
expand [(Param DeclType, SubExp)]
[(FParam, SubExp)]
merge

    let loop_pat_expanded :: PatternT Type
loop_pat_expanded =
          [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> PatElemT Type)
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> [a] -> [b]
map PatElemT Type -> PatElemT Type
expandPatElem ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern
loop_pat
        new_params :: [Param Type]
new_params =
          [ VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
pname (Type -> Param Type) -> Type -> Param Type
forall a b. (a -> b) -> a -> b
$ DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl DeclType
ptype
            | (Param VName
pname DeclType
ptype, SubExp
_) <- [(Param DeclType, SubExp)]
[(FParam, SubExp)]
merge
          ]
        new_arrs :: [VName]
new_arrs = ((Param DeclType, SubExp) -> VName)
-> [(Param DeclType, SubExp)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
merge_expanded
        rettype :: [Type]
rettype = (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
loop_pat_expanded

    -- If the map consumes something that is bound outside the loop
    -- (i.e. is not a merge parameter), we have to copy() it.  As a
    -- small simplification, we just remove the parameter outright if
    -- it is not used anymore.  This might happen if the parameter was
    -- used just as the inital value of a merge parameter.
    (([Param Type]
params', [VName]
arrs'), Stms SOACS
pre_copy_bnds) <-
      Binder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS ([Param Type], [VName])
 -> m (([Param Type], [VName]), Stms SOACS))
-> Binder SOACS ([Param Type], [VName])
-> m (([Param Type], [VName]), Stms SOACS)
forall a b. (a -> b) -> a -> b
$
        Scope SOACS
-> Binder SOACS ([Param Type], [VName])
-> Binder SOACS ([Param Type], [VName])
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
new_params) (Binder SOACS ([Param Type], [VName])
 -> Binder SOACS ([Param Type], [VName]))
-> Binder SOACS ([Param Type], [VName])
-> Binder SOACS ([Param Type], [VName])
forall a b. (a -> b) -> a -> b
$
          [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Param Type, VName)] -> ([Param Type], [VName]))
-> ([Maybe (Param Type, VName)] -> [(Param Type, VName)])
-> [Maybe (Param Type, VName)]
-> ([Param Type], [VName])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (Param Type, VName)] -> [(Param Type, VName)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (Param Type, VName)] -> ([Param Type], [VName]))
-> BinderT SOACS (State VNameSource) [Maybe (Param Type, VName)]
-> Binder SOACS ([Param Type], [VName])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Param Type, VName)
 -> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName)))
-> [(Param Type, VName)]
-> BinderT SOACS (State VNameSource) [Maybe (Param Type, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam [(Param Type, VName)]
params_and_arrs

    let lam :: LambdaT SOACS
lam = [LParam SOACS] -> Body -> [Type] -> LambdaT SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ([Param Type]
params' [Param Type] -> [Param Type] -> [Param Type]
forall a. Semigroup a => a -> a -> a
<> [Param Type]
new_params) Body
body [Type]
rettype
        map_bnd :: Stm
map_bnd =
          Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern
loop_pat_expanded StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$
            Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w ([VName]
arrs' [VName] -> [VName] -> [VName]
forall a. Semigroup a => a -> a -> a
<> [VName]
new_arrs) (LambdaT SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC LambdaT SOACS
lam)
        res :: [SubExp]
res = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
loop_pat_expanded
        pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
pat

    SeqLoop -> m SeqLoop
forall (m :: * -> *) a. Monad m => a -> m a
return (SeqLoop -> m SeqLoop) -> SeqLoop -> m SeqLoop
forall a b. (a -> b) -> a -> b
$
      [Int]
-> Pattern
-> [(FParam, SubExp)]
-> LoopForm SOACS
-> Body
-> SeqLoop
SeqLoop [Int]
perm PatternT Type
Pattern
pat' [(Param DeclType, SubExp)]
[(FParam, SubExp)]
merge_expanded LoopForm SOACS
form (Body -> SeqLoop) -> Body -> SeqLoop
forall a b. (a -> b) -> a -> b
$
        Stms SOACS -> [SubExp] -> Body
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stms SOACS
pre_copy_bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm
map_bnd) [SubExp]
res
    where
      free_in_body :: Names
free_in_body = Body -> Names
forall a. FreeIn a => a -> Names
freeIn Body
body

      copyOrRemoveParam :: (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
copyOrRemoveParam (Param Type
param, VName
arr)
        | Bool -> Bool
not (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param VName -> Names -> Bool
`nameIn` Names
free_in_body) =
          Maybe (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Param Type, VName)
forall a. Maybe a
Nothing
        | Bool
otherwise =
          Maybe (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Param Type, VName)
 -> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName)))
-> Maybe (Param Type, VName)
-> BinderT SOACS (State VNameSource) (Maybe (Param Type, VName))
forall a b. (a -> b) -> a -> b
$ (Param Type, VName) -> Maybe (Param Type, VName)
forall a. a -> Maybe a
Just (Param Type
param, VName
arr)

      expandedInit :: String -> SubExp -> m SubExp
expandedInit String
_ (Var VName
v)
        | Just VName
arr <- VName -> Maybe VName
isMapParameter VName
v =
          SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp -> m SubExp) -> SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
arr
      expandedInit String
param_name SubExp
se =
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp (String
param_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_expanded_init") (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) SubExp
se

      expand :: (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
expand (Param DeclType
merge_param, SubExp
merge_init) = do
        Param DeclType
expanded_param <-
          String -> DeclType -> m (Param DeclType)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (String
param_name String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_expanded") (DeclType -> m (Param DeclType)) -> DeclType -> m (Param DeclType)
forall a b. (a -> b) -> a -> b
$
            DeclType -> ShapeBase SubExp -> Uniqueness -> DeclType
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Param DeclType -> DeclType
forall dec. DeclTyped dec => Param dec -> DeclType
paramDeclType Param DeclType
merge_param) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w]) (Uniqueness -> DeclType) -> Uniqueness -> DeclType
forall a b. (a -> b) -> a -> b
$
              DeclType -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness (DeclType -> Uniqueness) -> DeclType -> Uniqueness
forall a b. (a -> b) -> a -> b
$ Param DeclType -> DeclType
forall t. DeclTyped t => t -> DeclType
declTypeOf Param DeclType
merge_param
        SubExp
expanded_init <- String -> SubExp -> m SubExp
expandedInit String
param_name SubExp
merge_init
        (Param DeclType, SubExp) -> m (Param DeclType, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (Param DeclType
expanded_param, SubExp
expanded_init)
        where
          param_name :: String
param_name = VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
merge_param

      expandPatElem :: PatElemT Type -> PatElemT Type
expandPatElem (PatElem VName
name Type
t) =
        VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (Type -> PatElemT Type) -> Type -> PatElemT Type
forall a b. (a -> b) -> a -> b
$ Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow Type
t SubExp
w

-- | Given a (parallel) map nesting and an inner sequential loop, move
-- the maps inside the sequential loop.  The result is several
-- statements - one of these will be the loop, which will then contain
-- statements with @map@ expressions.
interchangeLoops ::
  (MonadFreshNames m, HasScope SOACS m) =>
  KernelNest ->
  SeqLoop ->
  m (Stms SOACS)
interchangeLoops :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> SeqLoop -> m (Stms SOACS)
interchangeLoops KernelNest
nest SeqLoop
loop = do
  (SeqLoop
loop', Stms SOACS
bnds) <-
    Binder SOACS SeqLoop -> m (SeqLoop, Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS SeqLoop -> m (SeqLoop, Stms SOACS))
-> Binder SOACS SeqLoop -> m (SeqLoop, Stms SOACS)
forall a b. (a -> b) -> a -> b
$
      (SeqLoop -> LoopNesting -> Binder SOACS SeqLoop)
-> SeqLoop -> [LoopNesting] -> Binder SOACS SeqLoop
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM ((VName -> Maybe VName)
-> SeqLoop -> LoopNesting -> Binder SOACS SeqLoop
forall (m :: * -> *).
(MonadBinder m, LocalScope SOACS m) =>
(VName -> Maybe VName) -> SeqLoop -> LoopNesting -> m SeqLoop
interchangeLoop VName -> Maybe VName
isMapParameter) SeqLoop
loop ([LoopNesting] -> Binder SOACS SeqLoop)
-> [LoopNesting] -> Binder SOACS SeqLoop
forall a b. (a -> b) -> a -> b
$
        [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest
  Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (SeqLoop -> Stm
seqLoopStm SeqLoop
loop')
  where
    isMapParameter :: VName -> Maybe VName
isMapParameter VName
v =
      ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd (Maybe (Param Type, VName) -> Maybe VName)
-> Maybe (Param Type, VName) -> Maybe VName
forall a b. (a -> b) -> a -> b
$
        ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) ([(Param Type, VName)] -> Maybe (Param Type, VName))
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall a b. (a -> b) -> a -> b
$
          (LoopNesting -> [(Param Type, VName)])
-> [LoopNesting] -> [(Param Type, VName)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap LoopNesting -> [(Param Type, VName)]
loopNestingParamsAndArrs ([LoopNesting] -> [(Param Type, VName)])
-> [LoopNesting] -> [(Param Type, VName)]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest

data Branch = Branch [Int] Pattern SubExp Body Body (IfDec (BranchType SOACS))

branchStm :: Branch -> Stm
branchStm :: Branch -> Stm
branchStm (Branch [Int]
_ Pattern
pat SubExp
cond Body
tbranch Body
fbranch IfDec (BranchType SOACS)
ret) =
  Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ SubExp -> Body -> Body -> IfDec (BranchType SOACS) -> Exp SOACS
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cond Body
tbranch Body
fbranch IfDec (BranchType SOACS)
ret

interchangeBranch1 ::
  (MonadBinder m) =>
  Branch ->
  LoopNesting ->
  m Branch
interchangeBranch1 :: forall (m :: * -> *).
MonadBinder m =>
Branch -> LoopNesting -> m Branch
interchangeBranch1
  (Branch [Int]
perm Pattern
branch_pat SubExp
cond Body
tbranch Body
fbranch (IfDec [BranchType SOACS]
ret IfSort
if_sort))
  (MapNesting PatternT Type
pat StmAux ()
aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    let ret' :: [ExtType]
ret' = (ExtType -> ExtType) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map (ExtType -> ExtSize -> ExtType
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp -> ExtSize
forall a. a -> Ext a
Free SubExp
w) [ExtType]
[BranchType SOACS]
ret
        pat' :: PatternT Type
pat' = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
pat

        ([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
        lam_ret :: [Type]
lam_ret = [Int] -> [Type] -> [Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType ([Type] -> [Type]) -> [Type] -> [Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [Type]
forall dec. Typed dec => PatternT dec -> [Type]
patternTypes PatternT Type
pat

        branch_pat' :: PatternT Type
branch_pat' =
          [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ (PatElemT Type -> PatElemT Type)
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> [a] -> [b]
map ((Type -> Type) -> PatElemT Type -> PatElemT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> SubExp -> Type
forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
`arrayOfRow` SubExp
w)) ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternElements PatternT Type
Pattern
branch_pat

        mkBranch :: Body -> m Body
mkBranch Body
branch = (Body -> m Body
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody (Body -> m Body) -> m Body -> m Body
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (m Body -> m Body) -> m Body -> m Body
forall a b. (a -> b) -> a -> b
$ do
          let lam :: LambdaT SOACS
lam = [LParam SOACS] -> Body -> [Type] -> LambdaT SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda [Param Type]
[LParam SOACS]
params Body
branch [Type]
lam_ret
              res :: [SubExp]
res = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> [VName] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [VName]
forall dec. PatternT dec -> [VName]
patternNames PatternT Type
branch_pat'
              map_bnd :: Stm
map_bnd = Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let PatternT Type
Pattern
branch_pat' StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w [VName]
arrs (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ LambdaT SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC LambdaT SOACS
lam
          Body -> m Body
forall (m :: * -> *) a. Monad m => a -> m a
return (Body -> m Body) -> Body -> m Body
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [SubExp] -> Body
forall lore. Bindable lore => Stms lore -> [SubExp] -> Body lore
mkBody (Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm Stm
map_bnd) [SubExp]
res

    Body
tbranch' <- Body -> m Body
mkBranch Body
tbranch
    Body
fbranch' <- Body -> m Body
mkBranch Body
fbranch
    Branch -> m Branch
forall (m :: * -> *) a. Monad m => a -> m a
return (Branch -> m Branch) -> Branch -> m Branch
forall a b. (a -> b) -> a -> b
$
      [Int]
-> Pattern
-> SubExp
-> Body
-> Body
-> IfDec (BranchType SOACS)
-> Branch
Branch [Int
0 .. PatternT Type -> Int
forall dec. PatternT dec -> Int
patternSize PatternT Type
pat Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] PatternT Type
Pattern
pat' SubExp
cond Body
tbranch' Body
fbranch' (IfDec (BranchType SOACS) -> Branch)
-> IfDec (BranchType SOACS) -> Branch
forall a b. (a -> b) -> a -> b
$
        [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [ExtType]
ret' IfSort
if_sort

interchangeBranch ::
  (MonadFreshNames m, HasScope SOACS m) =>
  KernelNest ->
  Branch ->
  m (Stms SOACS)
interchangeBranch :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> Branch -> m (Stms SOACS)
interchangeBranch KernelNest
nest Branch
loop = do
  (Branch
loop', Stms SOACS
bnds) <-
    Binder SOACS Branch -> m (Branch, Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS Branch -> m (Branch, Stms SOACS))
-> Binder SOACS Branch -> m (Branch, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (Branch -> LoopNesting -> Binder SOACS Branch)
-> Branch -> [LoopNesting] -> Binder SOACS Branch
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM Branch -> LoopNesting -> Binder SOACS Branch
forall (m :: * -> *).
MonadBinder m =>
Branch -> LoopNesting -> m Branch
interchangeBranch1 Branch
loop ([LoopNesting] -> Binder SOACS Branch)
-> [LoopNesting] -> Binder SOACS Branch
forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest
  Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
bnds Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (Branch -> Stm
branchStm Branch
loop')

data WithAccStm
  = WithAccStm [Int] Pattern [(Shape, [VName], Maybe (Lambda, [SubExp]))] Lambda

withAccStm :: WithAccStm -> Stm
withAccStm :: WithAccStm -> Stm
withAccStm (WithAccStm [Int]
_ Pattern
pat [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs LambdaT SOACS
lam) =
  Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let Pattern
pat (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) (Exp SOACS -> Stm) -> Exp SOACS -> Stm
forall a b. (a -> b) -> a -> b
$ [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
-> LambdaT SOACS -> Exp SOACS
forall lore.
[(ShapeBase SubExp, [VName], Maybe (Lambda lore, [SubExp]))]
-> Lambda lore -> ExpT lore
WithAcc [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs LambdaT SOACS
lam

interchangeWithAcc1 ::
  (MonadBinder m, Lore m ~ SOACS) =>
  WithAccStm ->
  LoopNesting ->
  m WithAccStm
interchangeWithAcc1 :: forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
WithAccStm -> LoopNesting -> m WithAccStm
interchangeWithAcc1
  (WithAccStm [Int]
perm Pattern
_withacc_pat [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs LambdaT SOACS
acc_lam)
  (MapNesting PatternT Type
map_pat StmAux ()
map_aux SubExp
w [(Param Type, VName)]
params_and_arrs) = do
    [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs' <- ((ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))
 -> m (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp])))
-> [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
-> m [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))
onInput [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs
    let lam_params :: [LParam SOACS]
lam_params = LambdaT SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT SOACS
acc_lam
    Param Type
iota_p <- String -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"iota_p" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
    LambdaT SOACS
acc_lam' <- SubExp -> LambdaT SOACS -> m (LambdaT SOACS)
trLam (VName -> SubExp
Var (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iota_p)) (LambdaT SOACS -> m (LambdaT SOACS))
-> (m [SubExp] -> m (LambdaT SOACS))
-> m [SubExp]
-> m (LambdaT SOACS)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< [LParam (Lore m)] -> m [SubExp] -> m (Lambda (Lore m))
forall (m :: * -> *).
MonadBinder m =>
[LParam (Lore m)] -> m [SubExp] -> m (Lambda (Lore m))
mkLambda [LParam (Lore m)]
[LParam SOACS]
lam_params (m [SubExp] -> m (LambdaT SOACS))
-> m [SubExp] -> m (LambdaT SOACS)
forall a b. (a -> b) -> a -> b
$ do
      VName
iota_w <-
        String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"acc_inter_iota" (Exp SOACS -> m VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
          SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
      let ([Param Type]
params, [VName]
arrs) = [(Param Type, VName)] -> ([Param Type], [VName])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param Type, VName)]
params_and_arrs
          maplam_ret :: [Type]
maplam_ret = LambdaT SOACS -> [Type]
forall lore. LambdaT lore -> [Type]
lambdaReturnType LambdaT SOACS
acc_lam
          maplam :: LambdaT SOACS
maplam = [LParam SOACS] -> Body -> [Type] -> LambdaT SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda (Param Type
iota_p Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
params) (LambdaT SOACS -> Body
forall lore. LambdaT lore -> BodyT lore
lambdaBody LambdaT SOACS
acc_lam) [Type]
maplam_ret
      StmAux () -> m [SubExp] -> m [SubExp]
forall (m :: * -> *) anylore a.
MonadBinder m =>
StmAux anylore -> m a -> m a
auxing StmAux ()
map_aux (m [SubExp] -> m [SubExp])
-> (Exp SOACS -> m [SubExp]) -> Exp SOACS -> m [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Lore m) -> m [SubExp]
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m [SubExp]
letTupExp' String
"withacc_inter" (Exp SOACS -> m [SubExp]) -> Exp SOACS -> m [SubExp]
forall a b. (a -> b) -> a -> b
$
        Op SOACS -> Exp SOACS
forall lore. Op lore -> ExpT lore
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall lore. SubExp -> [VName] -> ScremaForm lore -> SOAC lore
Screma SubExp
w (VName
iota_w VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
arrs) (LambdaT SOACS -> ScremaForm SOACS
forall lore. Lambda lore -> ScremaForm lore
mapSOAC LambdaT SOACS
maplam)
    let pat :: PatternT Type
pat = [PatElemT Type] -> [PatElemT Type] -> PatternT Type
forall dec. [PatElemT dec] -> [PatElemT dec] -> PatternT dec
Pattern [] ([PatElemT Type] -> PatternT Type)
-> [PatElemT Type] -> PatternT Type
forall a b. (a -> b) -> a -> b
$ [Int] -> [PatElemT Type] -> [PatElemT Type]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm ([PatElemT Type] -> [PatElemT Type])
-> [PatElemT Type] -> [PatElemT Type]
forall a b. (a -> b) -> a -> b
$ PatternT Type -> [PatElemT Type]
forall dec. PatternT dec -> [PatElemT dec]
patternValueElements PatternT Type
map_pat
        perm' :: [Int]
perm' = [Int
0 .. PatternT Type -> Int
forall dec. PatternT dec -> Int
patternSize PatternT Type
pat Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
    WithAccStm -> m WithAccStm
forall (f :: * -> *) a. Applicative f => a -> f a
pure (WithAccStm -> m WithAccStm) -> WithAccStm -> m WithAccStm
forall a b. (a -> b) -> a -> b
$ [Int]
-> Pattern
-> [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
-> LambdaT SOACS
-> WithAccStm
WithAccStm [Int]
perm' PatternT Type
Pattern
pat [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs' LambdaT SOACS
acc_lam'
    where
      num_accs :: Int
num_accs = [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
-> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
inputs
      acc_certs :: [VName]
acc_certs = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type] -> [VName]) -> [Param Type] -> [VName]
forall a b. (a -> b) -> a -> b
$ Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take Int
num_accs ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ LambdaT SOACS -> [LParam SOACS]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT SOACS
acc_lam
      onArr :: VName -> m VName
onArr VName
v =
        VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName -> m VName)
-> (Maybe (Param Type, VName) -> VName)
-> Maybe (Param Type, VName)
-> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName
-> ((Param Type, VName) -> VName)
-> Maybe (Param Type, VName)
-> VName
forall b a. b -> (a -> b) -> Maybe a -> b
maybe VName
v (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd (Maybe (Param Type, VName) -> m VName)
-> Maybe (Param Type, VName) -> m VName
forall a b. (a -> b) -> a -> b
$
          ((Param Type, VName) -> Bool)
-> [(Param Type, VName)] -> Maybe (Param Type, VName)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
v) (VName -> Bool)
-> ((Param Type, VName) -> VName) -> (Param Type, VName) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
params_and_arrs
      onInput :: (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))
onInput (ShapeBase SubExp
shape, [VName]
arrs, Maybe (LambdaT SOACS, [SubExp])
op) =
        ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
shape,,) ([VName]
 -> Maybe (LambdaT SOACS, [SubExp])
 -> (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp])))
-> m [VName]
-> m (Maybe (LambdaT SOACS, [SubExp])
      -> (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp])))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> m VName) -> [VName] -> m [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m VName
onArr [VName]
arrs m (Maybe (LambdaT SOACS, [SubExp])
   -> (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp])))
-> m (Maybe (LambdaT SOACS, [SubExp]))
-> m (ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((LambdaT SOACS, [SubExp]) -> m (LambdaT SOACS, [SubExp]))
-> Maybe (LambdaT SOACS, [SubExp])
-> m (Maybe (LambdaT SOACS, [SubExp]))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (LambdaT SOACS, [SubExp]) -> m (LambdaT SOACS, [SubExp])
forall {m :: * -> *} {lore} {shape} {u} {b}.
(MonadFreshNames m, LParamInfo lore ~ TypeBase shape u) =>
(LambdaT lore, b) -> m (LambdaT lore, b)
onOp Maybe (LambdaT SOACS, [SubExp])
op

      onOp :: (LambdaT lore, b) -> m (LambdaT lore, b)
onOp (LambdaT lore
op_lam, b
nes) = do
        -- We need to add an additional index parameter because we are
        -- extending the index space of the accumulator.
        Param (TypeBase shape u)
idx_p <- String -> TypeBase shape u -> m (Param (TypeBase shape u))
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" (TypeBase shape u -> m (Param (TypeBase shape u)))
-> TypeBase shape u -> m (Param (TypeBase shape u))
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
        (LambdaT lore, b) -> m (LambdaT lore, b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LambdaT lore
op_lam {lambdaParams :: [LParam lore]
lambdaParams = Param (TypeBase shape u)
idx_p Param (TypeBase shape u)
-> [Param (TypeBase shape u)] -> [Param (TypeBase shape u)]
forall a. a -> [a] -> [a]
: LambdaT lore -> [LParam lore]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams LambdaT lore
op_lam}, b
nes)

      trType :: TypeBase shape u -> TypeBase shape u
      trType :: forall shape u. TypeBase shape u -> TypeBase shape u
trType (Acc VName
acc ShapeBase SubExp
ispace [Type]
ts u
u)
        | VName
acc VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
acc_certs =
          VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
forall shape u.
VName -> ShapeBase SubExp -> [Type] -> u -> TypeBase shape u
Acc VName
acc ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
w] ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
ispace) [Type]
ts u
u
      trType TypeBase shape u
t = TypeBase shape u
t

      trParam :: Param (TypeBase shape u) -> Param (TypeBase shape u)
      trParam :: forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam = (TypeBase shape u -> TypeBase shape u)
-> Param (TypeBase shape u) -> Param (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TypeBase shape u -> TypeBase shape u
forall shape u. TypeBase shape u -> TypeBase shape u
trType

      trLam :: SubExp -> LambdaT SOACS -> m (LambdaT SOACS)
trLam SubExp
i (Lambda [LParam SOACS]
params Body
body [Type]
ret) =
        Scope SOACS -> m (LambdaT SOACS) -> m (LambdaT SOACS)
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall lore dec.
(LParamInfo lore ~ dec) =>
[Param dec] -> Scope lore
scopeOfLParams [Param Type]
[LParam SOACS]
params) (m (LambdaT SOACS) -> m (LambdaT SOACS))
-> m (LambdaT SOACS) -> m (LambdaT SOACS)
forall a b. (a -> b) -> a -> b
$
          [LParam SOACS] -> Body -> [Type] -> LambdaT SOACS
forall lore. [LParam lore] -> BodyT lore -> [Type] -> LambdaT lore
Lambda ((Param Type -> Param Type) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> Param Type
forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam [Param Type]
[LParam SOACS]
params) (Body -> [Type] -> LambdaT SOACS)
-> m Body -> m ([Type] -> LambdaT SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Body -> m Body
trBody SubExp
i Body
body m ([Type] -> LambdaT SOACS) -> m [Type] -> m (LambdaT SOACS)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Type] -> m [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
forall shape u. TypeBase shape u -> TypeBase shape u
trType [Type]
ret)

      trBody :: SubExp -> Body -> m Body
trBody SubExp
i (Body BodyDec SOACS
dec Stms SOACS
stms [SubExp]
res) =
        Stms SOACS -> m Body -> m Body
forall lore a (m :: * -> *) b.
(Scoped lore a, LocalScope lore m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms (m Body -> m Body) -> m Body -> m Body
forall a b. (a -> b) -> a -> b
$ BodyDec SOACS -> Stms SOACS -> [SubExp] -> Body
forall lore. BodyDec lore -> Stms lore -> [SubExp] -> BodyT lore
Body BodyDec SOACS
dec (Stms SOACS -> [SubExp] -> Body)
-> m (Stms SOACS) -> m ([SubExp] -> Body)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm -> m Stm) -> Stms SOACS -> m (Stms SOACS)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (SubExp -> Stm -> m Stm
trStm SubExp
i) Stms SOACS
stms m ([SubExp] -> Body) -> m [SubExp] -> m Body
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [SubExp] -> m [SubExp]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp]
res

      trStm :: SubExp -> Stm -> m Stm
trStm SubExp
i (Let Pattern
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) =
        Pattern -> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm
forall lore.
Pattern lore -> StmAux (ExpDec lore) -> Exp lore -> Stm lore
Let ((Type -> Type) -> PatternT Type -> PatternT Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Type -> Type
forall shape u. TypeBase shape u -> TypeBase shape u
trType PatternT Type
Pattern
pat) StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm) -> m (Exp SOACS) -> m Stm
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> Exp SOACS -> m (Exp SOACS)
trExp SubExp
i Exp SOACS
e

      trSOAC :: SubExp -> SOAC SOACS -> m (SOAC SOACS)
trSOAC SubExp
i = SOACMapper SOACS SOACS m -> SOAC SOACS -> m (SOAC SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
mapSOACM SOACMapper SOACS SOACS m
mapper
        where
          mapper :: SOACMapper SOACS SOACS m
mapper =
            SOACMapper Any Any m
forall (m :: * -> *) lore. Monad m => SOACMapper lore lore m
identitySOACMapper {mapOnSOACLambda :: LambdaT SOACS -> m (LambdaT SOACS)
mapOnSOACLambda = SubExp -> LambdaT SOACS -> m (LambdaT SOACS)
trLam SubExp
i}

      trExp :: SubExp -> Exp SOACS -> m (Exp SOACS)
trExp SubExp
i (WithAcc [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
acc_inputs LambdaT SOACS
lam) =
        [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
-> LambdaT SOACS -> Exp SOACS
forall lore.
[(ShapeBase SubExp, [VName], Maybe (Lambda lore, [SubExp]))]
-> Lambda lore -> ExpT lore
WithAcc [(ShapeBase SubExp, [VName], Maybe (LambdaT SOACS, [SubExp]))]
acc_inputs (LambdaT SOACS -> Exp SOACS) -> m (LambdaT SOACS) -> m (Exp SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> LambdaT SOACS -> m (LambdaT SOACS)
trLam SubExp
i LambdaT SOACS
lam
      trExp SubExp
i (BasicOp (UpdateAcc VName
acc [SubExp]
is [SubExp]
ses)) = do
        Type
acc_t <- VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
acc
        Exp SOACS -> m (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> m (Exp SOACS)) -> Exp SOACS -> m (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ case Type
acc_t of
          Acc VName
cert ShapeBase SubExp
_ [Type]
_ NoUniqueness
_
            | VName
cert VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
acc_certs ->
              BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: [SubExp]
is) [SubExp]
ses
          Type
_ ->
            BasicOp -> Exp SOACS
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
acc [SubExp]
is [SubExp]
ses
      trExp SubExp
i Exp SOACS
e = Mapper SOACS SOACS m -> Exp SOACS -> m (Exp SOACS)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
Mapper flore tlore m -> Exp flore -> m (Exp tlore)
mapExpM Mapper SOACS SOACS m
mapper Exp SOACS
e
        where
          mapper :: Mapper SOACS SOACS m
mapper =
            Mapper Any Any m
forall (m :: * -> *) lore. Monad m => Mapper lore lore m
identityMapper
              { mapOnBody :: Scope SOACS -> Body -> m Body
mapOnBody = \Scope SOACS
scope -> Scope SOACS -> m Body -> m Body
forall lore (m :: * -> *) a.
LocalScope lore m =>
Scope lore -> m a -> m a
localScope Scope SOACS
scope (m Body -> m Body) -> (Body -> m Body) -> Body -> m Body
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Body -> m Body
trBody SubExp
i,
                mapOnRetType :: RetType SOACS -> m (RetType SOACS)
mapOnRetType = DeclExtType -> m DeclExtType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (DeclExtType -> m DeclExtType)
-> (DeclExtType -> DeclExtType) -> DeclExtType -> m DeclExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeclExtType -> DeclExtType
forall shape u. TypeBase shape u -> TypeBase shape u
trType,
                mapOnBranchType :: BranchType SOACS -> m (BranchType SOACS)
mapOnBranchType = ExtType -> m ExtType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExtType -> m ExtType)
-> (ExtType -> ExtType) -> ExtType -> m ExtType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtType -> ExtType
forall shape u. TypeBase shape u -> TypeBase shape u
trType,
                mapOnFParam :: FParam -> m FParam
mapOnFParam = Param DeclType -> m (Param DeclType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param DeclType -> m (Param DeclType))
-> (Param DeclType -> Param DeclType)
-> Param DeclType
-> m (Param DeclType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Param DeclType
forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam,
                mapOnLParam :: LParam SOACS -> m (LParam SOACS)
mapOnLParam = Param Type -> m (Param Type)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type -> m (Param Type))
-> (Param Type -> Param Type) -> Param Type -> m (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> Param Type
forall shape u.
Param (TypeBase shape u) -> Param (TypeBase shape u)
trParam,
                mapOnOp :: Op SOACS -> m (Op SOACS)
mapOnOp = SubExp -> SOAC SOACS -> m (SOAC SOACS)
trSOAC SubExp
i
              }

interchangeWithAcc ::
  (MonadFreshNames m, HasScope SOACS m) =>
  KernelNest ->
  WithAccStm ->
  m (Stms SOACS)
interchangeWithAcc :: forall (m :: * -> *).
(MonadFreshNames m, HasScope SOACS m) =>
KernelNest -> WithAccStm -> m (Stms SOACS)
interchangeWithAcc KernelNest
nest WithAccStm
withacc = do
  (WithAccStm
withacc', Stms SOACS
stms) <-
    Binder SOACS WithAccStm -> m (WithAccStm, Stms SOACS)
forall (m :: * -> *) somelore lore a.
(MonadFreshNames m, HasScope somelore m,
 SameScope somelore lore) =>
Binder lore a -> m (a, Stms lore)
runBinder (Binder SOACS WithAccStm -> m (WithAccStm, Stms SOACS))
-> Binder SOACS WithAccStm -> m (WithAccStm, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ (WithAccStm -> LoopNesting -> Binder SOACS WithAccStm)
-> WithAccStm -> [LoopNesting] -> Binder SOACS WithAccStm
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM WithAccStm -> LoopNesting -> Binder SOACS WithAccStm
forall (m :: * -> *).
(MonadBinder m, Lore m ~ SOACS) =>
WithAccStm -> LoopNesting -> m WithAccStm
interchangeWithAcc1 WithAccStm
withacc ([LoopNesting] -> Binder SOACS WithAccStm)
-> [LoopNesting] -> Binder SOACS WithAccStm
forall a b. (a -> b) -> a -> b
$ [LoopNesting] -> [LoopNesting]
forall a. [a] -> [a]
reverse ([LoopNesting] -> [LoopNesting]) -> [LoopNesting] -> [LoopNesting]
forall a b. (a -> b) -> a -> b
$ KernelNest -> [LoopNesting]
kernelNestLoops KernelNest
nest
  Stms SOACS -> m (Stms SOACS)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stms SOACS -> m (Stms SOACS)) -> Stms SOACS -> m (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stms SOACS
stms Stms SOACS -> Stms SOACS -> Stms SOACS
forall a. Semigroup a => a -> a -> a
<> Stm -> Stms SOACS
forall lore. Stm lore -> Stms lore
oneStm (WithAccStm -> Stm
withAccStm WithAccStm
withacc')