module Futhark.AD.Rev.Scan (diffScan, diffScanVec, diffScanAdd) where

import Control.Monad
import Data.List (transpose)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (chunk)

data FirstOrSecond = WrtFirst | WrtSecond

identityM :: Int -> Type -> ADM [[SubExp]]
identityM :: Int -> Type -> ADM [[SubExp]]
identityM Int
n Type
t =
  ([Exp SOACS] -> ADM [SubExp]) -> [[Exp SOACS]] -> ADM [[SubExp]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
    ((Exp SOACS -> ADM SubExp) -> [Exp SOACS] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"id"))
    [[if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j then Type -> Exp SOACS
forall rep. Type -> Exp rep
oneExp Type
t else Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
t | Int
i <- [Int
1 .. Int
n]] | Int
j <- [Int
1 .. Int
n]]

matrixMul :: [[PrimExp VName]] -> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul :: [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
m1 [[PrimExp VName]]
m2 PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [[(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
r [PrimExp VName]
q | [PrimExp VName]
q <- [[PrimExp VName]] -> [[PrimExp VName]]
forall a. [[a]] -> [[a]]
transpose [[PrimExp VName]]
m2] | [PrimExp VName]
r <- [[PrimExp VName]]
m1]

matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul :: [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
m [PrimExp VName]
v PrimType
t =
  let zero :: PrimExp VName
zero = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (SubExp -> PrimExp VName) -> SubExp -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
   in [(PrimExp VName -> PrimExp VName -> PrimExp VName)
-> PrimExp VName -> [PrimExp VName] -> PrimExp VName
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~) PrimExp VName
zero ([PrimExp VName] -> PrimExp VName)
-> [PrimExp VName] -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~*~) [PrimExp VName]
v [PrimExp VName]
r | [PrimExp VName]
r <- [[PrimExp VName]]
m]

vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd :: [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
vectorAdd = (PrimExp VName -> PrimExp VName -> PrimExp VName)
-> [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
(~+~)

orderArgs :: Special -> [a] -> [[a]]
orderArgs :: forall a. Special -> [a] -> [[a]]
orderArgs Special
s [a]
lst =
  let d :: Int
d = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div ([a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
lst) (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
s
   in Int -> [a] -> [[a]]
forall a. Int -> [a] -> [[a]]
chunk Int
d [a]
lst

-- computes `d(x op y)/dx` or d(x op y)/dy
mkScanAdjointLam :: VjpOps -> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam :: VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam0 FirstOrSecond
which [SubExp]
adjs = do
  let len :: Int
len = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([Type] -> Int) -> [Type] -> Int
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam0
  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam0
  let p2diff :: [Param Type]
p2diff =
        case FirstOrSecond
which of
          FirstOrSecond
WrtFirst -> Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
take Int
len ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
          FirstOrSecond
WrtSecond -> Int -> [Param Type] -> [Param Type]
forall a. Int -> [a] -> [a]
drop Int
len ([Param Type] -> [Param Type]) -> [Param Type] -> [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam
  VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops ((SubExp -> Adj) -> [SubExp] -> [Adj]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> Adj
AdjVal [SubExp]
adjs) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
p2diff) Lambda SOACS
lam

-- Should generate something like:
-- `\ j -> let i = n - 1 - j
--         if i < n-1 then ( ys_adj[i], df2dx ys[i] xs[i+1]) else (ys_adj[i],1) )`
-- where `ys` is  the result of scan
--       `xs` is  the input  of scan
--       `ys_adj` is the known adjoint of ys
--       `j` draw values from `iota n`
mkScanFusedMapLam ::
  VjpOps ->
  SubExp ->
  Lambda SOACS ->
  [VName] ->
  [VName] ->
  [VName] ->
  Special ->
  Int ->
  ADM (Lambda SOACS)
mkScanFusedMapLam :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w Lambda SOACS
scn_lam [VName]
xs [VName]
ys [VName]
ys_adj Special
s Int
d = do
  let sc :: SpecialCase
sc = Special -> SpecialCase
specialCase Special
s
  let k :: Int
k = Special -> Int
specialSubSize Special
s
  [Type]
ys_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
ys
  [[SubExp]]
idmat <- Int -> Type -> ADM [[SubExp]]
identityM ([VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
ys) (Type -> ADM [[SubExp]]) -> Type -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ [Type] -> Type
forall a. HasCallStack => [a] -> a
head [Type]
ys_ts
  [Lambda SOACS]
lams <- ([SubExp] -> ADM (Lambda SOACS))
-> [[SubExp]] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scn_lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat
  Param Type
par_i <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i
  [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
    ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"x"
      (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
        (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
        ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            [SubExp]
y_s <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
              [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]
            let zso :: [[SubExp]]
zso = Special -> [SubExp] -> [[SubExp]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s [SubExp]
y_s
            let ido :: [[[SubExp]]]
ido = Special -> [[SubExp]] -> [[[SubExp]]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([[SubExp]] -> [[[SubExp]]]) -> [[SubExp]] -> [[[SubExp]]]
forall a b. (a -> b) -> a -> b
$ Int -> SpecialCase -> [[SubExp]] -> [[SubExp]]
forall a. Int -> SpecialCase -> [[a]] -> [[a]]
case_jac Int
k SpecialCase
sc [[SubExp]]
idmat
            Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes ([SubExp] -> Result) -> [SubExp] -> Result
forall a b. (a -> b) -> a -> b
$ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[SubExp]] -> [SubExp]) -> [[SubExp]] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ([SubExp] -> [SubExp] -> [SubExp])
-> [[SubExp]] -> [[SubExp]] -> [[SubExp]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
(++) [[SubExp]]
zso ([[SubExp]] -> [[SubExp]]) -> [[SubExp]] -> [[SubExp]]
forall a b. (a -> b) -> a -> b
$ ([[SubExp]] -> [SubExp]) -> [[[SubExp]]] -> [[SubExp]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[[SubExp]]]
ido
        )
        ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
            SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))
            SubExp
j1 <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
            [SubExp]
y_s <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys_adj ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y_ ->
              [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y_ [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_j") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y_ [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]

            let args :: [ADM (Exp (Rep ADM))]
args =
                  (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j]) [VName]
ys [ADM (Exp (Rep ADM))]
-> [ADM (Exp (Rep ADM))] -> [ADM (Exp (Rep ADM))]
forall a. [a] -> [a] -> [a]
++ (VName -> ADM (Exp (Rep ADM))) -> [VName] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
`eIndex` [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j1]) [VName]
xs
            [Result]
lam_rs <- (Lambda SOACS -> ADM Result) -> [Lambda SOACS] -> ADM [Result]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams

            let yso :: [Result]
yso = Special -> Result -> [Result]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s (Result -> [Result]) -> Result -> [Result]
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
y_s
            let jaco :: [[Result]]
jaco = Special -> [Result] -> [[Result]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
s ([Result] -> [[Result]]) -> [Result] -> [[Result]]
forall a b. (a -> b) -> a -> b
$ Int -> SpecialCase -> [Result] -> [Result]
forall a. Int -> SpecialCase -> [[a]] -> [[a]]
case_jac Int
k SpecialCase
sc ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ [Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs

            Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> Result) -> [Result] -> Result
forall a b. (a -> b) -> a -> b
$ (Result -> Result -> Result) -> [Result] -> [Result] -> [Result]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Result -> Result -> Result
forall a. [a] -> [a] -> [a]
(++) [Result]
yso ([Result] -> [Result]) -> [Result] -> [Result]
forall a b. (a -> b) -> a -> b
$ ([Result] -> Result) -> [[Result]] -> [Result]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Result]]
jaco
        )
  where
    case_jac :: Int -> SpecialCase -> [[a]] -> [[a]]
    case_jac :: forall a. Int -> SpecialCase -> [[a]] -> [[a]]
case_jac Int
_ SpecialCase
Generic [[a]]
jac = [[a]]
jac
    case_jac Int
k SpecialCase
ZeroQuadrant [[a]]
jac =
      [[[a]]] -> [[a]]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
        ([[[a]]] -> [[a]]) -> [[[a]]] -> [[a]]
forall a b. (a -> b) -> a -> b
$ (Int -> [[a]] -> [[a]]) -> [Int] -> [[[a]]] -> [[[a]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
          (\Int
i -> ([a] -> [a]) -> [[a]] -> [[a]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> ([a] -> [a]) -> [a] -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k)))
          [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k]
        ([[[a]]] -> [[[a]]]) -> [[[a]]] -> [[[a]]]
forall a b. (a -> b) -> a -> b
$ Int -> [[a]] -> [[[a]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[a]]
jac
    case_jac Int
k SpecialCase
MatrixMul [[a]]
jac =
      Int -> [a] -> [a]
forall a. Int -> [a] -> [a]
take Int
k ([a] -> [a]) -> [[a]] -> [[a]]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> [[a]] -> [[a]]
forall a. Int -> [a] -> [a]
take Int
k [[a]]
jac

-- a1 a2 b -> a2 + b * a1
linFunT0 :: [PrimExp VName] -> [PrimExp VName] -> [[PrimExp VName]] -> Special -> PrimType -> [PrimExp VName]
linFunT0 :: [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1 [PrimExp VName]
a2 [[PrimExp VName]]
b Special
s PrimType
pt =
  let t :: [PrimExp VName]
t = case Special -> SpecialCase
specialCase Special
s of
        SpecialCase
MatrixMul ->
          ([PrimExp VName] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\[PrimExp VName]
v -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
v PrimType
pt) ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk (Special -> Int
specialSubSize Special
s) [PrimExp VName]
a1
        SpecialCase
_ -> [[PrimExp VName]] -> [PrimExp VName] -> PrimType -> [PrimExp VName]
matrixVecMul [[PrimExp VName]]
b [PrimExp VName]
a1 PrimType
pt
   in [PrimExp VName]
a2 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
`vectorAdd` [PrimExp VName]
t

-- \(a1, b1) (a2, b2) -> (a2 + b2 * a1, b2 * b1)
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO :: Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO Type
t Special
s = do
  let pt :: PrimType
pt = Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
  [SubExp]
neu_elm <- (Int, Int) -> ADM [SubExp]
mkNeutral ((Int, Int) -> ADM [SubExp]) -> (Int, Int) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ Special -> (Int, Int)
specialNeutral Special
s
  let (Int
as, Int
bs) = Special -> (Int, Int)
specialParams Special
s
  ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s) <- (Int, Int) -> ADM ([VName], [VName], [VName], [VName])
forall {m :: * -> *}.
MonadFreshNames m =>
(Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
as, Int
bs)
  let pet :: VName -> PrimExp VName
pet = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
pt (SubExp -> PrimExp VName)
-> (VName -> SubExp) -> VName -> PrimExp VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
  let (Int
_, Int
n) = Special -> (Int, Int)
specialNeutral Special
s

  Lambda SOACS
lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ((VName -> Param Type) -> [VName] -> [Param Type]
forall a b. (a -> b) -> [a] -> [b]
map (\VName
v -> Attrs -> VName -> Type -> Param Type
forall dec. Attrs -> VName -> dec -> Param dec
Param Attrs
forall a. Monoid a => a
mempty VName
v (Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t)) ([VName]
a1s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
b1s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
a2s [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
b2s)) (ADM Result -> ADM (Lambda SOACS))
-> (ADM [SubExp] -> ADM Result)
-> ADM [SubExp]
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([SubExp] -> Result) -> ADM [SubExp] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> Result
subExpsRes (ADM [SubExp] -> ADM (Lambda SOACS))
-> ADM [SubExp] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ do
    let [[PrimExp VName]
a1s', [PrimExp VName]
b1s', [PrimExp VName]
a2s', [PrimExp VName]
b2s'] = (([VName] -> [PrimExp VName]) -> [[VName]] -> [[PrimExp VName]]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([VName] -> [PrimExp VName]) -> [[VName]] -> [[PrimExp VName]])
-> ((VName -> PrimExp VName) -> [VName] -> [PrimExp VName])
-> (VName -> PrimExp VName)
-> [[VName]]
-> [[PrimExp VName]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName -> PrimExp VName) -> [VName] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) VName -> PrimExp VName
pet [[VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s]
    let ([[PrimExp VName]]
b1sm, [[PrimExp VName]]
b2sm) = (Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b1s', Int -> [PrimExp VName] -> [[PrimExp VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
n [PrimExp VName]
b2s')

    let t0 :: [PrimExp VName]
t0 = [PrimExp VName]
-> [PrimExp VName]
-> [[PrimExp VName]]
-> Special
-> PrimType
-> [PrimExp VName]
linFunT0 [PrimExp VName]
a1s' [PrimExp VName]
a2s' [[PrimExp VName]]
b2sm Special
s PrimType
pt
    let t1 :: [PrimExp VName]
t1 = [[PrimExp VName]] -> [PrimExp VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[PrimExp VName]] -> [PrimExp VName])
-> [[PrimExp VName]] -> [PrimExp VName]
forall a b. (a -> b) -> a -> b
$ [[PrimExp VName]]
-> [[PrimExp VName]] -> PrimType -> [[PrimExp VName]]
matrixMul [[PrimExp VName]]
b2sm [[PrimExp VName]]
b1sm PrimType
pt
    (PrimExp VName -> ADM SubExp) -> [PrimExp VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"r" (Exp SOACS -> ADM SubExp)
-> (PrimExp VName -> ADM (Exp SOACS))
-> PrimExp VName
-> ADM SubExp
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< PrimExp VName -> ADM (Exp (Rep ADM))
PrimExp VName -> ADM (Exp SOACS)
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
PrimExp VName -> m (Exp (Rep m))
toExp) ([PrimExp VName] -> ADM [SubExp])
-> [PrimExp VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [PrimExp VName]
t0 [PrimExp VName] -> [PrimExp VName] -> [PrimExp VName]
forall a. [a] -> [a] -> [a]
++ [PrimExp VName]
t1

  Scan SOACS -> ADM (Scan SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp]
neu_elm
  where
    mkNeutral :: (Int, Int) -> ADM [SubExp]
mkNeutral (Int
a, Int
b) = do
      [SubExp]
zeros <- Int -> ADM SubExp -> ADM [SubExp]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (ADM SubExp -> ADM [SubExp]) -> ADM SubExp -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zeros" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep ADM)
forall rep. Type -> Exp rep
zeroExp (Type -> Exp (Rep ADM)) -> Type -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
      [[SubExp]]
idmat <- Int -> Type -> ADM [[SubExp]]
identityM Int
b (Type -> ADM [[SubExp]]) -> Type -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t
      [SubExp] -> ADM [SubExp]
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([SubExp] -> ADM [SubExp]) -> [SubExp] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ [SubExp]
zeros [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
idmat

    mkParams :: (Int, Int) -> m ([VName], [VName], [VName], [VName])
mkParams (Int
a, Int
b) = do
      [VName]
a1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a1"
      [VName]
b1s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b1"
      [VName]
a2s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
a (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"a2"
      [VName]
b2s <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
b (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ [Char] -> m VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"b2"
      ([VName], [VName], [VName], [VName])
-> m ([VName], [VName], [VName], [VName])
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName]
a1s, [VName]
b1s, [VName]
a2s, [VName]
b2s)

-- perform the final map
-- let xs_contribs =
--    map3 (\ i a r -> if i==0 then r else (df2dy (ys[i-1]) a) \bar{*} r)
--         (iota n) xs (reverse ds)
mkScanFinalMap :: VjpOps -> SubExp -> Lambda SOACS -> [VName] -> [VName] -> [VName] -> ADM [VName]
mkScanFinalMap :: VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w Lambda SOACS
scan_lam [VName]
xs [VName]
ys [VName]
ds = do
  let eltps :: [Type]
eltps = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
scan_lam

  Param Type
par_i <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
  let i :: VName
i = Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i
  [Param Type]
par_x <- (VName -> Type -> ADM (Param Type))
-> [VName] -> [Type] -> ADM [Param Type]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM (\VName
x -> [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
x [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_par_x")) [VName]
xs [Type]
eltps

  Lambda SOACS
map_lam <-
    [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda (Param Type
par_i Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
par_x) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
      SubExp
j <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"j" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
w TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
1))

      [VName]
dj <-
        (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
          (\VName
dd -> [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
dd [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_dj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
dd [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
j])
          [VName]
ds

      ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs"
        (Exp SOACS -> ADM Result) -> ADM (Exp SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
          (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Bool VName -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall {k} v (t :: k).
Eq v =>
TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.==. TPrimExp Int64 VName
0)
          ([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM ([SubExp] -> ADM (Body (Rep ADM)))
-> [SubExp] -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj)
          ( ADM Result -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body (Rep ADM)))
-> ADM Result -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
              Lambda SOACS
lam <- VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
scan_lam FirstOrSecond
WrtSecond ([SubExp] -> ADM (Lambda SOACS)) -> [SubExp] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap VName -> SubExp
Var [VName]
dj

              SubExp
im1 <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"im1" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)
              [SubExp]
ys_im1 <- [VName] -> (VName -> ADM SubExp) -> ADM [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
ys ((VName -> ADM SubExp) -> ADM [SubExp])
-> (VName -> ADM SubExp) -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$ \VName
y ->
                [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp (VName -> [Char]
baseString VName
y [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_im1") (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
y [SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
im1]

              let args :: [ADM (Exp (Rep ADM))]
args = (SubExp -> ADM (Exp (Rep ADM)))
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> ADM (Exp (Rep ADM))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp ([SubExp] -> [ADM (Exp (Rep ADM))])
-> [SubExp] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [SubExp]
ys_im1 [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ (Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
par_x
              Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep ADM)
Lambda SOACS
lam [ADM (Exp (Rep ADM))]
args
          )

  VName
iota <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"scan_contribs" (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs) (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam

data SpecialCase
  = Generic
  | ZeroQuadrant
  | MatrixMul
  deriving (Int -> SpecialCase -> [Char] -> [Char]
[SpecialCase] -> [Char] -> [Char]
SpecialCase -> [Char]
(Int -> SpecialCase -> [Char] -> [Char])
-> (SpecialCase -> [Char])
-> ([SpecialCase] -> [Char] -> [Char])
-> Show SpecialCase
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> SpecialCase -> [Char] -> [Char]
showsPrec :: Int -> SpecialCase -> [Char] -> [Char]
$cshow :: SpecialCase -> [Char]
show :: SpecialCase -> [Char]
$cshowList :: [SpecialCase] -> [Char] -> [Char]
showList :: [SpecialCase] -> [Char] -> [Char]
Show)

data Special = Special
  { Special -> (Int, Int)
specialNeutral :: (Int, Int),
    Special -> (Int, Int)
specialParams :: (Int, Int),
    Special -> Int
specialScans :: Int,
    Special -> Int
specialSubSize :: Int,
    Special -> SpecialCase
specialCase :: SpecialCase
  }
  deriving (Int -> Special -> [Char] -> [Char]
[Special] -> [Char] -> [Char]
Special -> [Char]
(Int -> Special -> [Char] -> [Char])
-> (Special -> [Char])
-> ([Special] -> [Char] -> [Char])
-> Show Special
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
$cshowsPrec :: Int -> Special -> [Char] -> [Char]
showsPrec :: Int -> Special -> [Char] -> [Char]
$cshow :: Special -> [Char]
show :: Special -> [Char]
$cshowList :: [Special] -> [Char] -> [Char]
showList :: [Special] -> [Char] -> [Char]
Show)

subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats :: Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat Exp SOACS
zero =
  let sub_d :: [Int]
sub_d = (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Int
x -> Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) [Int
1 .. (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)]
      poss :: [Bool]
poss = (Int -> Bool) -> [Int] -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
m -> (([Exp SOACS], Int) -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Int -> ([Exp SOACS], Int) -> Bool
ok Int
m) ([([Exp SOACS], Int)] -> Bool) -> [([Exp SOACS], Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ [[Exp SOACS]] -> [Int] -> [([Exp SOACS], Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [[Exp SOACS]]
mat [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]) [Int]
sub_d
      tmp :: [(Bool, Int)]
tmp = ((Bool, Int) -> Bool) -> [(Bool, Int)] -> [(Bool, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([Bool] -> [Int] -> [(Bool, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
poss [Int]
sub_d)
   in if [(Bool, Int)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Bool, Int)]
tmp then Maybe Int
forall a. Maybe a
Nothing else Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Maybe Int) -> Int -> Maybe Int
forall a b. (a -> b) -> a -> b
$ (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ((Bool, Int) -> Int) -> (Bool, Int) -> Int
forall a b. (a -> b) -> a -> b
$ [(Bool, Int)] -> (Bool, Int)
forall a. HasCallStack => [a] -> a
head [(Bool, Int)]
tmp
  where
    ok :: Int -> ([Exp SOACS], Int) -> Bool
ok Int
m ([Exp SOACS]
row, Int
i) =
      ((Exp SOACS, Int) -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(Exp SOACS
v, Int
j) -> Exp SOACS
v Exp SOACS -> Exp SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Exp SOACS
zero Bool -> Bool -> Bool
|| (Int
i Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
j Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
m)) ([(Exp SOACS, Int)] -> Bool) -> [(Exp SOACS, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$
        [Exp SOACS] -> [Int] -> [(Exp SOACS, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Exp SOACS]
row [Int
0 .. Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

cases :: Int -> Type -> [[Exp SOACS]] -> Special
cases :: Int -> Type -> [[Exp SOACS]] -> Special
cases Int
d Type
t [[Exp SOACS]]
mat
  | Just Int
k <- Int -> [[Exp SOACS]] -> Exp SOACS -> Maybe Int
subMats Int
d [[Exp SOACS]]
mat (Exp SOACS -> Maybe Int) -> Exp SOACS -> Maybe Int
forall a b. (a -> b) -> a -> b
$ Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp Type
t =
      let nonZeros :: [[[Exp SOACS]]]
nonZeros = (Int -> [[Exp SOACS]] -> [[Exp SOACS]])
-> [Int] -> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
i -> ([Exp SOACS] -> [Exp SOACS]) -> [[Exp SOACS]] -> [[Exp SOACS]]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
take Int
k ([Exp SOACS] -> [Exp SOACS])
-> ([Exp SOACS] -> [Exp SOACS]) -> [Exp SOACS] -> [Exp SOACS]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [Exp SOACS] -> [Exp SOACS]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k))) [Int
0 .. Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k] ([[[Exp SOACS]]] -> [[[Exp SOACS]]])
-> [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a b. (a -> b) -> a -> b
$ Int -> [[Exp SOACS]] -> [[[Exp SOACS]]]
forall a. Int -> [a] -> [[a]]
chunk Int
k [[Exp SOACS]]
mat
       in if ([[Exp SOACS]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ([[Exp SOACS]] -> [[Exp SOACS]] -> Bool
forall a. Eq a => a -> a -> Bool
== [[[Exp SOACS]]] -> [[Exp SOACS]]
forall a. HasCallStack => [a] -> a
head [[[Exp SOACS]]]
nonZeros) ([[[Exp SOACS]]] -> Bool) -> [[[Exp SOACS]]] -> Bool
forall a b. (a -> b) -> a -> b
$ [[[Exp SOACS]]] -> [[[Exp SOACS]]]
forall a. HasCallStack => [a] -> [a]
tail [[[Exp SOACS]]]
nonZeros
            then (Int, Int) -> (Int, Int) -> Int -> Int -> SpecialCase -> Special
Special (Int
d, Int
k) (Int
d, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) Int
1 Int
k SpecialCase
MatrixMul
            else (Int, Int) -> (Int, Int) -> Int -> Int -> SpecialCase -> Special
Special (Int
k, Int
k) (Int
k, Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
k) (Int
d Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
k) Int
k SpecialCase
ZeroQuadrant
cases Int
d Type
_ [[Exp SOACS]]
_ = (Int, Int) -> (Int, Int) -> Int -> Int -> SpecialCase -> Special
Special (Int
d, Int
d) (Int
d, Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int
1 Int
d SpecialCase
Generic

identifyCase :: VjpOps -> Lambda SOACS -> ADM Special
identifyCase :: VjpOps -> Lambda SOACS -> ADM Special
identifyCase VjpOps
ops Lambda SOACS
lam = do
  let t :: [Type]
t = Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam
  let d :: Int
d = [Type] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
t
  [[SubExp]]
idmat <- Int -> Type -> ADM [[SubExp]]
identityM Int
d (Type -> ADM [[SubExp]]) -> Type -> ADM [[SubExp]]
forall a b. (a -> b) -> a -> b
$ [Type] -> Type
forall a. HasCallStack => [a] -> a
head [Type]
t
  [Lambda SOACS]
lams <- ([SubExp] -> ADM (Lambda SOACS))
-> [[SubExp]] -> ADM [Lambda SOACS]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (VjpOps
-> Lambda SOACS -> FirstOrSecond -> [SubExp] -> ADM (Lambda SOACS)
mkScanAdjointLam VjpOps
ops Lambda SOACS
lam FirstOrSecond
WrtFirst) [[SubExp]]
idmat

  [Param Type]
par1 <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp1") [Type]
t
  [Param Type]
par2 <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"tmp2") [Type]
t
  Lambda SOACS
jac_lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
par1 [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
par2) (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    let args :: [ADM (Exp (Rep ADM))]
args = (Param Type -> ADM (Exp (Rep ADM)))
-> [Param Type] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Param Type -> ADM (Exp (Rep ADM))
forall (m :: * -> *) t.
MonadBuilder m =>
Param t -> m (Exp (Rep m))
eParam ([Param Type] -> [ADM (Exp (Rep ADM))])
-> [Param Type] -> [ADM (Exp (Rep ADM))]
forall a b. (a -> b) -> a -> b
$ [Param Type]
par1 [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
par2
    [Result]
lam_rs <- (Lambda SOACS -> ADM Result) -> [Lambda SOACS] -> ADM [Result]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse (Lambda (Rep ADM) -> [ADM (Exp (Rep ADM))] -> ADM Result
forall (m :: * -> *).
MonadBuilder m =>
Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
`eLambda` [ADM (Exp (Rep ADM))]
args) [Lambda SOACS]
lams

    Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> ADM Result) -> Result -> ADM Result
forall a b. (a -> b) -> a -> b
$ [Result] -> Result
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([Result] -> [Result]
forall a. [[a]] -> [[a]]
transpose [Result]
lam_rs)

  Lambda SOACS
simp <- Lambda SOACS -> ADM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda Lambda SOACS
jac_lam
  let jac :: [[Exp rep]]
jac = Int -> [Exp rep] -> [[Exp rep]]
forall a. Int -> [a] -> [[a]]
chunk Int
d ([Exp rep] -> [[Exp rep]]) -> [Exp rep] -> [[Exp rep]]
forall a b. (a -> b) -> a -> b
$ (SubExpRes -> Exp rep) -> Result -> [Exp rep]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep)
-> (SubExpRes -> BasicOp) -> SubExpRes -> Exp rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp)
-> (SubExpRes -> SubExp) -> SubExpRes -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) (Result -> [Exp rep]) -> Result -> [Exp rep]
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Result
forall rep. Body rep -> Result
bodyResult (Body SOACS -> Result) -> Body SOACS -> Result
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
simp
  Special -> ADM Special
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Special -> ADM Special) -> Special -> ADM Special
forall a b. (a -> b) -> a -> b
$ Int -> Type -> [[Exp SOACS]] -> Special
cases Int
d ([Type] -> Type
forall a. HasCallStack => [a] -> a
head [Type]
t) [[Exp SOACS]]
forall {rep}. [[Exp rep]]
jac

diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan :: VjpOps -> [VName] -> SubExp -> [VName] -> Scan SOACS -> ADM ()
diffScan VjpOps
ops [VName]
ys SubExp
w [VName]
as Scan SOACS
scan = do
  Special
sc <- VjpOps -> Lambda SOACS -> ADM Special
identifyCase VjpOps
ops (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan)
  let d :: Int
d = [VName] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as
  [VName]
ys_adj <- (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM VName
lookupAdjVal [VName]
ys
  [Type]
as_ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
as
  Lambda SOACS
map1_lam <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> Special
-> Int
-> ADM (Lambda SOACS)
mkScanFusedMapLam VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys [VName]
ys_adj Special
sc Int
d
  Scan SOACS
scans_lin_fun_o <- Type -> Special -> ADM (Scan SOACS)
mkScanLinFunO ([Type] -> Type
forall a. HasCallStack => [a] -> a
head [Type]
as_ts) Special
sc
  [Scan SOACS]
scan_lams <- Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans (Special -> Int
specialScans Special
sc) Scan SOACS
scans_lin_fun_o
  VName
iota <-
    [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
  [VName]
r_scan <-
    [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"adj_ctrb_scan" (Exp SOACS -> ADM [VName])
-> (ScremaForm SOACS -> Exp SOACS)
-> ScremaForm SOACS
-> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Exp SOACS)
-> (ScremaForm SOACS -> SOAC SOACS)
-> ScremaForm SOACS
-> Exp SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
iota] (ScremaForm SOACS -> ADM [VName])
-> ScremaForm SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
      [Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Scan SOACS]
scan_lams Lambda SOACS
map1_lam

  [VName]
as_contribs <- VjpOps
-> SubExp
-> Lambda SOACS
-> [VName]
-> [VName]
-> [VName]
-> ADM [VName]
mkScanFinalMap VjpOps
ops SubExp
w (Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
scan) [VName]
as [VName]
ys (Special -> [VName] -> Int -> [VName]
forall {b}. Special -> [b] -> Int -> [b]
splitScanRes Special
sc [VName]
r_scan Int
d)
  (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
as [VName]
as_contribs
  where
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
    mkScans :: Int -> Scan SOACS -> ADM [Scan SOACS]
mkScans Int
d Scan SOACS
s =
      Int -> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
d (ADM (Scan SOACS) -> ADM [Scan SOACS])
-> ADM (Scan SOACS) -> ADM [Scan SOACS]
forall a b. (a -> b) -> a -> b
$ do
        Lambda SOACS
lam' <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda (Lambda SOACS -> ADM (Lambda SOACS))
-> Lambda SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> Lambda SOACS
forall rep. Scan rep -> Lambda rep
scanLambda Scan SOACS
s
        Scan SOACS -> ADM (Scan SOACS)
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Scan SOACS -> ADM (Scan SOACS)) -> Scan SOACS -> ADM (Scan SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam' ([SubExp] -> Scan SOACS) -> [SubExp] -> Scan SOACS
forall a b. (a -> b) -> a -> b
$ Scan SOACS -> [SubExp]
forall rep. Scan rep -> [SubExp]
scanNeutral Scan SOACS
s

    splitScanRes :: Special -> [b] -> Int -> [b]
splitScanRes Special
sc [b]
res Int
d =
      ([b] -> [b]) -> [[b]] -> [b]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (Int -> [b] -> [b]
forall a. Int -> [a] -> [a]
take (Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
d (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Special -> Int
specialScans Special
sc)) (Special -> [b] -> [[b]]
forall a. Special -> [a] -> [[a]]
orderArgs Special
sc [b]
res)

diffScanVec ::
  VjpOps ->
  [VName] ->
  StmAux () ->
  SubExp ->
  Lambda SOACS ->
  [SubExp] ->
  [VName] ->
  ADM () ->
  ADM ()
diffScanVec :: VjpOps
-> [VName]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [SubExp]
-> [VName]
-> ADM ()
-> ADM ()
diffScanVec VjpOps
ops [VName]
ys StmAux ()
aux SubExp
w Lambda SOACS
lam [SubExp]
ne [VName]
as ADM ()
m = do
  Seq (Stm SOACS)
stmts <- ADM [()] -> ADM (Stms (Rep ADM))
forall (m :: * -> *) a. MonadBuilder m => m a -> m (Stms (Rep m))
collectStms_ (ADM [()] -> ADM (Stms (Rep ADM)))
-> ADM [()] -> ADM (Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
    Int
rank <- Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank (Type -> Int) -> ADM Type -> ADM Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType ([VName] -> VName
forall a. HasCallStack => [a] -> a
head [VName]
as)
    let rear :: [Int]
rear = [Int
1, Int
0] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
drop Int
2 [Int
0 .. Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

    [VName]
transp_as <-
      (VName -> ADM VName) -> [VName] -> ADM [VName]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse
        (\VName
a -> [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
a [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_transp") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
rear VName
a)
        [VName]
as

    [Type]
ts <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
transp_as
    let n :: SubExp
n = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
ts

    [Param Type]
as_par <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"as_par" (Type -> ADM (Param Type))
-> (Type -> Type) -> Type -> ADM (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType) [Type]
ts
    [Param Type]
ne_par <- (Type -> ADM (Param Type)) -> [Type] -> ADM [Param Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ([Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ne_par") ([Type] -> ADM [Param Type]) -> [Type] -> ADM [Param Type]
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [Type]
forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
lam

    ScremaForm SOACS
scan_form <- [Scan SOACS] -> ADM (ScremaForm SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Scan rep] -> m (ScremaForm rep)
scanSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam ((Param Type -> SubExp) -> [Param Type] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> SubExp
Var (VName -> SubExp) -> (Param Type -> VName) -> Param Type -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param Type -> VName
forall dec. Param dec -> VName
paramName) [Param Type]
ne_par)]

    Lambda SOACS
map_lam <-
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
as_par [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
ne_par) (ADM Result -> ADM (Lambda SOACS))
-> (SOAC SOACS -> ADM Result) -> SOAC SOACS -> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([VName] -> Result) -> ADM [VName] -> ADM Result
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [VName] -> Result
varsRes (ADM [VName] -> ADM Result)
-> (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"map_res" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM (Lambda SOACS))
-> SOAC SOACS -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
as_par) ScremaForm SOACS
scan_form

    [VName]
transp_ys <-
      [Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"trans_ys" (Exp SOACS -> ADM [VName])
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op SOACS -> Exp SOACS
SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM [VName]) -> SOAC SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
        SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n ([VName]
transp_as [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [SubExp] -> [VName]
subExpVars [SubExp]
ne) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
map_lam)

    (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM [()]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
      (\VName
y VName
x -> StmAux () -> ADM () -> ADM ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep ADM) -> ADM ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
y] (Exp (Rep ADM) -> ADM ()) -> Exp (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
rear VName
x)
      [VName]
ys
      [VName]
transp_ys

  (Stm SOACS -> ADM () -> ADM ())
-> ADM () -> Seq (Stm SOACS) -> ADM ()
forall a b. (a -> b -> b) -> b -> Seq a -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops) ADM ()
m Seq (Stm SOACS)
stmts

diffScanAdd :: VjpOps -> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd :: VjpOps
-> VName -> SubExp -> Lambda SOACS -> SubExp -> VName -> ADM ()
diffScanAdd VjpOps
_ops VName
ys SubExp
n Lambda SOACS
lam' SubExp
ne VName
as = do
  Lambda SOACS
lam <- Lambda SOACS -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam'
  VName
ys_bar <- VName -> ADM VName
lookupAdjVal VName
ys

  Lambda SOACS
map_scan <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
ys_bar

  VName
iota <-
    [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep ADM)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep ADM)) -> BasicOp -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64

  VName
scan_res <-
    [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"res_rev" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ [Scan SOACS] -> Lambda SOACS -> ScremaForm SOACS
forall rep. [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC [Lambda SOACS -> [SubExp] -> Scan SOACS
forall rep. Lambda rep -> [SubExp] -> Scan rep
Scan Lambda SOACS
lam [SubExp
ne]] Lambda SOACS
map_scan

  Lambda SOACS
rev_lam <- VName -> ADM (Lambda SOACS)
rev_arr_lam VName
scan_res
  VName
contrb <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrb" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall rep. Op rep -> Exp rep
Op (OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM))
-> OpC (Rep ADM) (Rep ADM) -> Exp (Rep ADM)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
iota] (ScremaForm SOACS -> SOAC SOACS) -> ScremaForm SOACS -> SOAC SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
rev_lam

  VName -> VName -> ADM ()
updateAdj VName
as VName
contrb
  where
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
    rev_arr_lam :: VName -> ADM (Lambda SOACS)
rev_arr_lam VName
arr = do
      Param Type
par_i <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (Type -> ADM (Param Type)) -> Type -> ADM (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
par_i] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
        VName
a <-
          [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"ys_bar_rev"
            (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [ADM (Exp (Rep ADM))] -> ADM (Exp (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eIndex VName
arr [TPrimExp Int64 VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
TPrimExp Int64 VName -> m (Exp (Rep m))
toExp (SubExp -> TPrimExp Int64 VName
pe64 SubExp
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
par_i) TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
1)]
        Result -> ADM Result
forall a. a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName -> SubExpRes
varRes VName
a]