{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Map (vjpMap) where
import Control.Monad
import Data.Bifunctor (first)
import Futhark.AD.Rev.Monad
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Util (splitAt3)
data AdjVar
=
FreeAcc VName
|
FreeArr VName Shape PrimType
|
FreeNonAcc VName
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars :: [VName] -> ADM [AdjVar]
classifyAdjVars = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM AdjVar
f
where
f :: VName -> ADM AdjVar
f VName
v = do
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
Type
v_adj_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
case Type
v_adj_t of
Array PrimType
pt Shape
shape NoUniqueness
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> Shape -> PrimType -> AdjVar
FreeArr VName
v Shape
shape PrimType
pt
Acc {} ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeAcc VName
v
Type
_ ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> AdjVar
FreeNonAcc VName
v
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars :: [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [] = ([], [], [])
partitionAdjVars (AdjVar
fv : [AdjVar]
fvs) =
case AdjVar
fv of
FreeArr VName
v Shape
shape PrimType
t -> ((VName
v, (Shape
shape, PrimType
t)) forall a. a -> [a] -> [a]
: [(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs)
FreeAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, VName
v forall a. a -> [a] -> [a]
: [VName]
ys, [VName]
zs)
FreeNonAcc VName
v -> ([(VName, (Shape, PrimType))]
xs, [VName]
ys, VName
v forall a. a -> [a] -> [a]
: [VName]
zs)
where
([(VName, (Shape, PrimType))]
xs, [VName]
ys, [VName]
zs) = [AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars [AdjVar]
fvs
buildRenamedBody ::
MonadBuilder m =>
m (Result, a) ->
m (Body (Rep m), a)
buildRenamedBody :: forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody m (Result, a)
m = do
(Body (Rep m)
body, a
x) <- forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m
Body (Rep m)
body' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body (Rep m)
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body', a
x)
withAcc ::
[(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))] ->
([VName] -> ADM Result) ->
ADM [VName]
withAcc :: [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [] [VName] -> ADM Result
m =
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"withacc_res" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExpRes -> SubExp
resSubExp) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [VName] -> ADM Result
m []
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs [VName] -> ADM Result
m = do
([Param Type]
cert_params, [Param Type]
acc_params) <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs forall a b. (a -> b) -> a -> b
$ \(Shape
shape, [VName]
arrs, Maybe (Lambda SOACS, [SubExp])
_) -> do
Param Type
cert_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_cert_p" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit
[Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall u. Int -> TypeBase Shape u -> TypeBase Shape u
stripArray (forall a. ArrayShape a => a -> Int
shapeRank Shape
shape)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
arrs
Param Type
acc_param <- forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"acc_p" forall a b. (a -> b) -> a -> b
$ forall shape u. VName -> Shape -> [Type] -> u -> TypeBase shape u
Acc (forall dec. Param dec -> VName
paramName Param Type
cert_param) Shape
shape [Type]
ts NoUniqueness
NoUniqueness
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Param Type
cert_param, Param Type
acc_param)
Lambda SOACS
acc_lam <-
forall a. ADM a -> ADM a
subAD forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda ([Param Type]
cert_params forall a. [a] -> [a] -> [a]
++ [Param Type]
acc_params) forall a b. (a -> b) -> a -> b
$ [VName] -> ADM Result
m forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
acc_params
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"withhacc_res" forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
inputs Lambda SOACS
acc_lam
vjpMap :: VjpOps -> [Adj] -> StmAux () -> SubExp -> Lambda SOACS -> [VName] -> ADM ()
vjpMap :: VjpOps
-> [Adj]
-> StmAux ()
-> SubExp
-> Lambda SOACS
-> [VName]
-> ADM ()
vjpMap VjpOps
ops [Adj]
res_adjs StmAux ()
_ SubExp
w Lambda SOACS
map_lam [VName]
as
| Just [[(InBounds, SubExp, SubExp)]]
res_ivs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse [Adj]
res_adjs = forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
[VName]
free <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam Names -> Names -> Names
`namesSubtract` [VName] -> Names
namesFromList [VName]
as
[Type]
free_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free
let adjs_for :: [VName]
adjs_for = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) forall a. [a] -> [a] -> [a]
++ [VName]
free
adjs_ts :: [Type]
adjs_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) forall a. [a] -> [a] -> [a]
++ [Type]
free_ts
let oneHot :: Int -> Adj -> [Adj]
oneHot Int
res_i Adj
adj_v = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Type -> Adj
f [Int
0 :: Int ..] forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam
where
f :: Int -> Type -> Adj
f Int
j Type
t
| Int
res_i forall a. Eq a => a -> a -> Bool
== Int
j = Adj
adj_v
| Bool
otherwise = Shape -> PrimType -> Adj
AdjZero (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (forall shape u. TypeBase shape u -> PrimType
elemType Type
t)
ooBounds :: SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i = forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [Type]
adjs_ts) forall a b. (a -> b) -> a -> b
$ \(VName
a, Type
t) -> do
SubExp
scratch <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"oo_scratch" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => Type -> m (Exp (Rep m))
eBlank Type
t
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
OutOfBounds, SubExp
adj_i) SubExp
scratch
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [VName]
free forall a b. (a -> b) -> a -> b
$ \VName
v -> VName -> VName -> ADM ()
insAdj VName
v forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName]
as forall a. Semigroup a => a -> a -> a
<> [VName]
free)
inBounds :: Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v = forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildRenamedBody forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam) [VName]
as) forall a b. (a -> b) -> a -> b
$ \(Param Type
p, VName
a) -> do
Type
a_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
a
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [forall dec. Param dec -> VName
paramName Param Type
p] forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
a forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
a_t [forall d. d -> DimIndex d
DimFix SubExp
adj_i]
[SubExp]
adj_elems <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (Int -> Adj -> [Adj]
oneHot Int
res_i (SubExp -> Adj
AdjVal SubExp
adj_v)) [VName]
adjs_for Lambda SOACS
map_lam
let ([SubExp]
as_adj_elems, [SubExp]
free_adj_elems) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
as) [SubExp]
adj_elems
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
as [SubExp]
as_adj_elems) forall a b. (a -> b) -> a -> b
$ \(VName
a, SubExp
a_adj_elem) ->
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
a (InBounds
AssumeBounds, SubExp
adj_i) SubExp
a_adj_elem
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
free [SubExp]
free_adj_elems) forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExp
adj_se) -> do
VName
adj_se_v <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"adj_v" (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
adj_se)
VName -> VName -> ADM ()
insAdj VName
v VName
adj_se_v
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first [SubExp] -> Result
subExpsRes forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj ([VName]
as forall a. Semigroup a => a -> a -> a
<> [VName]
free)
forPos :: Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i (InBounds
check, SubExp
adj_i, SubExp
adj_v) = do
[Adj]
adjs <-
case InBounds
check of
CheckBounds Maybe SubExp
b -> do
(Body SOACS
obbranch, [SubExp] -> [Adj]
mkadjs) <- SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
ooBounds SubExp
adj_i
(Body SOACS
ibbranch, [SubExp] -> [Adj]
_) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [SubExp] -> [Adj]
mkadjs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
"map_adj_elem"
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf
(forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *).
MonadBuilder m =>
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eDimInBounds (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
w) (forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
adj_i)) forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp Maybe SubExp
b)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
ibbranch)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure Body SOACS
obbranch)
InBounds
AssumeBounds -> do
(Body SOACS
body, [SubExp] -> [Adj]
mkadjs) <- Int -> SubExp -> SubExp -> ADM (Body SOACS, [SubExp] -> [Adj])
inBounds Int
res_i SubExp
adj_i SubExp
adj_v
[SubExp] -> [Adj]
mkadjs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body SOACS
body
InBounds
OutOfBounds ->
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj [VName]
as
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM VName -> Adj -> ADM ()
setAdj ([VName]
as forall a. Semigroup a => a -> a -> a
<> [VName]
free) [Adj]
adjs
forRes :: Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes Int
res_i = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int -> (InBounds, SubExp, SubExp) -> ADM [()]
forPos Int
res_i)
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Int -> [(InBounds, SubExp, SubExp)] -> ADM ()
forRes [Int
0 ..] [[(InBounds, SubExp, SubExp)]]
res_ivs
where
isSparse :: Adj -> Maybe [(InBounds, SubExp, SubExp)]
isSparse (AdjSparse (Sparse Shape
shape PrimType
_ [(InBounds, SubExp, SubExp)]
ivs)) = do
forall (f :: * -> *). Alternative f => Bool -> f ()
guard forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
shape forall a. Eq a => a -> a -> Bool
== [SubExp
w]
forall a. a -> Maybe a
Just [(InBounds, SubExp, SubExp)]
ivs
isSparse Adj
_ =
forall a. Maybe a
Nothing
vjpMap VjpOps
ops [Adj]
pat_adj StmAux ()
aux SubExp
w Lambda SOACS
map_lam [VName]
as = forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
[VName]
pat_adj_vals <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [Adj]
pat_adj (forall rep. Lambda rep -> [Type]
lambdaReturnType Lambda SOACS
map_lam)) forall a b. (a -> b) -> a -> b
$ \(Adj
adj, Type
t) ->
case Type
t of
Acc {} -> forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"acc_adj_rep" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
w]) forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
adj
Type
_ -> Adj -> ADM VName
adjVal Adj
adj
[Param Type]
pat_adj_params <-
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"map_adj_p" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase Shape u
rowType forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
pat_adj_vals
Lambda SOACS
map_lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
map_lam
[VName]
free <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
map_lam'
[VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free forall a b. (a -> b) -> a -> b
$ \[VName]
free_with_adjs Names
free_without_adjs -> do
[VName]
free_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
free_with_adjs
[Type]
free_adjs_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
free_adjs
[Param Type]
free_adjs_params <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"free_adj_p") [Type]
free_adjs_ts
let lam_rev_params :: [Param Type]
lam_rev_params =
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam' forall a. [a] -> [a] -> [a]
++ [Param Type]
pat_adj_params forall a. [a] -> [a] -> [a]
++ [Param Type]
free_adjs_params
adjs_for :: [VName]
adjs_for = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam') forall a. [a] -> [a] -> [a]
++ [VName]
free
Lambda SOACS
lam_rev <-
forall (m :: * -> *).
MonadBuilder m =>
[LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [Param Type]
lam_rev_params forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ADM a -> ADM a
subAD forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Names -> ADM a -> ADM a
noAdjsFor Names
free_without_adjs forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
free_with_adjs forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param Type]
free_adjs_params
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Lambda rep -> Body rep
lambdaBody
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda VjpOps
ops (forall a b. (a -> b) -> [a] -> [b]
map forall t. Param t -> Adj
adjFromParam [Param Type]
pat_adj_params) [VName]
adjs_for Lambda SOACS
map_lam'
([VName]
param_contribs, [VName]
free_contribs) <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam'))) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"map_adjs" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w ([VName]
as forall a. [a] -> [a] -> [a]
++ [VName]
pat_adj_vals forall a. [a] -> [a] -> [a]
++ [VName]
free_adjs) (forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam_rev)
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
freeContrib [VName]
free [VName]
free_contribs
let param_ts :: [Type]
param_ts = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
map_lam')
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Type]
param_ts [VName]
as [VName]
param_contribs) forall a b. (a -> b) -> a -> b
$ \(Type
param_t, VName
a, VName
param_contrib) ->
case Type
param_t of
Acc {} -> VName -> VName -> ADM ()
freeContrib VName
a VName
param_contrib
Type
_ -> VName -> VName -> ADM ()
updateAdj VName
a VName
param_contrib
where
addIdxParams :: Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n Lambda rep
lam = do
[Param (TypeBase shape u)]
idxs <- forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) dec.
MonadFreshNames m =>
String -> dec -> m (Param dec)
newParam String
"idx" forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Lambda rep
lam {lambdaParams :: [LParam rep]
lambdaParams = [Param (TypeBase shape u)]
idxs forall a. [a] -> [a] -> [a]
++ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam}
accAddLambda :: Int -> Type -> ADM (Lambda SOACS)
accAddLambda Int
n Type
t = forall {rep} {shape} {u} {m :: * -> *}.
(LParamInfo rep ~ TypeBase shape u, MonadFreshNames m) =>
Int -> Lambda rep -> m (Lambda rep)
addIdxParams Int
n forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Type -> ADM (Lambda SOACS)
addLambda Type
t
withAccInput :: (VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput (VName
v, (a
shape, PrimType
pt)) = do
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
Lambda SOACS
add_lam <- Int -> Type -> ADM (Lambda SOACS)
accAddLambda (forall a. ArrayShape a => a -> Int
shapeRank a
shape) forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim PrimType
pt
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, [VName
v_adj], forall a. a -> Maybe a
Just (Lambda SOACS
add_lam, [SubExp
zero]))
accAdjoints :: [VName] -> ([VName] -> Names -> ADM ()) -> ADM ()
accAdjoints [VName]
free [VName] -> Names -> ADM ()
m = do
([(VName, (Shape, PrimType))]
arr_free, [VName]
acc_free, [VName]
nonacc_free) <-
[AdjVar] -> ([(VName, (Shape, PrimType))], [VName], [VName])
partitionAdjVars forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VName] -> ADM [AdjVar]
classifyAdjVars [VName]
free
[(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {a}.
ArrayShape a =>
(VName, (a, PrimType))
-> ADM (a, [VName], Maybe (Lambda SOACS, [SubExp]))
withAccInput [(VName, (Shape, PrimType))]
arr_free
let as_nonfree :: [VName]
as_nonfree = forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
free) [VName]
as
([VName]
arr_adjs, [VName]
acc_adjs, [VName]
rest_adjs) <-
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Int -> Int -> [a] -> ([a], [a], [a])
splitAt3 (forall (t :: * -> *) a. Foldable t => t a -> Int
length [(VName, (Shape, PrimType))]
arr_free) (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
acc_free)) forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
-> ([VName] -> ADM Result) -> ADM [VName]
withAcc [(Shape, [VName], Maybe (Lambda SOACS, [SubExp]))]
arr_free' forall a b. (a -> b) -> a -> b
$ \[VName]
accs -> do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
accs
() <- [VName] -> Names -> ADM ()
m ([VName]
acc_free forall a. [a] -> [a] -> [a]
++ forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) ([VName] -> Names
namesFromList [VName]
nonacc_free)
[VName]
acc_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
acc_free
[VName]
arr_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (VName -> ADM VName
lookupAdjVal forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) [(VName, (Shape, PrimType))]
arr_free
[VName]
nonacc_free_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
nonacc_free
[VName]
as_nonfree_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
as_nonfree
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes forall a b. (a -> b) -> a -> b
$ [VName]
arr_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
acc_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
nonacc_free_adj forall a. Semigroup a => a -> a -> a
<> [VName]
as_nonfree_adj
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
acc_free [VName]
acc_adjs
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(VName, (Shape, PrimType))]
arr_free) [VName]
arr_adjs
let ([VName]
nonacc_adjs, [VName]
as_nonfree_adjs) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
nonacc_free) [VName]
rest_adjs
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
nonacc_free [VName]
nonacc_adjs
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
as_nonfree [VName]
as_nonfree_adjs
freeContrib :: VName -> VName -> ADM ()
freeContrib VName
v VName
contribs = do
Type
contribs_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
contribs
case forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
contribs_t of
Acc {} -> forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
insAdj VName
v VName
contribs
Type
t -> do
Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
t
SubExp
zero <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp Type
t
ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
zero]]
VName
contrib_sum <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v forall a. Semigroup a => a -> a -> a
<> String
"_contrib_sum") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w [VName
contribs] ScremaForm SOACS
reduce
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
contrib_sum