{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev (revVJP) where
import Control.Monad
import Data.List ((\\))
import Data.List.NonEmpty (NonEmpty (..))
import Data.Map qualified as M
import Futhark.AD.Derivatives
import Futhark.AD.Rev.Loop
import Futhark.AD.Rev.Monad
import Futhark.AD.Rev.SOAC
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util (takeLast)
patName :: Pat Type -> ADM VName
patName :: Pat Type -> ADM VName
patName (Pat [PatElem Type
pe]) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall dec. PatElem dec -> VName
patElemName PatElem Type
pe
patName Pat Type
pat = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Expected single-element pattern: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Pat Type
pat
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
op ADM ()
m = do
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp BasicOp
op
ADM ()
m
VName
pat_v <- Pat Type -> ADM VName
patName Pat Type
pat
VName
pat_adj <- VName -> ADM VName
lookupAdjVal VName
pat_v
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
pat_v, VName
pat_adj)
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp :: Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m =
case BasicOp
e of
CmpOp CmpOp
cmp SubExp
x SubExp
y -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
let t :: PrimType
t = CmpOp -> PrimType
cmpOpType CmpOp
cmp
update :: VName -> ADM ()
update VName
contrib = do
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
y VName
contrib
case PrimType
t of
FloatType FloatType
ft ->
VName -> ADM ()
update forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall a b. (a -> b) -> a -> b
$
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match
[VName -> SubExp
Var VName
pat_adj]
[forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] forall a b. (a -> b) -> a -> b
$ forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [forall v. IsValue v => v -> SubExp
constant (forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Int
1 :: Int))]]
(forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [forall v. IsValue v => v -> SubExp
constant (forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Int
0 :: Int))])
(forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall shape u. PrimType -> TypeBase shape u
Prim (FloatType -> PrimType
FloatType FloatType
ft)] MatchSort
MatchNormal)
IntType IntType
it ->
VName -> ADM ()
update forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> ConvOp
BToI IntType
it) (VName -> SubExp
Var VName
pat_adj)
PrimType
Bool ->
VName -> ADM ()
update VName
pat_adj
PrimType
Unit ->
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
ConvOp ConvOp
op SubExp
x -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
VName
contrib <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (ConvOp -> ConvOp
flipConvOp ConvOp
op) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
UnOp UnOp
op SubExp
x -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
let t :: PrimType
t = UnOp -> PrimType
unOpType UnOp
op
VName
contrib <- do
let x_pe :: PrimExp VName
x_pe = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x
pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t (VName -> SubExp
Var VName
pat_adj)
dx :: PrimExp VName
dx = UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
op PrimExp VName
x_pe
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
dx
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
contrib
BinOp BinOp
op SubExp
x SubExp
y -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
op
(PrimExp VName
wrt_x, PrimExp VName
wrt_y) =
BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
op (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
x) (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t SubExp
y)
pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
t forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
pat_adj
VName
adj_x <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"binop_x_adj" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_x
VName
adj_y <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"binop_y_adj" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
wrt_y
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x VName
adj_x
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
y VName
adj_y
SubExp SubExp
se -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se VName
pat_adj
Assert {} ->
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
ArrayLit [SubExp]
elems Type
_ -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
pat_adj
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [(Int64
0 :: Int64) ..] [SubExp]
elems) forall a b. (a -> b) -> a -> b
$ \(Int64
i, SubExp
se) -> do
let slice :: Slice SubExp
slice = Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [forall d. d -> DimIndex d
DimFix (forall v. IsValue v => v -> SubExp
constant Int64
i)]
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"elem_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
Index VName
arr Slice SubExp
slice -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice Slice SubExp
slice VName
arr VName
pat_adj
FlatIndex {} -> forall a. HasCallStack => [Char] -> a
error [Char]
"FlatIndex not handled by AD yet."
FlatUpdate {} -> forall a. HasCallStack => [Char] -> a
error [Char]
"FlatUpdate not handled by AD yet."
Opaque OpaqueOp
_ SubExp
se -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
se VName
pat_adj
Reshape ReshapeKind
k Shape
_ VName
arr -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
Shape
arr_shape <- forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
VName -> VName -> ADM ()
updateAdj VName
arr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_reshape" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
k Shape
arr_shape VName
pat_adj
Rearrange [Int]
perm VName
arr -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
VName -> VName -> ADM ()
updateAdj VName
arr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_rearrange" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
[Int] -> VName -> BasicOp
Rearrange ([Int] -> [Int]
rearrangeInverse [Int]
perm) VName
pat_adj
Rotate [SubExp]
rots VName
arr -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
let neg :: SubExp -> Exp rep
neg = forall rep. BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
[SubExp]
rots' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rot_neg" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {rep}. SubExp -> Exp rep
neg) [SubExp]
rots
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
VName -> VName -> ADM ()
updateAdj VName
arr forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"adj_rotate" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
[SubExp] -> VName -> BasicOp
Rotate [SubExp]
rots' VName
pat_adj
Replicate (Shape [SubExp]
ns) SubExp
x -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
Type
x_t <- forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
x
Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda Type
x_t
SubExp
ne <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp Type
x_t
SubExp
n <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"rep_size" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) [SubExp]
ns
VName
pat_adj_flat <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
pat_adj forall a. Semigroup a => a -> a -> a
<> [Char]
"_flat") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary (forall d. [d] -> ShapeBase d
Shape forall a b. (a -> b) -> a -> b
$ SubExp
n forall a. a -> [a] -> [a]
: forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
x_t) VName
pat_adj
ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
ne]]
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
x
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"rep_contrib" (forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
pat_adj_flat] ScremaForm SOACS
reduce)
Concat Int
d (VName
arr :| [VName]
arrs) SubExp
_ -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
let sliceAdj :: SubExp -> [VName] -> ADM [VName]
sliceAdj SubExp
_ [] = forall (f :: * -> *) a. Applicative f => a -> f a
pure []
sliceAdj SubExp
start (VName
v : [VName]
vs) = do
Type
v_t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let w :: SubExp
w = forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
v_t
slice :: DimIndex SubExp
slice = forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
start SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1)
VName
pat_adj_slice <-
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
pat_adj forall a. Semigroup a => a -> a -> a
<> [Char]
"_slice") forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
pat_adj (Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
v_t Int
d [DimIndex SubExp
slice])
SubExp
start' <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"start" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef) SubExp
start SubExp
w
[VName]
slices <- SubExp -> [VName] -> ADM [VName]
sliceAdj SubExp
start' [VName]
vs
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName
pat_adj_slice forall a. a -> [a] -> [a]
: [VName]
slices
[VName]
slices <- SubExp -> [VName] -> ADM [VName]
sliceAdj (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) forall a b. (a -> b) -> a -> b
$ VName
arr forall a. a -> [a] -> [a]
: [VName]
arrs
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
updateAdj (VName
arr forall a. a -> [a] -> [a]
: [VName]
arrs) [VName]
slices
Copy VName
se -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
se VName
pat_adj
Manifest [Int]
_ VName
se -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
se VName
pat_adj
Scratch {} ->
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
Iota SubExp
n SubExp
_ SubExp
_ IntType
t -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
SubExp
ne <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" forall a b. (a -> b) -> a -> b
$ forall rep. Type -> Exp rep
zeroExp forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t
Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
t
ScremaForm SOACS
reduce <- forall rep (m :: * -> *).
(Buildable rep, MonadFreshNames m) =>
[Reduce rep] -> m (ScremaForm rep)
reduceSOAC [forall rep. Commutativity -> Lambda rep -> [SubExp] -> Reduce rep
Reduce Commutativity
Commutative Lambda SOACS
lam [SubExp
ne]]
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
n
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"iota_contrib" (forall rep. Op rep -> Exp rep
Op forall a b. (a -> b) -> a -> b
$ forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
n [VName
pat_adj] ScremaForm SOACS
reduce)
Update Safety
safety VName
arr Slice SubExp
slice SubExp
v -> do
(VName
_pat_v, VName
pat_adj) <- Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM (VName, VName)
commonBasicOp Pat Type
pat StmAux ()
aux BasicOp
e ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
VName
v_adj <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_val_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
pat_adj Slice SubExp
slice
Type
t <- forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
VName
v_adj_copy <-
case Type
t of
Array {} -> forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_val_adj_copy" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_adj
Type
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
v VName
v_adj_copy
SubExp
zeroes <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"update_zero" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Type -> Exp rep
zeroExp forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
v
forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$
VName -> VName -> ADM ()
updateAdj VName
arr
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"update_src_adj" (forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
pat_adj Slice SubExp
slice SubExp
zeroes)
UpdateAcc VName
_ [SubExp]
is [SubExp]
vs -> do
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm forall a b. (a -> b) -> a -> b
$ forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat Type
pat StmAux ()
aux forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp BasicOp
e
ADM ()
m
[VName]
pat_adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal (forall dec. Pat dec -> [VName]
patNames Pat Type
pat)
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat_adjs [SubExp]
vs) forall a b. (a -> b) -> a -> b
$ \(VName
adj, SubExp
v) -> do
VName
adj_i <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"updateacc_val_adj" forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
adj forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
is
SubExp -> VName -> ADM ()
updateSubExpAdj SubExp
v VName
adj_i
vjpOps :: VjpOps
vjpOps :: VjpOps
vjpOps =
VjpOps
{ vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda = [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda,
vjpStm :: Stm SOACS -> ADM () -> ADM ()
vjpStm = Stm SOACS -> ADM () -> ADM ()
diffStm
}
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm :: Stm SOACS -> ADM () -> ADM ()
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (BasicOp BasicOp
e)) ADM ()
m =
Pat Type -> StmAux () -> BasicOp -> ADM () -> ADM ()
diffBasicOp Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux BasicOp
e ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Apply Name
f [(SubExp, Diet)]
args [RetType SOACS]
_ (Safety, SrcLoc, [SrcLoc])
_)) ADM ()
m
| Just (PrimType
ret, [PrimType]
argts) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Name
f Map Name (PrimType, [PrimType])
builtInFunctions = do
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm SOACS
stm
ADM ()
m
VName
pat_adj <- VName -> ADM VName
lookupAdjVal forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Pat Type -> ADM VName
patName Pat (LetDec SOACS)
pat
let arg_pes :: [PrimExp VName]
arg_pes = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith PrimType -> SubExp -> PrimExp VName
primExpFromSubExp [PrimType]
argts (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args)
pat_adj' :: PrimExp VName
pat_adj' = PrimType -> SubExp -> PrimExp VName
primExpFromSubExp PrimType
ret (VName -> SubExp
Var VName
pat_adj)
convert :: PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ft PrimType
tt
| PrimType
ft forall a. Eq a => a -> a -> Bool
== PrimType
tt = forall a. a -> a
id
convert (IntType IntType
ft) (IntType IntType
tt) = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (IntType -> IntType -> ConvOp
SExt IntType
ft IntType
tt)
convert (FloatType FloatType
ft) (FloatType FloatType
tt) = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> FloatType -> ConvOp
FPConv FloatType
ft FloatType
tt)
convert PrimType
Bool (FloatType FloatType
tt) = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
BToF FloatType
tt)
convert (FloatType FloatType
ft) PrimType
Bool = forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp (FloatType -> ConvOp
FToB FloatType
ft)
convert PrimType
ft PrimType
tt = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"diffStm.convert: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString (Name
f, PrimType
ft, PrimType
tt)
[VName]
contribs <-
case Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin Name
f [PrimExp VName]
arg_pes of
Maybe [PrimExp VName]
Nothing ->
forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"No partial derivative defined for builtin function: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Name
f
Just [PrimExp VName]
derivs ->
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp VName]
derivs [PrimType]
argts) forall a b. (a -> b) -> a -> b
$ \(PrimExp VName
deriv, PrimType
argt) ->
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"contrib" forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> PrimType -> PrimExp VName -> PrimExp VName
convert PrimType
ret PrimType
argt forall a b. (a -> b) -> a -> b
$ PrimExp VName
pat_adj' forall v. PrimExp v -> PrimExp v -> PrimExp v
~*~ PrimExp VName
deriv
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExp -> VName -> ADM ()
updateSubExpAdj (forall a b. (a -> b) -> [a] -> [b]
map forall a b. (a, b) -> a
fst [(SubExp, Diet)]
args) [VName]
contribs
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_ (Match [SubExp]
ses [Case (Body SOACS)]
cases Body SOACS
defbody MatchDec (BranchType SOACS)
_)) ADM ()
m = do
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm SOACS
stm
ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
let cases_free :: [Names]
cases_free = forall a b. (a -> b) -> [a] -> [b]
map forall a. FreeIn a => a -> Names
freeIn [Case (Body SOACS)]
cases
defbody_free :: Names
defbody_free = forall a. FreeIn a => a -> Names
freeIn Body SOACS
defbody
branches_free :: [VName]
branches_free = Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. Monoid a => [a] -> a
mconcat forall a b. (a -> b) -> a -> b
$ Names
defbody_free forall a. a -> [a] -> [a]
: [Names]
cases_free
[Adj]
adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat
[VName]
branches_free_adj <-
( forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
branches_free)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"branch_adj"
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Exp rep -> m (Exp rep)
renameExp
)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
[SubExp]
-> [Case (m (Body (Rep m)))] -> m (Body (Rep m)) -> m (Exp (Rep m))
eMatch
[SubExp]
ses
(forall a b. (a -> b) -> [a] -> [b]
map (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$ [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
adjs [VName]
branches_free) [Case (Body SOACS)]
cases)
([Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
adjs [VName]
branches_free Body SOACS
defbody)
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj [VName]
branches_free [VName]
branches_free_adj
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op OpC SOACS SOACS
soac)) ADM ()
m =
VjpOps -> Pat Type -> StmAux () -> SOAC SOACS -> ADM () -> ADM ()
vjpSOAC VjpOps
vjpOps Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux OpC SOACS SOACS
soac ADM ()
m
diffStm (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux loop :: Exp SOACS
loop@DoLoop {}) ADM ()
m =
(Stms SOACS -> ADM ())
-> Pat Type -> StmAux () -> Exp SOACS -> ADM () -> ADM ()
diffLoop Stms SOACS -> ADM ()
diffStms Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
loop ADM ()
m
diffStm stm :: Stm SOACS
stm@(Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
_aux (WithAcc [WithAccInput SOACS]
inputs Lambda SOACS
lam)) ADM ()
m = do
forall (m :: * -> *). MonadBuilder m => Stm (Rep m) -> m ()
addStm Stm SOACS
stm
ADM ()
m
forall a. ADM a -> ADM a
returnSweepCode forall a b. (a -> b) -> a -> b
$ do
[Adj]
adjs <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM Adj
lookupAdj forall a b. (a -> b) -> a -> b
$ forall dec. Pat dec -> [VName]
patNames Pat (LetDec SOACS)
pat
Lambda SOACS
lam' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda SOACS
lam
[VName]
free_vars <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM VName -> ADM Bool
isActive forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList forall a b. (a -> b) -> a -> b
$ forall a. FreeIn a => a -> Names
freeIn Lambda SOACS
lam'
[VName]
free_accs <- forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall shape u. TypeBase shape u -> Bool
isAcc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
free_vars
let free_vars' :: [VName]
free_vars' = [VName]
free_vars forall a. Eq a => [a] -> [a] -> [a]
\\ [VName]
free_accs
Lambda SOACS
lam'' <- [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
adjs [VName]
free_vars' Lambda SOACS
lam'
[WithAccInput SOACS]
inputs' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {m :: * -> *} {rep} {a} {b} {b}.
(Rename (OpC rep rep), Rename (LetDec rep), Rename (ExpDec rep),
Rename (BodyDec rep), Rename (FParamInfo rep),
Rename (LParamInfo rep), Rename (RetType rep),
Rename (BranchType rep), MonadFreshNames m) =>
(a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda [WithAccInput SOACS]
inputs
[VName]
free_adjs <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"with_acc_contrib" forall a b. (a -> b) -> a -> b
$ forall rep. [WithAccInput rep] -> Lambda rep -> Exp rep
WithAcc [WithAccInput SOACS]
inputs' Lambda SOACS
lam''
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ VName -> VName -> ADM ()
insAdj ([VName]
arrs forall a. Semigroup a => a -> a -> a
<> [VName]
free_vars') [VName]
free_adjs
where
arrs :: [VName]
arrs = forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(Shape
_, [VName]
as, Maybe (Lambda SOACS, [SubExp])
_) -> [VName]
as) [WithAccInput SOACS]
inputs
renameInputLambda :: (a, b, Maybe (Lambda rep, b)) -> m (a, b, Maybe (Lambda rep, b))
renameInputLambda (a
shape, b
as, Just (Lambda rep
f, b
nes)) = do
Lambda rep
f' <- forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Lambda rep -> m (Lambda rep)
renameLambda Lambda rep
f
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
shape, b
as, forall a. a -> Maybe a
Just (Lambda rep
f', b
nes))
renameInputLambda (a, b, Maybe (Lambda rep, b))
input = forall (f :: * -> *) a. Applicative f => a -> f a
pure (a, b, Maybe (Lambda rep, b))
input
diffLambda' :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda' [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ts) =
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ do
Body () Stms SOACS
stms Result
res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
let body' :: Body SOACS
body' = forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) Result
res forall a. Semigroup a => a -> a -> a
<> forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
[Type]
ts' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
get_adjs_for
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body SOACS
body' forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take (forall (t :: * -> *) a. Foldable t => t a -> Int
length [WithAccInput SOACS]
inputs) [Type]
ts forall a. Semigroup a => a -> a -> a
<> [Type]
ts'
diffStm Stm SOACS
stm ADM ()
_ = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"diffStm unhandled:\n" forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString Stm SOACS
stm
diffStms :: Stms SOACS -> ADM ()
diffStms :: Stms SOACS -> ADM ()
diffStms Stms SOACS
all_stms
| Just (Stm SOACS
stm, Stms SOACS
stms) <- forall rep. Stms rep -> Maybe (Stm rep, Stms rep)
stmsHead Stms SOACS
all_stms = do
(Substitutions
subst, Stms SOACS
copy_stms) <- Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
stm
let (Stm SOACS
stm', Stms SOACS
stms') = forall a. Substitute a => Substitutions -> a -> a
substituteNames Substitutions
subst (Stm SOACS
stm, Stms SOACS
stms)
Stms SOACS -> ADM ()
diffStms Stms SOACS
copy_stms forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Stm SOACS -> ADM () -> ADM ()
diffStm Stm SOACS
stm' (Stms SOACS -> ADM ()
diffStms Stms SOACS
stms')
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (forall k a. Map k a -> [(k, a)]
M.toList Substitutions
subst) forall a b. (a -> b) -> a -> b
$ \(VName
from, VName
to) ->
VName -> Adj -> ADM ()
setAdj VName
from forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
to
| Bool
otherwise =
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess :: Stms SOACS -> ADM (Stms SOACS)
preprocess = Stms SOACS -> ADM (Stms SOACS)
stripmineStms
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody :: [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for (Body () Stms SOACS
stms Result
res) = forall a. ADM a -> ADM a
subAD forall a b. (a -> b) -> a -> b
$
forall a. ADM a -> ADM a
subSubsts forall a b. (a -> b) -> a -> b
$ do
let onResult :: SubExpRes -> Adj -> ADM ()
onResult (SubExpRes Certs
_ (Constant PrimValue
_)) Adj
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
onResult (SubExpRes Certs
_ (Var VName
v)) Adj
v_adj = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Adj -> ADM VName
adjVal Adj
v_adj
([VName]
adjs, Stms SOACS
stms') <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ do
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ SubExpRes -> Adj -> ADM ()
onResult (forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Adj]
res_adjs) Result
res) [Adj]
res_adjs
Stms SOACS -> ADM ()
diffStms forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Stms SOACS -> ADM (Stms SOACS)
preprocess Stms SOACS
stms
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM VName
lookupAdjVal [VName]
get_adjs_for
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms' forall a b. (a -> b) -> a -> b
$ Result
res forall a. Semigroup a => a -> a -> a
<> [VName] -> Result
varsRes [VName]
adjs
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
diffLambda [Adj]
res_adjs [VName]
get_adjs_for (Lambda [LParam SOACS]
params Body SOACS
body [Type]
_) =
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ do
Body () Stms SOACS
stms Result
res <- [Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody [Adj]
res_adjs [VName]
get_adjs_for Body SOACS
body
let body' :: Body SOACS
body' = forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms SOACS
stms forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
get_adjs_for) Result
res
[Type]
ts' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
get_adjs_for
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda [LParam SOACS]
params Body SOACS
body' [Type]
ts'
revVJP :: MonadFreshNames m => Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP :: forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope (Lambda [LParam SOACS]
params Body SOACS
body [Type]
ts) =
forall (m :: * -> *) a. MonadFreshNames m => ADM a -> m a
runADM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (Scope SOACS
scope forall a. Semigroup a => a -> a -> a
<> forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam SOACS]
params) forall a b. (a -> b) -> a -> b
$ do
[Param Type]
params_adj <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall rep. Body rep -> Result
bodyResult Body SOACS
body)) [Type]
ts) forall a b. (a -> b) -> a -> b
$ \(SubExp
se, Type
t) ->
forall dec. Attrs -> VName -> dec -> Param dec
Param forall a. Monoid a => a
mempty forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall b a. b -> (a -> b) -> Maybe a -> b
maybe (forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"const_adj") VName -> ADM VName
adjVName (SubExp -> Maybe VName
subExpVar SubExp
se) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
t
Body SOACS
body' <-
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope (forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [Param Type]
params_adj) forall a b. (a -> b) -> a -> b
$
[Adj] -> [VName] -> Body SOACS -> ADM (Body SOACS)
diffBody
(forall a b. (a -> b) -> [a] -> [b]
map forall t. Param t -> Adj
adjFromParam [Param Type]
params_adj)
(forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [LParam SOACS]
params)
Body SOACS
body
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda ([LParam SOACS]
params forall a. [a] -> [a] -> [a]
++ [Param Type]
params_adj) Body SOACS
body' ([Type]
ts forall a. Semigroup a => a -> a -> a
<> forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType [LParam SOACS]
params)