{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Futhark.IR.Mem.IxFun
( IxFun (..),
Shape,
LMAD (..),
LMADDim (..),
index,
mkExistential,
iota,
iotaOffset,
permute,
reshape,
coerce,
slice,
flatSlice,
rebase,
shape,
permutation,
rank,
isDirect,
substituteInIxFun,
substituteInLMAD,
existentialize,
closeEnough,
equivalent,
permuteInv,
disjoint,
disjoint2,
disjoint3,
dynamicEqualsLMAD,
)
where
import Control.Category
import Control.Monad
import Control.Monad.State
import Data.List (sort, zip4)
import Data.Map.Strict qualified as M
import Data.Traversable
import Futhark.Analysis.PrimExp
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.LMAD hiding
( flatSlice,
index,
iota,
mkExistential,
permutation,
permute,
reshape,
shape,
slice,
)
import Futhark.IR.Mem.LMAD qualified as LMAD
import Futhark.IR.Prop
import Futhark.IR.Syntax
( FlatSlice (..),
Slice (..),
unitSlice,
)
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (gcd, id, mod, (.))
data IxFun num = IxFun
{ forall num. IxFun num -> LMAD num
ixfunLMAD :: LMAD num,
forall num. IxFun num -> Shape num
base :: Shape num
}
deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
showsPrec :: Int -> IxFun num -> ShowS
$cshow :: forall num. Show num => IxFun num -> String
show :: IxFun num -> String
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
showList :: [IxFun num] -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
/= :: IxFun num -> IxFun num -> Bool
Eq)
instance Pretty num => Pretty (IxFun num) where
pretty :: forall ann. IxFun num -> Doc ann
pretty (IxFun LMAD num
lmad Shape num
oshp) =
Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
braces (Doc ann -> Doc ann)
-> ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. [Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
semistack ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$
[ Doc ann
"base:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
brackets ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (num -> Doc ann) -> Shape num -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oshp),
Doc ann
"LMAD:" Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> LMAD num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. LMAD num -> Doc ann
pretty LMAD num
lmad
]
instance Substitute num => Substitute (IxFun num) where
substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> IxFun a -> IxFun b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Rename (IxFun num) where
rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn num => FreeIn (IxFun num) where
freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall m a. Monoid m => (a -> m) -> IxFun a -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'
instance Functor IxFun where
fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap = (a -> b) -> IxFun a -> IxFun b
forall (t :: * -> *) a b. Traversable t => (a -> b) -> t a -> t b
fmapDefault
instance Foldable IxFun where
foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap = (a -> m) -> IxFun a -> m
forall (t :: * -> *) m a.
(Traversable t, Monoid m) =>
(a -> m) -> t a -> m
foldMapDefault
instance Traversable IxFun where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun LMAD a
lmad Shape a
oshp) =
LMAD b -> Shape b -> IxFun b
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD b -> Shape b -> IxFun b)
-> f (LMAD b) -> f (Shape b -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f LMAD a
lmad f (Shape b -> IxFun b) -> f (Shape b) -> f (IxFun b)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse a -> f b
f Shape a
oshp
substituteInIxFun ::
Ord a =>
M.Map a (TPrimExp t a) ->
IxFun (TPrimExp t a) ->
IxFun (TPrimExp t a)
substituteInIxFun :: forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun LMAD (TPrimExp t a)
lmad Shape (TPrimExp t a)
oshp) =
LMAD (TPrimExp t a) -> Shape (TPrimExp t a) -> IxFun (TPrimExp t a)
forall num. LMAD num -> Shape num -> IxFun num
IxFun
(Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
forall {k} a (t :: k).
Ord a =>
Map a (TPrimExp t a) -> LMAD (TPrimExp t a) -> LMAD (TPrimExp t a)
substituteInLMAD Map a (TPrimExp t a)
tab LMAD (TPrimExp t a)
lmad)
((TPrimExp t a -> TPrimExp t a)
-> Shape (TPrimExp t a) -> Shape (TPrimExp t a)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp a -> TPrimExp t a
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
where
tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall a b. (a -> b) -> Map a a -> Map a b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect (IxFun (LMAD num
offset [LMADDim num]
dims) Shape num
oshp) =
let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. HasCallStack => [a] -> [a]
tail Shape num
oshp))
in Shape num -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
Bool -> Bool -> Bool
&& num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
Bool -> Bool -> Bool
&& ((LMADDim num, Int, num, num) -> Bool)
-> [(LMADDim num, Int, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
(\(LMADDim num
s num
n Int
p, Int
m, num
d, num
se) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m)
([LMADDim num]
-> [Int]
-> Shape num
-> Shape num
-> [(LMADDim num, Int, num, num)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall a. IxFun a -> Bool
hasContiguousPerm (IxFun LMAD num
lmad Shape num
_) =
let perm :: [Int]
perm = LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad
in [Int]
perm [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> [Int]
forall a. Ord a => [a] -> [a]
sort [Int]
perm
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun LMAD num
lmad Shape num
_) =
[Int] -> Shape num -> Shape num
forall a. [Int] -> [a] -> [a]
permuteFwd (LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad) (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num
forall num. LMAD num -> Shape num
LMAD.shapeBase LMAD num
lmad
permutation :: IxFun num -> Permutation
permutation :: forall num. IxFun num -> [Int]
permutation = (LMADDim num -> Int) -> [LMADDim num] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
LMAD.ldPerm ([LMADDim num] -> [Int])
-> (IxFun num -> [LMADDim num]) -> IxFun num -> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims (LMAD num -> [LMADDim num])
-> (IxFun num -> LMAD num) -> IxFun num -> [LMADDim num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD
index ::
(IntegralExp num, Eq num) =>
IxFun num ->
Indices num ->
num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
LMAD.index (LMAD num -> Indices num -> num)
-> (IxFun num -> LMAD num) -> IxFun num -> Indices num -> num
forall b c a. (b -> c) -> (a -> b) -> a -> c
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns = LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (num -> Shape num -> LMAD num
forall num. IntegralExp num => num -> [num] -> LMAD num
LMAD.iota num
o Shape num
ns) Shape num
ns
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = num -> Shape num -> IxFun num
forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0
mkExistential :: Int -> [Int] -> Int -> IxFun (Ext a)
mkExistential :: forall a. Int -> [Int] -> Int -> IxFun (Ext a)
mkExistential Int
basis_rank [Int]
perm Int
start =
LMAD (Ext a) -> Shape (Ext a) -> IxFun (Ext a)
forall num. LMAD num -> Shape num -> IxFun num
IxFun ([Int] -> Int -> LMAD (Ext a)
forall a. [Int] -> Int -> LMAD (Ext a)
LMAD.mkExistential [Int]
perm Int
start) Shape (Ext a)
forall {a}. [Ext a]
basis
where
basis :: [Ext a]
basis = Int -> [Ext a] -> [Ext a]
forall a. Int -> [a] -> [a]
take Int
basis_rank ([Ext a] -> [Ext a]) -> [Ext a] -> [Ext a]
forall a b. (a -> b) -> a -> b
$ (Int -> Ext a) -> [Int] -> [Ext a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Ext a
forall a. Int -> Ext a
Ext [Int
start Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
dims_rank Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
2 ..]
dims_rank :: Int
dims_rank = [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm
permute ::
IntegralExp num =>
IxFun num ->
Permutation ->
IxFun num
permute :: forall num. IntegralExp num => IxFun num -> [Int] -> IxFun num
permute (IxFun LMAD num
lmad Shape num
oshp) [Int]
perm_new =
LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> [Int] -> LMAD num
forall num. LMAD num -> [Int] -> LMAD num
LMAD.permute LMAD num
lmad [Int]
perm_new) Shape num
oshp
slice ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice ixfun :: IxFun num
ixfun@(IxFun lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) Shape num
oshp) (Slice [DimIndex num]
is)
| [DimIndex num]
is [DimIndex num] -> [DimIndex num] -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
| Bool
otherwise =
LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> Slice num -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
LMAD.slice LMAD num
lmad ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice [DimIndex num]
is)) Shape num
oshp
flatSlice ::
(Eq num, IntegralExp num) =>
IxFun num ->
FlatSlice num ->
Maybe (IxFun num)
flatSlice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> FlatSlice num -> Maybe (IxFun num)
flatSlice (IxFun LMAD num
lmad Shape num
oshp) FlatSlice num
s = do
LMAD num
lmad' <- LMAD num -> FlatSlice num -> Maybe (LMAD num)
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> Maybe (LMAD num)
LMAD.flatSlice LMAD num
lmad FlatSlice num
s
IxFun num -> Maybe (IxFun num)
forall a. a -> Maybe a
Just (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun LMAD num
lmad' Shape num
oshp
reshape ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
Maybe (IxFun num)
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> Maybe (IxFun num)
reshape (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> Shape num -> IxFun num)
-> Maybe (LMAD num) -> Maybe (Shape num -> IxFun num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> LMAD num -> Shape num -> Maybe (LMAD num)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
LMAD.reshape LMAD num
lmad Shape num
new_shape Maybe (Shape num -> IxFun num)
-> Maybe (Shape num) -> Maybe (IxFun num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape num -> Maybe (Shape num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Shape num
new_shape
coerce ::
(Eq num, IntegralExp num) =>
IxFun num ->
Shape num ->
IxFun num
coerce :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
coerce (IxFun LMAD num
lmad Shape num
_) Shape num
new_shape =
LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun (LMAD num -> LMAD num
onLMAD LMAD num
lmad) Shape num
new_shape
where
onLMAD :: LMAD num -> LMAD num
onLMAD (LMAD num
offset [LMADDim num]
dims) = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
offset ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> num -> LMADDim num
forall {num}. LMADDim num -> num -> LMADDim num
onDim [LMADDim num]
dims Shape num
new_shape
onDim :: LMADDim num -> num -> LMADDim num
onDim LMADDim num
ld num
d = LMADDim num
ld {ldShape :: num
ldShape = num
d}
rank :: IntegralExp num => IxFun num -> Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss) Shape num
_) = [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss
rebase ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
Maybe (IxFun num)
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebase new_base :: IxFun num
new_base@(IxFun LMAD num
lmad_base Shape num
_) ixfun :: IxFun num
ixfun@(IxFun LMAD num
lmad Shape num
shp) = do
let dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad
perm :: [Int]
perm = LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad
perm_base :: [Int]
perm_base = LMAD num -> [Int]
forall num. LMAD num -> [Int]
LMAD.permutation LMAD num
lmad_base
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
IxFun num -> Shape num
forall num. IxFun num -> Shape num
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
Bool -> Bool -> Bool
&& (IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
new_base)
Bool -> Bool -> Bool
&& ([Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Int] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm_base Bool -> Bool -> Bool
|| IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun)
Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
( (num -> LMADDim num -> Bool -> Bool)
-> Shape num -> [LMADDim num] -> [Bool] -> [Bool]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
(\num
sn LMADDim num
ld Bool
inner -> Bool
inner Bool -> Bool -> Bool
|| num
sn num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim num -> num
forall num. LMADDim num -> num
ldShape LMADDim num
ld)
Shape num
shp
[LMADDim num]
dims
(Bool
True Bool -> [Bool] -> [Bool]
forall a. a -> [a] -> [a]
: Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
replicate ([LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool
False)
)
let perm_base' :: [Int]
perm_base' =
if IxFun num -> Bool
forall a. IxFun a -> Bool
hasContiguousPerm IxFun num
ixfun
then [Int]
perm_base
else (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ([Int]
perm !!) [Int]
perm_base
lmad_base' :: LMAD num
lmad_base' = [Int] -> LMAD num -> LMAD num
forall num. [Int] -> LMAD num -> LMAD num
LMAD.setPermutation [Int]
perm_base' LMAD num
lmad_base
dims_base :: [LMADDim num]
dims_base = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad_base'
n_fewer_dims :: Int
n_fewer_dims = [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
([LMADDim num]
dims_base', Shape num
offs_contrib) =
[(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, num)] -> ([LMADDim num], Shape num))
-> [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> LMADDim num -> (LMADDim num, num))
-> [LMADDim num] -> [LMADDim num] -> [(LMADDim num, num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s1 num
n1 Int
p1) (LMADDim {}) ->
let (num
s', num
off') = (num
s1, num
0)
in (num -> num -> Int -> LMADDim num
forall num. num -> num -> Int -> LMADDim num
LMADDim num
s' num
n1 (Int
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_fewer_dims), num
off')
)
(Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
[LMADDim num]
dims
off_base :: num
off_base = LMAD num -> num
forall num. LMAD num -> num
LMAD.offset LMAD num
lmad_base' num -> num -> num
forall a. Num a => a -> a -> a
+ Shape num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
lmad_base'' :: LMAD num
lmad_base'' =
Shape num -> LMAD num -> LMAD num
forall num. Shape num -> LMAD num -> LMAD num
LMAD.setShape
(LMAD num -> Shape num
forall num. LMAD num -> Shape num
LMAD.shape LMAD num
lmad)
( num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
off_base num -> num -> num
forall a. Num a => a -> a -> a
+ LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. HasCallStack => [a] -> a
last [LMADDim num]
dims_base) num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
LMAD.offset LMAD num
lmad)
[LMADDim num]
dims_base'
)
IxFun num -> Maybe (IxFun num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ LMAD num -> Shape num -> IxFun num
forall num. LMAD num -> Shape num -> IxFun num
IxFun LMAD num
lmad_base'' Shape num
shp
existentialize ::
IxFun (TPrimExp Int64 a) ->
IxFun (TPrimExp Int64 (Ext b))
existentialize :: forall a b.
IxFun (TPrimExp Int64 a) -> IxFun (TPrimExp Int64 (Ext b))
existentialize IxFun (TPrimExp Int64 a)
ixfun = State Int (IxFun (TPrimExp Int64 (Ext b)))
-> Int -> IxFun (TPrimExp Int64 (Ext b))
forall s a. State s a -> s -> a
evalState ((TPrimExp Int64 a -> StateT Int Identity (TPrimExp Int64 (Ext b)))
-> IxFun (TPrimExp Int64 a)
-> State Int (IxFun (TPrimExp Int64 (Ext b)))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse (StateT Int Identity (TPrimExp Int64 (Ext b))
-> TPrimExp Int64 a -> StateT Int Identity (TPrimExp Int64 (Ext b))
forall a b. a -> b -> a
const StateT Int Identity (TPrimExp Int64 (Ext b))
forall {k} {t :: k} {a}. StateT Int Identity (TPrimExp t (Ext a))
mkExt) IxFun (TPrimExp Int64 a)
ixfun) Int
0
where
mkExt :: StateT Int Identity (TPrimExp t (Ext a))
mkExt = do
Int
i <- StateT Int Identity Int
forall s (m :: * -> *). MonadState s m => m s
get
Int -> StateT Int Identity ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int -> StateT Int Identity ()) -> Int -> StateT Int Identity ()
forall a b. (a -> b) -> a -> b
$ Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a))
forall a. a -> StateT Int Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a)))
-> TPrimExp t (Ext a) -> StateT Int Identity (TPrimExp t (Ext a))
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext a) -> TPrimExp t (Ext a)
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp (Ext a) -> TPrimExp t (Ext a))
-> PrimExp (Ext a) -> TPrimExp t (Ext a)
forall a b. (a -> b) -> a -> b
$ Ext a -> PrimType -> PrimExp (Ext a)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext a
forall a. Int -> Ext a
Ext Int
i) PrimType
int64
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
([num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf2))
Bool -> Bool -> Bool
&& LMAD num -> LMAD num -> Bool
forall {num} {num}. LMAD num -> LMAD num -> Bool
closeEnoughLMADs (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf1) (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf2)
where
closeEnoughLMADs :: LMAD num -> LMAD num -> Bool
closeEnoughLMADs LMAD num
lmad1 LMAD num
lmad2 =
[LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad2)
Bool -> Bool -> Bool
&& (LMADDim num -> Int) -> [LMADDim num] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad1) [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim num -> Int) -> [LMADDim num] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD num
lmad2)
equivalent :: Eq num => IxFun num -> IxFun num -> Bool
equivalent :: forall num. Eq num => IxFun num -> IxFun num -> Bool
equivalent IxFun num
ixf1 IxFun num
ixf2 =
LMAD num -> LMAD num -> Bool
forall {b}. Eq b => LMAD b -> LMAD b -> Bool
equivalentLMADs (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf1) (IxFun num -> LMAD num
forall num. IxFun num -> LMAD num
ixfunLMAD IxFun num
ixf2)
where
equivalentLMADs :: LMAD b -> LMAD b -> Bool
equivalentLMADs LMAD b
lmad1 LMAD b
lmad2 =
[LMADDim b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim b] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad2)
Bool -> Bool -> Bool
&& (LMADDim b -> Int) -> [LMADDim b] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad1) [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> Int) -> [LMADDim b] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad2)
Bool -> Bool -> Bool
&& LMAD b -> b
forall num. LMAD num -> num
LMAD.offset LMAD b
lmad1 b -> b -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD b -> b
forall num. LMAD num -> num
LMAD.offset LMAD b
lmad2
Bool -> Bool -> Bool
&& (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad1) [b] -> [b] -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride (LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
LMAD.dims LMAD b
lmad2)