{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}

-- | Index simplification mechanics.
module Futhark.Optimise.Simplify.Rules.Index
  ( IndexResult (..),
    simplifyIndexing,
  )
where

import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rules.Simple
import Futhark.Util

isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False

isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False

-- | Some index expressions can be simplified to t'SubExp's, while
-- others produce another index expression (which may be further
-- simplifiable).
data IndexResult
  = IndexResult Certs VName (Slice SubExp)
  | SubExpResult Certs SubExp

-- | Try to simplify an index operation.
simplifyIndexing ::
  (MonadBuilder m) =>
  ST.SymbolTable (Rep m) ->
  TypeLookup ->
  VName ->
  Slice SubExp ->
  Bool ->
  (VName -> Bool) ->
  Maybe (m IndexResult)
simplifyIndexing :: forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> (VName -> Bool)
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Rep m)
vtable TypeLookup
seType VName
idd (Slice [DimIndex SubExp]
inds) Bool
consuming VName -> Bool
consumed =
  case VName -> Maybe (BasicOp, Certs)
defOf VName
idd of
    Maybe (BasicOp, Certs)
_
      | Just Type
t <- TypeLookup
seType (VName -> SubExp
Var VName
idd),
        [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [] ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd
      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds),
        Just (ST.Indexed Certs
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Rep m) -> Maybe Indexed
forall rep.
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Rep m)
vtable,
        PrimExp VName -> Bool
forall {v}. PrimExp v -> Bool
worthInlining PrimExp VName
e,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Rep m) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Rep m)
vtable) (Certs -> [VName]
unCerts Certs
cs) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e
      | Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds),
        Just (ST.IndexedArray Certs
cs VName
arr [TPrimExp Int64 VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Rep m) -> Maybe Indexed
forall rep.
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Rep m)
vtable,
        (TPrimExp Int64 VName -> Bool) -> [TPrimExp Int64 VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (PrimExp VName -> Bool
forall {v}. PrimExp v -> Bool
worthInlining (PrimExp VName -> Bool)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
inds'',
        VName
arr VName -> SymbolTable (Rep m) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable (Rep m)
vtable,
        (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Rep m) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Rep m)
vtable) (Certs -> [VName]
unCerts Certs
cs) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
            Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> Slice SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix
              ([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName -> m SubExp)
-> [TPrimExp Int64 VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (String -> TPrimExp Int64 VName -> m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [TPrimExp Int64 VName]
inds''
    Maybe (BasicOp, Certs)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
    Just (SubExp (Var VName
v), Certs
cs) ->
      m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
v (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
    Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certs
cs)
      | [DimFix SubExp
ii] <- [DimIndex SubExp]
inds,
        Just (Prim (IntType IntType
from_it)) <- TypeLookup
seType SubExp
ii ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
            let mul :: PrimExp v -> PrimExp v -> PrimExp v
mul = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
                add :: PrimExp v -> PrimExp v -> PrimExp v
add = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
             in (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> SubExp -> IndexResult
SubExpResult Certs
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
                  String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
                    ( IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
                        PrimExp VName -> PrimExp VName -> PrimExp VName
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
                    )
                      PrimExp VName -> PrimExp VName -> PrimExp VName
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
      | [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- [DimIndex SubExp]
inds ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
            SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
            SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
            let mul :: PrimExp v -> PrimExp v -> PrimExp v
mul = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
                add :: PrimExp v -> PrimExp v -> PrimExp v
add = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
            SubExp
i_offset'' <-
              String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
                ( PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
                    PrimExp VName -> PrimExp VName -> PrimExp VName
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
                )
                  PrimExp VName -> PrimExp VName -> PrimExp VName
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
            SubExp
i_stride'' <-
              String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iota_offset" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                  BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
            (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> SubExp -> IndexResult
SubExpResult Certs
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
              String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"slice_iota" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                  SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it
    Just (Index VName
aa Slice SubExp
ais, Certs
cs) ->
      m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
        Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
aa
          (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds)))
    Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certs
cs)
      | [DimFix {}] <- [DimIndex SubExp]
inds,
        VName -> SymbolTable (Rep m) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
vv SymbolTable (Rep m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
      | DimFix {} : [DimIndex SubExp]
is' <- [DimIndex SubExp]
inds,
        Bool -> Bool
not Bool
consuming,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
consumed VName
vv,
        VName -> SymbolTable (Rep m) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
vv SymbolTable (Rep m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
vv (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
is'
    Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certs
cs)
      | [DimFix {}] <- [DimIndex SubExp]
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs SubExp
val
    Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certs
cs)
      | ([DimIndex SubExp]
ds_inds, [DimIndex SubExp]
rest_inds) <- Int -> [DimIndex SubExp] -> ([DimIndex SubExp], [DimIndex SubExp])
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [DimIndex SubExp]
inds,
        ([SubExp]
ds', [DimIndex SubExp]
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp]))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], [DimIndex SubExp])
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> [DimIndex SubExp] -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index [DimIndex SubExp]
ds_inds,
        [SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds,
        SubExp -> SymbolTable (Rep m) -> Bool
forall rep. SubExp -> SymbolTable rep -> Bool
ST.subExpAvailable SubExp
v SymbolTable (Rep m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
            VName
arr <- String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"smaller_replicate" (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ShapeBase SubExp -> SubExp -> BasicOp
Replicate ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
            IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
ds_inds' [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
rest_inds
      where
        index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix {} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
        index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
n SubExp
s)
    Just (Rearrange [Int]
perm VName
src, Certs
cs)
      | [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall {d}. DimIndex d -> Bool
isIndex [DimIndex SubExp]
inds) ->
          let inds' :: [DimIndex SubExp]
inds' = [Int] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [DimIndex SubExp]
inds
           in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds'
      where
        isIndex :: DimIndex d -> Bool
isIndex DimFix {} = Bool
True
        isIndex DimIndex d
_ = Bool
False
    Just (Replicate (Shape []) (Var VName
src), Certs
cs)
      | Just [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
        [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ VName -> Bool
consumed VName
src,
        -- It is generally not safe to simplify a slice of a copy,
        -- because the result may be used in an in-place update of the
        -- original.  But we know this can only happen if the original
        -- is bound the same depth as we are!
        (DimIndex SubExp -> Bool) -> [DimIndex SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe SubExp -> Bool
forall a. Maybe a -> Bool
isJust (Maybe SubExp -> Bool)
-> (DimIndex SubExp -> Maybe SubExp) -> DimIndex SubExp -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix) [DimIndex SubExp]
inds
          Bool -> Bool -> Bool
|| Bool -> (Entry (Rep m) -> Bool) -> Maybe (Entry (Rep m)) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True ((SymbolTable (Rep m) -> Int
forall rep. SymbolTable rep -> Int
ST.loopDepth SymbolTable (Rep m)
vtable /=) (Int -> Bool) -> (Entry (Rep m) -> Int) -> Entry (Rep m) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry (Rep m) -> Int
forall rep. Entry rep -> Int
ST.entryDepth) (VName -> SymbolTable (Rep m) -> Maybe (Entry (Rep m))
forall rep. VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
src SymbolTable (Rep m)
vtable),
        Bool -> Bool
not Bool
consuming,
        VName -> SymbolTable (Rep m) -> Bool
forall rep. VName -> SymbolTable rep -> Bool
ST.available VName
src SymbolTable (Rep m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
    Just (Reshape ReshapeKind
ReshapeCoerce ShapeBase SubExp
newshape VName
src, Certs
cs)
      | Just [SubExp]
olddims <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
        [Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
newshape) [SubExp]
olddims,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
inds) [Bool]
changed_dims ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
      | Just [SubExp]
olddims <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
        ShapeBase SubExp -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
inds,
        [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
newshape) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
    Just (Reshape ReshapeKind
_ (Shape [SubExp
_]) VName
v2, Certs
cs)
      | Just [SubExp
_] <- Type -> [SubExp]
forall u. TypeBase (ShapeBase SubExp) u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
v2) ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
v2 (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
    Just (Concat Int
d (VName
x :| [VName]
xs) SubExp
_, Certs
cs)
      | -- HACK: simplifying the indexing of an N-array concatenation
        -- is going to produce an N-deep if expression, which is bad
        -- when N is large.  To try to avoid that, we use the
        -- heuristic not to simplify as long as any of the operands
        -- are themselves Concats.  The hope it that this will give
        -- simplification some time to cut down the concatenation to
        -- something smaller, before we start inlining.
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isConcat ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
        Just ([DimIndex SubExp]
ibef, DimFix SubExp
i, [DimIndex SubExp]
iaft) <- Int
-> [DimIndex SubExp]
-> Maybe ([DimIndex SubExp], DimIndex SubExp, [DimIndex SubExp])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d [DimIndex SubExp]
inds,
        Just (Prim PrimType
res_t) <-
          (Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase (ShapeBase SubExp) u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims ([DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds))
            (Type -> Type) -> Maybe Type -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable (Rep m) -> Maybe Type
forall rep. ASTRep rep => VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
x SymbolTable (Rep m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
          SubExp
x_len <- Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
d (Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
x
          [SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((Type -> SubExp) -> m Type -> m SubExp
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
d) (m Type -> m SubExp) -> (VName -> m Type) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType) [VName]
xs

          let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
                SubExp
added <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
n SubExp
m
                (SubExp, SubExp) -> m (SubExp, SubExp)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
added, SubExp
n)
          (SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
          let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts

          let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
                String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
ibef [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
iaft
              mkBranch ((VName
x', SubExp
start) : [(VName, SubExp)]
xs_and_starts') = do
                SubExp
cmp <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_cmp" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
start SubExp
i
                (SubExp
thisres, Stms (Rep m)
thisstms) <- m SubExp -> m (SubExp, Stms (Rep m))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m SubExp -> m (SubExp, Stms (Rep m)))
-> m SubExp -> m (SubExp, Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ do
                  SubExp
i' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_i" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
start
                  String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat" (Exp (Rep m) -> m SubExp)
-> (Slice SubExp -> Exp (Rep m)) -> Slice SubExp -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp (Rep m))
-> (Slice SubExp -> BasicOp) -> Slice SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
x' (Slice SubExp -> m SubExp) -> Slice SubExp -> m SubExp
forall a b. (a -> b) -> a -> b
$
                    [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp]
ibef [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
iaft)
                Body (Rep m)
thisbody <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
thisstms [SubExp -> SubExpRes
subExpRes SubExp
thisres]
                (SubExp
altres, Stms (Rep m)
altstms) <- m SubExp -> m (SubExp, Stms (Rep m))
forall a. m a -> m (a, Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms (m SubExp -> m (SubExp, Stms (Rep m)))
-> m SubExp -> m (SubExp, Stms (Rep m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
                Body (Rep m)
altbody <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
altstms [SubExp -> SubExpRes
subExpRes SubExp
altres]
                Certs -> m SubExp -> m SubExp
forall a. Certs -> m a -> m a
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs (m SubExp -> m SubExp)
-> (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_branch" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
                  [SubExp]
-> [Case (Body (Rep m))]
-> Body (Rep m)
-> MatchDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
cmp] [[Maybe PrimValue] -> Body (Rep m) -> Case (Body (Rep m))
forall body. [Maybe PrimValue] -> body -> Case body
Case [PrimValue -> Maybe PrimValue
forall a. a -> Maybe a
Just (PrimValue -> Maybe PrimValue) -> PrimValue -> Maybe PrimValue
forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
thisbody] Body (Rep m)
altbody (MatchDec (BranchType (Rep m)) -> Exp (Rep m))
-> MatchDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$
                    [BranchType (Rep m)] -> MatchSort -> MatchDec (BranchType (Rep m))
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [PrimType -> BranchType (Rep m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] MatchSort
MatchNormal
          Certs -> SubExp -> IndexResult
SubExpResult Certs
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
forall {m :: * -> *}.
MonadBuilder m =>
[(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts
    Just (ArrayLit [SubExp]
ses Type
_, Certs
cs)
      | DimFix (Constant (IntValue (Int64Value Int64
i))) : [DimIndex SubExp]
inds' <- [DimIndex SubExp]
inds,
        Just SubExp
se <- Int64 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int64
i [SubExp]
ses ->
          case [DimIndex SubExp]
inds' of
            [] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs SubExp
se
            [DimIndex SubExp]
_ | Var VName
v2 <- SubExp
se -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
v2 (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds'
            [DimIndex SubExp]
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
    Just (Update Safety
Unsafe VName
_ (Slice [DimIndex SubExp]
update_inds) SubExp
se, Certs
cs)
      | [DimIndex SubExp]
inds [DimIndex SubExp] -> [DimIndex SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
== [DimIndex SubExp]
update_inds,
        SubExp -> SymbolTable (Rep m) -> Bool
forall rep. SubExp -> SymbolTable rep -> Bool
ST.subExpAvailable SubExp
se SymbolTable (Rep m)
vtable ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs SubExp
se
    -- Indexing single-element arrays.  We know the index must be 0.
    Maybe (BasicOp, Certs)
_
      | Just Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd,
        SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase (ShapeBase SubExp) u -> SubExp
arraySize Int
0 Type
t,
        DimFix SubExp
i : [DimIndex SubExp]
inds' <- [DimIndex SubExp]
inds,
        Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
          m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> ([DimIndex SubExp] -> m IndexResult)
-> [DimIndex SubExp]
-> Maybe (m IndexResult)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IndexResult -> m IndexResult
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult)
-> ([DimIndex SubExp] -> IndexResult)
-> [DimIndex SubExp]
-> m IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> IndexResult)
-> ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp]
-> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Maybe (m IndexResult))
-> [DimIndex SubExp] -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
            SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) DimIndex SubExp -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. a -> [a] -> [a]
: [DimIndex SubExp]
inds'
    Maybe (BasicOp, Certs)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
  where
    defOf :: VName -> Maybe (BasicOp, Certs)
defOf VName
v = do
      (BasicOp BasicOp
op, Certs
def_cs) <- VName -> SymbolTable (Rep m) -> Maybe (Exp (Rep m), Certs)
forall rep. VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Rep m)
vtable
      (BasicOp, Certs) -> Maybe (BasicOp, Certs)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BasicOp
op, Certs
def_cs)
    worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
      | Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False -- totally ad-hoc.
      | Bool
otherwise = PrimExp v -> Bool
forall {v}. PrimExp v -> Bool
worthInlining' PrimExp v
e
    worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow {} PrimExp v
_ PrimExp v
_) = Bool
False
    worthInlining' (BinOpExp FPow {} PrimExp v
_ PrimExp v
_) = Bool
False
    worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
    worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
    worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
    worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
    worthInlining' FunExp {} = Bool
False
    worthInlining' PrimExp v
_ = Bool
True

    isConcat :: VName -> Bool
isConcat VName
v
      | Just (Concat {}, Certs
_) <- VName -> Maybe (BasicOp, Certs)
defOf VName
v =
          Bool
True
      | Bool
otherwise =
          Bool
False