{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

-- | Some simplification rules for t'BasicOp'.
module Futhark.Optimise.Simplify.Rules.BasicOp
  ( basicOpRules,
  )
where

import Control.Monad
import Data.List (find, foldl', isSuffixOf, sort)
import Data.List.NonEmpty (NonEmpty (..))
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.IR.Prop.Aliases
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules.Loop
import Futhark.Optimise.Simplify.Rules.Simple

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

data ConcatArg
  = ArgArrayLit [SubExp]
  | ArgReplicate [SubExp] SubExp
  | ArgVar VName

toConcatArg :: ST.SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg :: forall {k} (rep :: k).
SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg SymbolTable rep
vtable VName
v =
  case forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
    Just (ArrayLit [SubExp]
ses Type
_, Certs
cs) ->
      ([SubExp] -> ConcatArg
ArgArrayLit [SubExp]
ses, Certs
cs)
    Just (Replicate Shape
shape SubExp
se, Certs
cs) ->
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate [Int -> Shape -> SubExp
shapeSize Int
0 Shape
shape] SubExp
se, Certs
cs)
    Maybe (BasicOp, Certs)
_ ->
      (VName -> ConcatArg
ArgVar VName
v, forall a. Monoid a => a
mempty)

fromConcatArg ::
  MonadBuilder m =>
  Type ->
  (ConcatArg, Certs) ->
  m VName
fromConcatArg :: forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
t (ArgArrayLit [SubExp]
ses, Certs
cs) =
  forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_lit" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp]
ses forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
t
fromConcatArg Type
elem_type (ArgReplicate [SubExp]
ws SubExp
se, Certs
cs) = do
  let elem_shape :: Shape
elem_shape = forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
elem_type
  forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ do
    SubExp
w <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"concat_rep_w" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp (forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
ws)
    forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rep" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0 Shape
elem_shape SubExp
w) SubExp
se
fromConcatArg Type
_ (ArgVar VName
v, Certs
_) =
  forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v

fuseConcatArg ::
  [(ConcatArg, Certs)] ->
  (ConcatArg, Certs) ->
  [(ConcatArg, Certs)]
fuseConcatArg :: [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
xs (ArgArrayLit [], Certs
_) =
  [(ConcatArg, Certs)]
xs
fuseConcatArg [(ConcatArg, Certs)]
xs (ArgReplicate [SubExp
w] SubExp
se, Certs
cs)
  | SubExp -> Bool
isCt0 SubExp
w =
      [(ConcatArg, Certs)]
xs
  | SubExp -> Bool
isCt1 SubExp
w =
      [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg [(ConcatArg, Certs)]
xs ([SubExp] -> ConcatArg
ArgArrayLit [SubExp
se], Certs
cs)
fuseConcatArg ((ArgArrayLit [SubExp]
x_ses, Certs
x_cs) : [(ConcatArg, Certs)]
xs) (ArgArrayLit [SubExp]
y_ses, Certs
y_cs) =
  ([SubExp] -> ConcatArg
ArgArrayLit ([SubExp]
x_ses forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ses), Certs
x_cs forall a. Semigroup a => a -> a -> a
<> Certs
y_cs) forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs
fuseConcatArg ((ArgReplicate [SubExp]
x_ws SubExp
x_se, Certs
x_cs) : [(ConcatArg, Certs)]
xs) (ArgReplicate [SubExp]
y_ws SubExp
y_se, Certs
y_cs)
  | SubExp
x_se forall a. Eq a => a -> a -> Bool
== SubExp
y_se =
      ([SubExp] -> SubExp -> ConcatArg
ArgReplicate ([SubExp]
x_ws forall a. [a] -> [a] -> [a]
++ [SubExp]
y_ws) SubExp
x_se, Certs
x_cs forall a. Semigroup a => a -> a -> a
<> Certs
y_cs) forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs
fuseConcatArg [(ConcatArg, Certs)]
xs (ConcatArg, Certs)
y =
  (ConcatArg, Certs)
y forall a. a -> [a] -> [a]
: [(ConcatArg, Certs)]
xs

simplifyConcat :: BuilderOps rep => BottomUpRuleBasicOp rep
-- concat@1(transpose(x),transpose(y)) == transpose(concat@0(x,y))
simplifyConcat :: forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Concat Int
i (VName
x :| [VName]
xs) SubExp
new_d)
  | Just Int
r <- forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
x SymbolTable rep
vtable,
    let perm :: [Int]
perm = [Int
i] forall a. [a] -> [a] -> [a]
++ [Int
0 .. Int
i forall a. Num a => a -> a -> a
- Int
1] forall a. [a] -> [a] -> [a]
++ [Int
i forall a. Num a => a -> a -> a
+ Int
1 .. Int
r forall a. Num a => a -> a -> a
- Int
1],
    Just (VName
x', Certs
x_cs) <- [Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm VName
x,
    Just ([VName]
xs', [Certs]
xs_cs) <- forall a b. [(a, b)] -> ([a], [b])
unzip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ([Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm) [VName]
xs = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      VName
concat_rearrange <-
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
x_cs forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Certs]
xs_cs) forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"concat_rearrange" forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
0 (VName
x' forall a. a -> [a] -> NonEmpty a
:| [VName]
xs') SubExp
new_d
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
concat_rearrange
  where
    transposedBy :: [Int] -> VName -> Maybe (VName, Certs)
transposedBy [Int]
perm1 VName
v =
      case forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable rep
vtable of
        Just (BasicOp (Rearrange [Int]
perm2 VName
v'), Certs
vcs)
          | [Int]
perm1 forall a. Eq a => a -> a -> Bool
== [Int]
perm2 -> forall a. a -> Maybe a
Just (VName
v', Certs
vcs)
        Maybe (Exp rep, Certs)
_ -> forall a. Maybe a
Nothing

-- Removing a concatenation that involves only a single array.  This
-- may be produced as a result of other simplification rules.
simplifyConcat (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Concat Int
_ (VName
x :| []) SubExp
_) =
  forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
    -- Still need a copy because Concat produces a fresh array.
    forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
      forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
        forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
          VName -> BasicOp
Copy VName
x
-- concat xs (concat ys zs) == concat xs ys zs
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat (StmAux Certs
cs Attrs
attrs ExpDec rep
_) (Concat Int
i (VName
x :| [VName]
xs) SubExp
new_d)
  | VName
x' forall a. Eq a => a -> a -> Bool
/= VName
x Bool -> Bool -> Bool
|| forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs' forall a. Eq a => a -> a -> Bool
/= [VName]
xs =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs forall a. Semigroup a => a -> a -> a
<> Certs
x_cs forall a. Semigroup a => a -> a -> a
<> forall a. Monoid a => [a] -> a
mconcat [Certs]
xs_cs) forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$
            forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
              forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
i (VName
x' forall a. a -> [a] -> NonEmpty a
:| [VName]
zs forall a. [a] -> [a] -> [a]
++ forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs') SubExp
new_d
  where
    (VName
x' : [VName]
zs, Certs
x_cs) = VName -> ([VName], Certs)
isConcat VName
x
    ([[VName]]
xs', [Certs]
xs_cs) = forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> ([VName], Certs)
isConcat [VName]
xs
    isConcat :: VName -> ([VName], Certs)
isConcat VName
v = case forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v SymbolTable rep
vtable of
      Just (Concat Int
j (VName
y :| [VName]
ys) SubExp
_, Certs
v_cs) | Int
j forall a. Eq a => a -> a -> Bool
== Int
i -> (VName
y forall a. a -> [a] -> [a]
: [VName]
ys, Certs
v_cs)
      Maybe (BasicOp, Certs)
_ -> ([VName
v], forall a. Monoid a => a
mempty)

-- Fusing arguments to the concat when possible.  Only done when
-- concatenating along the outer dimension for now.
simplifyConcat (SymbolTable rep
vtable, UsageTable
_) Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Concat Int
0 (VName
x :| [VName]
xs) SubExp
outer_w)
  | -- We produce the to-be-concatenated arrays in reverse order, so
    -- reverse them back.
    (ConcatArg, Certs)
y : [(ConcatArg, Certs)]
ys <-
      [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forSingleArray forall a b. (a -> b) -> a -> b
$
        forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$
          forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' [(ConcatArg, Certs)] -> (ConcatArg, Certs) -> [(ConcatArg, Certs)]
fuseConcatArg forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$
            forall a b. (a -> b) -> [a] -> [b]
map (forall {k} (rep :: k).
SymbolTable rep -> VName -> (ConcatArg, Certs)
toConcatArg SymbolTable rep
vtable) forall a b. (a -> b) -> a -> b
$
              VName
x forall a. a -> [a] -> [a]
: [VName]
xs,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length [VName]
xs forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> Int
length [(ConcatArg, Certs)]
ys =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
        Type
elem_type <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
x
        VName
y' <- forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
elem_type (ConcatArg, Certs)
y
        [VName]
ys' <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *).
MonadBuilder m =>
Type -> (ConcatArg, Certs) -> m VName
fromConcatArg Type
elem_type) [(ConcatArg, Certs)]
ys
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Int -> NonEmpty VName -> SubExp -> BasicOp
Concat Int
0 (VName
y' forall a. a -> [a] -> NonEmpty a
:| [VName]
ys') SubExp
outer_w
  where
    -- If we fuse so much that there is only a single input left, then
    -- it must have the right size.
    forSingleArray :: [(ConcatArg, Certs)] -> [(ConcatArg, Certs)]
forSingleArray [(ArgReplicate [SubExp]
_ SubExp
v, Certs
cs)] =
      [([SubExp] -> SubExp -> ConcatArg
ArgReplicate [SubExp
outer_w] SubExp
v, Certs
cs)]
    forSingleArray [(ConcatArg, Certs)]
ys = [(ConcatArg, Certs)]
ys
simplifyConcat (SymbolTable rep, UsageTable)
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ = forall {k} (rep :: k). Rule rep
Skip

ruleBasicOp :: BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp :: forall rep. BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux BasicOp
op
  | Just (BasicOp
op', Certs
cs) <- forall {k} (rep :: k).
VarLookup rep -> TypeLookup -> BasicOp -> Maybe (BasicOp, Certs)
applySimpleRules VName -> Maybe (Exp rep, Certs)
defOf TypeLookup
seType BasicOp
op =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs forall a. Semigroup a => a -> a -> a
<> forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux) forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp BasicOp
op'
  where
    defOf :: VName -> Maybe (Exp rep, Certs)
defOf = (forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
`ST.lookupExp` TopDown rep
vtable)
    seType :: TypeLookup
seType (Var VName
v) = forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
v TopDown rep
vtable
    seType (Constant PrimValue
v) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall shape u. PrimType -> TypeBase shape u
Prim forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimType
primValueType PrimValue
v
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Update Safety
_ VName
src Slice SubExp
_ (Var VName
v))
  | Just (BasicOp Scratch {}, Certs
_) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
src
-- If we are writing a single-element slice from some array, and the
-- element of that array can be computed as a PrimExp based on the
-- index, let's just write that instead.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Update Safety
safety VName
src (Slice [DimSlice SubExp
i SubExp
n SubExp
s]) (Var VName
v))
  | SubExp -> Bool
isCt1 SubExp
n,
    SubExp -> Bool
isCt1 SubExp
s,
    Just (ST.Indexed Certs
cs PrimExp VName
e) <- forall {k} (rep :: k).
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
v [IntType -> Integer -> SubExp
intConst IntType
Int64 Integer
0] TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
        SubExp
e' <- forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"update_elem" PrimExp VName
e
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
safety VName
src (forall d. [DimIndex d] -> Slice d
Slice [forall d. d -> DimIndex d
DimFix SubExp
i]) SubExp
e'
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Update Safety
_ VName
dest Slice SubExp
destis (Var VName
v))
  | Just (Exp rep
e, Certs
_) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    Exp rep -> Bool
arrayFrom Exp rep
e =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
dest
  where
    arrayFrom :: Exp rep -> Bool
arrayFrom (BasicOp (Copy VName
copy_v))
      | Just (Exp rep
e', Certs
_) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
copy_v TopDown rep
vtable =
          Exp rep -> Bool
arrayFrom Exp rep
e'
    arrayFrom (BasicOp (Index VName
src Slice SubExp
srcis)) =
      VName
src forall a. Eq a => a -> a -> Bool
== VName
dest Bool -> Bool -> Bool
&& Slice SubExp
destis forall a. Eq a => a -> a -> Bool
== Slice SubExp
srcis
    arrayFrom (BasicOp (Replicate Shape
v_shape SubExp
v_se))
      | Just (Replicate Shape
dest_shape SubExp
dest_se, Certs
_) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
dest TopDown rep
vtable,
        SubExp
v_se forall a. Eq a => a -> a -> Bool
== SubExp
dest_se,
        forall d. ShapeBase d -> [d]
shapeDims Shape
v_shape forall a. Eq a => [a] -> [a] -> Bool
`isSuffixOf` forall d. ShapeBase d -> [d]
shapeDims Shape
dest_shape =
          Bool
True
    arrayFrom Exp rep
_ =
      Bool
False
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Update Safety
_ VName
dest Slice SubExp
is SubExp
se)
  | Just Type
dest_t <- forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
dest TopDown rep
vtable,
    Shape -> Slice SubExp -> Bool
isFullSlice (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t) Slice SubExp
is = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
      case SubExp
se of
        Var VName
v | Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null forall a b. (a -> b) -> a -> b
$ forall d. Slice d -> [d]
sliceDims Slice SubExp
is -> do
          VName
v_reshaped <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp (VName -> String
baseString VName
v forall a. [a] -> [a] -> [a]
++ String
"_reshaped") forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              ReshapeKind -> Shape -> VName -> BasicOp
Reshape ReshapeKind
ReshapeArbitrary (forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape Type
dest_t) VName
v
          forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> BasicOp
Copy VName
v_reshaped
        SubExp
_ -> forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> Type -> BasicOp
ArrayLit [SubExp
se] forall a b. (a -> b) -> a -> b
$ forall u. TypeBase Shape u -> TypeBase Shape u
rowType Type
dest_t
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat (StmAux Certs
cs1 Attrs
attrs ExpDec rep
_) (Update Safety
safety1 VName
dest1 Slice SubExp
is1 (Var VName
v1))
  | Just (Update Safety
safety2 VName
dest2 Slice SubExp
is2 SubExp
se2, Certs
cs2) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    Just (Copy VName
v3, Certs
cs3) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
dest2 TopDown rep
vtable,
    Just (Index VName
v4 Slice SubExp
is4, Certs
cs4) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v3 TopDown rep
vtable,
    Slice SubExp
is4 forall a. Eq a => a -> a -> Bool
== Slice SubExp
is1,
    VName
v4 forall a. Eq a => a -> a -> Bool
== VName
dest1 =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
cs1 forall a. Semigroup a => a -> a -> a
<> Certs
cs2 forall a. Semigroup a => a -> a -> a
<> Certs
cs3 forall a. Semigroup a => a -> a -> a
<> Certs
cs4) forall a b. (a -> b) -> a -> b
$ do
          Slice SubExp
is5 <- forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice forall a b. (a -> b) -> a -> b
$ forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is1) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
is2)
          forall (m :: * -> *) a. MonadBuilder m => Attrs -> m a -> m a
attributing Attrs
attrs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update (forall a. Ord a => a -> a -> a
max Safety
safety1 Safety
safety2) VName
dest1 Slice SubExp
is5 SubExp
se2
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (CmpOp (CmpEq PrimType
t) SubExp
se1 SubExp
se2)
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se1 SubExp
se2 = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  | Just RuleM rep ()
m <- SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith SubExp
se2 SubExp
se1 = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify RuleM rep ()
m
  where
    simplifyWith :: SubExp -> SubExp -> Maybe (RuleM rep ())
simplifyWith (Var VName
v) SubExp
x
      | Just Stm rep
stm <- forall {k} (rep :: k). VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
        Match [SubExp
p] [Case [Just (BoolValue Bool
True)] Body rep
tbranch] Body rep
fbranch MatchDec (BranchType rep)
_ <- forall {k} (rep :: k). Stm rep -> Exp rep
stmExp Stm rep
stm,
        Just (SubExp
y, SubExp
z) <-
          forall {k} {k} {dec} {rep :: k} {rep :: k}.
VName -> Pat dec -> Body rep -> Body rep -> Maybe (SubExp, SubExp)
returns VName
v (forall {k} (rep :: k). Stm rep -> Pat (LetDec rep)
stmPat Stm rep
stm) Body rep
tbranch Body rep
fbranch,
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Names
boundInBody Body rep
tbranch Names -> Names -> Bool
`namesIntersect` forall a. FreeIn a => a -> Names
freeIn SubExp
y,
        Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). Body rep -> Names
boundInBody Body rep
fbranch Names -> Names -> Bool
`namesIntersect` forall a. FreeIn a => a -> Names
freeIn SubExp
z = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
          SubExp
eq_x_y <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_y" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
y
          SubExp
eq_x_z <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"eq_x_z" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
t) SubExp
x SubExp
z
          SubExp
p_and_eq_x_y <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
p SubExp
eq_x_y
          SubExp
not_p <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"not_p" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp UnOp
Not SubExp
p
          SubExp
not_p_and_eq_x_z <-
            forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"p_and_eq_x_y" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogAnd SubExp
not_p SubExp
eq_x_z
          forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
p_and_eq_x_y SubExp
not_p_and_eq_x_z
    simplifyWith SubExp
_ SubExp
_ =
      forall a. Maybe a
Nothing

    returns :: VName -> Pat dec -> Body rep -> Body rep -> Maybe (SubExp, SubExp)
returns VName
v Pat dec
ifpat Body rep
tbranch Body rep
fbranch =
      forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a, b) -> b
snd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((forall a. Eq a => a -> a -> Bool
== VName
v) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall dec. PatElem dec -> VName
patElemName forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a, b) -> a
fst) forall a b. (a -> b) -> a -> b
$
        forall a b. [a] -> [b] -> [(a, b)]
zip (forall dec. Pat dec -> [PatElem dec]
patElems Pat dec
ifpat) forall a b. (a -> b) -> a -> b
$
          forall a b. [a] -> [b] -> [(a, b)]
zip (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
tbranch)) (forall a b. (a -> b) -> [a] -> [b]
map SubExpRes -> SubExp
resSubExp (forall {k} (rep :: k). Body rep -> Result
bodyResult Body rep
fbranch))
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) se :: SubExp
se@Constant {}) =
  forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate Shape
_ SubExp
se)
  | [Acc {}] <- forall dec. Typed dec => Pat dec -> [Type]
patTypes Pat (LetDec rep)
pat =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
se
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate (Shape []) (Var VName
v)) = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
  Type
v_t <- forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
  forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
    forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
      if forall shape u. TypeBase shape u -> Bool
primType Type
v_t
        then SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
        else VName -> BasicOp
Copy VName
v
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Replicate Shape
shape (Var VName
v))
  | Just (BasicOp (Replicate Shape
shape2 SubExp
se), Certs
cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (Shape
shape forall a. Semigroup a => a -> a -> a
<> Shape
shape2) SubExp
se
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (ArrayLit (SubExp
se : [SubExp]
ses) Type
_)
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Eq a => a -> a -> Bool
== SubExp
se) [SubExp]
ses =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
        let n :: SubExp
n = forall v. IsValue v => v -> SubExp
constant (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ses) forall a. Num a => a -> a -> a
+ Int64
1 :: Int64)
         in forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp
n]) SubExp
se
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Index VName
idd Slice SubExp
slice)
  | Just [SubExp]
inds <- forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
slice,
    Just (BasicOp (Reshape ReshapeKind
k Shape
newshape VName
idd2), Certs
idd_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
idd TopDown rep
vtable,
    forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
newshape forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
inds =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
        case ReshapeKind
k of
          ReshapeKind
ReshapeCoerce ->
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
idd_cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              VName -> Slice SubExp -> BasicOp
Index VName
idd2 Slice SubExp
slice
          ReshapeKind
ReshapeArbitrary -> do
            -- Linearise indices and map to old index space.
            [SubExp]
oldshape <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
idd2
            let new_inds :: [TPrimExp Int64 VName]
new_inds =
                  forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex
                    (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
oldshape)
                    (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 forall a b. (a -> b) -> a -> b
$ forall d. ShapeBase d -> [d]
shapeDims Shape
newshape)
                    (forall a b. (a -> b) -> [a] -> [b]
map SubExp -> TPrimExp Int64 VName
pe64 [SubExp]
inds)
            [SubExp]
new_inds' <-
              forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"new_index") [TPrimExp Int64 VName]
new_inds
            forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
idd_cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
              forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                  VName -> Slice SubExp -> BasicOp
Index VName
idd2 forall a b. (a -> b) -> a -> b
$
                    forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
                      forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix [SubExp]
new_inds'

-- Copying an iota is pointless; just make it an iota instead.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Copy VName
v)
  | Just (Iota SubExp
n SubExp
x SubExp
s IntType
it, Certs
v_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
n SubExp
x SubExp
s IntType
it
-- Handle identity permutation.
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Rearrange [Int]
perm VName
v)
  | forall a. Ord a => [a] -> [a]
sort [Int]
perm forall a. Eq a => a -> a -> Bool
== [Int]
perm =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rearrange [Int]
perm2 VName
e), Certs
v_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable =
      -- Rearranging a rearranging: compose the permutations.
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm2) VName
e
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets VName
v2), Certs
v_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    Just (BasicOp (Rearrange [Int]
perm3 VName
v3), Certs
v2_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v2 TopDown rep
vtable = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let offsets' :: [SubExp]
offsets' = forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm3) [SubExp]
offsets
      VName
rearrange_rotate <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rearrange_rotate" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
v3
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
v_cs forall a. Semigroup a => a -> a -> a
<> Certs
v2_cs) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              [Int] -> VName -> BasicOp
Rearrange ([Int]
perm [Int] -> [Int] -> [Int]
`rearrangeCompose` [Int]
perm3) VName
rearrange_rotate

-- Rearranging a replicate where the outer dimension is left untouched.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rearrange [Int]
perm VName
v1)
  | Just (BasicOp (Replicate Shape
dims (Var VName
v2)), Certs
v1_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v1 TopDown rep
vtable,
    Int
num_dims <- forall a. ArrayShape a => a -> Int
shapeRank Shape
dims,
    ([Int]
rep_perm, [Int]
rest_perm) <- forall a. Int -> [a] -> ([a], [a])
splitAt Int
num_dims [Int]
perm,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
rest_perm,
    [Int]
rep_perm forall a. Eq a => a -> a -> Bool
== [Int
0 .. forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
rep_perm forall a. Num a => a -> a -> a
- Int
1] =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v1_cs forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ do
            SubExp
v <-
              forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rearrange_replicate" forall a b. (a -> b) -> a -> b
$
                forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
                  [Int] -> VName -> BasicOp
Rearrange (forall a b. (a -> b) -> [a] -> [b]
map (forall a. Num a => a -> a -> a
subtract Int
num_dims) [Int]
rest_perm) VName
v2
            forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate Shape
dims SubExp
v

-- A zero-rotation is identity.
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
pat StmAux (ExpDec rep)
_ (Rotate [SubExp]
offsets VName
v)
  | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all SubExp -> Bool
isCt0 [SubExp]
offsets = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rotate [SubExp]
offsets VName
v)
  | Just (BasicOp (Rearrange [Int]
perm VName
v2), Certs
v_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable,
    Just (BasicOp (Rotate [SubExp]
offsets2 VName
v3), Certs
v2_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v2 TopDown rep
vtable = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      let offsets2' :: [SubExp]
offsets2' = forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [SubExp]
offsets2
          addOffsets :: SubExp -> SubExp -> m SubExp
addOffsets SubExp
x SubExp
y = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"summed_offset" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y
      [SubExp]
offsets' <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
addOffsets [SubExp]
offsets [SubExp]
offsets2'
      VName
rotate_rearrange <-
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"rotate_rearrange" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ [Int] -> VName -> BasicOp
Rearrange [Int]
perm VName
v3
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying (Certs
v_cs forall a. Semigroup a => a -> a -> a
<> Certs
v2_cs) forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets' VName
rotate_rearrange

-- Combining Rotates.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Rotate [SubExp]
offsets1 VName
v)
  | Just (BasicOp (Rotate [SubExp]
offsets2 VName
v2), Certs
v_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v TopDown rep
vtable = forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ do
      [SubExp]
offsets <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
add [SubExp]
offsets1 [SubExp]
offsets2
      forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
v_cs forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$
          forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
            forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
              [SubExp] -> VName -> BasicOp
Rotate [SubExp]
offsets VName
v2
  where
    add :: SubExp -> SubExp -> m SubExp
add SubExp
x SubExp
y = forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"offset" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
x SubExp
y

-- Simplify away 0<=i when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSle {} SubExp
x SubExp
y)
  | Constant (IntValue (Int64Value Int64
0)) <- SubExp
x,
    Var VName
v <- SubExp
y,
    Just SubExp
_ <- forall {k} (rep :: k). VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away i<n when 'i' is from a loop of form 'for i < n'.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} SubExp
x SubExp
y)
  | Var VName
v <- SubExp
x,
    Just SubExp
n <- forall {k} (rep :: k). VName -> SymbolTable rep -> Maybe SubExp
ST.lookupLoopVar VName
v TopDown rep
vtable,
    SubExp
n forall a. Eq a => a -> a -> Bool
== SubExp
y =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
True
-- Simplify away x<0 when 'x' has been used as array size.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (CmpOp CmpSlt {} (Var VName
x) SubExp
y)
  | SubExp -> Bool
isCt0 SubExp
y,
    forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False forall {k} (rep :: k). Entry rep -> Bool
ST.entryIsSize forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
x TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ forall v. IsValue v => v -> SubExp
constant Bool
False
-- Remove certificates for variables whose definition already contain
-- that certificate.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (SubExp (Var VName
v))
  | [VName]
cs <- Certs -> [VName]
unCerts forall a b. (a -> b) -> a -> b
$ forall dec. StmAux dec -> Certs
stmAuxCerts StmAux (ExpDec rep)
aux,
    Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
cs,
    Just [VName]
v_cs <- Certs -> [VName]
unCerts forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Stm rep -> Certs
stmCerts forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k). VName -> SymbolTable rep -> Maybe (Stm rep)
ST.lookupStm VName
v TopDown rep
vtable,
    [VName]
cs' <- forall a. (a -> Bool) -> [a] -> [a]
filter (forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [VName]
v_cs) [VName]
cs,
    [VName]
cs' forall a. Eq a => a -> a -> Bool
/= [VName]
cs =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying ([VName] -> Certs
Certs [VName]
cs') forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$
              VName -> SubExp
Var VName
v
-- Remove UpdateAccs that contribute the neutral value, which is
-- always a no-op.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (UpdateAcc VName
acc [SubExp]
_ [SubExp]
vs)
  | Pat [PatElem (LetDec rep)
pe] <- Pat (LetDec rep)
pat,
    Acc VName
token Shape
_ [Type]
_ NoUniqueness
_ <- forall dec. Typed dec => PatElem dec -> Type
patElemType PatElem (LetDec rep)
pe,
    Just (Shape
_, [VName]
_, Just (Lambda rep
_, [SubExp]
ne)) <- forall {k} (rep :: k). Entry rep -> Maybe (WithAccInput rep)
ST.entryAccInput forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
token TopDown rep
vtable,
    [SubExp]
vs forall a. Eq a => a -> a -> Bool
== [SubExp]
ne =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
acc
-- Manifest of a a copy can be simplified to manifesting the original
-- array, if it is still available.
ruleBasicOp TopDown rep
vtable Pat (LetDec rep)
pat StmAux (ExpDec rep)
aux (Manifest [Int]
perm VName
v1)
  | Just (Copy VName
v2, Certs
cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (BasicOp, Certs)
ST.lookupBasicOp VName
v1 TopDown rep
vtable,
    forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
ST.available VName
v2 TopDown rep
vtable =
      forall {k} (rep :: k). RuleM rep () -> Rule rep
Simplify forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) anyrep a.
MonadBuilder m =>
StmAux anyrep -> m a -> m a
auxing StmAux (ExpDec rep)
aux forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$
        forall (m :: * -> *).
MonadBuilder m =>
Pat (LetDec (Rep m)) -> Exp (Rep m) -> m ()
letBind Pat (LetDec rep)
pat forall a b. (a -> b) -> a -> b
$
          forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
            [Int] -> VName -> BasicOp
Manifest [Int]
perm VName
v2
ruleBasicOp TopDown rep
_ Pat (LetDec rep)
_ StmAux (ExpDec rep)
_ BasicOp
_ =
  forall {k} (rep :: k). Rule rep
Skip

topDownRules :: BuilderOps rep => [TopDownRule rep]
topDownRules :: forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules =
  [ forall {k} (rep :: k) a.
RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => TopDownRuleBasicOp rep
ruleBasicOp
  ]

bottomUpRules :: BuilderOps rep => [BottomUpRule rep]
bottomUpRules :: forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules =
  [ forall {k} (rep :: k) a.
RuleBasicOp rep a -> SimplificationRule rep a
RuleBasicOp forall rep. BuilderOps rep => BottomUpRuleBasicOp rep
simplifyConcat
  ]

-- | A set of simplification rules for t'BasicOp's.  Includes rules
-- from "Futhark.Optimise.Simplify.Rules.Simple".
basicOpRules :: (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules :: forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
basicOpRules = forall {k} (m :: k).
[TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook forall rep. BuilderOps rep => [TopDownRule rep]
topDownRules forall rep. BuilderOps rep => [BottomUpRule rep]
bottomUpRules forall a. Semigroup a => a -> a -> a
<> forall rep. (BuilderOps rep, Aliased rep) => RuleBook rep
loopRules