{-# LANGUAGE TypeFamilies #-}

-- | Apply all AD operators in the program, leaving AD-free code.
module Futhark.Pass.AD (applyAD, applyADInnermost) where

import Control.Monad
import Control.Monad.Reader
import Futhark.AD.Fwd (fwdJVP)
import Futhark.AD.Rev (revVJP)
import Futhark.Builder
import Futhark.IR.SOACS
import Futhark.IR.SOACS.Simplify (simplifyLambda)
import Futhark.Pass

-- | Whether we apply only the innermost AD operators, or all of them.
-- The former is very useful for debugging, but probably not useful
-- for actual compilation.
data Mode = Innermost | All
  deriving (Mode -> Mode -> Bool
(Mode -> Mode -> Bool) -> (Mode -> Mode -> Bool) -> Eq Mode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Mode -> Mode -> Bool
== :: Mode -> Mode -> Bool
$c/= :: Mode -> Mode -> Bool
/= :: Mode -> Mode -> Bool
Eq)

bindLambda ::
  (MonadBuilder m, Rep m ~ SOACS) =>
  Pat Type ->
  StmAux (ExpDec SOACS) ->
  Lambda SOACS ->
  [SubExp] ->
  m ()
bindLambda :: forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
pat StmAux (ExpDec SOACS)
aux (Lambda [LParam SOACS]
params [Type]
_ Body SOACS
body) [SubExp]
args = do
  StmAux () -> m () -> m ()
forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux ()
StmAux (ExpDec SOACS)
aux (m () -> m ())
-> (((Param Type, SubExp) -> m ()) -> m ())
-> ((Param Type, SubExp) -> m ())
-> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Param Type, SubExp)] -> ((Param Type, SubExp) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([Param Type] -> [SubExp] -> [(Param Type, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Param Type]
[LParam SOACS]
params [SubExp]
args) (((Param Type, SubExp) -> m ()) -> m ())
-> ((Param Type, SubExp) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(Param Type
param, SubExp
arg) ->
    [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
param] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$
      BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ case Param Type -> Type
forall dec. Typed dec => Param dec -> Type
paramType Param Type
param of
        Array {} -> Shape -> SubExp -> BasicOp
Replicate Shape
forall a. Monoid a => a
mempty SubExp
arg
        Type
_ -> SubExp -> BasicOp
SubExp SubExp
arg
  Result
res <- Body (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind Body (Rep m)
Body SOACS
body
  [(VName, SubExpRes)] -> ((VName, SubExpRes) -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ ([VName] -> Result -> [(VName, SubExpRes)]
forall a b. [a] -> [b] -> [(a, b)]
zip (Pat Type -> [VName]
forall dec. Pat dec -> [VName]
patNames Pat Type
pat) Result
res) (((VName, SubExpRes) -> m ()) -> m ())
-> ((VName, SubExpRes) -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ \(VName
v, SubExpRes Certs
cs SubExp
se) ->
    Certs -> m () -> m ()
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
v] (Exp (Rep m) -> m ()) -> Exp (Rep m) -> m ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se

onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm :: Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (VJP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) = do
  Lambda SOACS
lam' <- Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam
  if Mode
mode Mode -> Mode -> Bool
forall a. Eq a => a -> a -> Bool
== Mode
All Bool -> Bool -> Bool
|| Lambda SOACS
lam Lambda SOACS -> Lambda SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda SOACS
lam'
    then do
      Lambda SOACS
lam'' <- (ReaderT (Scope SOACS) PassM (Lambda SOACS)
-> Scope SOACS -> PassM (Lambda SOACS)
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
`runReaderT` Scope SOACS
scope) (ReaderT (Scope SOACS) PassM (Lambda SOACS)
 -> PassM (Lambda SOACS))
-> (Lambda SOACS -> ReaderT (Scope SOACS) PassM (Lambda SOACS))
-> Lambda SOACS
-> PassM (Lambda SOACS)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lambda SOACS -> ReaderT (Scope SOACS) PassM (Lambda SOACS)
forall (m :: * -> *).
(HasScope SOACS m, MonadFreshNames m) =>
Lambda SOACS -> m (Lambda SOACS)
simplifyLambda (Lambda SOACS -> PassM (Lambda SOACS))
-> PassM (Lambda SOACS) -> PassM (Lambda SOACS)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
revVJP Scope SOACS
scope Lambda SOACS
lam'
      BuilderT SOACS PassM () -> Scope SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat Type
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> [SubExp]
-> BuilderT SOACS PassM ()
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Lambda SOACS
lam'' ([SubExp] -> BuilderT SOACS PassM ())
-> [SubExp] -> BuilderT SOACS PassM ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
args [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
vec) Scope SOACS
scope
    else Stms SOACS -> PassM (Stms SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> PassM (Stms SOACS))
-> Stms SOACS -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> [SubExp] -> SOAC SOACS
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
VJP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Op (JVP Lambda SOACS
lam [SubExp]
args [SubExp]
vec))) = do
  Lambda SOACS
lam' <- Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam
  if Mode
mode Mode -> Mode -> Bool
forall a. Eq a => a -> a -> Bool
== Mode
All Bool -> Bool -> Bool
|| Lambda SOACS
lam Lambda SOACS -> Lambda SOACS -> Bool
forall a. Eq a => a -> a -> Bool
== Lambda SOACS
lam'
    then do
      Lambda SOACS
lam'' <- Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
forall (m :: * -> *).
MonadFreshNames m =>
Scope SOACS -> Lambda SOACS -> m (Lambda SOACS)
fwdJVP Scope SOACS
scope Lambda SOACS
lam'
      BuilderT SOACS PassM () -> Scope SOACS -> PassM (Stms SOACS)
forall (m :: * -> *) rep.
MonadFreshNames m =>
BuilderT rep m () -> Scope rep -> m (Stms rep)
runBuilderT_ (Pat Type
-> StmAux (ExpDec SOACS)
-> Lambda SOACS
-> [SubExp]
-> BuilderT SOACS PassM ()
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Pat Type
-> StmAux (ExpDec SOACS) -> Lambda SOACS -> [SubExp] -> m ()
bindLambda Pat Type
Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Lambda SOACS
lam'' ([SubExp] -> BuilderT SOACS PassM ())
-> [SubExp] -> BuilderT SOACS PassM ()
forall a b. (a -> b) -> a -> b
$ [SubExp]
args [SubExp] -> [SubExp] -> [SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp]
vec) Scope SOACS
scope
    else Stms SOACS -> PassM (Stms SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Stms SOACS -> PassM (Stms SOACS))
-> Stms SOACS -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS) -> Stm SOACS -> Stms SOACS
forall a b. (a -> b) -> a -> b
$ Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stm SOACS
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> [SubExp] -> [SubExp] -> SOAC SOACS
forall rep. Lambda rep -> [SubExp] -> [SubExp] -> SOAC rep
JVP Lambda SOACS
lam' [SubExp]
args [SubExp]
vec
onStm Mode
mode Scope SOACS
scope (Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux Exp SOACS
e) = Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm (Stm SOACS -> Stms SOACS)
-> (Exp SOACS -> Stm SOACS) -> Exp SOACS -> Stms SOACS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pat (LetDec SOACS)
-> StmAux (ExpDec SOACS) -> Exp SOACS -> Stm SOACS
forall rep.
Pat (LetDec rep) -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let Pat (LetDec SOACS)
pat StmAux (ExpDec SOACS)
aux (Exp SOACS -> Stms SOACS)
-> PassM (Exp SOACS) -> PassM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mapper SOACS SOACS PassM -> Exp SOACS -> PassM (Exp SOACS)
forall (m :: * -> *) frep trep.
Monad m =>
Mapper frep trep m -> Exp frep -> m (Exp trep)
mapExpM Mapper SOACS SOACS PassM
mapper Exp SOACS
e
  where
    mapper :: Mapper SOACS SOACS PassM
mapper =
      (forall rep (m :: * -> *). Monad m => Mapper rep rep m
identityMapper @SOACS)
        { mapOnBody :: Scope SOACS -> Body SOACS -> PassM (Body SOACS)
mapOnBody = \Scope SOACS
bscope -> Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (Scope SOACS
bscope Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope),
          mapOnOp :: Op SOACS -> PassM (Op SOACS)
mapOnOp = SOACMapper SOACS SOACS PassM -> SOAC SOACS -> PassM (SOAC SOACS)
forall (m :: * -> *) frep trep.
Monad m =>
SOACMapper frep trep m -> SOAC frep -> m (SOAC trep)
mapSOACM SOACMapper SOACS SOACS PassM
soac_mapper
        }
    soac_mapper :: SOACMapper SOACS SOACS PassM
soac_mapper = SOACMapper Any Any PassM
forall rep (m :: * -> *). Monad m => SOACMapper rep rep m
identitySOACMapper {mapOnSOACLambda :: Lambda SOACS -> PassM (Lambda SOACS)
mapOnSOACLambda = Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope}

onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms :: Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
mode Scope SOACS
scope Stms SOACS
stms = [Stms SOACS] -> Stms SOACS
forall a. Monoid a => [a] -> a
mconcat ([Stms SOACS] -> Stms SOACS)
-> PassM [Stms SOACS] -> PassM (Stms SOACS)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Stm SOACS -> PassM (Stms SOACS))
-> [Stm SOACS] -> PassM [Stms SOACS]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Mode -> Scope SOACS -> Stm SOACS -> PassM (Stms SOACS)
onStm Mode
mode Scope SOACS
scope') (Stms SOACS -> [Stm SOACS]
forall rep. Stms rep -> [Stm rep]
stmsToList Stms SOACS
stms)
  where
    scope' :: Scope SOACS
scope' = Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
stms Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope

onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody :: Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode Scope SOACS
scope Body SOACS
body = do
  Stms SOACS
stms <- Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
mode Scope SOACS
scope (Stms SOACS -> PassM (Stms SOACS))
-> Stms SOACS -> PassM (Stms SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS -> Stms SOACS
forall rep. Body rep -> Stms rep
bodyStms Body SOACS
body
  Body SOACS -> PassM (Body SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> PassM (Body SOACS))
-> Body SOACS -> PassM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Body SOACS
body {bodyStms :: Stms SOACS
bodyStms = Stms SOACS
stms}

onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda :: Mode -> Scope SOACS -> Lambda SOACS -> PassM (Lambda SOACS)
onLambda Mode
mode Scope SOACS
scope Lambda SOACS
lam = do
  Body SOACS
body <- Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode ([Param Type] -> Scope SOACS
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Lambda SOACS -> [LParam SOACS]
forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda SOACS
lam) Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> Scope SOACS
scope) (Body SOACS -> PassM (Body SOACS))
-> Body SOACS -> PassM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS -> Body SOACS
forall rep. Lambda rep -> Body rep
lambdaBody Lambda SOACS
lam
  Lambda SOACS -> PassM (Lambda SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda SOACS -> PassM (Lambda SOACS))
-> Lambda SOACS -> PassM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Lambda SOACS
lam {lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body}

onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun :: Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
mode Stms SOACS
consts FunDef SOACS
fd = do
  Body SOACS
body <- Mode -> Scope SOACS -> Body SOACS -> PassM (Body SOACS)
onBody Mode
mode (Stms SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms SOACS
consts Scope SOACS -> Scope SOACS -> Scope SOACS
forall a. Semigroup a => a -> a -> a
<> FunDef SOACS -> Scope SOACS
forall rep a. Scoped rep a => a -> Scope rep
scopeOf FunDef SOACS
fd) (Body SOACS -> PassM (Body SOACS))
-> Body SOACS -> PassM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ FunDef SOACS -> Body SOACS
forall rep. FunDef rep -> Body rep
funDefBody FunDef SOACS
fd
  FunDef SOACS -> PassM (FunDef SOACS)
forall a. a -> PassM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (FunDef SOACS -> PassM (FunDef SOACS))
-> FunDef SOACS -> PassM (FunDef SOACS)
forall a b. (a -> b) -> a -> b
$ FunDef SOACS
fd {funDefBody :: Body SOACS
funDefBody = Body SOACS
body}

applyAD :: Pass SOACS SOACS
applyAD :: Pass SOACS SOACS
applyAD =
  Pass
    { passName :: String
passName = String
"ad",
      passDescription :: String
passDescription = String
"Apply AD operators",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction =
        (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          (Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
All Scope SOACS
forall a. Monoid a => a
mempty)
          (Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
All)
    }

applyADInnermost :: Pass SOACS SOACS
applyADInnermost :: Pass SOACS SOACS
applyADInnermost =
  Pass
    { passName :: String
passName = String
"ad innermost",
      passDescription :: String
passDescription = String
"Apply innermost AD operators",
      passFunction :: Prog SOACS -> PassM (Prog SOACS)
passFunction =
        (Stms SOACS -> PassM (Stms SOACS))
-> (Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS))
-> Prog SOACS
-> PassM (Prog SOACS)
forall fromrep torep.
(Stms fromrep -> PassM (Stms torep))
-> (Stms torep -> FunDef fromrep -> PassM (FunDef torep))
-> Prog fromrep
-> PassM (Prog torep)
intraproceduralTransformationWithConsts
          (Mode -> Scope SOACS -> Stms SOACS -> PassM (Stms SOACS)
onStms Mode
Innermost Scope SOACS
forall a. Monoid a => a
mempty)
          (Mode -> Stms SOACS -> FunDef SOACS -> PassM (FunDef SOACS)
onFun Mode
Innermost)
    }