module Futhark.Optimise.Simplify.Rules.ClosedForm
( foldClosedForm,
loopClosedForm,
)
where
import Control.Monad
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Simple (VarLookup)
import Futhark.Transform.Rename
foldClosedForm ::
(BuilderOps rep) =>
VarLookup rep ->
Pat (LetDec rep) ->
Lambda rep ->
[SubExp] ->
[VName] ->
RuleM rep ()
foldClosedForm :: forall rep.
BuilderOps rep =>
VarLookup rep
-> Pat (LetDec rep)
-> Lambda rep
-> [SubExp]
-> [VName]
-> RuleM rep ()
foldClosedForm VarLookup rep
look Pat (LetDec rep)
pat Lambda rep
lam [SubExp]
accs [VName]
arrs = do
SubExp
inputsize <- forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType [VName]
arrs
PrimType
t <- case forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat of
[Prim PrimType
t] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
[Type]
_ -> forall rep a. RuleM rep a
cannotSimplify
Body rep
closedBody <-
forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
(forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat)
SubExp
inputsize
forall a. Monoid a => a
mempty
IntType
Int64
Map VName SubExp
knownBnds
(forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName (forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam))
(forall rep. Lambda rep -> Body rep
lambdaBody Lambda rep
lam)
[SubExp]
accs
VName
isEmpty <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [VName -> SubExp
Var VName
isEmpty]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
accs)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] MatchSort
MatchNormal)
)
where
knownBnds :: Map VName SubExp
knownBnds = forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs
loopClosedForm ::
(BuilderOps rep) =>
Pat (LetDec rep) ->
[(FParam rep, SubExp)] ->
Names ->
IntType ->
SubExp ->
Body rep ->
RuleM rep ()
loopClosedForm :: forall rep.
BuilderOps rep =>
Pat (LetDec rep)
-> [(FParam rep, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body rep
-> RuleM rep ()
loopClosedForm Pat (LetDec rep)
pat [(FParam rep, SubExp)]
merge Names
i IntType
it SubExp
bound Body rep
body = do
PrimType
t <- case forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat of
[Prim PrimType
t] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t
[Type]
_ -> forall rep a. RuleM rep a
cannotSimplify
Body rep
closedBody <-
forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults
[VName]
mergenames
SubExp
bound
Names
i
IntType
it
Map VName SubExp
knownBnds
(forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents)
Body rep
body
[SubExp]
mergeexp
VName
isEmpty <- forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
isEmpty] forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)
forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [VName -> SubExp
Var VName
isEmpty]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just (Bool -> PrimValue
BoolValue Bool
True)] forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
[SubExp] -> m (Body (Rep m))
resultBodyM [SubExp]
mergeexp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall rep (m :: * -> *).
(Renameable rep, MonadFreshNames m) =>
Body rep -> m (Body rep)
renameBody Body rep
closedBody
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] MatchSort
MatchNormal)
)
where
([FParam rep]
mergepat, [SubExp]
mergeexp) = forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam rep, SubExp)]
merge
mergeidents :: [Ident]
mergeidents = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam rep]
mergepat
mergenames :: [VName]
mergenames = forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [FParam rep]
mergepat
knownBnds :: Map VName SubExp
knownBnds = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp
checkResults ::
(BuilderOps rep) =>
[VName] ->
SubExp ->
Names ->
IntType ->
M.Map VName SubExp ->
[VName] ->
Body rep ->
[SubExp] ->
RuleM rep (Body rep)
checkResults :: forall rep.
BuilderOps rep =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body rep
-> [SubExp]
-> RuleM rep (Body rep)
checkResults [VName]
pat SubExp
size Names
untouchable IntType
it Map VName SubExp
knownBnds [VName]
params Body rep
body [SubExp]
accs = do
((), Stms rep
stms) <-
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ forall {rep}.
BuilderOps rep =>
(VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat Result
res) (forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms rep
stms forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes [VName]
pat
where
stmMap :: Map VName (Exp rep)
stmMap = forall rep. Body rep -> Map VName (Exp rep)
makeBindMap Body rep
body
([VName]
accparams, [VName]
_) = forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
res :: Result
res = forall rep. Body rep -> Result
bodyResult Body rep
body
nonFree :: Names
nonFree = forall rep. Body rep -> Names
boundInBody Body rep
body forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params forall a. Semigroup a => a -> a -> a
<> Names
untouchable
checkResult :: (VName, SubExpRes) -> (VName, SubExp) -> RuleM rep ()
checkResult (VName
p, SubExpRes Certs
_ (Var VName
v)) (VName
accparam, SubExp
acc)
| Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp rep)
stmMap,
SubExp
x forall a. Eq a => a -> a -> Bool
/= SubExp
y = do
let isThisAccum :: SubExp -> Bool
isThisAccum = (forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
accparam)
(SubExp
this, SubExp
el) <- forall a rep. Maybe a -> RuleM rep a
liftMaybe forall a b. (a -> b) -> a -> b
$
case ( (SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
(SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)
) of
((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_ -> forall a. Maybe a
Nothing
case BinOp
bop of
BinOp
LogAnd ->
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p] forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
Add IntType
t Overflow
w -> do
SubExp
size' <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
t SubExp
size
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(IntType -> Overflow -> BinOp
Add IntType
t Overflow
w)
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
FAdd FloatType
t | Just RuleM rep SubExp
properly_typed_size <- forall {m :: * -> *}.
MonadBuilder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
SubExp
size' <- RuleM rep SubExp
properly_typed_size
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName
p]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp
(FloatType -> BinOp
FAdd FloatType
t)
(forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp SubExp
this)
(forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
BinOp
_ -> forall rep a. RuleM rep a
cannotSimplify
checkResult (VName, SubExpRes)
_ (VName, SubExp)
_ = forall rep a. RuleM rep a
cannotSimplify
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
| VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
asFreeSubExp SubExp
se = forall a. a -> Maybe a
Just SubExp
se
properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"converted_size" forall a b. (a -> b) -> a -> b
$
forall rep. BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
it FloatType
t) SubExp
size
determineKnownBindings ::
VarLookup rep ->
Lambda rep ->
[SubExp] ->
[VName] ->
M.Map VName SubExp
determineKnownBindings :: forall rep.
VarLookup rep
-> Lambda rep -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup rep
look Lambda rep
lam [SubExp]
accs [VName]
arrs =
Map VName SubExp
accBnds forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
where
([Param (LParamInfo rep)]
accparams, [Param (LParamInfo rep)]
arrparams) =
forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) forall a b. (a -> b) -> a -> b
$ forall rep. Lambda rep -> [LParam rep]
lambdaParams Lambda rep
lam
accBnds :: Map VName SubExp
accBnds =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
accparams) [SubExp]
accs
arrBnds :: Map VName SubExp
arrBnds =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {a}. (a, VName) -> Maybe (a, SubExp)
isReplicate forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Param dec -> VName
paramName [Param (LParamInfo rep)]
arrparams) [VName]
arrs
isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
| Just (BasicOp (Replicate (Shape (SubExp
_ : [SubExp]
_)) SubExp
ve), Certs
cs) <- VarLookup rep
look VName
v,
Certs
cs forall a. Eq a => a -> a -> Bool
== forall a. Monoid a => a
mempty =
forall a. a -> Maybe a
Just (a
p, SubExp
ve)
isReplicate (a, VName)
_ = forall a. Maybe a
Nothing
makeBindMap :: Body rep -> M.Map VName (Exp rep)
makeBindMap :: forall rep. Body rep -> Map VName (Exp rep)
makeBindMap = forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {rep}. Stm rep -> Maybe (VName, Exp rep)
isSingletonStm forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Stms rep -> [Stm rep]
stmsToList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall rep. Body rep -> Stms rep
bodyStms
where
isSingletonStm :: Stm rep -> Maybe (VName, Exp rep)
isSingletonStm (Let Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ Exp rep
e) = case forall dec. Pat dec -> [VName]
patNames Pat (LetDec rep)
pat of
[VName
v] -> forall a. a -> Maybe a
Just (VName
v, Exp rep
e)
[VName]
_ -> forall a. Maybe a
Nothing