{-# LANGUAGE FlexibleContexts #-}

-- | This module implements facilities for determining whether a
-- reduction or fold can be expressed in a closed form (i.e. not as a
-- SOAC).
--
-- Right now, the module can detect only trivial cases.  In the
-- future, we would like to make it more powerful, as well as possibly
-- also being able to analyse sequential loops.
module Futhark.Optimise.Simplify.Rules.ClosedForm
  ( foldClosedForm,
    loopClosedForm,
  )
where

import Control.Monad
import qualified Data.Map.Strict as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Simple (VarLookup)
import Futhark.Transform.Rename

{-
Motivation:

  let {*[int,x_size_27] map_computed_shape_1286} = replicate(x_size_27,
                                                             all_equal_shape_1044) in
  let {*[bool,x_size_27] map_size_checks_1292} = replicate(x_size_27, x_1291) in
  let {bool all_equal_checked_1298, int all_equal_shape_1299} =
    reduceT(fn {bool, int} (bool bacc_1293, int nacc_1294, bool belm_1295,
                            int nelm_1296) =>
              let {bool tuplit_elems_1297} = bacc_1293 && belm_1295 in
              {tuplit_elems_1297, nelm_1296},
            {True, 0}, map_size_checks_1292, map_computed_shape_1286)
-}

-- | @foldClosedForm look foldfun accargs arrargs@ determines whether
-- each of the results of @foldfun@ can be expressed in a closed form.
foldClosedForm ::
  (ASTLore lore, BinderOps lore) =>
  VarLookup lore ->
  Pattern lore ->
  Lambda lore ->
  [SubExp] ->
  [VName] ->
  RuleM lore ()
foldClosedForm :: forall lore.
(ASTLore lore, BinderOps lore) =>
VarLookup lore
-> Pattern lore
-> Lambda lore
-> [SubExp]
-> [VName]
-> RuleM lore ()
foldClosedForm VarLookup lore
look Pattern lore
pat Lambda lore
lam [SubExp]
accs [VName]
arrs = do
  SubExp
inputsize <- Int -> [TypeBase Shape NoUniqueness] -> SubExp
forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
0 ([TypeBase Shape NoUniqueness] -> SubExp)
-> RuleM lore [TypeBase Shape NoUniqueness] -> RuleM lore SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> RuleM lore (TypeBase Shape NoUniqueness))
-> [VName] -> RuleM lore [TypeBase Shape NoUniqueness]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM VName -> RuleM lore (TypeBase Shape NoUniqueness)
forall lore (m :: * -> *).
HasScope lore m =>
VName -> m (TypeBase Shape NoUniqueness)
lookupType [VName]
arrs

  PrimType
t <- case Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat of
    [Prim PrimType
t] -> PrimType -> RuleM lore PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
    [TypeBase Shape NoUniqueness]
_ -> RuleM lore PrimType
forall lore a. RuleM lore a
cannotSimplify

  Body lore
closedBody <-
    [VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults
      (Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat)
      SubExp
inputsize
      Names
forall a. Monoid a => a
mempty
      IntType
Int64
      Map VName SubExp
knownBnds
      ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName (Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam))
      (Lambda lore -> Body lore
forall lore. LambdaT lore -> BodyT lore
lambdaBody Lambda lore
lam)
      [SubExp]
accs
  VName
isEmpty <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"fold_input_is_empty"
  [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
isEmpty] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
inputsize (IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0)
  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat
    (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If (VName -> SubExp
Var VName
isEmpty)
            (Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
accs
            RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
closedBody
            RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType lore
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal)
        )
  where
    knownBnds :: Map VName SubExp
knownBnds = VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
forall lore.
VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup lore
look Lambda lore
lam [SubExp]
accs [VName]
arrs

-- | @loopClosedForm pat respat merge bound bodys@ determines whether
-- the do-loop can be expressed in a closed form.
loopClosedForm ::
  (ASTLore lore, BinderOps lore) =>
  Pattern lore ->
  [(FParam lore, SubExp)] ->
  Names ->
  IntType ->
  SubExp ->
  Body lore ->
  RuleM lore ()
loopClosedForm :: forall lore.
(ASTLore lore, BinderOps lore) =>
Pattern lore
-> [(FParam lore, SubExp)]
-> Names
-> IntType
-> SubExp
-> Body lore
-> RuleM lore ()
loopClosedForm Pattern lore
pat [(FParam lore, SubExp)]
merge Names
i IntType
it SubExp
bound Body lore
body = do
  PrimType
t <- case Pattern lore -> [TypeBase Shape NoUniqueness]
forall dec.
Typed dec =>
PatternT dec -> [TypeBase Shape NoUniqueness]
patternTypes Pattern lore
pat of
    [Prim PrimType
t] -> PrimType -> RuleM lore PrimType
forall (m :: * -> *) a. Monad m => a -> m a
return PrimType
t
    [TypeBase Shape NoUniqueness]
_ -> RuleM lore PrimType
forall lore a. RuleM lore a
cannotSimplify

  Body lore
closedBody <-
    [VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults
      [VName]
mergenames
      SubExp
bound
      Names
i
      IntType
it
      Map VName SubExp
knownBnds
      ((Ident -> VName) -> [Ident] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Ident -> VName
identName [Ident]
mergeidents)
      Body lore
body
      [SubExp]
mergeexp
  VName
isEmpty <- String -> RuleM lore VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"bound_is_zero"
  [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
isEmpty] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$
    BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
it) SubExp
bound (IntType -> Integer -> SubExp
intConst IntType
it Integer
0)

  Pattern (Lore (RuleM lore))
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
Pattern (Lore m) -> Exp (Lore m) -> m ()
letBind Pattern lore
Pattern (Lore (RuleM lore))
pat
    (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ( SubExp
-> Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If (VName -> SubExp
Var VName
isEmpty)
            (Body lore -> Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
[SubExp] -> m (Body (Lore m))
resultBodyM [SubExp]
mergeexp
            RuleM lore (Body lore -> IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (Body lore)
-> RuleM lore (IfDec (BranchType lore) -> ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body lore -> RuleM lore (Body lore)
forall lore (m :: * -> *).
(Renameable lore, MonadFreshNames m) =>
Body lore -> m (Body lore)
renameBody Body lore
closedBody
            RuleM lore (IfDec (BranchType lore) -> ExpT lore)
-> RuleM lore (IfDec (BranchType lore)) -> RuleM lore (ExpT lore)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> IfDec (BranchType lore) -> RuleM lore (IfDec (BranchType lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([BranchType lore] -> IfSort -> IfDec (BranchType lore)
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType lore
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
t] IfSort
IfNormal)
        )
  where
    ([FParam lore]
mergepat, [SubExp]
mergeexp) = [(FParam lore, SubExp)] -> ([FParam lore], [SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip [(FParam lore, SubExp)]
merge
    mergeidents :: [Ident]
mergeidents = (FParam lore -> Ident) -> [FParam lore] -> [Ident]
forall a b. (a -> b) -> [a] -> [b]
map FParam lore -> Ident
forall dec. Typed dec => Param dec -> Ident
paramIdent [FParam lore]
mergepat
    mergenames :: [VName]
mergenames = (FParam lore -> VName) -> [FParam lore] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map FParam lore -> VName
forall dec. Param dec -> VName
paramName [FParam lore]
mergepat
    knownBnds :: Map VName SubExp
knownBnds = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
mergenames [SubExp]
mergeexp

checkResults ::
  BinderOps lore =>
  [VName] ->
  SubExp ->
  Names ->
  IntType ->
  M.Map VName SubExp ->
  -- | Lambda-bound
  [VName] ->
  Body lore ->
  [SubExp] ->
  RuleM lore (Body lore)
checkResults :: forall lore.
BinderOps lore =>
[VName]
-> SubExp
-> Names
-> IntType
-> Map VName SubExp
-> [VName]
-> Body lore
-> [SubExp]
-> RuleM lore (Body lore)
checkResults [VName]
pat SubExp
size Names
untouchable IntType
it Map VName SubExp
knownBnds [VName]
params Body lore
body [SubExp]
accs = do
  ((), Stms lore
bnds) <-
    RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore)))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore))))
-> RuleM lore () -> RuleM lore ((), Stms (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$
      ((VName, SubExp) -> (VName, SubExp) -> RuleM lore ())
-> [(VName, SubExp)] -> [(VName, SubExp)] -> RuleM lore ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
forall {lore}.
BinderOps lore =>
(VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
checkResult ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
pat [SubExp]
res) ([VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
accparams [SubExp]
accs)
  Stms (Lore (RuleM lore))
-> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms lore
Stms (Lore (RuleM lore))
bnds ([SubExp] -> RuleM lore (Body (Lore (RuleM lore))))
-> [SubExp] -> RuleM lore (Body (Lore (RuleM lore)))
forall a b. (a -> b) -> a -> b
$ (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var [VName]
pat
  where
    bndMap :: Map VName (Exp lore)
bndMap = Body lore -> Map VName (Exp lore)
forall lore. Body lore -> Map VName (Exp lore)
makeBindMap Body lore
body
    ([VName]
accparams, [VName]
_) = Int -> [VName] -> ([VName], [VName])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) [VName]
params
    res :: [SubExp]
res = Body lore -> [SubExp]
forall lore. BodyT lore -> [SubExp]
bodyResult Body lore
body

    nonFree :: Names
nonFree = Body lore -> Names
forall lore. Body lore -> Names
boundInBody Body lore
body Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> [VName] -> Names
namesFromList [VName]
params Names -> Names -> Names
forall a. Semigroup a => a -> a -> a
<> Names
untouchable

    checkResult :: (VName, SubExp) -> (VName, SubExp) -> RuleM lore ()
checkResult (VName
p, Var VName
v) (VName
accparam, SubExp
acc)
      | Just (BasicOp (BinOp BinOp
bop SubExp
x SubExp
y)) <- VName -> Map VName (Exp lore) -> Maybe (Exp lore)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName (Exp lore)
bndMap = do
        -- One of x,y must be *this* accumulator, and the other must
        -- be something that is free in the body.
        let isThisAccum :: SubExp -> Bool
isThisAccum = (SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== VName -> SubExp
Var VName
accparam)
        (SubExp
this, SubExp
el) <- Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp)
forall a lore. Maybe a -> RuleM lore a
liftMaybe (Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp))
-> Maybe (SubExp, SubExp) -> RuleM lore (SubExp, SubExp)
forall a b. (a -> b) -> a -> b
$
          case ( (SubExp -> Maybe SubExp
asFreeSubExp SubExp
x, SubExp -> Bool
isThisAccum SubExp
y),
                 (SubExp -> Maybe SubExp
asFreeSubExp SubExp
y, SubExp -> Bool
isThisAccum SubExp
x)
               ) of
            ((Just SubExp
free, Bool
True), (Maybe SubExp, Bool)
_) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
            ((Maybe SubExp, Bool)
_, (Just SubExp
free, Bool
True)) -> (SubExp, SubExp) -> Maybe (SubExp, SubExp)
forall a. a -> Maybe a
Just (SubExp
acc, SubExp
free)
            ((Maybe SubExp, Bool), (Maybe SubExp, Bool))
_ -> Maybe (SubExp, SubExp)
forall a. Maybe a
Nothing

        case BinOp
bop of
          BinOp
LogAnd ->
            [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p] (Exp (Lore (RuleM lore)) -> RuleM lore ())
-> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
this SubExp
el
          Add IntType
t Overflow
w | Just RuleM lore SubExp
properly_typed_size <- IntType -> Maybe (RuleM lore SubExp)
forall {m :: * -> *}. MonadBinder m => IntType -> Maybe (m SubExp)
properIntSize IntType
t -> do
            SubExp
size' <- RuleM lore SubExp
properly_typed_size
            [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p]
              (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
                (IntType -> Overflow -> BinOp
Add IntType
t Overflow
w)
                (SubExp -> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
this)
                (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
t Overflow
w) SubExp
el SubExp
size')
          FAdd FloatType
t | Just RuleM lore SubExp
properly_typed_size <- FloatType -> Maybe (RuleM lore SubExp)
forall {m :: * -> *}.
MonadBinder m =>
FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t -> do
            SubExp
size' <- RuleM lore SubExp
properly_typed_size
            [VName] -> Exp (Lore (RuleM lore)) -> RuleM lore ()
forall (m :: * -> *).
MonadBinder m =>
[VName] -> Exp (Lore m) -> m ()
letBindNames [VName
p]
              (ExpT lore -> RuleM lore ())
-> RuleM lore (ExpT lore) -> RuleM lore ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< BinOp
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
-> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *).
MonadBinder m =>
BinOp -> m (Exp (Lore m)) -> m (Exp (Lore m)) -> m (Exp (Lore m))
eBinOp
                (FloatType -> BinOp
FAdd FloatType
t)
                (SubExp -> RuleM lore (Exp (Lore (RuleM lore)))
forall (m :: * -> *). MonadBinder m => SubExp -> m (Exp (Lore m))
eSubExp SubExp
this)
                (ExpT lore -> RuleM lore (ExpT lore)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ExpT lore -> RuleM lore (ExpT lore))
-> ExpT lore -> RuleM lore (ExpT lore)
forall a b. (a -> b) -> a -> b
$ BasicOp -> ExpT lore
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> ExpT lore) -> BasicOp -> ExpT lore
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (FloatType -> BinOp
FMul FloatType
t) SubExp
el SubExp
size')
          BinOp
_ -> RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify -- Um... sorry.
    checkResult (VName, SubExp)
_ (VName, SubExp)
_ = RuleM lore ()
forall lore a. RuleM lore a
cannotSimplify

    asFreeSubExp :: SubExp -> Maybe SubExp
    asFreeSubExp :: SubExp -> Maybe SubExp
asFreeSubExp (Var VName
v)
      | VName
v VName -> Names -> Bool
`nameIn` Names
nonFree = VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
knownBnds
    asFreeSubExp SubExp
se = SubExp -> Maybe SubExp
forall a. a -> Maybe a
Just SubExp
se

    properIntSize :: IntType -> Maybe (m SubExp)
properIntSize IntType
Int64 = m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$ SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
size
    properIntSize IntType
t =
      m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"converted_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
SExt IntType
it IntType
t) SubExp
size

    properFloatSize :: FloatType -> Maybe (m SubExp)
properFloatSize FloatType
t =
      m SubExp -> Maybe (m SubExp)
forall a. a -> Maybe a
Just (m SubExp -> Maybe (m SubExp)) -> m SubExp -> Maybe (m SubExp)
forall a b. (a -> b) -> a -> b
$
        String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"converted_size" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
          BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> FloatType -> ConvOp
SIToFP IntType
it FloatType
t) SubExp
size

determineKnownBindings ::
  VarLookup lore ->
  Lambda lore ->
  [SubExp] ->
  [VName] ->
  M.Map VName SubExp
determineKnownBindings :: forall lore.
VarLookup lore
-> Lambda lore -> [SubExp] -> [VName] -> Map VName SubExp
determineKnownBindings VarLookup lore
look Lambda lore
lam [SubExp]
accs [VName]
arrs =
  Map VName SubExp
accBnds Map VName SubExp -> Map VName SubExp -> Map VName SubExp
forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arrBnds
  where
    ([Param (LParamInfo lore)]
accparams, [Param (LParamInfo lore)]
arrparams) =
      Int
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
accs) ([Param (LParamInfo lore)]
 -> ([Param (LParamInfo lore)], [Param (LParamInfo lore)]))
-> [Param (LParamInfo lore)]
-> ([Param (LParamInfo lore)], [Param (LParamInfo lore)])
forall a b. (a -> b) -> a -> b
$ Lambda lore -> [Param (LParamInfo lore)]
forall lore. LambdaT lore -> [LParam lore]
lambdaParams Lambda lore
lam
    accBnds :: Map VName SubExp
accBnds =
      [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
        [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
accparams) [SubExp]
accs
    arrBnds :: Map VName SubExp
arrBnds =
      [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$
        ((VName, VName) -> Maybe (VName, SubExp))
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (VName, VName) -> Maybe (VName, SubExp)
forall {a}. (a, VName) -> Maybe (a, SubExp)
isReplicate ([(VName, VName)] -> [(VName, SubExp)])
-> [(VName, VName)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$
          [VName] -> [VName] -> [(VName, VName)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param (LParamInfo lore) -> VName)
-> [Param (LParamInfo lore)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param (LParamInfo lore) -> VName
forall dec. Param dec -> VName
paramName [Param (LParamInfo lore)]
arrparams) [VName]
arrs

    isReplicate :: (a, VName) -> Maybe (a, SubExp)
isReplicate (a
p, VName
v)
      | Just (BasicOp (Replicate Shape
_ SubExp
ve), Certificates
cs) <- VarLookup lore
look VName
v,
        Certificates
cs Certificates -> Certificates -> Bool
forall a. Eq a => a -> a -> Bool
== Certificates
forall a. Monoid a => a
mempty =
        (a, SubExp) -> Maybe (a, SubExp)
forall a. a -> Maybe a
Just (a
p, SubExp
ve)
    isReplicate (a, VName)
_ = Maybe (a, SubExp)
forall a. Maybe a
Nothing

makeBindMap :: Body lore -> M.Map VName (Exp lore)
makeBindMap :: forall lore. Body lore -> Map VName (Exp lore)
makeBindMap = [(VName, Exp lore)] -> Map VName (Exp lore)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, Exp lore)] -> Map VName (Exp lore))
-> (BodyT lore -> [(VName, Exp lore)])
-> BodyT lore
-> Map VName (Exp lore)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Stm lore -> Maybe (VName, Exp lore))
-> [Stm lore] -> [(VName, Exp lore)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Stm lore -> Maybe (VName, Exp lore)
forall {lore}. Stm lore -> Maybe (VName, Exp lore)
isSingletonStm ([Stm lore] -> [(VName, Exp lore)])
-> (BodyT lore -> [Stm lore]) -> BodyT lore -> [(VName, Exp lore)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Stms lore -> [Stm lore]
forall lore. Stms lore -> [Stm lore]
stmsToList (Stms lore -> [Stm lore])
-> (BodyT lore -> Stms lore) -> BodyT lore -> [Stm lore]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BodyT lore -> Stms lore
forall lore. BodyT lore -> Stms lore
bodyStms
  where
    isSingletonStm :: Stm lore -> Maybe (VName, Exp lore)
isSingletonStm (Let Pattern lore
pat StmAux (ExpDec lore)
_ Exp lore
e) = case Pattern lore -> [VName]
forall dec. PatternT dec -> [VName]
patternNames Pattern lore
pat of
      [VName
v] -> (VName, Exp lore) -> Maybe (VName, Exp lore)
forall a. a -> Maybe a
Just (VName
v, Exp lore
e)
      [VName]
_ -> Maybe (VName, Exp lore)
forall a. Maybe a
Nothing