{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}
module Futhark.Optimise.Simplify.Rules.Index
( IndexResult (..),
simplifyIndexing,
)
where
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rules.Simple
import Futhark.Util
isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False
isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False
data IndexResult
= IndexResult Certificates VName (Slice SubExp)
| SubExpResult Certificates SubExp
simplifyIndexing ::
MonadBinder m =>
ST.SymbolTable (Lore m) ->
TypeLookup ->
VName ->
Slice SubExp ->
Bool ->
Maybe (m IndexResult)
simplifyIndexing :: SymbolTable (Lore m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Lore m)
vtable TypeLookup
seType VName
idd Slice SubExp
inds Bool
consuming =
case VName -> Maybe (BasicOp, Certificates)
defOf VName
idd of
Maybe (BasicOp, Certificates)
_
| Just Type
t <- TypeLookup
seType (VName -> SubExp
Var VName
idd),
Slice SubExp
inds Slice SubExp -> Slice SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== Type -> Slice SubExp -> Slice SubExp
fullSlice Type
t [] ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
forall a. Monoid a => a
mempty (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd
| Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
Just (ST.Indexed Certificates
cs PrimExp VName
e) <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining PrimExp VName
e,
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e
| Just [SubExp]
inds' <- Slice SubExp -> Maybe [SubExp]
forall d. Slice d -> Maybe [d]
sliceIndices Slice SubExp
inds,
Just (ST.IndexedArray Certificates
cs VName
arr [TPrimExp Int64 VName]
inds'') <- VName -> [SubExp] -> SymbolTable (Lore m) -> Maybe Indexed
forall lore.
ASTLore lore =>
VName -> [SubExp] -> SymbolTable lore -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Lore m)
vtable,
(TPrimExp Int64 VName -> Bool) -> [TPrimExp Int64 VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (PrimExp VName -> Bool
forall v. PrimExp v -> Bool
worthInlining (PrimExp VName -> Bool)
-> (TPrimExp Int64 VName -> PrimExp VName)
-> TPrimExp Int64 VName
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TPrimExp Int64 VName -> PrimExp VName
forall t v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
inds'',
(VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
`ST.elem` SymbolTable (Lore m)
vtable) (Certificates -> [VName]
unCertificates Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult)
-> ([SubExp] -> Slice SubExp) -> [SubExp] -> IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> DimIndex SubExp) -> [SubExp] -> Slice SubExp
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix
([SubExp] -> IndexResult) -> m [SubExp] -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TPrimExp Int64 VName -> m SubExp)
-> [TPrimExp Int64 VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> TPrimExp Int64 VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [TPrimExp Int64 VName]
inds''
Maybe (BasicOp, Certificates)
Nothing -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
Just (SubExp (Var VName
v), Certificates
cs) -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v Slice SubExp
inds
Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certificates
cs)
| [DimFix SubExp
ii] <- Slice SubExp
inds,
Just (Prim (IntType IntType
from_it)) <- TypeLookup
seType SubExp
ii ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
let mul :: PrimExp v -> PrimExp v -> PrimExp v
mul = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
add :: PrimExp v -> PrimExp v -> PrimExp v
add = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
in (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
( IntType -> PrimExp VName -> PrimExp VName
forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
)
PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
| [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- Slice SubExp
inds ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
SubExp
i_offset' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
SubExp
i_stride' <- IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
let mul :: PrimExp v -> PrimExp v -> PrimExp v
mul = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
add :: PrimExp v -> PrimExp v -> PrimExp v
add = BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (BinOp -> PrimExp v -> PrimExp v -> PrimExp v)
-> BinOp -> PrimExp v -> PrimExp v -> PrimExp v
forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
SubExp
i_offset'' <-
String -> PrimExp VName -> m SubExp
forall (m :: * -> *) a.
(MonadBinder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" (PrimExp VName -> m SubExp) -> PrimExp VName -> m SubExp
forall a b. (a -> b) -> a -> b
$
( PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
)
PrimExp VName -> PrimExp VName -> PrimExp VName
forall v. PrimExp v -> PrimExp v -> PrimExp v
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
SubExp
i_stride'' <-
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"iota_offset" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
(SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs) (m SubExp -> m IndexResult) -> m SubExp -> m IndexResult
forall a b. (a -> b) -> a -> b
$
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"slice_iota" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it
Just (Rotate [SubExp]
offsets VName
a, Certificates
cs)
| Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp -> Bool)
-> [SubExp] -> Slice SubExp -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
forall d. SubExp -> DimIndex d -> Bool
rotateAndSlice [SubExp]
offsets Slice SubExp
inds -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
[SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> m Type -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
a
let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
SubExp
i_p_o <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"i_p_o" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
o
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"rot_i" (BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
Int64 Safety
Unsafe) SubExp
i_p_o SubExp
d)
adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (SubExp -> SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SubExp -> SubExp -> SubExp -> f SubExp
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d f (SubExp -> SubExp -> DimIndex SubExp)
-> f SubExp -> f (SubExp -> DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n f (SubExp -> DimIndex SubExp) -> f SubExp -> f (DimIndex SubExp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> f SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
a (Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp))
-> [(DimIndex SubExp, SubExp, SubExp)] -> m (Slice SubExp)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (DimIndex SubExp, SubExp, SubExp) -> m (DimIndex SubExp)
forall (f :: * -> *).
MonadBinder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (Slice SubExp
-> [SubExp] -> [SubExp] -> [(DimIndex SubExp, SubExp, SubExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Slice SubExp
inds [SubExp]
offsets [SubExp]
dims)
where
rotateAndSlice :: SubExp -> DimIndex d -> Bool
rotateAndSlice SubExp
r DimSlice {} = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
r
rotateAndSlice SubExp
_ DimIndex d
_ = Bool
False
Just (Index VName
aa Slice SubExp
ais, Certificates
cs) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
aa
(Slice SubExp -> IndexResult) -> m (Slice SubExp) -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
forall (m :: * -> *).
MonadBinder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (Slice (TPrimExp Int64 VName)
-> Slice (TPrimExp Int64 VName) -> Slice (TPrimExp Int64 VName)
forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
inds))
Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certificates
cs)
| [DimFix {}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
| DimFix {} : Slice SubExp
is' <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
vv Slice SubExp
is'
Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certificates
cs)
| [DimFix {}] <- Slice SubExp
inds, Bool -> Bool
not Bool
consuming -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
val
Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certificates
cs)
| (Slice SubExp
ds_inds, Slice SubExp
rest_inds) <- Int -> Slice SubExp -> (Slice SubExp, Slice SubExp)
forall a. Int -> [a] -> ([a], [a])
splitAt ([SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) Slice SubExp
inds,
([SubExp]
ds', Slice SubExp
ds_inds') <- [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp))
-> [(SubExp, DimIndex SubExp)] -> ([SubExp], Slice SubExp)
forall a b. (a -> b) -> a -> b
$ (DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp))
-> Slice SubExp -> [(SubExp, DimIndex SubExp)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index Slice SubExp
ds_inds,
[SubExp]
ds' [SubExp] -> [SubExp] -> Bool
forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
VName
arr <- String -> Exp (Lore m) -> m VName
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m VName
letExp String
"smaller_replicate" (Exp (Lore m) -> m VName) -> Exp (Lore m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
IndexResult -> m IndexResult
forall (m :: * -> *) a. Monad m => a -> m a
return (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
arr (Slice SubExp -> IndexResult) -> Slice SubExp -> IndexResult
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ds_inds' Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ Slice SubExp
rest_inds
where
index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix {} = Maybe (SubExp, DimIndex SubExp)
forall a. Maybe a
Nothing
index (DimSlice SubExp
_ SubExp
n SubExp
s) = (SubExp, DimIndex SubExp) -> Maybe (SubExp, DimIndex SubExp)
forall a. a -> Maybe a
Just (SubExp
n, SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
n SubExp
s)
Just (Rearrange [Int]
perm VName
src, Certificates
cs)
| [Int] -> Int
rearrangeReach [Int]
perm Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ((DimIndex SubExp -> Bool) -> Slice SubExp -> Slice SubExp
forall a. (a -> Bool) -> [a] -> [a]
takeWhile DimIndex SubExp -> Bool
forall d. DimIndex d -> Bool
isIndex Slice SubExp
inds) ->
let inds' :: Slice SubExp
inds' = [Int] -> Slice SubExp -> Slice SubExp
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) Slice SubExp
inds
in m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds'
where
isIndex :: DimIndex d -> Bool
isIndex DimFix {} = Bool
True
isIndex DimIndex d
_ = Bool
False
Just (Copy VName
src, Certificates
cs)
| Just [SubExp]
dims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
Just [SubExp]
_ <- (DimIndex SubExp -> Maybe SubExp) -> Slice SubExp -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM DimIndex SubExp -> Maybe SubExp
forall d. DimIndex d -> Maybe d
dimFix Slice SubExp
inds,
Bool -> Bool
not Bool
consuming,
VName -> SymbolTable (Lore m) -> Bool
forall lore. VName -> SymbolTable lore -> Bool
ST.available VName
src SymbolTable (Lore m)
vtable ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
Just (Reshape ShapeChange SubExp
newshape VName
src, Certificates
cs)
| Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
Just [SubExp]
olddims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
[Bool]
changed_dims <- (SubExp -> SubExp -> Bool) -> [SubExp] -> [SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
(/=) [SubExp]
newdims [SubExp]
olddims,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
or ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
drop (Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds) [Bool]
changed_dims ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
| Just [SubExp]
newdims <- ShapeChange SubExp -> Maybe [SubExp]
forall d. ShapeChange d -> Maybe [d]
shapeCoercion ShapeChange SubExp
newshape,
Just [SubExp]
olddims <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
ShapeChange SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange SubExp
newshape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Slice SubExp -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice SubExp
inds,
[SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
newdims ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
src Slice SubExp
inds
Just (Reshape [DimChange SubExp
_] VName
v2, Certificates
cs)
| Just [SubExp
_] <- Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims (Type -> [SubExp]) -> Maybe Type -> Maybe [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
v2) ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds
Just (Concat Int
d VName
x [VName]
xs SubExp
_, Certificates
cs)
|
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ (VName -> Bool) -> [VName] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isConcat ([VName] -> Bool) -> [VName] -> Bool
forall a b. (a -> b) -> a -> b
$ VName
x VName -> [VName] -> [VName]
forall a. a -> [a] -> [a]
: [VName]
xs,
Just (Slice SubExp
ibef, DimFix SubExp
i, Slice SubExp
iaft) <- Int
-> Slice SubExp
-> Maybe (Slice SubExp, DimIndex SubExp, Slice SubExp)
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d Slice SubExp
inds,
Just (Prim PrimType
res_t) <-
(Type -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` Slice SubExp -> [SubExp]
forall d. Slice d -> [d]
sliceDims Slice SubExp
inds)
(Type -> Type) -> Maybe Type -> Maybe Type
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> SymbolTable (Lore m) -> Maybe Type
forall lore.
ASTLore lore =>
VName -> SymbolTable lore -> Maybe Type
ST.lookupType VName
x SymbolTable (Lore m)
vtable -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ do
SubExp
x_len <- Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d (Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType VName
x
[SubExp]
xs_lens <- (VName -> m SubExp) -> [VName] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Type -> SubExp) -> m Type -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) (m Type -> m SubExp) -> (VName -> m Type) -> VName -> m SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> m Type
forall lore (m :: * -> *). HasScope lore m => VName -> m Type
lookupType) [VName]
xs
let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
SubExp
added <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_add" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
n SubExp
m
(SubExp, SubExp) -> m (SubExp, SubExp)
forall (m :: * -> *) a. Monad m => a -> m a
return (SubExp
added, SubExp
n)
(SubExp
_, [SubExp]
starts) <- (SubExp -> SubExp -> m (SubExp, SubExp))
-> SubExp -> [SubExp] -> m (SubExp, [SubExp])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM SubExp -> SubExp -> m (SubExp, SubExp)
forall (m :: * -> *).
MonadBinder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = [(VName, SubExp)] -> [(VName, SubExp)]
forall a. [a] -> [a]
reverse ([(VName, SubExp)] -> [(VName, SubExp)])
-> [(VName, SubExp)] -> [(VName, SubExp)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts
let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
mkBranch ((VName
x', SubExp
start) : [(VName, SubExp)]
xs_and_starts') = do
SubExp
cmp <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_cmp" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
start SubExp
i
(SubExp
thisres, Stms (Lore m)
thisbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ do
SubExp
i' <- String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_i" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
start
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Lore m)
forall lore. BasicOp -> ExpT lore
BasicOp (BasicOp -> Exp (Lore m)) -> BasicOp -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x' (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Slice SubExp
ibef Slice SubExp -> Slice SubExp -> Slice SubExp
forall a. [a] -> [a] -> [a]
++ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
i' DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
iaft
BodyT (Lore m)
thisbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
thisbnds [SubExp
thisres]
(SubExp
altres, Stms (Lore m)
altbnds) <- m SubExp -> m (SubExp, Stms (Lore m))
forall (m :: * -> *) a.
MonadBinder m =>
m a -> m (a, Stms (Lore m))
collectStms (m SubExp -> m (SubExp, Stms (Lore m)))
-> m SubExp -> m (SubExp, Stms (Lore m))
forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
BodyT (Lore m)
altbody <- Stms (Lore m) -> [SubExp] -> m (BodyT (Lore m))
forall (m :: * -> *).
MonadBinder m =>
Stms (Lore m) -> [SubExp] -> m (Body (Lore m))
mkBodyM Stms (Lore m)
altbnds [SubExp
altres]
String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
"index_concat_branch" (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
SubExp
-> BodyT (Lore m)
-> BodyT (Lore m)
-> IfDec (BranchType (Lore m))
-> Exp (Lore m)
forall lore.
SubExp
-> BodyT lore -> BodyT lore -> IfDec (BranchType lore) -> ExpT lore
If SubExp
cmp BodyT (Lore m)
thisbody BodyT (Lore m)
altbody (IfDec (BranchType (Lore m)) -> Exp (Lore m))
-> IfDec (BranchType (Lore m)) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$
[BranchType (Lore m)] -> IfSort -> IfDec (BranchType (Lore m))
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [PrimType -> BranchType (Lore m)
forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] IfSort
IfNormal
Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs (SubExp -> IndexResult) -> m SubExp -> m IndexResult
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VName, SubExp)] -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
[(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts
Just (ArrayLit [SubExp]
ses Type
_, Certificates
cs)
| DimFix (Constant (IntValue (Int64Value Int64
i))) : Slice SubExp
inds' <- Slice SubExp
inds,
Just SubExp
se <- Int64 -> [SubExp] -> Maybe SubExp
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int64
i [SubExp]
ses ->
case Slice SubExp
inds' of
[] -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> SubExp -> IndexResult
SubExpResult Certificates
cs SubExp
se
Slice SubExp
_ | Var VName
v2 <- SubExp
se -> m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> m IndexResult -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$ IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult) -> IndexResult -> m IndexResult
forall a b. (a -> b) -> a -> b
$ Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
cs VName
v2 Slice SubExp
inds'
Slice SubExp
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
Maybe (BasicOp, Certificates)
_
| Just Type
t <- TypeLookup
seType TypeLookup -> TypeLookup
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd,
SubExp -> Bool
isCt1 (SubExp -> Bool) -> SubExp -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> Type -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t,
DimFix SubExp
i : Slice SubExp
inds' <- Slice SubExp
inds,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
m IndexResult -> Maybe (m IndexResult)
forall a. a -> Maybe a
Just (m IndexResult -> Maybe (m IndexResult))
-> (Slice SubExp -> m IndexResult)
-> Slice SubExp
-> Maybe (m IndexResult)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IndexResult -> m IndexResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IndexResult -> m IndexResult)
-> (Slice SubExp -> IndexResult) -> Slice SubExp -> m IndexResult
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certificates -> VName -> Slice SubExp -> IndexResult
IndexResult Certificates
forall a. Monoid a => a
mempty VName
idd (Slice SubExp -> Maybe (m IndexResult))
-> Slice SubExp -> Maybe (m IndexResult)
forall a b. (a -> b) -> a -> b
$
SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) DimIndex SubExp -> Slice SubExp -> Slice SubExp
forall a. a -> [a] -> [a]
: Slice SubExp
inds'
Maybe (BasicOp, Certificates)
_ -> Maybe (m IndexResult)
forall a. Maybe a
Nothing
where
defOf :: VName -> Maybe (BasicOp, Certificates)
defOf VName
v = do
(BasicOp BasicOp
op, Certificates
def_cs) <- VName -> SymbolTable (Lore m) -> Maybe (Exp (Lore m), Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (Exp lore, Certificates)
ST.lookupExp VName
v SymbolTable (Lore m)
vtable
(BasicOp, Certificates) -> Maybe (BasicOp, Certificates)
forall (m :: * -> *) a. Monad m => a -> m a
return (BasicOp
op, Certificates
def_cs)
worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
| Int -> PrimExp v -> Bool
forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False
| Bool
otherwise = PrimExp v -> Bool
forall v. PrimExp v -> Bool
worthInlining' PrimExp v
e
worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow {} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp FPow {} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' FunExp {} = Bool
False
worthInlining' PrimExp v
_ = Bool
True
isConcat :: VName -> Bool
isConcat VName
v
| Just (Concat {}, Certificates
_) <- VName -> Maybe (BasicOp, Certificates)
defOf VName
v =
Bool
True
| Bool
otherwise =
Bool
False