{-# OPTIONS_GHC -Wno-overlapping-patterns -Wno-incomplete-patterns -Wno-incomplete-uni-patterns -Wno-incomplete-record-updates #-}
module Futhark.Optimise.Simplify.Rules.Index
( IndexResult (..),
simplifyIndexing,
)
where
import Data.List.NonEmpty (NonEmpty (..))
import Data.Maybe
import Futhark.Analysis.PrimExp.Convert
import Futhark.Analysis.SymbolTable qualified as ST
import Futhark.Construct
import Futhark.IR
import Futhark.Optimise.Simplify.Rules.Simple
import Futhark.Util
isCt1 :: SubExp -> Bool
isCt1 :: SubExp -> Bool
isCt1 (Constant PrimValue
v) = PrimValue -> Bool
oneIsh PrimValue
v
isCt1 SubExp
_ = Bool
False
isCt0 :: SubExp -> Bool
isCt0 :: SubExp -> Bool
isCt0 (Constant PrimValue
v) = PrimValue -> Bool
zeroIsh PrimValue
v
isCt0 SubExp
_ = Bool
False
data IndexResult
= IndexResult Certs VName (Slice SubExp)
| SubExpResult Certs SubExp
simplifyIndexing ::
MonadBuilder m =>
ST.SymbolTable (Rep m) ->
TypeLookup ->
VName ->
Slice SubExp ->
Bool ->
Maybe (m IndexResult)
simplifyIndexing :: forall (m :: * -> *).
MonadBuilder m =>
SymbolTable (Rep m)
-> TypeLookup
-> VName
-> Slice SubExp
-> Bool
-> Maybe (m IndexResult)
simplifyIndexing SymbolTable (Rep m)
vtable TypeLookup
seType VName
idd (Slice [DimIndex SubExp]
inds) Bool
consuming =
case VName -> Maybe (BasicOp, Certs)
defOf VName
idd of
Maybe (BasicOp, Certs)
_
| Just Type
t <- TypeLookup
seType (VName -> SubExp
Var VName
idd),
forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds forall a. Eq a => a -> a -> Bool
== Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [] ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult forall a. Monoid a => a
mempty forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd
| Just [SubExp]
inds' <- forall d. Slice d -> Maybe [d]
sliceIndices (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds),
Just (ST.Indexed Certs
cs PrimExp VName
e) <- forall {k} (rep :: k).
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Rep m)
vtable,
forall {v}. PrimExp v -> Bool
worthInlining PrimExp VName
e,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Rep m)
vtable) (Certs -> [VName]
unCerts Certs
cs) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp" PrimExp VName
e
| Just [SubExp]
inds' <- forall d. Slice d -> Maybe [d]
sliceIndices (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds),
Just (ST.IndexedArray Certs
cs VName
arr [TPrimExp Int64 VName]
inds'') <- forall {k} (rep :: k).
ASTRep rep =>
VName -> [SubExp] -> SymbolTable rep -> Maybe Indexed
ST.index VName
idd [SubExp]
inds' SymbolTable (Rep m)
vtable,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {v}. PrimExp v -> Bool
worthInlining forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) [TPrimExp Int64 VName]
inds'',
VName
arr forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.available` SymbolTable (Rep m)
vtable,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
`ST.elem` SymbolTable (Rep m)
vtable) (Certs -> [VName]
unCerts Certs
cs) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
arr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map forall d. d -> DimIndex d
DimFix
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_primexp") [TPrimExp Int64 VName]
inds''
Maybe (BasicOp, Certs)
Nothing -> forall a. Maybe a
Nothing
Just (SubExp (Var VName
v), Certs
cs) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
v forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
Just (Iota SubExp
_ SubExp
x SubExp
s IntType
to_it, Certs
cs)
| [DimFix SubExp
ii] <- [DimIndex SubExp]
inds,
Just (Prim (IntType IntType
from_it)) <- TypeLookup
seType SubExp
ii ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
let mul :: PrimExp v -> PrimExp v -> PrimExp v
mul = forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
add :: PrimExp v -> PrimExp v -> PrimExp v
add = forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
in forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> SubExp -> IndexResult
SubExpResult Certs
cs) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"index_iota" forall a b. (a -> b) -> a -> b
$
( forall v. IntType -> PrimExp v -> PrimExp v
sExt IntType
to_it (PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
from_it) SubExp
ii)
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
)
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
| [DimSlice SubExp
i_offset SubExp
i_n SubExp
i_stride] <- [DimIndex SubExp]
inds ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
SubExp
i_offset' <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_offset
SubExp
i_stride' <- forall (m :: * -> *).
MonadBuilder m =>
IntType -> SubExp -> m SubExp
asIntS IntType
to_it SubExp
i_stride
let mul :: PrimExp v -> PrimExp v -> PrimExp v
mul = forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Mul IntType
to_it Overflow
OverflowWrap
add :: PrimExp v -> PrimExp v -> PrimExp v
add = forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp forall a b. (a -> b) -> a -> b
$ IntType -> Overflow -> BinOp
Add IntType
to_it Overflow
OverflowWrap
SubExp
i_offset'' <-
forall (m :: * -> *) a.
(MonadBuilder m, ToExp a) =>
String -> a -> m SubExp
toSubExp String
"iota_offset" forall a b. (a -> b) -> a -> b
$
( PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
x
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`mul` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
s
)
forall {v}. PrimExp v -> PrimExp v -> PrimExp v
`add` PrimType -> SubExp -> PrimExp VName
primExpFromSubExp (IntType -> PrimType
IntType IntType
to_it) SubExp
i_offset'
SubExp
i_stride'' <-
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"iota_offset" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowWrap) SubExp
s SubExp
i_stride'
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Certs -> SubExp -> IndexResult
SubExpResult Certs
cs) forall a b. (a -> b) -> a -> b
$
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"slice_iota" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$
SubExp -> SubExp -> SubExp -> IntType -> BasicOp
Iota SubExp
i_n SubExp
i_offset'' SubExp
i_stride'' IntType
to_it
Just (Rotate [SubExp]
offsets VName
a, Certs
cs)
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *). Foldable t => t Bool -> Bool
or forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall {d}. SubExp -> DimIndex d -> Bool
rotateAndSlice [SubExp]
offsets [DimIndex SubExp]
inds -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
[SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
a
let adjustI :: SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d = do
SubExp
i_p_o <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"i_p_o" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
o
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"rot_i" (forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Safety -> BinOp
SMod IntType
Int64 Safety
Unsafe) SubExp
i_p_o SubExp
d)
adjust :: (DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (DimFix SubExp
i, SubExp
o, SubExp
d) =
forall d. d -> DimIndex d
DimFix forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d
adjust (DimSlice SubExp
i SubExp
n SubExp
s, SubExp
o, SubExp
d) =
forall d. d -> d -> d -> DimIndex d
DimSlice forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> SubExp -> m SubExp
adjustI SubExp
i SubExp
o SubExp
d forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
n forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
s
Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
a forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {f :: * -> *}.
MonadBuilder f =>
(DimIndex SubExp, SubExp, SubExp) -> f (DimIndex SubExp)
adjust (forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [DimIndex SubExp]
inds [SubExp]
offsets [SubExp]
dims)
where
rotateAndSlice :: SubExp -> DimIndex d -> Bool
rotateAndSlice SubExp
r DimSlice {} = Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
r
rotateAndSlice SubExp
_ DimIndex d
_ = Bool
False
Just (Index VName
aa Slice SubExp
ais, Certs
cs) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$
Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
aa
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *).
MonadBuilder m =>
Slice (TPrimExp Int64 VName) -> m (Slice SubExp)
subExpSlice (forall d. Num d => Slice d -> Slice d -> Slice d
sliceSlice (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice Slice SubExp
ais) (Slice SubExp -> Slice (TPrimExp Int64 VName)
primExpSlice (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds)))
Just (Replicate (Shape [SubExp
_]) (Var VName
vv), Certs
cs)
| [DimFix {}] <- [DimIndex SubExp]
inds,
Bool -> Bool
not Bool
consuming,
forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
ST.available VName
vv SymbolTable (Rep m)
vtable ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
vv
| DimFix {} : [DimIndex SubExp]
is' <- [DimIndex SubExp]
inds,
Bool -> Bool
not Bool
consuming,
forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
ST.available VName
vv SymbolTable (Rep m)
vtable ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
vv forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
is'
Just (Replicate (Shape [SubExp
_]) val :: SubExp
val@(Constant PrimValue
_), Certs
cs)
| [DimFix {}] <- [DimIndex SubExp]
inds, Bool -> Bool
not Bool
consuming -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs SubExp
val
Just (Replicate (Shape [SubExp]
ds) SubExp
v, Certs
cs)
| ([DimIndex SubExp]
ds_inds, [DimIndex SubExp]
rest_inds) <- forall a. Int -> [a] -> ([a], [a])
splitAt (forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
ds) [DimIndex SubExp]
inds,
([SubExp]
ds', [DimIndex SubExp]
ds_inds') <- forall a b. [(a, b)] -> ([a], [b])
unzip forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index [DimIndex SubExp]
ds_inds,
[SubExp]
ds' forall a. Eq a => a -> a -> Bool
/= [SubExp]
ds ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
VName
arr <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"smaller_replicate" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ Shape -> SubExp -> BasicOp
Replicate (forall d. [d] -> ShapeBase d
Shape [SubExp]
ds') SubExp
v
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
arr forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
ds_inds' forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
rest_inds
where
index :: DimIndex SubExp -> Maybe (SubExp, DimIndex SubExp)
index DimFix {} = forall a. Maybe a
Nothing
index (DimSlice SubExp
_ SubExp
n SubExp
s) = forall a. a -> Maybe a
Just (SubExp
n, forall d. d -> d -> d -> DimIndex d
DimSlice (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
n SubExp
s)
Just (Rearrange [Int]
perm VName
src, Certs
cs)
| [Int] -> Int
rearrangeReach [Int]
perm forall a. Ord a => a -> a -> Bool
<= forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall a. (a -> Bool) -> [a] -> [a]
takeWhile forall {d}. DimIndex d -> Bool
isIndex [DimIndex SubExp]
inds) ->
let inds' :: [DimIndex SubExp]
inds' = forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm) [DimIndex SubExp]
inds
in forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds'
where
isIndex :: DimIndex d -> Bool
isIndex DimFix {} = Bool
True
isIndex DimIndex d
_ = Bool
False
Just (Copy VName
src, Certs
cs)
| Just [SubExp]
dims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
inds forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
dims,
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. DimIndex d -> Maybe d
dimFix) [DimIndex SubExp]
inds
Bool -> Bool -> Bool
|| forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True ((forall {k} (rep :: k). SymbolTable rep -> Int
ST.loopDepth SymbolTable (Rep m)
vtable /=) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). Entry rep -> Int
ST.entryDepth) (forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Entry rep)
ST.lookup VName
src SymbolTable (Rep m)
vtable),
Bool -> Bool
not Bool
consuming,
forall {k} (rep :: k). VName -> SymbolTable rep -> Bool
ST.available VName
src SymbolTable (Rep m)
vtable ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
Just (Reshape ReshapeKind
ReshapeCoerce Shape
newshape VName
src, Certs
cs)
| Just [SubExp]
olddims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
[Bool]
changed_dims <- forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith forall a. Eq a => a -> a -> Bool
(/=) (forall d. ShapeBase d -> [d]
shapeDims Shape
newshape) [SubExp]
olddims,
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *). Foldable t => t Bool -> Bool
or forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
drop (forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
inds) [Bool]
changed_dims ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
| Just [SubExp]
olddims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
src),
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape
newshape forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
inds,
forall (t :: * -> *) a. Foldable t => t a -> Int
length [SubExp]
olddims forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length (forall d. ShapeBase d -> [d]
shapeDims Shape
newshape) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
src forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
Just (Reshape ReshapeKind
_ (Shape [SubExp
_]) VName
v2, Certs
cs)
| Just [SubExp
_] <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeLookup
seType (VName -> SubExp
Var VName
v2) ->
forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
v2 forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds
Just (Concat Int
d (VName
x :| [VName]
xs) SubExp
_, Certs
cs)
|
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any VName -> Bool
isConcat forall a b. (a -> b) -> a -> b
$ VName
x forall a. a -> [a] -> [a]
: [VName]
xs,
Just ([DimIndex SubExp]
ibef, DimFix SubExp
i, [DimIndex SubExp]
iaft) <- forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
d [DimIndex SubExp]
inds,
Just (Prim PrimType
res_t) <-
(forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` forall d. Slice d -> [d]
sliceDims (forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k).
ASTRep rep =>
VName -> SymbolTable rep -> Maybe Type
ST.lookupType VName
x SymbolTable (Rep m)
vtable -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ do
SubExp
x_len <- forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
x
[SubExp]
xs_lens <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
d) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType) [VName]
xs
let add :: SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
n SubExp
m = do
SubExp
added <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_add" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowWrap) SubExp
n SubExp
m
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp
added, SubExp
n)
(SubExp
_, [SubExp]
starts) <- forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m (SubExp, SubExp)
add SubExp
x_len [SubExp]
xs_lens
let xs_and_starts :: [(VName, SubExp)]
xs_and_starts = forall a. [a] -> [a]
reverse forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
xs [SubExp]
starts
let mkBranch :: [(VName, SubExp)] -> m SubExp
mkBranch [] =
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
x forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
ibef forall a. [a] -> [a] -> [a]
++ forall d. d -> DimIndex d
DimFix SubExp
i forall a. a -> [a] -> [a]
: [DimIndex SubExp]
iaft
mkBranch ((VName
x', SubExp
start) : [(VName, SubExp)]
xs_and_starts') = do
SubExp
cmp <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_cmp" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
start SubExp
i
(SubExp
thisres, Stms (Rep m)
thisstms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ do
SubExp
i' <- forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_i" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowWrap) SubExp
i SubExp
start
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat" forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Slice SubExp -> BasicOp
Index VName
x' forall a b. (a -> b) -> a -> b
$
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp]
ibef forall a. [a] -> [a] -> [a]
++ forall d. d -> DimIndex d
DimFix SubExp
i' forall a. a -> [a] -> [a]
: [DimIndex SubExp]
iaft)
Body (Rep m)
thisbody <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
thisstms [SubExp -> SubExpRes
subExpRes SubExp
thisres]
(SubExp
altres, Stms (Rep m)
altstms) <- forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms forall a b. (a -> b) -> a -> b
$ [(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts'
Body (Rep m)
altbody <- forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
altstms [SubExp -> SubExpRes
subExpRes SubExp
altres]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"index_concat_branch" forall a b. (a -> b) -> a -> b
$
forall {k} (rep :: k).
[SubExp]
-> [Case (Body rep)]
-> Body rep
-> MatchDec (BranchType rep)
-> Exp rep
Match [SubExp
cmp] [forall body. [Maybe PrimValue] -> body -> Case body
Case [forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Bool -> PrimValue
BoolValue Bool
True] Body (Rep m)
thisbody] Body (Rep m)
altbody forall a b. (a -> b) -> a -> b
$
forall rt. [rt] -> MatchSort -> MatchDec rt
MatchDec [forall rt. IsBodyType rt => PrimType -> rt
primBodyType PrimType
res_t] MatchSort
MatchNormal
Certs -> SubExp -> IndexResult
SubExpResult Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {m :: * -> *}.
MonadBuilder m =>
[(VName, SubExp)] -> m SubExp
mkBranch [(VName, SubExp)]
xs_and_starts
Just (ArrayLit [SubExp]
ses Type
_, Certs
cs)
| DimFix (Constant (IntValue (Int64Value Int64
i))) : [DimIndex SubExp]
inds' <- [DimIndex SubExp]
inds,
Just SubExp
se <- forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int64
i [SubExp]
ses ->
case [DimIndex SubExp]
inds' of
[] -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> SubExp -> IndexResult
SubExpResult Certs
cs SubExp
se
[DimIndex SubExp]
_ | Var VName
v2 <- SubExp
se -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Certs -> VName -> Slice SubExp -> IndexResult
IndexResult Certs
cs VName
v2 forall a b. (a -> b) -> a -> b
$ forall d. [DimIndex d] -> Slice d
Slice [DimIndex SubExp]
inds'
[DimIndex SubExp]
_ -> forall a. Maybe a
Nothing
Maybe (BasicOp, Certs)
_
| Just Type
t <- TypeLookup
seType forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
idd,
SubExp -> Bool
isCt1 forall a b. (a -> b) -> a -> b
$ forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
0 Type
t,
DimFix SubExp
i : [DimIndex SubExp]
inds' <- [DimIndex SubExp]
inds,
Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ SubExp -> Bool
isCt0 SubExp
i ->
forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. Certs -> VName -> Slice SubExp -> IndexResult
IndexResult forall a. Monoid a => a
mempty VName
idd forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall d. [DimIndex d] -> Slice d
Slice forall a b. (a -> b) -> a -> b
$
forall d. d -> DimIndex d
DimFix (forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) forall a. a -> [a] -> [a]
: [DimIndex SubExp]
inds'
Maybe (BasicOp, Certs)
_ -> forall a. Maybe a
Nothing
where
defOf :: VName -> Maybe (BasicOp, Certs)
defOf VName
v = do
(BasicOp BasicOp
op, Certs
def_cs) <- forall {k} (rep :: k).
VName -> SymbolTable rep -> Maybe (Exp rep, Certs)
ST.lookupExp VName
v SymbolTable (Rep m)
vtable
forall (f :: * -> *) a. Applicative f => a -> f a
pure (BasicOp
op, Certs
def_cs)
worthInlining :: PrimExp v -> Bool
worthInlining PrimExp v
e
| forall v. Int -> PrimExp v -> Bool
primExpSizeAtLeast Int
20 PrimExp v
e = Bool
False
| Bool
otherwise = forall {v}. PrimExp v -> Bool
worthInlining' PrimExp v
e
worthInlining' :: PrimExp v -> Bool
worthInlining' (BinOpExp Pow {} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp FPow {} PrimExp v
_ PrimExp v
_) = Bool
False
worthInlining' (BinOpExp BinOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (CmpOpExp CmpOp
_ PrimExp v
x PrimExp v
y) = PrimExp v -> Bool
worthInlining' PrimExp v
x Bool -> Bool -> Bool
&& PrimExp v -> Bool
worthInlining' PrimExp v
y
worthInlining' (ConvOpExp ConvOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' (UnOpExp UnOp
_ PrimExp v
x) = PrimExp v -> Bool
worthInlining' PrimExp v
x
worthInlining' FunExp {} = Bool
False
worthInlining' PrimExp v
_ = Bool
True
isConcat :: VName -> Bool
isConcat VName
v
| Just (Concat {}, Certs
_) <- VName -> Maybe (BasicOp, Certs)
defOf VName
v =
Bool
True
| Bool
otherwise =
Bool
False