{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.AD.Rev.Map (vjpMap) where

import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)

-- | A classification of a free variable based on its adjoint.  The
-- 'VName' stored is *not* the adjoint, but the primal variable.
data AdjVar
  = -- | Adjoint is already an accumulator.
    FreeAcc VName
  | -- | Currently has no adjoint, but should be given one, and is an
    -- array with this shape and element type.
    FreeArr VName Shape PrimType
  | -- | Does not need an accumulator adjoint (might still be an array).
    FreeNonAcc VName

classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM AdjVar
f
  where
    f :: VName -> ADM AdjVar
f VName
v = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Type
v_adj_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
      case Type
v_adj_t of
        Array PrimType
pt Shape
shape NoUniqueness
_ ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
        Acc {} ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
        Type
_ ->
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v

partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
  case AdjVar
fv of
    FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
    FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
    FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v forall a. a -> [a] -> [a]
: [VName]
zs)
  where
    ([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs

buildRenamedBody ::
  MonadBuilder m =>
  m (Result, a) ->
  m (Body (Rep m), a)
buildRenamedBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
  (Body (Rep m)
body, a
x) <- forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
  Body (Rep m)
body' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body (Rep m)
body
  forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body', a
x)

withAcc ::
  [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
  ([VName] -> ADM Result) ->
  ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
  ([Param Type]
cert_params, [Param Type]
acc_params) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
      Param Type
cert_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
      [Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
arrs
      Param Type
acc_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_p" forall a b. (a -> b) -> a -> b
$ forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (forall dec. Param dec -> VName
paramName Param Type
cert_param) Shape
shape [Type]
ts NoUniqueness
NoUniqueness
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
  Lambda SOACS
acc_lam <-
    forall a. ADM a -> ADM a
subAD forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_params forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params) forall a b. (a -> b) -> a -> b
$ [VName] -> ADM Result
m forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
acc_params
  forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"withhacc_res" forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam

vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
  | Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
      -- Since at most only a constant number of adjoint are nonzero
      -- (length res_ivs), there is no need for the return sweep code to
      -- contain a Map at all.

      [VName]
free <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam
      [Type]
free_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free
      let adjs_for :: [VName]
adjs_for = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) forall a. [a] -> [a] -> [a]
++ [VName]
free
          adjs_ts :: [Type]
adjs_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) forall a. [a] -> [a] -> [a]
++ [Type]
free_ts

      let oneHot :: Int -> Adj -> [Adj]
oneHot Int
res_i Adj
adj_v = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
            where
              f :: Int -> Type -> Adj
f Int
j Type
t
                | Int
res_i forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
                | Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
          -- Values for the out-of-bounds case does not matter, as we will
          -- be writing to an out-of-bounds index anyway, which is ignored.
          ooBounds :: SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i = forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody forall a b. (a -> b) -> a -> b
$ do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
              SubExp
scratch <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
OutOfBounds, SubExp
adj_i) SubExp
scratch
            forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as
          inBounds :: Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v = forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody forall a b. (a -> b) -> a -> b
$ do
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
              Type
a_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
              forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
a forall a b. (a -> b) -> a -> b
$
                Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [forall d. d -> DimIndex d
DimFix SubExp
adj_i]
            [SubExp]
adj_elems <-
              forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
                forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
adj_elems) forall a b. (a -> b) -> a -> b
$ \(VName
a, SubExp
a_adj_elem) -> do
              VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
            forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as

          -- Generate an iteration of the map function for every
          -- position.  This is a bit inefficient - probably we could do
          -- some deduplication.
          forPos :: Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
            [Adj]
as_adj <-
              case InBounds
check of
                CheckBounds Maybe SubExp
b -> do
                  (Body SOACS
obbranch, [SubExp] -> [Adj]
mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
                  (Body SOACS
ibbranch, [SubExp] -> [Adj]
_) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [Adj]
mkadjs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"map_adj_elem"
                    forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
                      (forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
adj_i)) forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp Maybe SubExp
b)
                      (forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
ibbranch)
                      (forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
obbranch)
                InBounds
AssumeBounds -> do
                  (Body SOACS
body, [SubExp] -> [Adj]
mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
                  [SubExp] -> [Adj]
mkadjs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body SOACS
body
                InBounds
OutOfBounds ->
                  forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as

            forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Adj -> ADM ()
setAdj [VName]
as [Adj]
as_adj

          -- Generate an iteration of the map function for every result.
          forRes :: Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes Int
res_i = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)

      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes [Int
0 ..] [[(InBounds, SubExp, SubExp)]]
res_ivs
  where
    isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
      forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape forall a. Eq a => a -> a -> Bool
== [SubExp
w]
      forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
    isSparse Adj
_ =
      forall a. Maybe a
Nothing
-- See Note [Adjoints of accumulators] for how we deal with
-- accumulators - it's a bit tricky here.
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
  [VName]
pat_adj_vals <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
    case Type
t of
      Acc {} -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
      Type
_ -> Adj -> ADM VName
adjVal Adj
adj
  [Param Type]
pat_adj_params <-
    forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"map_adj_p" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
pat_adj_vals

  Lambda SOACS
map_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
map_lam
  [VName]
free <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam'

  [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free forall a b. (a -> b) -> a -> b
$ \[VName]
free_with_adjs Names
free_without_adjs -> do
    [VName]
free_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
    [Type]
free_adjs_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free_adjs
    [Param Type]
free_adjs_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"free_adj_p") [Type]
free_adjs_ts
    let lam_rev_params :: [Param Type]
lam_rev_params =
          forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
        adjs_for :: [VName]
adjs_for = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') forall a. [a] -> [a] -> [a]
++ [VName]
free
    Lambda SOACS
lam_rev <-
      forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
lam_rev_params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Names -> ADM a -> ADM a
noAdjsFor Names
free_without_adjs forall a b. (a -> b) -> a -> b
$ do
        forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
free_with_adjs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
free_adjs_params
        forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
          forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map forall t. Param t -> Adj
adjFromParam [Param Type]
pat_adj_params) [VName]
adjs_for Lambda SOACS
map_lam'

    ([VName]
param_contribs, [VName]
free_contribs) <-
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam'))) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_adjs" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
          forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
as forall a. [a] -> [a] -> [a]
++ [VName]
pat_adj_vals forall a. [a] -> [a] -> [a]
++ [VName]
free_adjs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_rev)

    -- Crucial that we handle the free contribs first in case 'free'
    -- and 'as' intersect.
    forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
freeContrib [VName]
free [VName]
free_contribs
    let param_ts :: [Type]
param_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
param_ts [VName]
as [VName]
param_contribs) forall a b. (a -> b) -> a -> b
$ \(Type
param_t, VName
a, VName
param_contrib) ->
      case Type
param_t of
        Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
        Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
  where
    addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
      [Param (TypeBase shape u)]
idxs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = [Param (TypeBase shape u)]
idxs forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}

    accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = forall {rep} {shape} {u} {m :: * -> *}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t

    withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
      VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
      Lambda SOACS
add_lam <- Int -> Type -> ADM (Lambda SOACS)
accAddLambda (forall a. ArrayShape a => a -> Int
shapeRank a
shape) forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, [VName
v_adj], forall a. a -> Maybe a
Just (Lambda SOACS
add_lam, [SubExp
zero]))

    accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
      ([(VName, (Shape, PrimType))]
arr_free, [VName]
acc_free, [VName]
nonacc_free) <-
        [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
      [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}.
ArrayShape a =>
(VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput [(VName, (Shape, PrimType))]
arr_free
      -- We only consider those input arrays that are also not free in
      -- the lambda.
      let as_nonfree :: [VName]
as_nonfree = forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
      ([VName]
arr_adjs, [VName]
acc_adjs, [VName]
rest_adjs) <-
        forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, (Shape, PrimType))]
arr_free) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
acc_free)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' forall a b. (a -> b) -> a -> b
$ \[VName]
accs -> do
          forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
          () <- [VName] -> Names -> ADM ()
m ([VName]
acc_free forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
          [VName]
acc_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
acc_free
          [VName]
arr_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> ADM VName
lookupAdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, (Shape, PrimType))]
arr_free
          [VName]
nonacc_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
nonacc_free
          [VName]
as_nonfree_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
as_nonfree
          forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ [VName]
arr_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
acc_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
nonacc_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
as_nonfree_adj
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
acc_free [VName]
acc_adjs
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
arr_adjs
      let ([VName]
nonacc_adjs, [VName]
as_nonfree_adjs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
nonacc_free) [VName]
rest_adjs
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
nonacc_free [VName]
nonacc_adjs
      forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
as_nonfree [VName]
as_nonfree_adjs

    freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
      Type
contribs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
      case forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
contribs_t of
        Acc {} -> forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
        Type
t -> do
          Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
          SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp Type
t
          ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
zero]]
          VName
contrib_sum <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_contrib_sum") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
              forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
contribs] ScremaForm SOACS
reduce
          forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
contrib_sum