{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.AD.Rev.Monad
( ADM,
RState (..),
runADM,
Adj (..),
InBounds (..),
Sparse (..),
adjFromParam,
adjFromVar,
lookupAdj,
lookupAdjVal,
adjVal,
updateAdj,
updateSubExpAdj,
updateAdjSlice,
updateAdjIndex,
setAdj,
insAdj,
adjsReps,
copyConsumedArrsInStm,
copyConsumedArrsInBody,
addSubstitution,
returnSweepCode,
adjVName,
subAD,
noAdjsFor,
subSubsts,
isActive,
tabNest,
oneExp,
zeroExp,
unitAdjOfType,
addLambda,
VjpOps (..),
setLoopTape,
lookupLoopTape,
substLoopTape,
renameLoopTape,
)
where
import Control.Monad
import Control.Monad.State.Strict
import Data.Bifunctor (second)
import Data.List (foldl')
import qualified Data.Map as M
import Data.Maybe
import qualified Futhark.Analysis.Alias as Alias
import Futhark.Analysis.PrimExp.Convert
import Futhark.Builder
import Futhark.IR.Aliases (consumedInStms)
import Futhark.IR.Prop.Aliases
import Futhark.IR.SOACS
import Futhark.Tools
import Futhark.Transform.Substitute
import Futhark.Util (chunks)
zeroExp :: Type -> Exp rep
zeroExp :: Type -> Exp rep
zeroExp (Prim PrimType
pt) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
pt
zeroExp Type
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"zeroExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
t
onePrim :: PrimType -> PrimValue
onePrim :: PrimType -> PrimValue
onePrim (IntType IntType
it) = IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> IntValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ IntType -> Int -> IntValue
forall int. Integral int => IntType -> int -> IntValue
intValue IntType
it (Int
1 :: Int)
onePrim (FloatType FloatType
ft) = FloatValue -> PrimValue
FloatValue (FloatValue -> PrimValue) -> FloatValue -> PrimValue
forall a b. (a -> b) -> a -> b
$ FloatType -> Double -> FloatValue
forall num. Real num => FloatType -> num -> FloatValue
floatValue FloatType
ft (Double
1 :: Double)
onePrim PrimType
Bool = Bool -> PrimValue
BoolValue Bool
True
onePrim PrimType
Unit = PrimValue
UnitValue
oneExp :: Type -> Exp rep
oneExp :: Type -> Exp rep
oneExp (Prim PrimType
t) = BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
forall v. IsValue v => v -> SubExp
constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
t
oneExp (Array PrimType
pt Shape
shape NoUniqueness
_) =
BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
shape (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
onePrim PrimType
pt
oneExp Type
t = [Char] -> Exp rep
forall a. HasCallStack => [Char] -> a
error ([Char] -> Exp rep) -> [Char] -> Exp rep
forall a b. (a -> b) -> a -> b
$ [Char]
"oneExp: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
t
data InBounds
=
CheckBounds (Maybe SubExp)
| AssumeBounds
|
OutOfBounds
deriving (InBounds -> InBounds -> Bool
(InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool) -> Eq InBounds
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: InBounds -> InBounds -> Bool
$c/= :: InBounds -> InBounds -> Bool
== :: InBounds -> InBounds -> Bool
$c== :: InBounds -> InBounds -> Bool
Eq, Eq InBounds
Eq InBounds
-> (InBounds -> InBounds -> Ordering)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> Bool)
-> (InBounds -> InBounds -> InBounds)
-> (InBounds -> InBounds -> InBounds)
-> Ord InBounds
InBounds -> InBounds -> Bool
InBounds -> InBounds -> Ordering
InBounds -> InBounds -> InBounds
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: InBounds -> InBounds -> InBounds
$cmin :: InBounds -> InBounds -> InBounds
max :: InBounds -> InBounds -> InBounds
$cmax :: InBounds -> InBounds -> InBounds
>= :: InBounds -> InBounds -> Bool
$c>= :: InBounds -> InBounds -> Bool
> :: InBounds -> InBounds -> Bool
$c> :: InBounds -> InBounds -> Bool
<= :: InBounds -> InBounds -> Bool
$c<= :: InBounds -> InBounds -> Bool
< :: InBounds -> InBounds -> Bool
$c< :: InBounds -> InBounds -> Bool
compare :: InBounds -> InBounds -> Ordering
$ccompare :: InBounds -> InBounds -> Ordering
$cp1Ord :: Eq InBounds
Ord, Int -> InBounds -> [Char] -> [Char]
[InBounds] -> [Char] -> [Char]
InBounds -> [Char]
(Int -> InBounds -> [Char] -> [Char])
-> (InBounds -> [Char])
-> ([InBounds] -> [Char] -> [Char])
-> Show InBounds
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [InBounds] -> [Char] -> [Char]
$cshowList :: [InBounds] -> [Char] -> [Char]
show :: InBounds -> [Char]
$cshow :: InBounds -> [Char]
showsPrec :: Int -> InBounds -> [Char] -> [Char]
$cshowsPrec :: Int -> InBounds -> [Char] -> [Char]
Show)
data Sparse = Sparse
{
Sparse -> Shape
sparseShape :: Shape,
Sparse -> PrimType
sparseType :: PrimType,
Sparse -> [(InBounds, SubExp, SubExp)]
sparseIdxVals :: [(InBounds, SubExp, SubExp)]
}
deriving (Sparse -> Sparse -> Bool
(Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool) -> Eq Sparse
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Sparse -> Sparse -> Bool
$c/= :: Sparse -> Sparse -> Bool
== :: Sparse -> Sparse -> Bool
$c== :: Sparse -> Sparse -> Bool
Eq, Eq Sparse
Eq Sparse
-> (Sparse -> Sparse -> Ordering)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Bool)
-> (Sparse -> Sparse -> Sparse)
-> (Sparse -> Sparse -> Sparse)
-> Ord Sparse
Sparse -> Sparse -> Bool
Sparse -> Sparse -> Ordering
Sparse -> Sparse -> Sparse
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Sparse -> Sparse -> Sparse
$cmin :: Sparse -> Sparse -> Sparse
max :: Sparse -> Sparse -> Sparse
$cmax :: Sparse -> Sparse -> Sparse
>= :: Sparse -> Sparse -> Bool
$c>= :: Sparse -> Sparse -> Bool
> :: Sparse -> Sparse -> Bool
$c> :: Sparse -> Sparse -> Bool
<= :: Sparse -> Sparse -> Bool
$c<= :: Sparse -> Sparse -> Bool
< :: Sparse -> Sparse -> Bool
$c< :: Sparse -> Sparse -> Bool
compare :: Sparse -> Sparse -> Ordering
$ccompare :: Sparse -> Sparse -> Ordering
$cp1Ord :: Eq Sparse
Ord, Int -> Sparse -> [Char] -> [Char]
[Sparse] -> [Char] -> [Char]
Sparse -> [Char]
(Int -> Sparse -> [Char] -> [Char])
-> (Sparse -> [Char])
-> ([Sparse] -> [Char] -> [Char])
-> Show Sparse
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Sparse] -> [Char] -> [Char]
$cshowList :: [Sparse] -> [Char] -> [Char]
show :: Sparse -> [Char]
$cshow :: Sparse -> [Char]
showsPrec :: Int -> Sparse -> [Char] -> [Char]
$cshowsPrec :: Int -> Sparse -> [Char] -> [Char]
Show)
data Adj
= AdjSparse Sparse
| AdjVal SubExp
| AdjZero Shape PrimType
deriving (Adj -> Adj -> Bool
(Adj -> Adj -> Bool) -> (Adj -> Adj -> Bool) -> Eq Adj
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Adj -> Adj -> Bool
$c/= :: Adj -> Adj -> Bool
== :: Adj -> Adj -> Bool
$c== :: Adj -> Adj -> Bool
Eq, Eq Adj
Eq Adj
-> (Adj -> Adj -> Ordering)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Bool)
-> (Adj -> Adj -> Adj)
-> (Adj -> Adj -> Adj)
-> Ord Adj
Adj -> Adj -> Bool
Adj -> Adj -> Ordering
Adj -> Adj -> Adj
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Adj -> Adj -> Adj
$cmin :: Adj -> Adj -> Adj
max :: Adj -> Adj -> Adj
$cmax :: Adj -> Adj -> Adj
>= :: Adj -> Adj -> Bool
$c>= :: Adj -> Adj -> Bool
> :: Adj -> Adj -> Bool
$c> :: Adj -> Adj -> Bool
<= :: Adj -> Adj -> Bool
$c<= :: Adj -> Adj -> Bool
< :: Adj -> Adj -> Bool
$c< :: Adj -> Adj -> Bool
compare :: Adj -> Adj -> Ordering
$ccompare :: Adj -> Adj -> Ordering
$cp1Ord :: Eq Adj
Ord, Int -> Adj -> [Char] -> [Char]
[Adj] -> [Char] -> [Char]
Adj -> [Char]
(Int -> Adj -> [Char] -> [Char])
-> (Adj -> [Char]) -> ([Adj] -> [Char] -> [Char]) -> Show Adj
forall a.
(Int -> a -> [Char] -> [Char])
-> (a -> [Char]) -> ([a] -> [Char] -> [Char]) -> Show a
showList :: [Adj] -> [Char] -> [Char]
$cshowList :: [Adj] -> [Char] -> [Char]
show :: Adj -> [Char]
$cshow :: Adj -> [Char]
showsPrec :: Int -> Adj -> [Char] -> [Char]
$cshowsPrec :: Int -> Adj -> [Char] -> [Char]
Show)
instance Substitute Adj where
substituteNames :: Map VName VName -> Adj -> Adj
substituteNames Map VName VName
m (AdjVal (Var VName
v)) = SubExp -> Adj
AdjVal (SubExp -> Adj) -> SubExp -> Adj
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var (VName -> SubExp) -> VName -> SubExp
forall a b. (a -> b) -> a -> b
$ Map VName VName -> VName -> VName
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
m VName
v
substituteNames Map VName VName
_ Adj
adj = Adj
adj
zeroArray :: MonadBuilder m => Shape -> Type -> m VName
zeroArray :: Shape -> Type -> m VName
zeroArray Shape
shape Type
t
| Shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank Shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 =
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zero" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep m)
forall rep. Type -> Exp rep
zeroExp Type
t
| Bool
otherwise = do
SubExp
zero <- [Char] -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"zero" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Type -> Exp (Rep m)
forall rep. Type -> Exp rep
zeroExp Type
t
Attrs -> m VName -> m VName
forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing (Attr -> Attrs
oneAttr Attr
"sequential") (m VName -> m VName) -> m VName -> m VName
forall a b. (a -> b) -> a -> b
$
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"zeroes_" (Exp (Rep m) -> m VName)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
Shape -> SubExp -> BasicOp
Replicate Shape
shape SubExp
zero
sparseArray :: (MonadBuilder m, Rep m ~ SOACS) => Sparse -> m VName
sparseArray :: Sparse -> m VName
sparseArray (Sparse Shape
shape PrimType
t [(InBounds, SubExp, SubExp)]
ivs) = do
(VName -> [(InBounds, SubExp, SubExp)] -> m VName)
-> [(InBounds, SubExp, SubExp)] -> VName -> m VName
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VName -> (InBounds, SubExp, SubExp) -> m VName)
-> VName -> [(InBounds, SubExp, SubExp)] -> m VName
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM VName -> (InBounds, SubExp, SubExp) -> m VName
f) [(InBounds, SubExp, SubExp)]
ivs (VName -> m VName) -> m VName -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Shape -> Type -> m VName
forall (m :: * -> *). MonadBuilder m => Shape -> Type -> m VName
zeroArray Shape
shape (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t)
where
arr_t :: Type
arr_t = PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t Type -> Shape -> Type
`arrayOfShape` Shape
shape
f :: VName -> (InBounds, SubExp, SubExp) -> m VName
f VName
arr (InBounds
check, SubExp
i, SubExp
se) = do
let stm :: Safety -> m VName
stm Safety
s =
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"sparse" (Exp SOACS -> m VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
arr (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se
case InBounds
check of
InBounds
AssumeBounds -> Safety -> m VName
stm Safety
Unsafe
CheckBounds Maybe SubExp
_ -> Safety -> m VName
stm Safety
Safe
InBounds
OutOfBounds -> VName -> m VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
arr
adjFromVar :: VName -> Adj
adjFromVar :: VName -> Adj
adjFromVar = SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjFromParam :: Param t -> Adj
adjFromParam :: Param t -> Adj
adjFromParam = VName -> Adj
adjFromVar (VName -> Adj) -> (Param t -> VName) -> Param t -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param t -> VName
forall dec. Param dec -> VName
paramName
unitAdjOfType :: Type -> ADM Adj
unitAdjOfType :: Type -> ADM Adj
unitAdjOfType Type
t = SubExp -> Adj
AdjVal (SubExp -> Adj) -> ADM SubExp -> ADM Adj
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"adj_unit" (Type -> Exp SOACS
forall rep. Type -> Exp rep
oneExp Type
t)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep :: Adj -> ([SubExp], [SubExp] -> Adj)
adjRep (AdjVal SubExp
se) = ([SubExp
se], \[SubExp
se'] -> SubExp -> Adj
AdjVal SubExp
se')
adjRep (AdjZero Shape
shape PrimType
pt) = ([], \[] -> Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
pt)
adjRep (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) =
(((InBounds, SubExp, SubExp) -> [SubExp])
-> [(InBounds, SubExp, SubExp)] -> [SubExp]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (InBounds, SubExp, SubExp) -> [SubExp]
forall a a. (a, a, a) -> [a]
ivRep [(InBounds, SubExp, SubExp)]
ivs, Sparse -> Adj
AdjSparse (Sparse -> Adj) -> ([SubExp] -> Sparse) -> [SubExp] -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> ([SubExp] -> [(InBounds, SubExp, SubExp)]) -> [SubExp] -> Sparse
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(InBounds, SubExp, SubExp)]
-> [SubExp] -> [(InBounds, SubExp, SubExp)]
forall b c c. [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, SubExp, SubExp)]
ivs)
where
ivRep :: (a, a, a) -> [a]
ivRep (a
_, a
i, a
v) = [a
i, a
v]
repIvs :: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs ((InBounds
check, b
_, c
_) : [(InBounds, b, c)]
ivs') (c
i : c
v : [c]
ses) =
(InBounds
check', c
i, c
v) (InBounds, c, c) -> [(InBounds, c, c)] -> [(InBounds, c, c)]
forall a. a -> [a] -> [a]
: [(InBounds, b, c)] -> [c] -> [(InBounds, c, c)]
repIvs [(InBounds, b, c)]
ivs' [c]
ses
where
check' :: InBounds
check' = case InBounds
check of
InBounds
AssumeBounds -> InBounds
AssumeBounds
CheckBounds Maybe SubExp
b -> Maybe SubExp -> InBounds
CheckBounds Maybe SubExp
b
InBounds
OutOfBounds -> Maybe SubExp -> InBounds
CheckBounds (SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just (Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False))
repIvs [(InBounds, b, c)]
_ [c]
_ = []
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps :: [Adj] -> ([SubExp], [SubExp] -> [Adj])
adjsReps [Adj]
adjs =
let ([[SubExp]]
reps, [[SubExp] -> Adj]
fs) = [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. [(a, b)] -> ([a], [b])
unzip ([([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj]))
-> [([SubExp], [SubExp] -> Adj)] -> ([[SubExp]], [[SubExp] -> Adj])
forall a b. (a -> b) -> a -> b
$ (Adj -> ([SubExp], [SubExp] -> Adj))
-> [Adj] -> [([SubExp], [SubExp] -> Adj)]
forall a b. (a -> b) -> [a] -> [b]
map Adj -> ([SubExp], [SubExp] -> Adj)
adjRep [Adj]
adjs
in ([[SubExp]] -> [SubExp]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[SubExp]]
reps, (([SubExp] -> Adj) -> [SubExp] -> Adj)
-> [[SubExp] -> Adj] -> [[SubExp]] -> [Adj]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ([SubExp] -> Adj) -> [SubExp] -> Adj
forall a b. (a -> b) -> a -> b
($) [[SubExp] -> Adj]
fs ([[SubExp]] -> [Adj])
-> ([SubExp] -> [[SubExp]]) -> [SubExp] -> [Adj]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> [SubExp] -> [[SubExp]]
forall a. [Int] -> [a] -> [[a]]
chunks (([SubExp] -> Int) -> [[SubExp]] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [[SubExp]]
reps))
data RState = RState
{ RState -> Map VName Adj
stateAdjs :: M.Map VName Adj,
RState -> Map VName VName
stateLoopTape :: Substitutions,
RState -> Map VName VName
stateSubsts :: Substitutions,
RState -> VNameSource
stateNameSource :: VNameSource
}
newtype ADM a = ADM (BuilderT SOACS (State RState) a)
deriving
( a -> ADM b -> ADM a
(a -> b) -> ADM a -> ADM b
(forall a b. (a -> b) -> ADM a -> ADM b)
-> (forall a b. a -> ADM b -> ADM a) -> Functor ADM
forall a b. a -> ADM b -> ADM a
forall a b. (a -> b) -> ADM a -> ADM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> ADM b -> ADM a
$c<$ :: forall a b. a -> ADM b -> ADM a
fmap :: (a -> b) -> ADM a -> ADM b
$cfmap :: forall a b. (a -> b) -> ADM a -> ADM b
Functor,
Functor ADM
a -> ADM a
Functor ADM
-> (forall a. a -> ADM a)
-> (forall a b. ADM (a -> b) -> ADM a -> ADM b)
-> (forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM a)
-> Applicative ADM
ADM a -> ADM b -> ADM b
ADM a -> ADM b -> ADM a
ADM (a -> b) -> ADM a -> ADM b
(a -> b -> c) -> ADM a -> ADM b -> ADM c
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM (a -> b) -> ADM a -> ADM b
forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: ADM a -> ADM b -> ADM a
$c<* :: forall a b. ADM a -> ADM b -> ADM a
*> :: ADM a -> ADM b -> ADM b
$c*> :: forall a b. ADM a -> ADM b -> ADM b
liftA2 :: (a -> b -> c) -> ADM a -> ADM b -> ADM c
$cliftA2 :: forall a b c. (a -> b -> c) -> ADM a -> ADM b -> ADM c
<*> :: ADM (a -> b) -> ADM a -> ADM b
$c<*> :: forall a b. ADM (a -> b) -> ADM a -> ADM b
pure :: a -> ADM a
$cpure :: forall a. a -> ADM a
$cp1Applicative :: Functor ADM
Applicative,
Applicative ADM
a -> ADM a
Applicative ADM
-> (forall a b. ADM a -> (a -> ADM b) -> ADM b)
-> (forall a b. ADM a -> ADM b -> ADM b)
-> (forall a. a -> ADM a)
-> Monad ADM
ADM a -> (a -> ADM b) -> ADM b
ADM a -> ADM b -> ADM b
forall a. a -> ADM a
forall a b. ADM a -> ADM b -> ADM b
forall a b. ADM a -> (a -> ADM b) -> ADM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> ADM a
$creturn :: forall a. a -> ADM a
>> :: ADM a -> ADM b -> ADM b
$c>> :: forall a b. ADM a -> ADM b -> ADM b
>>= :: ADM a -> (a -> ADM b) -> ADM b
$c>>= :: forall a b. ADM a -> (a -> ADM b) -> ADM b
$cp1Monad :: Applicative ADM
Monad,
MonadState RState,
Monad ADM
Applicative ADM
ADM VNameSource
Applicative ADM
-> Monad ADM
-> ADM VNameSource
-> (VNameSource -> ADM ())
-> MonadFreshNames ADM
VNameSource -> ADM ()
forall (m :: * -> *).
Applicative m
-> Monad m
-> m VNameSource
-> (VNameSource -> m ())
-> MonadFreshNames m
putNameSource :: VNameSource -> ADM ()
$cputNameSource :: VNameSource -> ADM ()
getNameSource :: ADM VNameSource
$cgetNameSource :: ADM VNameSource
$cp2MonadFreshNames :: Monad ADM
$cp1MonadFreshNames :: Applicative ADM
MonadFreshNames,
HasScope SOACS,
LocalScope SOACS
)
instance MonadBuilder ADM where
type Rep ADM = SOACS
mkExpDecM :: Pat (LetDec (Rep ADM)) -> Exp (Rep ADM) -> ADM (ExpDec (Rep ADM))
mkExpDecM Pat (LetDec (Rep ADM))
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> BuilderT SOACS (State RState) () -> ADM ()
forall a b. (a -> b) -> a -> b
$ Pat (LetDec (Rep (BuilderT SOACS (State RState))))
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (ExpDec (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m (ExpDec (Rep m))
mkExpDecM Pat (LetDec (Rep (BuilderT SOACS (State RState))))
Pat (LetDec (Rep ADM))
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
mkBodyM :: Stms (Rep ADM) -> Result -> ADM (Body (Rep ADM))
mkBodyM Stms (Rep ADM)
bnds Result
res = BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS))
-> BuilderT SOACS (State RState) (Body SOACS) -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ Stms (Rep (BuilderT SOACS (State RState)))
-> Result
-> BuilderT
SOACS (State RState) (Body (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep (BuilderT SOACS (State RState)))
Stms (Rep ADM)
bnds Result
res
mkLetNamesM :: [VName] -> Exp (Rep ADM) -> ADM (Stm (Rep ADM))
mkLetNamesM [VName]
pat Exp (Rep ADM)
e = BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS))
-> BuilderT SOACS (State RState) (Stm SOACS) -> ADM (Stm SOACS)
forall a b. (a -> b) -> a -> b
$ [VName]
-> Exp (Rep (BuilderT SOACS (State RState)))
-> BuilderT
SOACS (State RState) (Stm (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m (Stm (Rep m))
mkLetNamesM [VName]
pat Exp (Rep (BuilderT SOACS (State RState)))
Exp (Rep ADM)
e
addStms :: Stms (Rep ADM) -> ADM ()
addStms = BuilderT SOACS (State RState) () -> ADM ()
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) () -> ADM ())
-> (Stms SOACS -> BuilderT SOACS (State RState) ())
-> Stms SOACS
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms SOACS -> BuilderT SOACS (State RState) ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms
collectStms :: ADM a -> ADM (a, Stms (Rep ADM))
collectStms (ADM BuilderT SOACS (State RState) a
m) = BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS)
forall a. BuilderT SOACS (State RState) a -> ADM a
ADM (BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS))
-> BuilderT SOACS (State RState) (a, Stms SOACS)
-> ADM (a, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ BuilderT SOACS (State RState) a
-> BuilderT
SOACS
(State RState)
(a, Stms (Rep (BuilderT SOACS (State RState))))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms BuilderT SOACS (State RState) a
m
instance MonadFreshNames (State RState) where
getNameSource :: State RState VNameSource
getNameSource = (RState -> VNameSource) -> State RState VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> VNameSource
stateNameSource
putNameSource :: VNameSource -> State RState ()
putNameSource VNameSource
src = (RState -> RState) -> State RState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\RState
env -> RState
env {stateNameSource :: VNameSource
stateNameSource = VNameSource
src})
runADM :: MonadFreshNames m => ADM a -> m a
runADM :: ADM a -> m a
runADM (ADM BuilderT SOACS (State RState) a
m) =
(VNameSource -> (a, VNameSource)) -> m a
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> (a, VNameSource)) -> m a)
-> (VNameSource -> (a, VNameSource)) -> m a
forall a b. (a -> b) -> a -> b
$ \VNameSource
vn ->
(RState -> VNameSource) -> (a, RState) -> (a, VNameSource)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second RState -> VNameSource
stateNameSource ((a, RState) -> (a, VNameSource))
-> (a, RState) -> (a, VNameSource)
forall a b. (a -> b) -> a -> b
$
State RState a -> RState -> (a, RState)
forall s a. State s a -> s -> (a, s)
runState
((a, Stms SOACS) -> a
forall a b. (a, b) -> a
fst ((a, Stms SOACS) -> a)
-> State RState (a, Stms SOACS) -> State RState a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BuilderT SOACS (State RState) a
-> Scope SOACS -> State RState (a, Stms SOACS)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT BuilderT SOACS (State RState) a
m Scope SOACS
forall a. Monoid a => a
mempty)
(Map VName Adj
-> Map VName VName -> Map VName VName -> VNameSource -> RState
RState Map VName Adj
forall a. Monoid a => a
mempty Map VName VName
forall a. Monoid a => a
mempty Map VName VName
forall a. Monoid a => a
mempty VNameSource
vn)
adjVal :: Adj -> ADM VName
adjVal :: Adj -> ADM VName
adjVal (AdjVal SubExp
se) = [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"const_adj" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
adjVal (AdjSparse Sparse
sparse) = Sparse -> ADM VName
forall (m :: * -> *).
(MonadBuilder m, Rep m ~ SOACS) =>
Sparse -> m VName
sparseArray Sparse
sparse
adjVal (AdjZero Shape
shape PrimType
t) = Shape -> Type -> ADM VName
forall (m :: * -> *). MonadBuilder m => Shape -> Type -> m VName
zeroArray Shape
shape (Type -> ADM VName) -> Type -> ADM VName
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
setAdj :: VName -> Adj -> ADM ()
setAdj :: VName -> Adj -> ADM ()
setAdj VName
v Adj
v_adj = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateAdjs :: Map VName Adj
stateAdjs = VName -> Adj -> Map VName Adj -> Map VName Adj
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v Adj
v_adj (Map VName Adj -> Map VName Adj) -> Map VName Adj -> Map VName Adj
forall a b. (a -> b) -> a -> b
$ RState -> Map VName Adj
stateAdjs RState
env}
insAdj :: VName -> VName -> ADM ()
insAdj :: VName -> VName -> ADM ()
insAdj VName
v = VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> (VName -> Adj) -> VName -> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> Adj
AdjVal (SubExp -> Adj) -> (VName -> SubExp) -> VName -> Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
adjVName :: VName -> ADM VName
adjVName :: VName -> ADM VName
adjVName VName
v = [Char] -> ADM VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj")
copyConsumedArrsInStm :: Stm SOACS -> ADM (Substitutions, Stms SOACS)
copyConsumedArrsInStm :: Stm SOACS -> ADM (Map VName VName, Stms SOACS)
copyConsumedArrsInStm Stm SOACS
s = Stm SOACS
-> ADM (Map VName VName, Stms SOACS)
-> ADM (Map VName VName, Stms SOACS)
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM (Map VName VName, Stms SOACS)
-> ADM (Map VName VName, Stms SOACS))
-> ADM (Map VName VName, Stms SOACS)
-> ADM (Map VName VName, Stms SOACS)
forall a b. (a -> b) -> a -> b
$ ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM)))
-> ADM (Map VName VName) -> ADM (Map VName VName, Stms (Rep ADM))
forall a b. (a -> b) -> a -> b
$ Stm SOACS -> ADM (Map VName VName)
copyConsumedArrsInStm' Stm SOACS
s
where
copyConsumedArrsInStm' :: Stm SOACS -> ADM (Map VName VName)
copyConsumedArrsInStm' Stm SOACS
stm =
let onConsumed :: VName -> ADM [(VName, VName)]
onConsumed VName
v = Stm SOACS -> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall rep a (m :: * -> *) b.
(Scoped rep a, LocalScope rep m) =>
a -> m b -> m b
inScopeOf Stm SOACS
s (ADM [(VName, VName)] -> ADM [(VName, VName)])
-> ADM [(VName, VName)] -> ADM [(VName, VName)]
forall a b. (a -> b) -> a -> b
$ do
Type
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
case Type
v_t of
Array {} -> do
VName
v' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
VName -> VName -> ADM ()
addSubstitution VName
v' VName
v
[(VName, VName)] -> ADM [(VName, VName)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName
v, VName
v')]
Type
_ -> [(VName, VName)] -> ADM [(VName, VName)]
forall (f :: * -> *) a. Applicative f => a -> f a
pure [(VName, VName)]
forall a. Monoid a => a
mempty
in [(VName, VName)] -> Map VName VName
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, VName)] -> Map VName VName)
-> ([[(VName, VName)]] -> [(VName, VName)])
-> [[(VName, VName)]]
-> Map VName VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[(VName, VName)]] -> [(VName, VName)]
forall a. Monoid a => [a] -> a
mconcat
([[(VName, VName)]] -> Map VName VName)
-> ADM [[(VName, VName)]] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM [(VName, VName)])
-> [VName] -> ADM [[(VName, VName)]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM [(VName, VName)]
onConsumed (Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Stms (Aliases SOACS) -> Names
forall rep. Aliased rep => Stms rep -> Names
consumedInStms (Stms (Aliases SOACS) -> Names) -> Stms (Aliases SOACS) -> Names
forall a b. (a -> b) -> a -> b
$ (Stms (Aliases SOACS), AliasesAndConsumed) -> Stms (Aliases SOACS)
forall a b. (a, b) -> a
fst (AliasTable
-> Stms SOACS -> (Stms (Aliases SOACS), AliasesAndConsumed)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Stms rep -> (Stms (Aliases rep), AliasesAndConsumed)
Alias.analyseStms AliasTable
forall a. Monoid a => a
mempty (Stm SOACS -> Stms SOACS
forall rep. Stm rep -> Stms rep
oneStm Stm SOACS
stm)))
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM Substitutions
copyConsumedArrsInBody :: [VName] -> Body SOACS -> ADM (Map VName VName)
copyConsumedArrsInBody [VName]
dontCopy Body SOACS
b =
[Map VName VName] -> Map VName VName
forall a. Monoid a => [a] -> a
mconcat ([Map VName VName] -> Map VName VName)
-> ADM [Map VName VName] -> ADM (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> ADM (Map VName VName))
-> [VName] -> ADM [Map VName VName]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> ADM (Map VName VName)
forall (m :: * -> *).
MonadBuilder m =>
VName -> m (Map VName VName)
onConsumed ((VName -> Bool) -> [VName] -> [VName]
forall a. (a -> Bool) -> [a] -> [a]
filter (VName -> [VName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
dontCopy) ([VName] -> [VName]) -> [VName] -> [VName]
forall a b. (a -> b) -> a -> b
$ Names -> [VName]
namesToList (Names -> [VName]) -> Names -> [VName]
forall a b. (a -> b) -> a -> b
$ Body (Aliases SOACS) -> Names
forall rep. Aliased rep => Body rep -> Names
consumedInBody (AliasTable -> Body SOACS -> Body (Aliases SOACS)
forall rep.
(ASTRep rep, CanBeAliased (Op rep)) =>
AliasTable -> Body rep -> Body (Aliases rep)
Alias.analyseBody AliasTable
forall a. Monoid a => a
mempty Body SOACS
b))
where
onConsumed :: VName -> m (Map VName VName)
onConsumed VName
v = do
Type
v_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
case Type
v_t of
Acc {} -> [Char] -> m (Map VName VName)
forall a. HasCallStack => [Char] -> a
error ([Char] -> m (Map VName VName)) -> [Char] -> m (Map VName VName)
forall a b. (a -> b) -> a -> b
$ [Char]
"copyConsumedArrsInBody: Acc " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> VName -> [Char]
forall a. Pretty a => a -> [Char]
pretty VName
v
Array {} -> VName -> VName -> Map VName VName
forall k a. k -> a -> Map k a
M.singleton VName
v (VName -> Map VName VName) -> m VName -> m (Map VName VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_ad_copy") (BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v)
Type
_ -> Map VName VName -> m (Map VName VName)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map VName VName
forall a. Monoid a => a
mempty
returnSweepCode :: ADM a -> ADM a
returnSweepCode :: ADM a -> ADM a
returnSweepCode ADM a
m = do
(a
a, Stms SOACS
stms) <- ADM a -> ADM (a, Stms (Rep ADM))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms ADM a
m
Map VName VName
substs <- (RState -> Map VName VName) -> ADM (Map VName VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName VName
stateSubsts
Stms (Rep ADM) -> ADM ()
forall (m :: * -> *). MonadBuilder m => Stms (Rep m) -> m ()
addStms (Stms (Rep ADM) -> ADM ()) -> Stms (Rep ADM) -> ADM ()
forall a b. (a -> b) -> a -> b
$ Map VName VName -> Stms SOACS -> Stms SOACS
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs Stms SOACS
stms
a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a
addSubstitution :: VName -> VName -> ADM ()
addSubstitution :: VName -> VName -> ADM ()
addSubstitution VName
v VName
v' = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateSubsts :: Map VName VName
stateSubsts = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
v' (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ RState -> Map VName VName
stateSubsts RState
env}
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor :: Names -> ADM a -> ADM a
noAdjsFor Names
names ADM a
m = do
[Adj]
old <- (RState -> [Adj]) -> ADM [Adj]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> [Adj]) -> ADM [Adj]) -> (RState -> [Adj]) -> ADM [Adj]
forall a b. (a -> b) -> a -> b
$ \RState
env -> (VName -> Maybe Adj) -> [VName] -> [Adj]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = (Map VName Adj -> VName -> Map VName Adj)
-> Map VName Adj -> [VName] -> Map VName Adj
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((VName -> Map VName Adj -> Map VName Adj)
-> Map VName Adj -> VName -> Map VName Adj
forall a b c. (a -> b -> c) -> b -> a -> c
flip VName -> Map VName Adj -> Map VName Adj
forall k a. Ord k => k -> Map k a -> Map k a
M.delete) (RState -> Map VName Adj
stateAdjs RState
env) [VName]
names'}
a
x <- ADM a
m
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env -> RState
env {stateAdjs :: Map VName Adj
stateAdjs = [(VName, Adj)] -> Map VName Adj
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([VName] -> [Adj] -> [(VName, Adj)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names' [Adj]
old) Map VName Adj -> Map VName Adj -> Map VName Adj
forall a. Semigroup a => a -> a -> a
<> RState -> Map VName Adj
stateAdjs RState
env}
a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
where
names' :: [VName]
names' = Names -> [VName]
namesToList Names
names
addBinOp :: PrimType -> BinOp
addBinOp :: PrimType -> BinOp
addBinOp (IntType IntType
it) = IntType -> Overflow -> BinOp
Add IntType
it Overflow
OverflowWrap
addBinOp (FloatType FloatType
ft) = FloatType -> BinOp
FAdd FloatType
ft
addBinOp PrimType
Bool = BinOp
LogAnd
addBinOp PrimType
Unit = BinOp
LogAnd
tabNest :: Int -> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest :: Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest = [VName]
-> Int
-> [VName]
-> ([VName] -> [VName] -> ADM [VName])
-> ADM [VName]
forall t (m :: * -> *).
(Eq t, MonadBuilder m, Num t, LParamInfo (Rep m) ~ Type,
Op (Rep m) ~ SOAC (Rep m), BodyDec (Rep m) ~ ()) =>
[VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' []
where
tabNest' :: [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' [VName]
is t
0 [VName]
vs [VName] -> [VName] -> m [VName]
f = [VName] -> [VName] -> m [VName]
f ([VName] -> [VName]
forall a. [a] -> [a]
reverse [VName]
is) [VName]
vs
tabNest' [VName]
is t
n [VName]
vs [VName] -> [VName] -> m [VName]
f = do
[Type]
vs_ts <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
vs
let w :: SubExp
w = Int -> [Type] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 [Type]
vs_ts
VName
iota <-
[Char] -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"tab_iota" (Exp (Rep m) -> m VName)
-> (BasicOp -> Exp (Rep m)) -> BasicOp -> m VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> m VName) -> BasicOp -> m VName
forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
w (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0) (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
1) IntType
Int64
Param Type
iparam <- [Char] -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"i" (Type -> m (Param Type)) -> Type -> m (Param Type)
forall a b. (a -> b) -> a -> b
$ PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64
[Param Type]
params <- [VName] -> (VName -> m (Param Type)) -> m [Param Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [VName]
vs ((VName -> m (Param Type)) -> m [Param Type])
-> (VName -> m (Param Type)) -> m [Param Type]
forall a b. (a -> b) -> a -> b
$ \VName
v ->
[Char] -> Type -> m (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_p") (Type -> m (Param Type))
-> (Type -> Type) -> Type -> m (Param Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType (Type -> m (Param Type)) -> m Type -> m (Param Type)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
(([Type]
ret, Result
res), Stms (Rep m)
stms) <- m ([Type], Result) -> m (([Type], Result), Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m ([Type], Result) -> m (([Type], Result), Stms (Rep m)))
-> (m ([Type], Result) -> m ([Type], Result))
-> m ([Type], Result)
-> m (([Type], Result), Stms (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m) -> m ([Type], Result) -> m ([Type], Result)
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([Param Type] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams (Param Type
iparam Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
params)) (m ([Type], Result) -> m (([Type], Result), Stms (Rep m)))
-> m ([Type], Result) -> m (([Type], Result), Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
[VName]
res <- [VName]
-> t -> [VName] -> ([VName] -> [VName] -> m [VName]) -> m [VName]
tabNest' (Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
iparam VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
is) (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) ((Param Type -> VName) -> [Param Type] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param Type -> VName
forall dec. Param dec -> VName
paramName [Param Type]
params) [VName] -> [VName] -> m [VName]
f
[Type]
ret <- (VName -> m Type) -> [VName] -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
res
([Type], Result) -> m ([Type], Result)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Type]
ret, [VName] -> Result
varsRes [VName]
res)
let lam :: Lambda (Rep m)
lam = [LParam (Rep m)] -> Body (Rep m) -> [Type] -> Lambda (Rep m)
forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda (Param Type
iparam Param Type -> [Param Type] -> [Param Type]
forall a. a -> [a] -> [a]
: [Param Type]
params) (BodyDec (Rep m) -> Stms (Rep m) -> Result -> Body (Rep m)
forall rep. BodyDec rep -> Stms rep -> Result -> Body rep
Body () Stms (Rep m)
stms Result
res) [Type]
ret
[Char] -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"tab" (Exp (Rep m) -> m [VName]) -> Exp (Rep m) -> m [VName]
forall a b. (a -> b) -> a -> b
$ Op (Rep m) -> Exp (Rep m)
forall rep. Op rep -> Exp rep
Op (Op (Rep m) -> Exp (Rep m)) -> Op (Rep m) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm (Rep m) -> SOAC (Rep m)
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma SubExp
w (VName
iota VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
vs) (Lambda (Rep m) -> ScremaForm (Rep m)
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda (Rep m)
lam)
addLambda :: Type -> ADM (Lambda SOACS)
addLambda :: Type -> ADM (Lambda SOACS)
addLambda (Prim PrimType
pt) = BinOp -> PrimType -> ADM (Lambda (Rep ADM))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda (PrimType -> BinOp
addBinOp PrimType
pt) PrimType
pt
addLambda t :: Type
t@Array {} = do
Param Type
xs_p <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"xs" Type
t
Param Type
ys_p <- [Char] -> Type -> ADM (Param Type)
forall (m :: * -> *) dec.
MonadFreshNames m =>
[Char] -> dec -> m (Param dec)
newParam [Char]
"ys" Type
t
Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
Body SOACS
body <- ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM (ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM)))
-> ADM (Body (Rep ADM)) -> ADM (Body (Rep ADM))
forall a b. (a -> b) -> a -> b
$ do
SubExp
res <-
[Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"lam_map" (Exp SOACS -> ADM SubExp)
-> (SOAC SOACS -> Exp SOACS) -> SOAC SOACS -> ADM SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (SOAC SOACS -> ADM SubExp) -> SOAC SOACS -> ADM SubExp
forall a b. (a -> b) -> a -> b
$
SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t) [Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
xs_p, Param Type -> VName
forall dec. Param dec -> VName
paramName Param Type
ys_p] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
Body SOACS -> ADM (Body SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body SOACS -> ADM (Body SOACS)) -> Body SOACS -> ADM (Body SOACS)
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Body SOACS
forall rep. Buildable rep => [SubExp] -> Body rep
resultBody [SubExp
res]
Lambda SOACS -> ADM (Lambda SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
Lambda :: forall rep. [LParam rep] -> Body rep -> [Type] -> Lambda rep
Lambda
{ lambdaParams :: [LParam SOACS]
lambdaParams = [Param Type
LParam SOACS
xs_p, Param Type
LParam SOACS
ys_p],
lambdaReturnType :: [Type]
lambdaReturnType = [Type
t],
lambdaBody :: Body SOACS
lambdaBody = Body SOACS
body
}
addLambda Type
t =
[Char] -> ADM (Lambda SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Lambda SOACS)) -> [Char] -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addLambda: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Show a => a -> [Char]
show Type
t
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp :: VName -> VName -> ADM (Exp SOACS)
addExp VName
x VName
y = do
Type
x_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
case Type
x_t of
Prim PrimType
pt ->
Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (PrimType -> BinOp
addBinOp PrimType
pt) (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
Array {} -> do
Lambda SOACS
lam <- Type -> ADM (Lambda SOACS)
addLambda (Type -> ADM (Lambda SOACS)) -> Type -> ADM (Lambda SOACS)
forall a b. (a -> b) -> a -> b
$ Type -> Type
forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
x_t
Exp SOACS -> ADM (Exp SOACS)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp SOACS -> ADM (Exp SOACS)) -> Exp SOACS -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ Op SOACS -> Exp SOACS
forall rep. Op rep -> Exp rep
Op (Op SOACS -> Exp SOACS) -> Op SOACS -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> [VName] -> ScremaForm SOACS -> SOAC SOACS
forall rep. SubExp -> [VName] -> ScremaForm rep -> SOAC rep
Screma (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
x_t) [VName
x, VName
y] (Lambda SOACS -> ScremaForm SOACS
forall rep. Lambda rep -> ScremaForm rep
mapSOAC Lambda SOACS
lam)
Type
_ ->
[Char] -> ADM (Exp SOACS)
forall a. HasCallStack => [Char] -> a
error ([Char] -> ADM (Exp SOACS)) -> [Char] -> ADM (Exp SOACS)
forall a b. (a -> b) -> a -> b
$ [Char]
"addExp: unexpected type: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Type -> [Char]
forall a. Pretty a => a -> [Char]
pretty Type
x_t
lookupAdj :: VName -> ADM Adj
lookupAdj :: VName -> ADM Adj
lookupAdj VName
v = do
Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing -> do
Type
v_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
case Type
v_t of
Acc VName
_ Shape
shape [Prim PrimType
t] NoUniqueness
_ -> Adj -> ADM Adj
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero Shape
shape PrimType
t
Type
_ -> Adj -> ADM Adj
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Adj -> ADM Adj) -> Adj -> ADM Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> Adj
AdjZero (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
v_t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
v_t)
Just Adj
v_adj -> Adj -> ADM Adj
forall (f :: * -> *) a. Applicative f => a -> f a
pure Adj
v_adj
lookupAdjVal :: VName -> ADM VName
lookupAdjVal :: VName -> ADM VName
lookupAdjVal VName
v = Adj -> ADM VName
adjVal (Adj -> ADM VName) -> ADM Adj -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM Adj
lookupAdj VName
v
updateAdj :: VName -> VName -> ADM ()
updateAdj :: VName -> VName -> ADM ()
updateAdj VName
v VName
d = do
Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing ->
VName -> VName -> ADM ()
insAdj VName
v VName
d
Just Adj
adj -> do
VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
case Type
v_adj_t of
Acc {} -> do
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ADM Type -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
d
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] ->
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' ((VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
d']
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
Type
_ -> do
VName
v_adj' <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_adj") (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adj VName
d
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice :: Slice SubExp -> VName -> VName -> ADM ()
updateAdjSlice (Slice [DimFix SubExp
i]) VName
v VName
d =
VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
AssumeBounds, SubExp
i) (VName -> SubExp
Var VName
d)
updateAdjSlice Slice SubExp
slice VName
v VName
d = do
Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
VName
v_adj <- VName -> ADM VName
lookupAdjVal VName
v
Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
VName
v_adj' <- case Type
v_adj_t of
Acc {} -> do
let dims :: [SubExp]
dims = Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
slice
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
d, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
d', VName
v_adj'] -> do
[SubExp]
slice' <-
(TPrimExp Int64 VName -> ADM SubExp)
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ([Char] -> TPrimExp Int64 VName -> ADM SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
[Char] -> a -> m SubExp
toSubExp [Char]
"index") ([TPrimExp Int64 VName] -> ADM [SubExp])
-> [TPrimExp Int64 VName] -> ADM [SubExp]
forall a b. (a -> b) -> a -> b
$
Slice (TPrimExp Int64 VName)
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall d. Num d => Slice d -> [d] -> [d]
fixSlice ((SubExp -> TPrimExp Int64 VName)
-> Slice SubExp -> Slice (TPrimExp Int64 VName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SubExp -> TPrimExp Int64 VName
pe64 Slice SubExp
slice) ([TPrimExp Int64 VName] -> [TPrimExp Int64 VName])
-> [TPrimExp Int64 VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> a -> b
$
(VName -> TPrimExp Int64 VName)
-> [VName] -> [TPrimExp Int64 VName]
forall a b. (a -> b) -> [a] -> [b]
map VName -> TPrimExp Int64 VName
forall a. a -> TPrimExp Int64 a
le64 [VName]
is
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp (VName -> [Char]
baseString VName
v_adj') (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' [SubExp]
slice' [VName -> SubExp
Var VName
d']
VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
Type
_ -> do
VName
v_adjslice <-
if Type -> Bool
forall shape u. TypeBase shape u -> Bool
primType Type
t
then VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
else [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"_slice") (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
v_adj Slice SubExp
slice
[Char] -> VName -> Slice SubExp -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace [Char]
"updated_adj" VName
v_adj Slice SubExp
slice (Exp SOACS -> ADM VName) -> ADM (Exp SOACS) -> ADM VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
v_adjslice VName
d
VName -> VName -> ADM ()
insAdj VName
v VName
v_adj'
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj :: SubExp -> VName -> ADM ()
updateSubExpAdj Constant {} VName
_ = () -> ADM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
updateSubExpAdj (Var VName
v) VName
d = ADM () -> ADM ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (ADM () -> ADM ()) -> ADM () -> ADM ()
forall a b. (a -> b) -> a -> b
$ VName -> VName -> ADM ()
updateAdj VName
v VName
d
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex :: VName -> (InBounds, SubExp) -> SubExp -> ADM ()
updateAdjIndex VName
v (InBounds
check, SubExp
i) SubExp
se = do
Maybe Adj
maybeAdj <- (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe Adj) -> ADM (Maybe Adj))
-> (RState -> Maybe Adj) -> ADM (Maybe Adj)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName Adj -> Maybe Adj
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName Adj -> Maybe Adj)
-> (RState -> Map VName Adj) -> RState -> Maybe Adj
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName Adj
stateAdjs
Type
t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v
let iv :: (InBounds, SubExp, SubExp)
iv = (InBounds
check, SubExp
i, SubExp
se)
case Maybe Adj
maybeAdj of
Maybe Adj
Nothing -> do
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) [(InBounds, SubExp, SubExp)
iv]
Just AdjZero {} ->
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse (Type -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
t) (Type -> PrimType
forall shape u. TypeBase shape u -> PrimType
elemType Type
t) [(InBounds, SubExp, SubExp)
iv]
Just (AdjSparse (Sparse Shape
shape PrimType
pt [(InBounds, SubExp, SubExp)]
ivs)) ->
VName -> Adj -> ADM ()
setAdj VName
v (Adj -> ADM ()) -> Adj -> ADM ()
forall a b. (a -> b) -> a -> b
$ Sparse -> Adj
AdjSparse (Sparse -> Adj) -> Sparse -> Adj
forall a b. (a -> b) -> a -> b
$ Shape -> PrimType -> [(InBounds, SubExp, SubExp)] -> Sparse
Sparse Shape
shape PrimType
pt ([(InBounds, SubExp, SubExp)] -> Sparse)
-> [(InBounds, SubExp, SubExp)] -> Sparse
forall a b. (a -> b) -> a -> b
$ (InBounds, SubExp, SubExp)
iv (InBounds, SubExp, SubExp)
-> [(InBounds, SubExp, SubExp)] -> [(InBounds, SubExp, SubExp)]
forall a. a -> [a] -> [a]
: [(InBounds, SubExp, SubExp)]
ivs
Just adj :: Adj
adj@AdjVal {} -> do
VName
v_adj <- Adj -> ADM VName
adjVal Adj
adj
Type
v_adj_t <- VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
v_adj
VName
se_v <- [Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
"se_v" (Exp (Rep ADM) -> ADM VName) -> Exp (Rep ADM) -> ADM VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp SOACS) -> BasicOp -> Exp SOACS
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
VName -> VName -> ADM ()
insAdj VName
v
(VName -> ADM ()) -> ADM VName -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< case Type
v_adj_t of
Acc {}
| InBounds
check InBounds -> InBounds -> Bool
forall a. Eq a => a -> a -> Bool
== InBounds
OutOfBounds ->
VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
| Bool
otherwise -> do
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> ADM Type -> ADM [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
se_v
~[VName
v_adj'] <-
Int
-> [VName] -> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
tabNest ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims) [VName
se_v, VName
v_adj] (([VName] -> [VName] -> ADM [VName]) -> ADM [VName])
-> ([VName] -> [VName] -> ADM [VName]) -> ADM [VName]
forall a b. (a -> b) -> a -> b
$ \[VName]
is [VName
se_v', VName
v_adj'] ->
[Char] -> Exp (Rep ADM) -> ADM [VName]
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m [VName]
letTupExp [Char]
"acc" (Exp SOACS -> ADM [VName])
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM [VName]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM [VName]) -> BasicOp -> ADM [VName]
forall a b. (a -> b) -> a -> b
$
VName -> [SubExp] -> [SubExp] -> BasicOp
UpdateAcc VName
v_adj' (SubExp
i SubExp -> [SubExp] -> [SubExp]
forall a. a -> [a] -> [a]
: (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
is) [VName -> SubExp
Var VName
se_v']
VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj'
Type
_ -> do
let stms :: Safety -> ADM VName
stms Safety
s = do
VName
v_adj_i <-
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_i") (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
VName -> Slice SubExp -> BasicOp
Index VName
v_adj (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]
SubExp
se_update <- [Char] -> Exp (Rep ADM) -> ADM SubExp
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"updated_adj_i" (Exp SOACS -> ADM SubExp) -> ADM (Exp SOACS) -> ADM SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> VName -> ADM (Exp SOACS)
addExp VName
se_v VName
v_adj_i
[Char] -> Exp (Rep ADM) -> ADM VName
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp (VName -> [Char]
baseString VName
v_adj) (Exp SOACS -> ADM VName)
-> (BasicOp -> Exp SOACS) -> BasicOp -> ADM VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp SOACS
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> ADM VName) -> BasicOp -> ADM VName
forall a b. (a -> b) -> a -> b
$
Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
s VName
v_adj (Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
v_adj_t [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
se_update
case InBounds
check of
CheckBounds Maybe SubExp
_ -> Safety -> ADM VName
stms Safety
Safe
InBounds
AssumeBounds -> Safety -> ADM VName
stms Safety
Unsafe
InBounds
OutOfBounds -> VName -> ADM VName
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v_adj
isActive :: VName -> ADM Bool
isActive :: VName -> ADM Bool
isActive = (Type -> Bool) -> ADM Type -> ADM Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
Unit) (ADM Type -> ADM Bool) -> (VName -> ADM Type) -> VName -> ADM Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> ADM Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType
subAD :: ADM a -> ADM a
subAD :: ADM a -> ADM a
subAD ADM a
m = do
Map VName Adj
old_state_adjs <- (RState -> Map VName Adj) -> ADM (Map VName Adj)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName Adj
stateAdjs
a
x <- ADM a
m
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateAdjs :: Map VName Adj
stateAdjs = Map VName Adj
old_state_adjs}
a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
subSubsts :: ADM a -> ADM a
subSubsts :: ADM a -> ADM a
subSubsts ADM a
m = do
Map VName VName
old_state_substs <- (RState -> Map VName VName) -> ADM (Map VName VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets RState -> Map VName VName
stateSubsts
a
x <- ADM a
m
(RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
s -> RState
s {stateSubsts :: Map VName VName
stateSubsts = Map VName VName
old_state_substs}
a -> ADM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
x
data VjpOps = VjpOps
{ VjpOps -> [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS)
vjpLambda :: [Adj] -> [VName] -> Lambda SOACS -> ADM (Lambda SOACS),
VjpOps -> Stm SOACS -> ADM () -> ADM ()
vjpStm :: Stm SOACS -> ADM () -> ADM ()
}
setLoopTape :: VName -> VName -> ADM ()
setLoopTape :: VName -> VName -> ADM ()
setLoopTape VName
v VName
vs = (RState -> RState) -> ADM ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((RState -> RState) -> ADM ()) -> (RState -> RState) -> ADM ()
forall a b. (a -> b) -> a -> b
$ \RState
env ->
RState
env {stateLoopTape :: Map VName VName
stateLoopTape = VName -> VName -> Map VName VName -> Map VName VName
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VName
v VName
vs (Map VName VName -> Map VName VName)
-> Map VName VName -> Map VName VName
forall a b. (a -> b) -> a -> b
$ RState -> Map VName VName
stateLoopTape RState
env}
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape :: VName -> ADM (Maybe VName)
lookupLoopTape VName
v = (RState -> Maybe VName) -> ADM (Maybe VName)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((RState -> Maybe VName) -> ADM (Maybe VName))
-> (RState -> Maybe VName) -> ADM (Maybe VName)
forall a b. (a -> b) -> a -> b
$ VName -> Map VName VName -> Maybe VName
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v (Map VName VName -> Maybe VName)
-> (RState -> Map VName VName) -> RState -> Maybe VName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. RState -> Map VName VName
stateLoopTape
substLoopTape :: VName -> VName -> ADM ()
substLoopTape :: VName -> VName -> ADM ()
substLoopTape VName
v VName
v' = (VName -> ADM ()) -> Maybe VName -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (VName -> VName -> ADM ()
setLoopTape VName
v') (Maybe VName -> ADM ()) -> ADM (Maybe VName) -> ADM ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< VName -> ADM (Maybe VName)
lookupLoopTape VName
v
renameLoopTape :: Substitutions -> ADM ()
renameLoopTape :: Map VName VName -> ADM ()
renameLoopTape = ((VName, VName) -> ADM ()) -> [(VName, VName)] -> ADM ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ((VName -> VName -> ADM ()) -> (VName, VName) -> ADM ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VName -> VName -> ADM ()
substLoopTape) ([(VName, VName)] -> ADM ())
-> (Map VName VName -> [(VName, VName)])
-> Map VName VName
-> ADM ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map VName VName -> [(VName, VName)]
forall k a. Map k a -> [(k, a)]
M.toList