{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Scatter (vjpScatter) where

import Control.Monad
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Util (chunk)

withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds :: [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [] = PrimExp VName -> TPrimExp Bool VName
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp Bool VName)
-> PrimExp VName -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp (Bool -> PrimValue
BoolValue Bool
True)
withinBounds [(SubExp
q, VName
i)] = (VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. SubExp -> TPrimExp Int64 VName
pe64 SubExp
q) TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. (SubExp -> TPrimExp Int64 VName
pe64 (IntType -> Integer -> SubExp
intConst IntType
Int64 (-Integer
1)) TPrimExp Int64 VName -> TPrimExp Int64 VName -> TPrimExp Bool VName
forall t v. TPrimExp t v -> TPrimExp t v -> TPrimExp Bool v
.<. VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 VName
i)
withinBounds ((SubExp, VName)
qi : [(SubExp, VName)]
qis) = [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)
qi] TPrimExp Bool VName -> TPrimExp Bool VName -> TPrimExp Bool VName
forall v. TPrimExp Bool v -> TPrimExp Bool v -> TPrimExp Bool v
.&&. [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds [(SubExp, VName)]
qis

-- Generates a potential tower-of-maps lambda body for an indexing operation.
-- Assuming parameters:
--   `arr`   the array that is indexed
--   `[(w_1, i_1), (w_2, i_2), ..., (w_k, i_k)]` outer lambda formal parameters and their bounds
--   `[n_1,n_2,...]ptp` the type of the index expression `arr[i_1,i_2,...,i_k]`
-- Generates something like:
-- (\ i_1 i_2 ->
--    map (\j_1 -> ... if (i_1 >= 0 && i_1 < w_1) &&
--                        (i_2 >= 0 && i_2 < w_2) && ...
--                     then arr[i_1, i_2, ... j_1, ...]
--                     else 0
--        ) (iota n_1)
-- )
-- The idea is that you do not want to put under the `if` something
--     that is an array because it would not flatten well!
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody :: VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
as [(SubExp, Param Type)]
wpis = VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
as [(SubExp, Param Type)]
wpis []
  where
    genRecLamBody :: VName -> [(SubExp, Param Type)] -> [Param Type] -> Type -> ADM (Body SOACS)
    genRecLamBody :: VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape []) NoUniqueness
_) =
      VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
    genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Array PrimType
t (Shape (SubExp
s : [SubExp]
ss)) NoUniqueness
_) = do
      Param Type
new_ip <- String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"i" (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)
      let t' :: Type
t' = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t Type -> ShapeBase SubExp -> Type
`arrayOfShape` [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
ss
      Lambda SOACS
inner_lam <-
        [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type
LParam (Rep ADM)
new_ip] (ADM Result -> ADM (Lambda (Rep ADM)))
-> ADM Result -> ADM (Lambda (Rep ADM))
forall a b. (a -> b) -> a -> b
$
          Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result) -> ADM (Body SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName
-> [(SubExp, Param Type)]
-> [Param Type]
-> Type
-> ADM (Body SOACS)
genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis ([Param Type]
nest_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type
new_ip]) Type
t'
      let ([SubExp]
_, [Param Type]
orig_pis) = [(SubExp, Param Type)] -> ([SubExp], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
      ADM Result -> ADM (Body SOACS)
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (ADM Result -> ADM (Body SOACS))
-> (ADM Result -> ADM Result) -> ADM Result -> ADM (Body SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope SOACS -> ADM Result -> ADM Result
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) (ADM Result -> ADM (Body SOACS)) -> ADM Result -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ do
        VName
iota_v <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"iota" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
s (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
        SubExp
r <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_elem") (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
s [VName
iota_v] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
inner_lam)
        Result -> ADM Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure [SubExp -> SubExpRes
subExpRes SubExp
r]
    genRecLamBody VName
arr [(SubExp, Param Type)]
w_pis [Param Type]
nest_pis (Prim PrimType
ptp) = do
      let ([SubExp]
ws, [Param Type]
orig_pis) = [(SubExp, Param Type)] -> ([SubExp], [Param Type])
forall a b. [(a, b)] -> ([a], [b])
unzip [(SubExp, Param Type)]
w_pis
      let inds :: [VName]
inds = (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName ([Param Type]
orig_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)
      Scope SOACS -> ADM (Body SOACS) -> ADM (Body SOACS)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams ([Param Type]
orig_pis [Param Type] -> [Param Type] -> [Param Type]
forall a. [a] -> [a] -> [a]
++ [Param Type]
nest_pis)) (ADM (Body SOACS) -> ADM (Body SOACS))
-> ADM (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$
        [ADM (Exp (Rep ADM))] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[m (Exp (Rep m))] -> m (Body (Rep m))
eBody
          [ ADM (Exp (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Body (Rep ADM))
-> ADM (Exp (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
              (TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (TPrimExp Bool VName -> ADM (Exp (Rep ADM)))
-> TPrimExp Bool VName -> ADM (Exp (Rep ADM))
forall a b. (a -> b) -> a -> b
$ [(SubExp, VName)] -> TPrimExp Bool VName
withinBounds ([(SubExp, VName)] -> TPrimExp Bool VName)
-> [(SubExp, VName)] -> TPrimExp Bool VName
forall a b. (a -> b) -> a -> b
$ [SubExp] -> [VName] -> [(SubExp, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws ([VName] -> [(SubExp, VName)]) -> [VName] -> [(SubExp, VName)]
forall a b. (a -> b) -> a -> b
$ (Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
orig_pis)
              ( do
                  SubExp
r <- String -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"r" (Exp (Rep ADM) -> ADM SubExp) -> Exp (Rep ADM) -> ADM SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (VName -> DimIndex SubExp) -> [VName] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp)
-> (VName -> SubExp) -> VName -> DimIndex SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var) [VName]
inds
                  [SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp
r]
              )
              ([SubExp] -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
ptp])
          ]
    genRecLamBody VName
_ [(SubExp, Param Type)]
_ [Param Type]
_ Type
_ = String -> ADM (Body SOACS)
forall a. HasCallStack => String -> a
error String
"In Rev.hs, helper function genRecLamBody, unreachable case reached!"

--
-- Original:
--   let ys = scatter xs is vs
-- Assumes no duplicate indices in `is`
-- Forward Sweep:
--   let xs_save = gather xs is
--   let ys = scatter xs is vs
-- Return Sweep:
--   let vs_ctrbs = gather is ys_adj
--   let vs_adj \overline{+}= vs_ctrbs -- by map or generalized reduction
--   let xs_adj = scatter ys_adj is \overline{0}
--   let xs = scatter ys is xs_save
vjpScatter1 ::
  PatElem Type ->
  StmAux () ->
  (SubExp, [VName], (ShapeBase SubExp, Int, VName)) ->
  ADM () ->
  ADM ()
vjpScatter1 :: PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m = do
  let rank :: Int
rank = [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([SubExp] -> Int) -> [SubExp] -> Int
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp
      ([VName]
all_inds, [VName]
val_as) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
num_vals) [VName]
ass
      inds_as :: [[VName]]
inds_as = Int -> [VName] -> [[VName]]
forall a. Int -> [a] -> [[a]]
chunk Int
rank [VName]
all_inds
  Type
xs_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
xs
  let val_t :: Type
val_t = Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
xs_t
  -- computing xs_save
  [VName]
xs_saves <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
xs Type
xs_t
  -- performing the scatter
  Lambda SOACS
id_lam <-
    [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda ([Type] -> ADM (Lambda SOACS)) -> [Type] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$
      Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
shp) Type
val_t
  Stm (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm (Stm (Rep ADM) -> ADM ()) -> Stm (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pys]) StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp
-> [VName]
-> Lambda SOACS
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall rep.
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w [VName]
ass Lambda SOACS
id_lam [(ShapeBase SubExp
shp, Int
num_vals, VName
xs)]
  ADM ()
m
  let ys :: VName
ys = PatElem Type -> VName
forall dec. PatElem dec -> VName
patElemName PatElem Type
pys
  -- XXX: Since our restoration of xs will consume ys, we have to
  -- make a copy of ys in the chance that it is actually the result
  -- of the program.  In that case the asymptotics will not be
  -- (locally) preserved, but since ys must necessarily have been
  -- constructed somewhere close, they are probably globally OK.
  VName
ys_copy <- String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
ys String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_copy") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
ys
  ADM () -> ADM ()
forall a. ADM a -> ADM a
returnSweepCode (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ do
    VName
ys_adj <- VName -> ADM VName
lookupAdjVal VName
ys
    -- computing vs_ctrbs and updating vs_adj
    [VName]
vs_ctrbs <- [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
ys_adj Type
xs_t
    (VName -> VName -> ADM ()) -> [VName] -> [VName] -> ADM ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj [VName]
val_as [VName]
vs_ctrbs -- use Slice?
    -- creating xs_adj
    [VName]
zeros <-
      Int -> ADM VName -> ADM [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([VName] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
val_as) (ADM VName -> ADM [VName])
-> (Exp SOACS -> ADM VName) -> Exp SOACS -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"zeros" (Exp SOACS -> ADM [VName]) -> Exp SOACS -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
        Type -> Exp SOACS
forall rep. Type -> Exp rep
zeroExp (Type -> Exp SOACS) -> Type -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Type
xs_t Type -> SubExp -> Type
forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
`setOuterSize` SubExp
w
    let f_tps :: [Type]
f_tps = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
num_vals) (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_vals Type
val_t
    Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
    VName
xs_adj <-
      String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
xs String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_adj") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
        SubExp
-> [VName]
-> Lambda SOACS
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall rep.
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
all_inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
zeros) Lambda SOACS
f [(ShapeBase SubExp
shp, Int
num_vals, VName
ys_adj)]
    VName -> VName -> ADM ()
insAdj VName
xs VName
xs_adj -- reusing the ys_adj for xs_adj!
    Lambda SOACS
f' <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda [Type]
f_tps
    VName
xs_rc <-
      StmAux () -> ADM VName -> ADM VName
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux (ADM VName -> ADM VName)
-> (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
xs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_rc") (Exp SOACS -> ADM VName)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM VName) -> SOAC SOACS -> ADM VName
forall a b. (a -> b) -> a -> b
$
        SubExp
-> [VName]
-> Lambda SOACS
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall rep.
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
all_inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
xs_saves) Lambda SOACS
f' [(ShapeBase SubExp
shp, Int
num_vals, VName
ys)]
    VName -> VName -> ADM ()
addSubstitution VName
xs VName
xs_rc
    VName -> VName -> ADM ()
addSubstitution VName
ys VName
ys_copy
  where
    -- Creates a potential map-nest that indexes in full the array,
    --   and applies the condition of indices within bounds at the
    --   deepest level in the nest so that everything can be parallel.
    mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
    mkGather :: [[VName]] -> VName -> Type -> ADM [VName]
mkGather [[VName]]
inds_as VName
arr Type
arr_t = do
      [[Param Type]]
ips <- [[VName]] -> ([VName] -> ADM [Param Type]) -> ADM [[Param Type]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[VName]]
inds_as (([VName] -> ADM [Param Type]) -> ADM [[Param Type]])
-> ([VName] -> ADM [Param Type]) -> ADM [[Param Type]]
forall a b. (a -> b) -> a -> b
$ \[VName]
idxs ->
        (VName -> ADM (Param Type)) -> [VName] -> ADM [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (\VName
idx -> String -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam (VName -> String
baseString VName
idx String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_elem") (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64)) [VName]
idxs

      Lambda SOACS
gather_lam <- [LParam (Rep ADM)] -> ADM Result -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([[Param Type]] -> [Param Type]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[Param Type]]
ips) (ADM Result -> ADM (Lambda SOACS))
-> (([Param Type] -> ADM Result) -> ADM Result)
-> ([Param Type] -> ADM Result)
-> ADM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Result] -> Result) -> ADM [Result] -> ADM Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Result] -> Result
forall a. Monoid a => [a] -> a
mconcat (ADM [Result] -> ADM Result)
-> (([Param Type] -> ADM Result) -> ADM [Result])
-> ([Param Type] -> ADM Result)
-> ADM Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Param Type]] -> ([Param Type] -> ADM Result) -> ADM [Result]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [[Param Type]]
ips (([Param Type] -> ADM Result) -> ADM (Lambda SOACS))
-> ([Param Type] -> ADM Result) -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ \[Param Type]
idxs -> do
        let q :: Int
q = [Param Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param Type]
idxs
            ([SubExp]
ws, Type
eltp) = (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
q ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims Type
arr_t, Int -> Type -> Type
forall u.
Int
-> TypeBase (ShapeBase SubExp) u -> TypeBase (ShapeBase SubExp) u
stripArray Int
q Type
arr_t)
        Body SOACS -> ADM Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body SOACS -> ADM Result) -> ADM (Body SOACS) -> ADM Result
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> [(SubExp, Param Type)] -> Type -> ADM (Body SOACS)
genIdxLamBody VName
arr ([SubExp] -> [Param Type] -> [(SubExp, Param Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [SubExp]
ws [Param Type]
idxs) Type
eltp
      let soac :: SOAC SOACS
soac = SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
inds_as) (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
gather_lam)
      String -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp (VName -> String
baseString VName
arr String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_gather") (Exp (Rep ADM) -> ADM [VName]) -> Exp (Rep ADM) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op Op SOACS
SOAC SOACS
soac

vjpScatter ::
  VjpOps ->
  Pat Type ->
  StmAux () ->
  (SubExp, [VName], Lambda SOACS, [(Shape, Int, VName)]) ->
  ADM () ->
  ADM ()
vjpScatter :: VjpOps
-> Pat Type
-> StmAux ()
-> (SubExp, [VName], Lambda SOACS,
    [(ShapeBase SubExp, Int, VName)])
-> ADM ()
-> ADM ()
vjpScatter VjpOps
ops (Pat [PatElem Type]
pes) StmAux ()
aux (SubExp
w, [VName]
ass, Lambda SOACS
lam, [(ShapeBase SubExp, Int, VName)]
written_info) ADM ()
m
  | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam,
    [(ShapeBase SubExp
shp, Int
num_vals, VName
xs)] <- [(ShapeBase SubExp, Int, VName)]
written_info,
    [PatElem Type
pys] <- [PatElem Type]
pes =
      PatElem Type
-> StmAux ()
-> (SubExp, [VName], (ShapeBase SubExp, Int, VName))
-> ADM ()
-> ADM ()
vjpScatter1 PatElem Type
pys StmAux ()
aux (SubExp
w, [VName]
ass, (ShapeBase SubExp
shp, Int
num_vals, VName
xs)) ADM ()
m
  | Lambda SOACS -> Bool
forall rep. Lambda rep -> Bool
isIdentityLambda Lambda SOACS
lam = do
      let sind :: Int
sind = [(ShapeBase SubExp, Int, VName)] -> Int
forall a c. [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase SubExp, Int, VName)]
written_info
          ([VName]
inds, [VName]
vals) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
sind [VName]
ass
      [Stm SOACS]
lst_stms <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
inds, [VName]
vals) ([PatElem Type]
-> [(ShapeBase SubExp, Int, VName)]
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
forall a b. [a] -> [b] -> [(a, b)]
zip [PatElem Type]
pes [(ShapeBase SubExp, Int, VName)]
written_info)
      Stms SOACS -> ADM ()
diffScatters ([Stm SOACS] -> Stms SOACS
forall rep. [Stm rep] -> Stms rep
stmsFromList [Stm SOACS]
lst_stms)
  | Bool
otherwise =
      String -> ADM ()
forall a. HasCallStack => String -> a
error String
"vjpScatter: cannot handle"
  where
    splitInd :: [(ShapeBase a, Int, c)] -> Int
splitInd [] = Int
0
    splitInd ((ShapeBase a
shp, Int
num_res, c
_) : [(ShapeBase a, Int, c)]
rest) =
      Int
num_res Int -> Int -> Int
forall a. Num a => a -> a -> a
* [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ShapeBase a -> [a]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase a
shp) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [(ShapeBase a, Int, c)] -> Int
splitInd [(ShapeBase a, Int, c)]
rest
    chunkScatterInps :: ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
acc_inds, [VName]
acc_vals) [] =
      case ([VName]
acc_inds, [VName]
acc_vals) of
        ([], []) -> [Stm SOACS] -> ADM [Stm SOACS]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        ([VName], [VName])
_ -> String -> ADM [Stm SOACS]
forall a. HasCallStack => String -> a
error String
"chunkScatterInps: cannot handle"
    chunkScatterInps
      ([VName]
acc_inds, [VName]
acc_vals)
      ((PatElem Type
pe, info :: (ShapeBase SubExp, Int, VName)
info@(ShapeBase SubExp
shp, Int
num_vals, VName
_)) : [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest) = do
        let num_inds :: Int
num_inds = Int
num_vals Int -> Int -> Int
forall a. Num a => a -> a -> a
* [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
shp)
            ([VName]
curr_inds, [VName]
other_inds) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_inds [VName]
acc_inds
            ([VName]
curr_vals, [VName]
other_vals) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_vals [VName]
acc_vals
        [Type]
vtps <- (VName -> ADM Type) -> [VName] -> ADM [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
curr_vals
        Lambda SOACS
f <- [Type] -> ADM (Lambda SOACS)
forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Type] -> m (Lambda rep)
mkIdentityLambda (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
num_inds (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
vtps)
        let stm :: Stm SOACS
stm =
              Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElem Type] -> Pat Type
forall dec. [PatElem dec] -> Pat dec
Pat [PatElem Type
pe]) StmAux ()
StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> Stm SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> Stm SOACS) -> SOAC SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$
                SubExp
-> [VName]
-> Lambda SOACS
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC SOACS
forall rep.
SubExp
-> [VName]
-> Lambda rep
-> [(ShapeBase SubExp, Int, VName)]
-> SOAC rep
Scatter SubExp
w ([VName]
curr_inds [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName]
curr_vals) Lambda SOACS
f [(ShapeBase SubExp, Int, VName)
info]
        [Stm SOACS]
stms_rest <- ([VName], [VName])
-> [(PatElem Type, (ShapeBase SubExp, Int, VName))]
-> ADM [Stm SOACS]
chunkScatterInps ([VName]
other_inds, [VName]
other_vals) [(PatElem Type, (ShapeBase SubExp, Int, VName))]
rest
        [Stm SOACS] -> ADM [Stm SOACS]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Stm SOACS] -> ADM [Stm SOACS]) -> [Stm SOACS] -> ADM [Stm SOACS]
forall a b. (a -> b) -> a -> b
$ Stm SOACS
stm Stm SOACS -> [Stm SOACS] -> [Stm SOACS]
forall a. a -> [a] -> [a]
: [Stm SOACS]
stms_rest
    diffScatters :: Stms SOACS -> ADM ()
diffScatters Stms SOACS
all_stms
      | Just (Stm SOACS
stm, Stms SOACS
stms) <- Stms SOACS -> Maybe (Stm SOACS, Stms SOACS)
forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms =
          VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm VjpOps
ops Stm SOACS
stm (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ Stms SOACS -> ADM ()
diffScatters Stms SOACS
stms
      | Bool
otherwise = ADM ()
m