module Futhark.AD.Rev.Scan (diffScan) where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (pairs, unpairs)

data FirstOrSecond = WrtFirst | WrtSecond

-- computes `d(x op y)/dx` or d(x op y)/dy
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which = do
  let len :: Int
len = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam0
  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  let p2diff :: [Param Type]
p2diff =
        case FirstOrSecond
which of
          FirstOrSecond
WrtFirst -> Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take Int
len ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
          FirstOrSecond
WrtSecond -> Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop Int
len ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
  [Adj]
p_adjs <- (Type -> ADM Adj) -> [Type] -> ADM [Adj]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM Adj
unitAdjOfType (Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam)
  VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops [Adj]
p_adjs ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
p2diff) Lambda SOACS
lam

-- Should generate something like:
-- `\ j -> let i = n - 1 - j
--         if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (0,1) )`
-- where `ys` is  the result of scan
--       `xs` is  the input  of scan
--       `ys_adj` is the known adjoint of ys
--       `j` draw values from `iota n`
mkScanFusedMapLam :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj = do
  Lambda SOACS
lam <- VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst
  [Type]
ys_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
ys
  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i
  [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
    ([VName] -> Result) -> ADM [VName] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"x"
      (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
            [SubExp]
zs <- (Type -> ADM SubExp) -> [Type] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ct_zero" (Exp SOACS -> ADM SubExp)
-> (Type -> Exp SOACS) -> Type -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp (Type -> Exp SOACS) -> (Type -> Type) -> Type -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ys_ts
            [SubExp]
os <- (Type -> ADM SubExp) -> [Type] -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ct_one" (Exp SOACS -> ADM SubExp)
-> (Type -> Exp SOACS) -> Type -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Exp SOACS
forall rep. Type -> Exp rep
oneExp (Type -> Exp SOACS) -> (Type -> Type) -> Type -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ys_ts
            Result -> ADM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> [SubExp] -> Result
forall a b. (a -> b) -> a -> b
$ [(SubExp, SubExp)] -> [SubExp]
forall a. [(a, a)] -> [a]
unpairs ([(SubExp, SubExp)] -> [SubExp]) -> [(SubExp, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [SubExp] -> [(SubExp, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
zs [SubExp]
os
        )
        ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            SubExp
j1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"j1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
            let index :: SubExp -> VName -> Type -> Exp rep
index SubExp
idx VName
arr Type
t = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
idx]
            [SubExp]
y_s <- [(VName, Type)] -> ((VName, Type) -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
ys_adj [Type]
ys_ts) (((VName, Type) -> ADM SubExp) -> ADM [SubExp])
-> ((VName, Type) -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \(VName
y_, Type
t) ->
              String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
y_ String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_j") (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> Type -> Exp SOACS
forall rep. SubExp -> VName -> Type -> Exp rep
index SubExp
j VName
y_ Type
t
            Result
lam_rs <-
              Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
lam ([ADM (Exp SOACS)] -> ADM Result)
-> ([Exp SOACS] -> [ADM (Exp SOACS)]) -> [Exp SOACS] -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Exp SOACS -> ADM (Exp SOACS)) -> [Exp SOACS] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Exp SOACS] -> ADM Result) -> [Exp SOACS] -> ADM Result
forall a b. (a -> b) -> a -> b
$
                (VName -> Type -> Exp SOACS) -> [VName] -> [Type] -> [Exp SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SubExp -> VName -> Type -> Exp SOACS
forall rep. SubExp -> VName -> Type -> Exp rep
index SubExp
j) [VName]
ys [Type]
ys_ts [Exp SOACS] -> [Exp SOACS] -> [Exp SOACS]
forall a. [a] -> [a] -> [a]
++ (VName -> Type -> Exp SOACS) -> [VName] -> [Type] -> [Exp SOACS]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (SubExp -> VName -> Type -> Exp SOACS
forall rep. SubExp -> VName -> Type -> Exp rep
index SubExp
j1) [VName]
xs [Type]
ys_ts
            Result -> ADM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [(SubExpRes, SubExpRes)] -> Result
forall a. [(a, a)] -> [a]
unpairs ([(SubExpRes, SubExpRes)] -> Result)
-> [(SubExpRes, SubExpRes)] -> Result
forall a b. (a -> b) -> a -> b
$ Result -> Result -> [(SubExpRes, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip ([SubExp] -> Result
subExpsRes [SubExp]
y_s) Result
lam_rs
        )

-- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b1 * b2)
mkScanLinFunO :: Type -> ADM (Scan SOACS)
mkScanLinFunO :: Type -> ADM (Scan SOACS)
mkScanLinFunO Type
t = do
  let pt :: PrimType
pt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
  SubExp
zero <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zeros" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
t
  SubExp
one <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"ones" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
oneExp Type
t
  [VName]
tmp <- (String -> ADM VName) -> [String] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM String -> ADM VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName [String
"a1", String
"b1", String
"a2", String
"b2"]
  let [VName
a1, VName
b1, VName
a2, VName
b2] = [VName]
tmp
      pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
  Lambda SOACS
lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ((VName -> Param Type) -> [VName] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v Type
t) [VName
a1, VName
b1, VName
a2, VName
b2]) (ADM Result -> ADM (Lambda SOACS))
-> (ADM [VName] -> ADM Result) -> ADM [VName] -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> Result) -> ADM [VName] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM (Lambda SOACS))
-> ADM [VName] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
    Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t) [VName
a1, VName
b1, VName
a2, VName
b2] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
a1', VName
b1', VName
a2', VName
b2'] -> do
      VName
x <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"x" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
a2' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ VName -> PrimExp VName
pet VName
b2' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
a1'
      VName
y <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"y" (Exp SOACS -> ADM VName)
-> (PrimExp VName -> ADM (Exp SOACS)) -> PrimExp VName -> ADM VName
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM VName) -> PrimExp VName -> ADM VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
b1' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
b2'
      [VName] -> ADM [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName
x, VName
y]
  Scan SOACS -> ADM (Scan SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
zero, SubExp
one]

-- build the map following the scan with linear-function-composition:
-- for each (ds,cs) length-n array results of the scan, combine them as:
--    `let rs = map2 (\ d_i c_i -> d_i + c_i * y_adj[n-1]) d c |> reverse`
-- but insert explicit indexing to reverse inside the map.
mkScan2ndMaps :: SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps :: SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps SubExp
w (Type
arr_tp, VName
y_adj, (VName
ds, VName
cs)) = do
  SubExp
nm1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"nm1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
  VName
y_adj_last <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
y_adj String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_last") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
y_adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
nm1]

  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  Lambda SOACS
lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i
    SubExp
j <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
    VName
dj <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
ds String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_dj") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
ds (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j]
    VName
cj <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
cs String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_cj") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
cs (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_tp [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
j]

    let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
arr_tp) (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
    ([VName] -> Result) -> ADM [VName] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName])
-> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
arr_tp)) [VName
y_adj_last, VName
dj, VName
cj] (([VName] -> [VName] -> ADM [VName]) -> ADM Result)
-> ([VName] -> [VName] -> ADM [VName]) -> ADM Result
forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
y_adj_last', VName
dj', VName
cj'] ->
      String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [VName])
-> (PrimExp VName -> ADM (Exp SOACS))
-> PrimExp VName
-> ADM [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM [VName]) -> PrimExp VName -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
dj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~+~ VName -> PrimExp VName
pet VName
cj' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
y_adj_last'

  VName
iota <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"after_scan" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
lam))

-- perform the final map, which is fusable with the maps obtained from `mkScan2ndMaps`
-- let xs_contribs =
--    map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r)
--         (iota n) xs rs
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
rs = do
  let eltps :: [Type]
eltps = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_lam
  Lambda SOACS
lam <- VjpOps -> Lambda SOACS -> FirstOrSecond -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond
  Param Type
par_i <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i
  [Param Type]
par_x <- ((VName, Type) -> ADM (Param Type))
-> [(VName, Type)] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(VName
x, Type
t) -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
x String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_par_x") Type
t) ([(VName, Type)] -> ADM [Param Type])
-> [(VName, Type)] -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [Type]
eltps
  [Param Type]
par_r <- ((VName, Type) -> ADM (Param Type))
-> [(VName, Type)] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\(VName
r, Type
t) -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
r String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_par_r") Type
t) ([(VName, Type)] -> ADM [Param Type])
-> [(VName, Type)] -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Type] -> [(VName, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
rs [Type]
eltps

  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param Type
par_i Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
par_x [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
par_r) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
      ([VName] -> Result) -> ADM [VName] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_contribs"
        (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
          ([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> ADM (Body (Rep ADM)))
-> [SubExp] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
par_r)
          ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              SubExp
im1 <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"im1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
              [SubExp]
ys_im1 <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y -> do
                Type
y_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
y
                String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
y String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_last") (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
y (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
y_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
im1]

              [VName]
lam_res <-
                (SubExpRes -> ADM VName) -> Result -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"const" (Exp SOACS -> ADM VName)
-> (SubExpRes -> Exp SOACS) -> SubExpRes -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp)
                  (Result -> ADM [VName]) -> ADM Result -> ADM [VName]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
lam ((SubExp -> ADM (Exp SOACS)) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ADM (Exp SOACS)
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp SOACS)]) -> [SubExp] -> [ADM (Exp SOACS)]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
par_x)

              ([[VName]] -> Result) -> ADM [[VName]] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([VName] -> Result
varsRes ([VName] -> Result)
-> ([[VName]] -> [VName]) -> [[VName]] -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[VName]] -> [VName]
forall a. Monoid a => [a] -> a
mconcat) (ADM [[VName]] -> ADM Result)
-> (((VName, VName, Type) -> ADM [VName]) -> ADM [[VName]])
-> ((VName, VName, Type) -> ADM [VName])
-> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(VName, VName, Type)]
-> ((VName, VName, Type) -> ADM [VName]) -> ADM [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM ([VName] -> [VName] -> [Type] -> [(VName, VName, Type)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [VName]
lam_res ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
par_r) [Type]
eltps) (((VName, VName, Type) -> ADM [VName]) -> ADM Result)
-> ((VName, VName, Type) -> ADM [VName]) -> ADM Result
forall a b. (a -> b) -> a -> b
$
                \(VName
lam_r, VName
r, Type
eltp) -> do
                  let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
eltp) (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var

                  Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest (Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
eltp) [VName
lam_r, VName
r] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
_ [VName
lam_r', VName
r'] ->
                    String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"res" (Exp SOACS -> ADM [VName])
-> (PrimExp VName -> ADM (Exp SOACS))
-> PrimExp VName
-> ADM [VName]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (PrimExp VName -> ADM [VName]) -> PrimExp VName -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ VName -> PrimExp VName
pet VName
lam_r' PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ VName -> PrimExp VName
pet VName
r'
          )

  VName
iota <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"scan_contribs" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
rs) ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [] [] Lambda SOACS
map_lam))

diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
  [VName]
ys_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
ys
  [Type]
as_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
as
  Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj
  [Scan SOACS]
scans_lin_fun_o <- (Type -> ADM (Scan SOACS)) -> [Type] -> ADM [Scan SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Type -> ADM (Scan SOACS)
mkScanLinFunO ([Type] -> ADM [Scan SOACS]) -> [Type] -> ADM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType (Lambda SOACS -> [Type]) -> Lambda SOACS -> [Type]
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan
  VName
iota <-
    String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [VName]
r_scan <-
    String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"adj_ctrb_scan" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] ([Scan SOACS] -> [Reduce SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep.
[Scan rep] -> [Reduce rep] -> Lambda rep -> ScremaForm rep
ScremaForm [Scan SOACS]
scans_lin_fun_o [] Lambda SOACS
map1_lam))
  [VName]
red_nms <- ((Type, VName, (VName, VName)) -> ADM VName)
-> [(Type, VName, (VName, VName))] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (SubExp -> (Type, VName, (VName, VName)) -> ADM VName
mkScan2ndMaps SubExp
w) ([Type]
-> [VName] -> [(VName, VName)] -> [(Type, VName, (VName, VName))]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
as_ts [VName]
ys_adj ([VName] -> [(VName, VName)]
forall a. [a] -> [(a, a)]
pairs [VName]
r_scan))
  [VName]
as_contribs <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
red_nms
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs