{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Loop (diffLoop, stripmineStms) where

import Control.Monad
import Data.Foldable (toList)
import Data.List (nub, (\\))
import qualified Data.Map as M
import Data.Maybe
import Futhark.AD.Rev.Monad
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (traverseFold)

-- | A convenience function to bring the components of a for-loop into
-- scope and throw an error if the passed 'Exp' is not a for-loop.
bindForLoop ::
  PrettyRep rep =>
  Exp rep ->
  ( [(Param (FParamInfo rep), SubExp)] ->
    LoopForm rep ->
    VName ->
    IntType ->
    SubExp ->
    [(Param (LParamInfo rep), VName)] ->
    Body rep ->
    a
  ) ->
  a
bindForLoop :: forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop (DoLoop [(FParam rep, SubExp)]
val_pats form :: LoopForm rep
form@(ForLoop VName
i IntType
it SubExp
bound [(LParam rep, VName)]
loop_vars) Body rep
body) [(FParam rep, SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(LParam rep, VName)]
-> Body rep
-> a
f =
  [(FParam rep, SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(LParam rep, VName)]
-> Body rep
-> a
f [(FParam rep, SubExp)]
val_pats LoopForm rep
form VName
i IntType
it SubExp
bound [(LParam rep, VName)]
loop_vars Body rep
body
bindForLoop Exp rep
e [(FParam rep, SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(LParam rep, VName)]
-> Body rep
-> a
_ = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> String -> a
forall a b. (a -> b) -> a -> b
$ String
"bindForLoop: not a for-loop:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Exp rep -> String
forall a. Pretty a => a -> String
pretty Exp rep
e

-- | A convenience function to rename a for-loop and then bind the
-- renamed components.
renameForLoop ::
  (MonadFreshNames m, Renameable rep, PrettyRep rep) =>
  Exp rep ->
  ( Exp rep ->
    [(Param (FParamInfo rep), SubExp)] ->
    LoopForm rep ->
    VName ->
    IntType ->
    SubExp ->
    [(Param (LParamInfo rep), VName)] ->
    Body rep ->
    m a
  ) ->
  m a
renameForLoop :: forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp rep
loop Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(Param (LParamInfo rep), VName)]
-> Body rep
-> m a
f = Exp rep -> m (Exp rep)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp rep
loop m (Exp rep) -> (Exp rep -> m a) -> m a
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Exp rep
loop' -> Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp rep
loop' (Exp rep
-> [(Param (FParamInfo rep), SubExp)]
-> LoopForm rep
-> VName
-> IntType
-> SubExp
-> [(Param (LParamInfo rep), VName)]
-> Body rep
-> m a
f Exp rep
loop')

-- | Is the loop a while-loop?
isWhileLoop :: Exp rep -> Bool
isWhileLoop :: forall rep. Exp rep -> Bool
isWhileLoop (DoLoop [(FParam rep, SubExp)]
_ WhileLoop {} Body rep
_) = Bool
True
isWhileLoop Exp rep
_ = Bool
False

-- | Transforms a 'ForLoop' into a 'ForLoop' with an empty list of
-- loop variables.
removeLoopVars :: MonadBuilder m => Exp (Rep m) -> m (Exp (Rep m))
removeLoopVars :: forall (m :: * -> *).
MonadBuilder m =>
Exp (Rep m) -> m (Exp (Rep m))
removeLoopVars Exp (Rep m)
loop =
  Exp (Rep m)
-> ([(Param (FParamInfo (Rep m)), SubExp)]
    -> LoopForm (Rep m)
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo (Rep m)), VName)]
    -> Body (Rep m)
    -> m (Exp (Rep m)))
-> m (Exp (Rep m))
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp (Rep m)
loop (([(Param (FParamInfo (Rep m)), SubExp)]
  -> LoopForm (Rep m)
  -> VName
  -> IntType
  -> SubExp
  -> [(Param (LParamInfo (Rep m)), VName)]
  -> Body (Rep m)
  -> m (Exp (Rep m)))
 -> m (Exp (Rep m)))
-> ([(Param (FParamInfo (Rep m)), SubExp)]
    -> LoopForm (Rep m)
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo (Rep m)), VName)]
    -> Body (Rep m)
    -> m (Exp (Rep m)))
-> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ \[(Param (FParamInfo (Rep m)), SubExp)]
val_pats LoopForm (Rep m)
form VName
i IntType
_it SubExp
_bound [(Param (LParamInfo (Rep m)), VName)]
loop_vars Body (Rep m)
body -> do
    let indexify :: (Param (LParamInfo (Rep m)), VName) -> m (VName, VName)
indexify (Param (LParamInfo (Rep m))
x_param, VName
xs) = do
          Type
xs_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
xs
          VName
x' <-
            String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString (VName -> String) -> VName -> String
forall a b. (a -> b) -> a -> b
$ Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
x_param) (Exp (Rep m) -> m VName)
-> (Slice SubExp -> Exp (Rep m)) -> Slice SubExp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
xs (Slice SubExp -> m VName) -> Slice SubExp -> m VName
forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
xs_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
i)]
          (VName, VName) -> m (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param (LParamInfo (Rep m)) -> VName
forall dec. Param dec -> VName
paramName Param (LParamInfo (Rep m))
x_param, VName
x')
    ([(VName, VName)]
substs_list, Stms (Rep m)
subst_stms) <- m [(VName, VName)] -> m ([(VName, VName)], Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m [(VName, VName)] -> m ([(VName, VName)], Stms (Rep m)))
-> m [(VName, VName)] -> m ([(VName, VName)], Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ ((Param (LParamInfo (Rep m)), VName) -> m (VName, VName))
-> [(Param (LParamInfo (Rep m)), VName)] -> m [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (Param (LParamInfo (Rep m)), VName) -> m (VName, VName)
indexify [(Param (LParamInfo (Rep m)), VName)]
loop_vars
    let Body BodyDec (Rep m)
aux' Stms (Rep m)
stms' [SubExpRes]
res' = Map VName VName -> Body (Rep m) -> Body (Rep m)
forall a. Substitute a => Map VName VName -> a -> a
substituteNames ([(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName, VName)]
substs_list) Body (Rep m)
body
    Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ [(Param (FParamInfo (Rep m)), SubExp)]
-> LoopForm (Rep m) -> Body (Rep m) -> Exp (Rep m)
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param (FParamInfo (Rep m)), SubExp)]
val_pats LoopForm (Rep m)
form (Body (Rep m) -> Exp (Rep m)) -> Body (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BodyDec (Rep m) -> Stms (Rep m) -> [SubExpRes] -> Body (Rep m)
forall rep. BodyDec rep -> Stms rep -> [SubExpRes] -> Body rep
Body BodyDec (Rep m)
aux' (Stms (Rep m)
subst_stms Stms (Rep m) -> Stms (Rep m) -> Stms (Rep m)
forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
stms') [SubExpRes]
res'

-- | Augments a while-loop to also compute the number of iterations.
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters :: Exp SOACS -> ADM SubExp
computeWhileIters (DoLoop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
b) Body SOACS
body) = do
  VName
bound_v <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound"
  let t :: TypeBase shape u
t = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase shape u) -> PrimType -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      bound_param :: Param DeclType
bound_param = Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
bound_v DeclType
forall {shape} {u}. TypeBase shape u
t
  SubExp
bound_init <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bound_init" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
forall {shape} {u}. TypeBase shape u
t
  Body SOACS
body' <- Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType
bound_param]) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
    ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      SubExp
bound_plus_one <-
        let one :: SubExp
one = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 (Int
1 :: Int)
         in String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bound+1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) (VName -> SubExp
Var VName
bound_v) SubExp
one
      Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
      [SubExpRes] -> ADM [SubExpRes]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> [SubExpRes]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> SubExpRes
subExpRes SubExp
bound_plus_one) [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. Semigroup a => a -> a -> a
<> Body SOACS -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body SOACS
body)
  [SubExp]
res <- String -> Exp (Rep ADM) -> ADM [SubExp]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"loop" (Exp (Rep ADM) -> ADM [SubExp]) -> Exp (Rep ADM) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop ((Param DeclType
bound_param, SubExp
bound_init) (Param DeclType, SubExp)
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. a -> [a] -> [a]
: [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats) (VName -> LoopForm SOACS
forall rep. VName -> LoopForm rep
WhileLoop VName
b) Body SOACS
body'
  SubExp -> ADM SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ [SubExp] -> SubExp
forall a. [a] -> a
head [SubExp]
res
computeWhileIters Exp SOACS
e = String -> ADM SubExp
forall a. HasCallStack => String -> a
error (String -> ADM SubExp) -> String -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ String
"convertWhileIters: not a while-loop:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> String
forall a. Pretty a => a -> String
pretty Exp SOACS
e

-- | Converts a 'WhileLoop' into a 'ForLoop'. Requires that the
-- surrounding 'DoLoop' is annotated with a @#[bound(n)]@ attribute,
-- where @n@ is an upper bound on the number of iterations of the
-- while-loop. The resulting for-loop will execute for @n@ iterations on
-- all inputs, so the tighter the bound the better.
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop :: SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se (DoLoop [(FParam SOACS, SubExp)]
val_pats (WhileLoop VName
cond) Body SOACS
body) =
  Scope SOACS -> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope SOACS)
-> [Param DeclType] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats) (ADM (Exp SOACS) -> ADM (Exp SOACS))
-> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ do
    VName
i <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"i"
    Body SOACS
body' <-
      [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
        [ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
            (Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
cond)
            (Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
body)
            ([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> ADM (Body (Rep ADM)))
-> [SubExp] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> SubExp)
-> [(Param DeclType, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> ((Param DeclType, SubExp) -> VName)
-> (Param DeclType, SubExp)
-> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName)
-> ((Param DeclType, SubExp) -> Param DeclType)
-> (Param DeclType, SubExp)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst) [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats)
        ]
    Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(FParam SOACS, SubExp)]
val_pats (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
Int64 SubExp
bound_se [(LParam SOACS, VName)]
forall a. Monoid a => a
mempty) Body SOACS
body'
convertWhileLoop SubExp
_ Exp SOACS
e = String -> ADM (Exp SOACS)
forall a. HasCallStack => String -> a
error (String -> ADM (Exp SOACS)) -> String -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ String
"convertWhileLoopBound: not a while-loop:\n" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Exp SOACS -> String
forall a. Pretty a => a -> String
pretty Exp SOACS
e

-- | @nestifyLoop n bound loop@ transforms a loop into a depth-@n@ loop nest
-- of @bound@-iteration loops. This transformation does not preserve
-- the original semantics of the loop: @n@ and @bound@ may be arbitrary and have
-- no relation to the number of iterations of @loop@.
nestifyLoop ::
  SubExp ->
  Integer ->
  Exp SOACS ->
  ADM (Exp SOACS)
nestifyLoop :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_se = SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
bound_se
  where
    nestifyLoop' :: SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop' SubExp
offset Integer
n Exp SOACS
loop = Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop [(Param DeclType, SubExp)]
-> LoopForm SOACS
-> VName
-> IntType
-> SubExp
-> [(Param Type, VName)]
-> Body SOACS
-> ADM (Exp SOACS)
[(FParam SOACS, SubExp)]
-> LoopForm SOACS
-> VName
-> IntType
-> SubExp
-> [(LParam SOACS, VName)]
-> Body SOACS
-> ADM (Exp SOACS)
nestify
      where
        nestify :: [(Param DeclType, SubExp)]
-> LoopForm SOACS
-> VName
-> IntType
-> SubExp
-> [(Param Type, VName)]
-> Body SOACS
-> ADM (Exp SOACS)
nestify [(Param DeclType, SubExp)]
val_pats LoopForm SOACS
_form VName
i IntType
it SubExp
_bound [(Param Type, VName)]
loop_vars Body SOACS
body
          | Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
1 = do
              Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM (Exp SOACS))
 -> ADM (Exp SOACS))
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Exp SOACS))
-> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
_form' VName
i' IntType
it' SubExp
_bound' [(LParam SOACS, VName)]
loop_vars' Body SOACS
body' -> do
                let loop_params :: [Param DeclType]
loop_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
val_pats
                    loop_params' :: [Param DeclType]
loop_params' = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
                    loop_inits' :: [SubExp]
loop_inits' = (Param DeclType -> SubExp) -> [Param DeclType] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (Param DeclType -> VName) -> Param DeclType -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
loop_params
                    val_pats'' :: [(Param DeclType, SubExp)]
val_pats'' = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
loop_params' [SubExp]
loop_inits'
                Body SOACS
outer_body <-
                  ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m [SubExpRes] -> m (Body (Rep m))
buildBody_ (ADM [SubExpRes] -> ADM (Body (Rep ADM)))
-> ADM [SubExpRes] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
                    SubExp
offset' <-
                      String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"offset" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
                        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
it Overflow
OverflowUndef) SubExp
offset (VName -> SubExp
Var VName
i)

                    Body SOACS
inner_body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
                      VName
i_inner <-
                        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_inner" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
                          BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
offset' (VName -> SubExp
Var VName
i')
                      Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
i' VName
i_inner) Body SOACS
body'

                    [VName]
inner_loop <-
                      String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"inner_loop"
                        (Exp SOACS -> ADM [VName]) -> ADM (Exp SOACS) -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop'
                          SubExp
offset'
                          (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
                          ([(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats'' (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i' IntType
it' SubExp
bound_se [(LParam SOACS, VName)]
loop_vars') Body SOACS
inner_body)
                    [SubExpRes] -> ADM [SubExpRes]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExpRes] -> ADM [SubExpRes]) -> [SubExpRes] -> ADM [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes [VName]
inner_loop
                Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$
                  [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound_se [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars) Body SOACS
outer_body
          | Integer
n Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
1 =
              Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i IntType
it SubExp
bound_se [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars) Body SOACS
body
          | Bool
otherwise = Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Exp SOACS
loop

-- | @stripmine n pat loop@ stripmines a loop into a depth-@n@ loop nest.
-- An additional @bound - (floor(bound^(1/n)))^n@-iteration remainder loop is
-- inserted after the stripmined loop which executes the remaining iterations
-- so that the stripmined loop is semantically equivalent to the original loop.
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine :: Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
pat Exp SOACS
loop = do
  Exp SOACS
loop' <- Exp (Rep ADM) -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
Exp (Rep m) -> m (Exp (Rep m))
removeLoopVars Exp (Rep ADM)
Exp SOACS
loop
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop' (([(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM (Stms SOACS))
 -> ADM (Stms SOACS))
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm SOACS
_form VName
_i IntType
it SubExp
bound [(LParam SOACS, VName)]
_loop_vars Body SOACS
_body -> do
    let n_root :: SubExp
n_root = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
Float64 (Double
1 Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
n :: Double)
    SubExp
bound_float <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bound_f64" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
UIToFP IntType
it FloatType
Float64) SubExp
bound
    SubExp
bound' <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bound" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FPow FloatType
Float64) SubExp
bound_float SubExp
n_root
    SubExp
bound_int <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bound_int" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (FloatType -> IntType -> ConvOp
FPToUI FloatType
Float64 IntType
it) SubExp
bound'
    SubExp
total_iters <-
      String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"total_iters" (Exp SOACS -> ADM SubExp)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM SubExp) -> BasicOp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
        BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> BinOp
Pow IntType
it) SubExp
bound_int (PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it Integer
n)
    SubExp
remain_iters <-
      String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"remain_iters" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
total_iters
    Exp SOACS
mined_loop <- SubExp -> Integer -> Exp SOACS -> ADM (Exp SOACS)
nestifyLoop SubExp
bound_int Integer
n Exp SOACS
loop
    Pat Type
pat' <- Pat Type -> ADM (Pat Type)
forall dec (m :: * -> *).
(Rename dec, MonadFreshNames m) =>
Pat dec -> m (Pat dec)
renamePat Pat Type
pat
    Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM (Stms SOACS))
 -> ADM (Stms SOACS))
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Stms SOACS))
-> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \Exp SOACS
_loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
_form' VName
i' IntType
it' SubExp
_bound' [(LParam SOACS, VName)]
loop_vars' Body SOACS
body' -> do
      Body SOACS
remain_body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        VName
i_remain <-
          String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"i_remain" (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowUndef) SubExp
total_iters (VName -> SubExp
Var VName
i')
        Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Body SOACS -> Body SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
i' VName
i_remain) Body SOACS
body'
      let loop_params_rem :: [Param DeclType]
loop_params_rem = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
          loop_inits_rem :: [SubExp]
loop_inits_rem = (PatElem Type -> SubExp) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp)
-> (PatElem Type -> VName) -> PatElem Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName) ([PatElem Type] -> [SubExp]) -> [PatElem Type] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat'
          val_pats_rem :: [(Param DeclType, SubExp)]
val_pats_rem = [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
loop_params_rem [SubExp]
loop_inits_rem
          remain_loop :: Exp SOACS
remain_loop = [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats_rem (VName
-> IntType -> SubExp -> [(LParam SOACS, VName)] -> LoopForm SOACS
forall rep.
VName -> IntType -> SubExp -> [(LParam rep, VName)] -> LoopForm rep
ForLoop VName
i' IntType
it' SubExp
remain_iters [(LParam SOACS, VName)]
loop_vars') Body SOACS
remain_body
      ADM () -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM () -> ADM (Stms (Rep ADM))) -> ADM () -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep ADM))
pat' Exp (Rep ADM)
Exp SOACS
mined_loop
        Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat Type
Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
Exp SOACS
remain_loop

-- | Stripmines a statement. Only has an effect when the statement's
-- expression is a for-loop with a @#[stripmine(n)]@ attribute, where
-- @n@ is the nesting depth.
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm :: Stm SOACS -> ADM (Stms SOACS)
stripmineStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@(DoLoop [(FParam SOACS, SubExp)]
_ ForLoop {} Body SOACS
_)) =
  case [Integer]
nums of
    (Integer
n : [Integer]
_) -> Integer -> Pat Type -> Exp SOACS -> ADM (Stms SOACS)
stripmine Integer
n Pat Type
Pat (LetDec SOACS)
pat Exp SOACS
loop
    [Integer]
_ -> Stms SOACS -> ADM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm
  where
    extractNum :: Attr -> Maybe Integer
extractNum (AttrComp Name
"stripmine" [AttrInt Integer
n]) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
n
    extractNum Attr
_ = Maybe Integer
forall a. Maybe a
Nothing
    nums :: [Integer]
nums = [Maybe Integer] -> [Integer]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Integer] -> [Integer]) -> [Maybe Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Attr -> Maybe Integer) -> Attrs -> [Maybe Integer]
forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
extractNum (Attrs -> [Maybe Integer]) -> Attrs -> [Maybe Integer]
forall a b. (a -> b) -> a -> b
$ StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
StmAux (ExpDec SOACS)
aux
stripmineStm Stm SOACS
stm = Stms SOACS -> ADM (Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm

stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms :: Stms SOACS -> ADM (Stms SOACS)
stripmineStms = (Stm SOACS -> ADM (Stms SOACS)) -> Stms SOACS -> ADM (Stms SOACS)
forall m (t :: * -> *) (f :: * -> *) a.
(Monoid m, Traversable t, Applicative f) =>
(a -> f m) -> t a -> f m
traverseFold Stm SOACS -> ADM (Stms SOACS)
stripmineStm

-- | Forward pass transformation of a loop. This includes modifying the loop
-- to save the loop values at each iteration onto a tape as well as copying
-- any consumed arrays in the loop's body and consuming said copies in lieu of
-- the originals (which will be consumed later in the reverse pass).
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop :: Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop =
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM ())
 -> ADM ())
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm SOACS
form VName
i IntType
_it SubExp
bound [(LParam SOACS, VName)]
_loop_vars Body SOACS
body -> do
    SubExp
bound64 <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 SubExp
bound
    let loop_params :: [Param DeclType]
loop_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats
        is_true_dep :: Param dec -> Bool
is_true_dep = Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool) -> (Param dec -> Attrs) -> Param dec -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> Attrs
forall dec. Param dec -> Attrs
paramAttrs
        dont_copy_params :: [Param DeclType]
dont_copy_params = (Param DeclType -> Bool) -> [Param DeclType] -> [Param DeclType]
forall a. (a -> Bool) -> [a] -> [a]
filter Param DeclType -> Bool
forall {dec}. Param dec -> Bool
is_true_dep [Param DeclType]
loop_params
        dont_copy :: [VName]
dont_copy = (Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
dont_copy_params
        loop_params_to_copy :: [Param DeclType]
loop_params_to_copy = [Param DeclType]
loop_params [Param DeclType] -> [Param DeclType] -> [Param DeclType]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Param DeclType]
dont_copy_params

    [SubExp]
empty_saved_array <-
      [Param DeclType] -> (Param DeclType -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
loop_params_to_copy ((Param DeclType -> ADM SubExp) -> ADM [SubExp])
-> (Param DeclType -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \Param DeclType
p ->
        String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p) String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_empty_saved")
          (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank (DeclType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Param DeclType -> DeclType
forall dec. Param dec -> dec
paramDec Param DeclType
p) ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness)

    (Body SOACS
body', ([PatElem Type]
saved_pats, [Param DeclType]
saved_params)) <- ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM (Body (Rep ADM), ([PatElem Type], [Param DeclType]))
forall (m :: * -> *) a.
MonadBuilder m =>
m ([SubExpRes], a) -> m (Body (Rep m), a)
buildBody (ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
 -> ADM (Body (Rep ADM), ([PatElem Type], [Param DeclType])))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM (Body (Rep ADM), ([PatElem Type], [Param DeclType]))
forall a b. (a -> b) -> a -> b
$
      Scope SOACS
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType]
loop_params) (ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
 -> ADM ([SubExpRes], ([PatElem Type], [Param DeclType])))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
forall a b. (a -> b) -> a -> b
$
        LoopForm SOACS
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form (ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
 -> ADM ([SubExpRes], ([PatElem Type], [Param DeclType])))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
forall a b. (a -> b) -> a -> b
$ do
          Map VName VName
copy_substs <- [VName] -> Body SOACS -> ADM (Map VName VName)
copyConsumedArrsInBody [VName]
dont_copy Body SOACS
body
          Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
          SubExp
i_i64 <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i
          ([VName]
saved_updates, [(PatElem Type, Param DeclType)]
saved_pats_params) <- ([(VName, (PatElem Type, Param DeclType))]
 -> ([VName], [(PatElem Type, Param DeclType)]))
-> ADM [(VName, (PatElem Type, Param DeclType))]
-> ADM ([VName], [(PatElem Type, Param DeclType)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, (PatElem Type, Param DeclType))]
-> ([VName], [(PatElem Type, Param DeclType)])
forall a b. [(a, b)] -> ([a], [b])
unzip (ADM [(VName, (PatElem Type, Param DeclType))]
 -> ADM ([VName], [(PatElem Type, Param DeclType)]))
-> ADM [(VName, (PatElem Type, Param DeclType))]
-> ADM ([VName], [(PatElem Type, Param DeclType)])
forall a b. (a -> b) -> a -> b
$
            [Param DeclType]
-> (Param DeclType -> ADM (VName, (PatElem Type, Param DeclType)))
-> ADM [(VName, (PatElem Type, Param DeclType))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Param DeclType]
loop_params_to_copy ((Param DeclType -> ADM (VName, (PatElem Type, Param DeclType)))
 -> ADM [(VName, (PatElem Type, Param DeclType))])
-> (Param DeclType -> ADM (VName, (PatElem Type, Param DeclType)))
-> ADM [(VName, (PatElem Type, Param DeclType))]
forall a b. (a -> b) -> a -> b
$ \Param DeclType
p -> do
              let v :: VName
v = Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p
                  t :: DeclType
t = Param DeclType -> DeclType
forall dec. Param dec -> dec
paramDec Param DeclType
p
              VName
saved_param_v <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ADM VName) -> String -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_saved"
              VName
saved_pat_v <- String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> ADM VName) -> String -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_saved"
              VName -> VName -> ADM ()
setLoopTape VName
v VName
saved_pat_v
              let saved_param :: Param DeclType
saved_param = Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
saved_param_v (DeclType -> Param DeclType) -> DeclType -> Param DeclType
forall a b. (a -> b) -> a -> b
$ DeclType -> ShapeBase SubExp -> Uniqueness -> DeclType
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf DeclType
t ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) Uniqueness
Unique
                  saved_pat :: PatElem Type
saved_pat = VName -> Type -> PatElem Type
forall dec. VName -> dec -> PatElem dec
PatElem VName
saved_pat_v (Type -> PatElem Type) -> Type -> PatElem Type
forall a b. (a -> b) -> a -> b
$ DeclType -> ShapeBase SubExp -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf DeclType
t ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp
bound64]) NoUniqueness
NoUniqueness
              VName
saved_update <-
                Scope SOACS -> ADM VName -> ADM VName
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams [Param DeclType
saved_param])
                  (ADM VName -> ADM VName) -> ADM VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ String -> VName -> Slice SubExp -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace
                    (VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_saved_update")
                    VName
saved_param_v
                    (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice (DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (DeclType -> Type) -> DeclType -> Type
forall a b. (a -> b) -> a -> b
$ Param DeclType -> DeclType
forall dec. Param dec -> dec
paramDec Param DeclType
saved_param) [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i_i64])
                  (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Exp SOACS -> Exp SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
copy_substs
                  (Exp SOACS -> Exp SOACS) -> Exp SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp
                  (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp
                  (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
              (VName, (PatElem Type, Param DeclType))
-> ADM (VName, (PatElem Type, Param DeclType))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
saved_update, (PatElem Type
saved_pat, Param DeclType
saved_param))
          ([SubExpRes], ([PatElem Type], [Param DeclType]))
-> ADM ([SubExpRes], ([PatElem Type], [Param DeclType]))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body SOACS
body [SubExpRes] -> [SubExpRes] -> [SubExpRes]
forall a. Semigroup a => a -> a -> a
<> [VName] -> [SubExpRes]
varsRes [VName]
saved_updates, [(PatElem Type, Param DeclType)]
-> ([PatElem Type], [Param DeclType])
forall a b. [(a, b)] -> ([a], [b])
unzip [(PatElem Type, Param DeclType)]
saved_pats_params)

    let pat' :: Pat Type
pat' = Pat Type
pat Pat Type -> Pat Type -> Pat Type
forall a. Semigroup a => a -> a -> a
<> [PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type]
saved_pats
        val_pats' :: [(Param DeclType, SubExp)]
val_pats' = [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats [(Param DeclType, SubExp)]
-> [(Param DeclType, SubExp)] -> [(Param DeclType, SubExp)]
forall a. Semigroup a => a -> a -> a
<> [Param DeclType] -> [SubExp] -> [(Param DeclType, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
saved_params [SubExp]
empty_saved_array
    Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
Pat (LetDec SOACS)
pat' StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
form Body SOACS
body'

-- | Construct a loop value-pattern for the adjoint of the
-- given variable.
valPatAdj :: VName -> ADM (Param DeclType, SubExp)
valPatAdj :: VName -> ADM (Param DeclType, SubExp)
valPatAdj VName
v = do
  VName
v_adj <- VName -> ADM VName
adjVName VName
v
  VName
init_adj <- VName -> ADM VName
lookupAdjVal VName
v
  Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
init_adj
  (Param DeclType, SubExp) -> ADM (Param DeclType, SubExp)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attrs -> VName -> DeclType -> Param DeclType
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v_adj (Type -> Uniqueness -> DeclType
forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl Type
t Uniqueness
Unique), VName -> SubExp
Var VName
init_adj)

valPatAdjs :: LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
valPatAdjs :: LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
valPatAdjs = (([VName] -> ADM [(Param DeclType, SubExp)])
-> LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (([VName] -> ADM [(Param DeclType, SubExp)])
 -> LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)]))
-> ((VName -> ADM (Param DeclType, SubExp))
    -> [VName] -> ADM [(Param DeclType, SubExp)])
-> (VName -> ADM (Param DeclType, SubExp))
-> LoopInfo [VName]
-> ADM (LoopInfo [(Param DeclType, SubExp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> ADM (Param DeclType, SubExp))
-> [VName] -> ADM [(Param DeclType, SubExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM) VName -> ADM (Param DeclType, SubExp)
valPatAdj

-- | Reverses a loop by substituting the loop index as well as reversing
-- the arrays that loop variables are bound to.
reverseIndices :: Exp SOACS -> ADM (Substitutions, Substitutions, Stms SOACS)
reverseIndices :: Exp SOACS -> ADM (Map VName VName, Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop = do
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Map VName VName, Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Map VName VName, Stms SOACS)
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM (Map VName VName, Map VName VName, Stms SOACS))
 -> ADM (Map VName VName, Map VName VName, Stms SOACS))
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM (Map VName VName, Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Map VName VName, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
_val_pats LoopForm SOACS
form VName
i IntType
it SubExp
bound [(LParam SOACS, VName)]
loop_vars Body SOACS
_body -> do
    SubExp
bound_minus_one <-
      LoopForm SOACS -> ADM SubExp -> ADM SubExp
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form (ADM SubExp -> ADM SubExp) -> ADM SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
        let one :: SubExp
one = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
         in String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"bound-1" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowUndef) SubExp
bound SubExp
one

    Map VName VName
var_arrays_substs <- ([(VName, VName)] -> Map VName VName)
-> ADM [(VName, VName)] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList (ADM [(VName, VName)] -> ADM (Map VName VName))
-> ADM [(VName, VName)] -> ADM (Map VName VName)
forall a b. (a -> b) -> a -> b
$
      LoopForm SOACS -> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form (ADM [(VName, VName)] -> ADM [(VName, VName)])
-> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ do
        [VName] -> (VName -> ADM (VName, VName)) -> ADM [(VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars) ((VName -> ADM (VName, VName)) -> ADM [(VName, VName)])
-> (VName -> ADM (VName, VName)) -> ADM [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \VName
xs -> do
          Type
xs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
xs
          VName
xs_rev <-
            String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"reverse" (Exp SOACS -> ADM VName)
-> (Slice SubExp -> Exp SOACS) -> Slice SubExp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
xs (Slice SubExp -> ADM VName) -> Slice SubExp -> ADM VName
forall a b. (a -> b) -> a -> b
$
              Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice
                Type
xs_t
                [SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
bound_minus_one SubExp
bound (PrimValue -> SubExp
Constant (IntValue -> PrimValue
IntValue (Int64 -> IntValue
Int64Value (-Int64
1))))]
          (VName, VName) -> ADM (VName, VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
xs, VName
xs_rev)

    (VName
i_rev, Stms SOACS
i_stms) <- ADM VName -> ADM (VName, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM VName -> ADM (VName, Stms (Rep ADM)))
-> ADM VName -> ADM (VName, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      LoopForm SOACS -> ADM VName -> ADM VName
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form (ADM VName -> ADM VName) -> ADM VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ do
        String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
i String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_rev") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
            BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
it Overflow
OverflowWrap) SubExp
bound_minus_one (VName -> SubExp
Var VName
i)

    (Map VName VName, Map VName VName, Stms SOACS)
-> ADM (Map VName VName, Map VName VName, Stms SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Map VName VName
var_arrays_substs, VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
i VName
i_rev, Stms SOACS
i_stms)

-- | Pures a substitution which substitutes values in the reverse
-- loop body with values from the tape.
restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM Substitutions
restore :: Stms SOACS -> [Param DeclType] -> VName -> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param DeclType]
loop_params' VName
i' =
  [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> ([Maybe (VName, VName)] -> [(VName, VName)])
-> [Maybe (VName, VName)]
-> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (VName, VName)] -> [(VName, VName)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (VName, VName)] -> Map VName VName)
-> ADM [Maybe (VName, VName)] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Param DeclType -> ADM (Maybe (VName, VName)))
-> [Param DeclType] -> ADM [Maybe (VName, VName)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Param DeclType -> ADM (Maybe (VName, VName))
f [Param DeclType]
loop_params'
  where
    dont_copy :: [VName]
dont_copy =
      (Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName ([Param DeclType] -> [VName]) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> a -> b
$ (Param DeclType -> Bool) -> [Param DeclType] -> [Param DeclType]
forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool)
-> (Param DeclType -> Attrs) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Attrs
forall dec. Param dec -> Attrs
paramAttrs) [Param DeclType]
loop_params'
    f :: Param DeclType -> ADM (Maybe (VName, VName))
f Param DeclType
p
      | VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dont_copy = do
          Maybe VName
m_vs <- VName -> ADM (Maybe VName)
lookupLoopTape VName
v
          case Maybe VName
m_vs of
            Maybe VName
Nothing -> Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, VName)
forall a. Maybe a
Nothing
            Just VName
vs -> do
              Type
vs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
              SubExp
i_i64' <- IntType -> SubExp -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
Int64 (SubExp -> ADM SubExp) -> SubExp -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i'
              VName
v' <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"restore" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
vs (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i_i64']
              Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
              VName
v'' <- case (Type
t, VName
v VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [VName]
consumed) of
                (Array {}, Bool
True) -> String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"restore_copy" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v'
                (Type, Bool)
_ -> VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v'
              Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (VName, VName) -> ADM (Maybe (VName, VName)))
-> Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall a b. (a -> b) -> a -> b
$ (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (VName
v, VName
v'')
      | Bool
otherwise = Maybe (VName, VName) -> ADM (Maybe (VName, VName))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (VName, VName)
forall a. Maybe a
Nothing
      where
        v :: VName
v = Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p
        consumed :: [VName]
consumed = Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst ((Stms (Aliases SOACS), AliasesAndConsumed)
 -> Stms (Aliases SOACS))
-> (Stms (Aliases SOACS), AliasesAndConsumed)
-> Stms (Aliases SOACS)
forall a b. (a -> b) -> a -> b
$ AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty Stms SOACS
stms_adj

-- | A type to keep track of and seperate values corresponding to different
-- parts of the loop.
data LoopInfo a = LoopInfo
  { forall a. LoopInfo a -> a
loopRes :: a,
    forall a. LoopInfo a -> a
loopFree :: a,
    forall a. LoopInfo a -> a
loopVars :: a,
    forall a. LoopInfo a -> a
loopVals :: a
  }
  deriving ((forall a b. (a -> b) -> LoopInfo a -> LoopInfo b)
-> (forall a b. a -> LoopInfo b -> LoopInfo a) -> Functor LoopInfo
forall a b. a -> LoopInfo b -> LoopInfo a
forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
$c<$ :: forall a b. a -> LoopInfo b -> LoopInfo a
fmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
$cfmap :: forall a b. (a -> b) -> LoopInfo a -> LoopInfo b
Functor, (forall m. Monoid m => LoopInfo m -> m)
-> (forall m a. Monoid m => (a -> m) -> LoopInfo a -> m)
-> (forall m a. Monoid m => (a -> m) -> LoopInfo a -> m)
-> (forall a b. (a -> b -> b) -> b -> LoopInfo a -> b)
-> (forall a b. (a -> b -> b) -> b -> LoopInfo a -> b)
-> (forall b a. (b -> a -> b) -> b -> LoopInfo a -> b)
-> (forall b a. (b -> a -> b) -> b -> LoopInfo a -> b)
-> (forall a. (a -> a -> a) -> LoopInfo a -> a)
-> (forall a. (a -> a -> a) -> LoopInfo a -> a)
-> (forall a. LoopInfo a -> [a])
-> (forall a. LoopInfo a -> Bool)
-> (forall a. LoopInfo a -> Int)
-> (forall a. Eq a => a -> LoopInfo a -> Bool)
-> (forall a. Ord a => LoopInfo a -> a)
-> (forall a. Ord a => LoopInfo a -> a)
-> (forall a. Num a => LoopInfo a -> a)
-> (forall a. Num a => LoopInfo a -> a)
-> Foldable LoopInfo
forall a. Eq a => a -> LoopInfo a -> Bool
forall a. Num a => LoopInfo a -> a
forall a. Ord a => LoopInfo a -> a
forall m. Monoid m => LoopInfo m -> m
forall a. LoopInfo a -> Bool
forall a. LoopInfo a -> Int
forall a. LoopInfo a -> [a]
forall a. (a -> a -> a) -> LoopInfo a -> a
forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
product :: forall a. Num a => LoopInfo a -> a
$cproduct :: forall a. Num a => LoopInfo a -> a
sum :: forall a. Num a => LoopInfo a -> a
$csum :: forall a. Num a => LoopInfo a -> a
minimum :: forall a. Ord a => LoopInfo a -> a
$cminimum :: forall a. Ord a => LoopInfo a -> a
maximum :: forall a. Ord a => LoopInfo a -> a
$cmaximum :: forall a. Ord a => LoopInfo a -> a
elem :: forall a. Eq a => a -> LoopInfo a -> Bool
$celem :: forall a. Eq a => a -> LoopInfo a -> Bool
length :: forall a. LoopInfo a -> Int
$clength :: forall a. LoopInfo a -> Int
null :: forall a. LoopInfo a -> Bool
$cnull :: forall a. LoopInfo a -> Bool
toList :: forall a. LoopInfo a -> [a]
$ctoList :: forall a. LoopInfo a -> [a]
foldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
$cfoldr1 :: forall a. (a -> a -> a) -> LoopInfo a -> a
foldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> LoopInfo a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
$cfoldr :: forall a b. (a -> b -> b) -> b -> LoopInfo a -> b
foldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> LoopInfo a -> m
fold :: forall m. Monoid m => LoopInfo m -> m
$cfold :: forall m. Monoid m => LoopInfo m -> m
Foldable, Functor LoopInfo
Foldable LoopInfo
Functor LoopInfo
-> Foldable LoopInfo
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> LoopInfo a -> f (LoopInfo b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    LoopInfo (f a) -> f (LoopInfo a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> LoopInfo a -> m (LoopInfo b))
-> (forall (m :: * -> *) a.
    Monad m =>
    LoopInfo (m a) -> m (LoopInfo a))
-> Traversable LoopInfo
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
sequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
$csequence :: forall (m :: * -> *) a. Monad m => LoopInfo (m a) -> m (LoopInfo a)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> LoopInfo a -> m (LoopInfo b)
sequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
$csequenceA :: forall (f :: * -> *) a.
Applicative f =>
LoopInfo (f a) -> f (LoopInfo a)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LoopInfo a -> f (LoopInfo b)
Traversable, Int -> LoopInfo a -> String -> String
[LoopInfo a] -> String -> String
LoopInfo a -> String
(Int -> LoopInfo a -> String -> String)
-> (LoopInfo a -> String)
-> ([LoopInfo a] -> String -> String)
-> Show (LoopInfo a)
forall a. Show a => Int -> LoopInfo a -> String -> String
forall a. Show a => [LoopInfo a] -> String -> String
forall a. Show a => LoopInfo a -> String
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
showList :: [LoopInfo a] -> String -> String
$cshowList :: forall a. Show a => [LoopInfo a] -> String -> String
show :: LoopInfo a -> String
$cshow :: forall a. Show a => LoopInfo a -> String
showsPrec :: Int -> LoopInfo a -> String -> String
$cshowsPrec :: forall a. Show a => Int -> LoopInfo a -> String -> String
Show)

-- | Transforms a for-loop into its reverse-mode derivative.
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop =
  Exp SOACS
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall rep a.
PrettyRep rep =>
Exp rep
-> ([(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> a)
-> a
bindForLoop Exp SOACS
loop (([(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM ())
 -> ADM ())
-> ([(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$ \[(FParam SOACS, SubExp)]
val_pats LoopForm SOACS
_form VName
_i IntType
_it SubExp
_bound [(LParam SOACS, VName)]
_loop_vars Body SOACS
_body ->
    Exp SOACS
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall (m :: * -> *) rep a.
(MonadFreshNames m, Renameable rep, PrettyRep rep) =>
Exp rep
-> (Exp rep
    -> [(Param (FParamInfo rep), SubExp)]
    -> LoopForm rep
    -> VName
    -> IntType
    -> SubExp
    -> [(Param (LParamInfo rep), VName)]
    -> Body rep
    -> m a)
-> m a
renameForLoop Exp SOACS
loop ((Exp SOACS
  -> [(FParam SOACS, SubExp)]
  -> LoopForm SOACS
  -> VName
  -> IntType
  -> SubExp
  -> [(LParam SOACS, VName)]
  -> Body SOACS
  -> ADM ())
 -> ADM ())
-> (Exp SOACS
    -> [(FParam SOACS, SubExp)]
    -> LoopForm SOACS
    -> VName
    -> IntType
    -> SubExp
    -> [(LParam SOACS, VName)]
    -> Body SOACS
    -> ADM ())
-> ADM ()
forall a b. (a -> b) -> a -> b
$
      \Exp SOACS
loop' [(FParam SOACS, SubExp)]
val_pats' LoopForm SOACS
form' VName
i' IntType
_it' SubExp
_bound' [(LParam SOACS, VName)]
loop_vars' Body SOACS
body' -> do
        let loop_params :: [Param DeclType]
loop_params = ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats
            ([Param DeclType]
loop_params', [SubExp]
loop_vals') = [(Param DeclType, SubExp)] -> ([Param DeclType], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pats'
            loop_var_arrays' :: [VName]
loop_var_arrays' = ((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type, VName) -> VName
forall a b. (a, b) -> b
snd [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars'
            getVName :: SubExp -> Maybe VName
getVName Constant {} = Maybe VName
forall a. Maybe a
Nothing
            getVName (Var VName
v) = VName -> Maybe VName
forall a. a -> Maybe a
Just VName
v
            loop_vnames :: LoopInfo [VName]
loop_vnames =
              LoopInfo :: forall a. a -> a -> a -> a -> LoopInfo a
LoopInfo
                { loopRes :: [VName]
loopRes = (SubExpRes -> Maybe VName) -> [SubExpRes] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExpRes -> Maybe VName
subExpResVName ([SubExpRes] -> [VName]) -> [SubExpRes] -> [VName]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body SOACS
body',
                  loopFree :: [VName]
loopFree =
                    (Names -> [VName]
namesToList (Exp SOACS -> Names
forall a. FreeIn a => a -> Names
freeIn Exp SOACS
loop') [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
loop_var_arrays') [VName] -> [VName] -> [VName]
forall a. Eq a => [a] -> [a] -> [a]
\\ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals',
                  loopVars :: [VName]
loopVars = [VName]
loop_var_arrays',
                  loopVals :: [VName]
loopVals = [VName] -> [VName]
forall a. Eq a => [a] -> [a]
nub ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ (SubExp -> Maybe VName) -> [SubExp] -> [VName]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe SubExp -> Maybe VName
getVName [SubExp]
loop_vals'
                }

        Map VName VName -> ADM ()
renameLoopTape (Map VName VName -> ADM ()) -> Map VName VName -> ADM ()
forall a b. (a -> b) -> a -> b
$ [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$ [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
loop_params) ((Param DeclType -> VName) -> [Param DeclType] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param DeclType -> VName
forall dec. Param dec -> VName
paramName [Param DeclType]
loop_params')

        [(SubExpRes, PatElem Type)]
-> ((SubExpRes, PatElem Type) -> ADM ()) -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([SubExpRes] -> [PatElem Type] -> [(SubExpRes, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Body SOACS -> [SubExpRes]
forall rep. Body rep -> [SubExpRes]
bodyResult Body SOACS
body') ([PatElem Type] -> [(SubExpRes, PatElem Type)])
-> [PatElem Type] -> [(SubExpRes, PatElem Type)]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) (((SubExpRes, PatElem Type) -> ADM ()) -> ADM ())
-> ((SubExpRes, PatElem Type) -> ADM ()) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \(SubExpRes
se_res, PatElem Type
pe) ->
          case SubExpRes -> Maybe VName
subExpResVName SubExpRes
se_res of
            Just VName
v -> VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> ADM Adj -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj (PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
            Maybe VName
Nothing -> () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

        (Map VName VName
var_array_substs, Map VName VName
i_subst, Stms SOACS
i_stms) <-
          Exp SOACS -> ADM (Map VName VName, Map VName VName, Stms SOACS)
reverseIndices Exp SOACS
loop'

        LoopInfo [(Param DeclType, SubExp)]
val_pat_adjs <- LoopInfo [VName] -> ADM (LoopInfo [(Param DeclType, SubExp)])
valPatAdjs LoopInfo [VName]
loop_vnames
        let val_pat_adjs_list :: [(Param DeclType, SubExp)]
val_pat_adjs_list = [[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)])
-> [[(Param DeclType, SubExp)]] -> [(Param DeclType, SubExp)]
forall a b. (a -> b) -> a -> b
$ LoopInfo [(Param DeclType, SubExp)] -> [[(Param DeclType, SubExp)]]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [(Param DeclType, SubExp)]
val_pat_adjs

        (LoopInfo [VName]
loop_adjs, Stms SOACS
stms_adj) <- ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM)))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$
          LoopForm SOACS -> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf LoopForm SOACS
form' (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName]))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$
            Scope SOACS -> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams (((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
val_pat_adjs_list [Param DeclType] -> [Param DeclType] -> [Param DeclType]
forall a. Semigroup a => a -> a -> a
<> [Param DeclType]
loop_params')) (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName]))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$ do
              Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms Stms (Rep ADM)
Stms SOACS
i_stms
              (LoopInfo [VName]
loop_adjs, Stms SOACS
stms_adj) <- ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM)))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName], Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$
                ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a. ADM a -> ADM a
subAD (ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName]))
-> ADM (LoopInfo [VName]) -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$ do
                  ((Param DeclType, SubExp) -> VName -> ADM ())
-> [(Param DeclType, SubExp)] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
                    (\(Param DeclType, SubExp)
val_pat VName
v -> VName -> VName -> ADM ()
insAdj VName
v (Param DeclType -> VName
forall dec. Param dec -> VName
paramName (Param DeclType -> VName) -> Param DeclType -> VName
forall a b. (a -> b) -> a -> b
$ (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst (Param DeclType, SubExp)
val_pat))
                    [(Param DeclType, SubExp)]
val_pat_adjs_list
                    ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [[VName]]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_vnames)
                  Stms SOACS -> ADM ()
diffStms (Stms SOACS -> ADM ()) -> Stms SOACS -> ADM ()
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body'

                  let update_var_arrays :: VName -> VName -> ADM ()
update_var_arrays VName
v VName
vs = do
                        Type
vs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
vs
                        VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
                        Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
vs_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> SubExp -> DimIndex SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
i']) VName
vs VName
v_adj
                  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_
                    VName -> VName -> ADM ()
update_var_arrays
                    (((Param Type, VName) -> VName) -> [(Param Type, VName)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Param Type -> VName
forall dec. Param dec -> VName
paramName (Param Type -> VName)
-> ((Param Type, VName) -> Param Type)
-> (Param Type, VName)
-> VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Param Type, VName) -> Param Type
forall a b. (a, b) -> a
fst) [(Param Type, VName)]
[(LParam SOACS, VName)]
loop_vars')
                    (LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_vnames)

                  [VName]
loop_res_adjs <- (Param DeclType -> ADM VName) -> [Param DeclType] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> ADM VName
lookupAdjVal (VName -> ADM VName)
-> (Param DeclType -> VName) -> Param DeclType -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> VName
forall dec. Param dec -> VName
paramName) [Param DeclType]
loop_params'
                  [VName]
loop_free_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_vnames
                  [VName]
loop_vars_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_vnames
                  [VName]
loop_vals_adjs <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal ([VName] -> ADM [VName]) -> [VName] -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVals LoopInfo [VName]
loop_vnames

                  LoopInfo [VName] -> ADM (LoopInfo [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LoopInfo [VName] -> ADM (LoopInfo [VName]))
-> LoopInfo [VName] -> ADM (LoopInfo [VName])
forall a b. (a -> b) -> a -> b
$
                    LoopInfo :: forall a. a -> a -> a -> a -> LoopInfo a
LoopInfo
                      { loopRes :: [VName]
loopRes = [VName]
loop_res_adjs,
                        loopFree :: [VName]
loopFree = [VName]
loop_free_adjs,
                        loopVars :: [VName]
loopVars = [VName]
loop_vars_adjs,
                        loopVals :: [VName]
loopVals = [VName]
loop_vals_adjs
                      }
              (Map VName VName
substs, Stms SOACS
restore_stms) <-
                ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM)))
-> ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> [Param DeclType] -> VName -> ADM (Map VName VName)
restore Stms SOACS
stms_adj [Param DeclType]
loop_params' VName
i'
              Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
i_subst Stms SOACS
restore_stms
              Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
i_subst (Stms SOACS -> Stms SOACS) -> Stms SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms SOACS
stms_adj
              LoopInfo [VName] -> ADM (LoopInfo [VName])
forall (f :: * -> *) a. Applicative f => a -> f a
pure LoopInfo [VName]
loop_adjs

        Stms SOACS -> ADM () -> ADM ()
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stms SOACS
stms_adj (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$
          Scope SOACS -> ADM () -> ADM ()
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param DeclType] -> Scope SOACS
forall rep dec. (FParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfFParams ([Param DeclType] -> Scope SOACS)
-> [Param DeclType] -> Scope SOACS
forall a b. (a -> b) -> a -> b
$ ((Param DeclType, SubExp) -> Param DeclType)
-> [(Param DeclType, SubExp)] -> [Param DeclType]
forall a b. (a -> b) -> [a] -> [b]
map (Param DeclType, SubExp) -> Param DeclType
forall a b. (a, b) -> a
fst [(Param DeclType, SubExp)]
val_pat_adjs_list) (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
            let body_adj :: Body SOACS
body_adj = Stms SOACS -> [SubExpRes] -> Body SOACS
forall rep. Buildable rep => Stms rep -> [SubExpRes] -> Body rep
mkBody Stms SOACS
stms_adj ([SubExpRes] -> Body SOACS) -> [SubExpRes] -> Body SOACS
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExpRes]
varsRes ([VName] -> [SubExpRes]) -> [VName] -> [SubExpRes]
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[VName]] -> [VName]) -> [[VName]] -> [VName]
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [[VName]]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList LoopInfo [VName]
loop_adjs
                restore_true_deps :: Map VName VName
restore_true_deps = [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> [(VName, VName)] -> Map VName VName
forall a b. (a -> b) -> a -> b
$
                  (((Param DeclType, PatElem Type) -> Maybe (VName, VName))
 -> [(Param DeclType, PatElem Type)] -> [(VName, VName)])
-> [(Param DeclType, PatElem Type)]
-> ((Param DeclType, PatElem Type) -> Maybe (VName, VName))
-> [(VName, VName)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((Param DeclType, PatElem Type) -> Maybe (VName, VName))
-> [(Param DeclType, PatElem Type)] -> [(VName, VName)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ([Param DeclType]
-> [PatElem Type] -> [(Param DeclType, PatElem Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param DeclType]
loop_params' ([PatElem Type] -> [(Param DeclType, PatElem Type)])
-> [PatElem Type] -> [(Param DeclType, PatElem Type)]
forall a b. (a -> b) -> a -> b
$ Pat Type -> [PatElem Type]
forall dec. Pat dec -> [PatElem dec]
patElems Pat Type
pat) (((Param DeclType, PatElem Type) -> Maybe (VName, VName))
 -> [(VName, VName)])
-> ((Param DeclType, PatElem Type) -> Maybe (VName, VName))
-> [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ \(Param DeclType
p, PatElem Type
pe) ->
                    if Param DeclType
p Param DeclType -> [Param DeclType] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` (Param DeclType -> Bool) -> [Param DeclType] -> [Param DeclType]
forall a. (a -> Bool) -> [a] -> [a]
filter (Attr -> Attrs -> Bool
inAttrs (Name -> Attr
AttrName Name
"true_dep") (Attrs -> Bool)
-> (Param DeclType -> Attrs) -> Param DeclType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param DeclType -> Attrs
forall dec. Param dec -> Attrs
paramAttrs) [Param DeclType]
loop_params'
                      then (VName, VName) -> Maybe (VName, VName)
forall a. a -> Maybe a
Just (Param DeclType -> VName
forall dec. Param dec -> VName
paramName Param DeclType
p, PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pe)
                      else Maybe (VName, VName)
forall a. Maybe a
Nothing
            [VName]
adjs' <-
              String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"loop_adj" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
                Map VName VName -> Exp SOACS -> Exp SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames (Map VName VName
restore_true_deps Map VName VName -> Map VName VName -> Map VName VName
forall a. Semigroup a => a -> a -> a
<> Map VName VName
var_array_substs) (Exp SOACS -> Exp SOACS) -> Exp SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$
                  [(FParam SOACS, SubExp)]
-> LoopForm SOACS -> Body SOACS -> Exp SOACS
forall rep.
[(FParam rep, SubExp)] -> LoopForm rep -> Body rep -> Exp rep
DoLoop [(Param DeclType, SubExp)]
[(FParam SOACS, SubExp)]
val_pat_adjs_list LoopForm SOACS
form' Body SOACS
body_adj
            let ([VName]
loop_res_adjs, [VName]
loop_free_var_val_adjs) =
                  Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopRes LoopInfo [VName]
loop_adjs) [VName]
adjs'
                ([VName]
loop_free_adjs, [VName]
loop_var_val_adjs) =
                  Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_adjs) [VName]
loop_free_var_val_adjs
                ([VName]
loop_var_adjs, [VName]
loop_val_adjs) =
                  Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([VName] -> Int) -> [VName] -> Int
forall a b. (a -> b) -> a -> b
$ LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_adjs) [VName]
loop_var_val_adjs
            ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
              (SubExp -> VName -> ADM ()) -> [SubExp] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExp -> VName -> ADM ()
updateSubExpAdj [SubExp]
loop_vals' [VName]
loop_res_adjs
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopFree LoopInfo [VName]
loop_vnames) [VName]
loop_free_adjs
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVars LoopInfo [VName]
loop_vnames) [VName]
loop_var_adjs
              (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (LoopInfo [VName] -> [VName]
forall a. LoopInfo a -> a
loopVals LoopInfo [VName]
loop_vnames) [VName]
loop_val_adjs

-- | Transforms a loop into its reverse-mode derivative.
diffLoop :: (Stms SOACS -> ADM ()) -> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop :: (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
loop ADM ()
m
  | Exp SOACS -> Bool
forall rep. Exp rep -> Bool
isWhileLoop Exp SOACS
loop =
      let getBound :: Attr -> Maybe Integer
getBound (AttrComp Name
"bound" [AttrInt Integer
b]) = Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
b
          getBound Attr
_ = Maybe Integer
forall a. Maybe a
Nothing
          bounds :: [Integer]
bounds = [Maybe Integer] -> [Integer]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe Integer] -> [Integer]) -> [Maybe Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Attr -> Maybe Integer) -> Attrs -> [Maybe Integer]
forall a. (Attr -> a) -> Attrs -> [a]
mapAttrs Attr -> Maybe Integer
getBound (Attrs -> [Maybe Integer]) -> Attrs -> [Maybe Integer]
forall a b. (a -> b) -> a -> b
$ StmAux () -> Attrs
forall dec. StmAux dec -> Attrs
stmAuxAttrs StmAux ()
aux
       in case [Integer]
bounds of
            (Integer
bound : [Integer]
_) -> do
              let bound_se :: SubExp
bound_se = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Integer -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
Int64 Integer
bound
              Exp SOACS
for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound_se Exp SOACS
loop
              (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
for_loop ADM ()
m
            [Integer]
_ -> do
              SubExp
bound <- Exp SOACS -> ADM SubExp
computeWhileIters Exp SOACS
loop
              Exp SOACS
for_loop <- SubExp -> Exp SOACS -> ADM (Exp SOACS)
convertWhileLoop SubExp
bound (Exp SOACS -> ADM (Exp SOACS))
-> ADM (Exp SOACS) -> ADM (Exp SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp SOACS -> ADM (Exp SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp Exp SOACS
loop
              (Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat StmAux ()
aux Exp SOACS
for_loop ADM ()
m
  | Bool
otherwise = do
      Pat Type -> StmAux () -> Exp SOACS -> ADM ()
fwdLoop Pat Type
pat StmAux ()
aux Exp SOACS
loop
      ADM ()
m
      (Stms SOACS -> ADM ()) -> Pat Type -> Exp SOACS -> ADM ()
revLoop Stms SOACS -> ADM ()
diffStms Pat Type
pat Exp SOACS
loop