{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Futhark.IR.Mem.IxFun
( IxFun (..),
LMAD (..),
LMADDim (..),
Monotonicity (..),
index,
iota,
iotaOffset,
permute,
rotate,
reshape,
slice,
rebase,
shape,
rank,
linearWithOffset,
rearrangeWithOffset,
isDirect,
isLinear,
substituteInIxFun,
leastGeneralGeneralization,
existentialize,
closeEnough,
)
where
import Control.Category
import Control.Monad.Identity
import Control.Monad.State
import Control.Monad.Writer
import Data.Function (on)
import Data.List (sort, sortBy, zip4, zip5, zipWith5)
import Data.List.NonEmpty (NonEmpty (..))
import qualified Data.List.NonEmpty as NE
import qualified Data.Map.Strict as M
import Data.Maybe (isJust)
import Futhark.Analysis.PrimExp
( IntExp,
PrimExp (..),
TPrimExp (..),
primExpType,
)
import Futhark.Analysis.PrimExp.Convert (substituteInPrimExp)
import qualified Futhark.Analysis.PrimExp.Generalize as PEG
import Futhark.IR.Prop
import Futhark.IR.Syntax
( DimChange (..),
DimIndex (..),
ShapeChange,
Slice,
dimFix,
unitSlice,
)
import Futhark.IR.Syntax.Core (Ext (..))
import Futhark.Transform.Rename
import Futhark.Transform.Substitute
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (id, mod, (.))
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data Monotonicity
= Inc
| Dec
|
Unknown
deriving (Int -> Monotonicity -> ShowS
[Monotonicity] -> ShowS
Monotonicity -> String
(Int -> Monotonicity -> ShowS)
-> (Monotonicity -> String)
-> ([Monotonicity] -> ShowS)
-> Show Monotonicity
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Monotonicity] -> ShowS
$cshowList :: [Monotonicity] -> ShowS
show :: Monotonicity -> String
$cshow :: Monotonicity -> String
showsPrec :: Int -> Monotonicity -> ShowS
$cshowsPrec :: Int -> Monotonicity -> ShowS
Show, Monotonicity -> Monotonicity -> Bool
(Monotonicity -> Monotonicity -> Bool)
-> (Monotonicity -> Monotonicity -> Bool) -> Eq Monotonicity
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Monotonicity -> Monotonicity -> Bool
$c/= :: Monotonicity -> Monotonicity -> Bool
== :: Monotonicity -> Monotonicity -> Bool
$c== :: Monotonicity -> Monotonicity -> Bool
Eq)
data LMADDim num = LMADDim
{ forall num. LMADDim num -> num
ldStride :: num,
forall num. LMADDim num -> num
ldRotate :: num,
forall num. LMADDim num -> num
ldShape :: num,
forall num. LMADDim num -> Int
ldPerm :: Int,
forall num. LMADDim num -> Monotonicity
ldMon :: Monotonicity
}
deriving (Int -> LMADDim num -> ShowS
[LMADDim num] -> ShowS
LMADDim num -> String
(Int -> LMADDim num -> ShowS)
-> (LMADDim num -> String)
-> ([LMADDim num] -> ShowS)
-> Show (LMADDim num)
forall num. Show num => Int -> LMADDim num -> ShowS
forall num. Show num => [LMADDim num] -> ShowS
forall num. Show num => LMADDim num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMADDim num] -> ShowS
$cshowList :: forall num. Show num => [LMADDim num] -> ShowS
show :: LMADDim num -> String
$cshow :: forall num. Show num => LMADDim num -> String
showsPrec :: Int -> LMADDim num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMADDim num -> ShowS
Show, LMADDim num -> LMADDim num -> Bool
(LMADDim num -> LMADDim num -> Bool)
-> (LMADDim num -> LMADDim num -> Bool) -> Eq (LMADDim num)
forall num. Eq num => LMADDim num -> LMADDim num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMADDim num -> LMADDim num -> Bool
$c/= :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
== :: LMADDim num -> LMADDim num -> Bool
$c== :: forall num. Eq num => LMADDim num -> LMADDim num -> Bool
Eq)
data LMAD num = LMAD
{ forall num. LMAD num -> num
lmadOffset :: num,
forall num. LMAD num -> [LMADDim num]
lmadDims :: [LMADDim num]
}
deriving (Int -> LMAD num -> ShowS
[LMAD num] -> ShowS
LMAD num -> String
(Int -> LMAD num -> ShowS)
-> (LMAD num -> String) -> ([LMAD num] -> ShowS) -> Show (LMAD num)
forall num. Show num => Int -> LMAD num -> ShowS
forall num. Show num => [LMAD num] -> ShowS
forall num. Show num => LMAD num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [LMAD num] -> ShowS
$cshowList :: forall num. Show num => [LMAD num] -> ShowS
show :: LMAD num -> String
$cshow :: forall num. Show num => LMAD num -> String
showsPrec :: Int -> LMAD num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> LMAD num -> ShowS
Show, LMAD num -> LMAD num -> Bool
(LMAD num -> LMAD num -> Bool)
-> (LMAD num -> LMAD num -> Bool) -> Eq (LMAD num)
forall num. Eq num => LMAD num -> LMAD num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: LMAD num -> LMAD num -> Bool
$c/= :: forall num. Eq num => LMAD num -> LMAD num -> Bool
== :: LMAD num -> LMAD num -> Bool
$c== :: forall num. Eq num => LMAD num -> LMAD num -> Bool
Eq)
data IxFun num = IxFun
{ forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs :: NonEmpty (LMAD num),
forall num. IxFun num -> Shape num
base :: Shape num,
forall num. IxFun num -> Bool
ixfunContig :: Bool
}
deriving (Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IxFun num] -> ShowS
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
show :: IxFun num -> String
$cshow :: forall num. Show num => IxFun num -> String
showsPrec :: Int -> IxFun num -> ShowS
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
Show, IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
Eq)
instance Pretty Monotonicity where
ppr :: Monotonicity -> Doc
ppr = String -> Doc
text (String -> Doc) -> (Monotonicity -> String) -> Monotonicity -> Doc
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Monotonicity -> String
forall a. Show a => a -> String
show
instance Pretty num => Pretty (LMAD num) where
ppr :: LMAD num -> Doc
ppr (LMAD num
offset [LMADDim num]
dims) =
Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
[Doc] -> Doc
semisep
[ Doc
"offset: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
oneLine (num -> Doc
forall a. Pretty a => a -> Doc
ppr num
offset),
Doc
"strides: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall {b}. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldStride,
Doc
"rotates: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall {b}. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldRotate,
Doc
"shape: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> num) -> Doc
forall {b}. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> num
forall num. LMADDim num -> num
ldShape,
Doc
"permutation: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Int) -> Doc
forall {b}. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm,
Doc
"monotonicity: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> (LMADDim num -> Monotonicity) -> Doc
forall {b}. Pretty b => (LMADDim num -> b) -> Doc
p LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon
]
where
p :: (LMADDim num -> b) -> Doc
p LMADDim num -> b
f = Doc -> Doc
oneLine (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ Doc -> Doc
brackets (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$ [Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> Doc) -> [LMADDim num] -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map (b -> Doc
forall a. Pretty a => a -> Doc
ppr (b -> Doc) -> (LMADDim num -> b) -> LMADDim num -> Doc
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> b
f) [LMADDim num]
dims
instance Pretty num => Pretty (IxFun num) where
ppr :: IxFun num -> Doc
ppr (IxFun NonEmpty (LMAD num)
lmads Shape num
oshp Bool
cg) =
Doc -> Doc
braces (Doc -> Doc) -> Doc -> Doc
forall a b. (a -> b) -> a -> b
$
[Doc] -> Doc
semisep
[ Doc
"base: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ (num -> Doc) -> Shape num -> [Doc]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc
forall a. Pretty a => a -> Doc
ppr Shape num
oshp),
Doc
"contiguous: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> if Bool
cg then Doc
"true" else Doc
"false",
Doc
"LMADs: " Doc -> Doc -> Doc
forall a. Semigroup a => a -> a -> a
<> Doc -> Doc
brackets ([Doc] -> Doc
commasep ([Doc] -> Doc) -> [Doc] -> Doc
forall a b. (a -> b) -> a -> b
$ NonEmpty Doc -> [Doc]
forall a. NonEmpty a -> [a]
NE.toList (NonEmpty Doc -> [Doc]) -> NonEmpty Doc -> [Doc]
forall a b. (a -> b) -> a -> b
$ (LMAD num -> Doc) -> NonEmpty (LMAD num) -> NonEmpty Doc
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map LMAD num -> Doc
forall a. Pretty a => a -> Doc
ppr NonEmpty (LMAD num)
lmads)
]
instance Substitute num => Substitute (LMAD num) where
substituteNames :: Map VName VName -> LMAD num -> LMAD num
substituteNames Map VName VName
substs = (num -> num) -> LMAD num -> LMAD num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> LMAD num -> LMAD num)
-> (num -> num) -> LMAD num -> LMAD num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Substitute (IxFun num) where
substituteNames :: Map VName VName -> IxFun num -> IxFun num
substituteNames Map VName VName
substs = (num -> num) -> IxFun num -> IxFun num
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((num -> num) -> IxFun num -> IxFun num)
-> (num -> num) -> IxFun num -> IxFun num
forall a b. (a -> b) -> a -> b
$ Map VName VName -> num -> num
forall a. Substitute a => Map VName VName -> a -> a
substituteNames Map VName VName
substs
instance Substitute num => Rename (LMAD num) where
rename :: LMAD num -> RenameM (LMAD num)
rename = LMAD num -> RenameM (LMAD num)
forall a. Substitute a => a -> RenameM a
substituteRename
instance Substitute num => Rename (IxFun num) where
rename :: IxFun num -> RenameM (IxFun num)
rename = IxFun num -> RenameM (IxFun num)
forall a. Substitute a => a -> RenameM a
substituteRename
instance FreeIn num => FreeIn (LMAD num) where
freeIn' :: LMAD num -> FV
freeIn' = (num -> FV) -> LMAD num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'
instance FreeIn num => FreeIn (IxFun num) where
freeIn' :: IxFun num -> FV
freeIn' = (num -> FV) -> IxFun num -> FV
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap num -> FV
forall a. FreeIn a => a -> FV
freeIn'
instance Functor LMAD where
fmap :: forall a b. (a -> b) -> LMAD a -> LMAD b
fmap a -> b
f = Identity (LMAD b) -> LMAD b
forall a. Identity a -> a
runIdentity (Identity (LMAD b) -> LMAD b)
-> (LMAD a -> Identity (LMAD b)) -> LMAD a -> LMAD b
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> Identity b) -> LMAD a -> Identity (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)
instance Functor IxFun where
fmap :: forall a b. (a -> b) -> IxFun a -> IxFun b
fmap a -> b
f = Identity (IxFun b) -> IxFun b
forall a. Identity a -> a
runIdentity (Identity (IxFun b) -> IxFun b)
-> (IxFun a -> Identity (IxFun b)) -> IxFun a -> IxFun b
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> Identity b) -> IxFun a -> Identity (IxFun b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (b -> Identity b
forall (m :: * -> *) a. Monad m => a -> m a
return (b -> Identity b) -> (a -> b) -> a -> Identity b
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)
instance Foldable LMAD where
foldMap :: forall m a. Monoid m => (a -> m) -> LMAD a -> m
foldMap a -> m
f = Writer m (LMAD ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (LMAD ()) -> m)
-> (LMAD a -> Writer m (LMAD ())) -> LMAD a -> m
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> WriterT m Identity ()) -> LMAD a -> Writer m (LMAD ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)
instance Foldable IxFun where
foldMap :: forall m a. Monoid m => (a -> m) -> IxFun a -> m
foldMap a -> m
f = Writer m (IxFun ()) -> m
forall w a. Writer w a -> w
execWriter (Writer m (IxFun ()) -> m)
-> (IxFun a -> Writer m (IxFun ())) -> IxFun a -> m
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (a -> WriterT m Identity ()) -> IxFun a -> Writer m (IxFun ())
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (m -> WriterT m Identity ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (m -> WriterT m Identity ())
-> (a -> m) -> a -> WriterT m Identity ()
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> m
f)
instance Traversable LMAD where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> LMAD a -> f (LMAD b)
traverse a -> f b
f (LMAD a
offset [LMADDim a]
dims) =
b -> [LMADDim b] -> LMAD b
forall num. num -> [LMADDim num] -> LMAD num
LMAD (b -> [LMADDim b] -> LMAD b) -> f b -> f ([LMADDim b] -> LMAD b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
offset f ([LMADDim b] -> LMAD b) -> f [LMADDim b] -> f (LMAD b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (LMADDim a -> f (LMADDim b)) -> [LMADDim a] -> f [LMADDim b]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse LMADDim a -> f (LMADDim b)
f' [LMADDim a]
dims
where
f' :: LMADDim a -> f (LMADDim b)
f' (LMADDim a
s a
r a
n Int
p Monotonicity
m) =
b -> b -> b -> Int -> Monotonicity -> LMADDim b
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (b -> b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> f b
f a
s f (b -> b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (b -> Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
r f (b -> Int -> Monotonicity -> LMADDim b)
-> f b -> f (Int -> Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> a -> f b
f a
n f (Int -> Monotonicity -> LMADDim b)
-> f Int -> f (Monotonicity -> LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> f Int
forall (f :: * -> *) a. Applicative f => a -> f a
pure Int
p f (Monotonicity -> LMADDim b) -> f Monotonicity -> f (LMADDim b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Monotonicity -> f Monotonicity
forall (f :: * -> *) a. Applicative f => a -> f a
pure Monotonicity
m
instance Traversable IxFun where
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> IxFun a -> f (IxFun b)
traverse a -> f b
f (IxFun NonEmpty (LMAD a)
lmads Shape a
oshp Bool
cg) =
NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD b) -> Shape b -> Bool -> IxFun b)
-> f (NonEmpty (LMAD b)) -> f (Shape b -> Bool -> IxFun b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (LMAD a -> f (LMAD b))
-> NonEmpty (LMAD a) -> f (NonEmpty (LMAD b))
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse ((a -> f b) -> LMAD a -> f (LMAD b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f) NonEmpty (LMAD a)
lmads f (Shape b -> Bool -> IxFun b)
-> f (Shape b) -> f (Bool -> IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> f b) -> Shape a -> f (Shape b)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse a -> f b
f Shape a
oshp f (Bool -> IxFun b) -> f Bool -> f (IxFun b)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Bool -> f Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure Bool
cg
(++@) :: [a] -> NonEmpty a -> NonEmpty a
[a]
es ++@ :: forall a. [a] -> NonEmpty a -> NonEmpty a
++@ (a
ne :| [a]
nes) = case [a]
es of
a
e : [a]
es' -> a
e a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
es' [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
ne] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
nes
[] -> a
ne a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
nes
(@++@) :: NonEmpty a -> NonEmpty a -> NonEmpty a
(a
x :| [a]
xs) @++@ :: forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ (a
y :| [a]
ys) = a
x a -> [a] -> NonEmpty a
forall a. a -> [a] -> NonEmpty a
:| [a]
xs [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
y] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
ys
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity :: Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
Inc = Monotonicity
Dec
invertMonotonicity Monotonicity
Dec = Monotonicity
Inc
invertMonotonicity Monotonicity
Unknown = Monotonicity
Unknown
lmadPermutation :: LMAD num -> Permutation
lmadPermutation :: forall num. LMAD num -> Permutation
lmadPermutation = (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num] -> Permutation)
-> (LMAD num -> [LMADDim num]) -> LMAD num -> Permutation
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
setLMADPermutation :: Permutation -> LMAD num -> LMAD num
setLMADPermutation :: forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad =
LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> Int -> LMADDim num)
-> [LMADDim num] -> Permutation -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim Int
p -> LMADDim num
dim {ldPerm :: Int
ldPerm = Int
p}) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Permutation
perm}
setLMADShape :: Shape num -> LMAD num -> LMAD num
setLMADShape :: forall num. Shape num -> LMAD num -> LMAD num
setLMADShape Shape num
shp LMAD num
lmad = LMAD num
lmad {lmadDims :: [LMADDim num]
lmadDims = (LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\LMADDim num
dim num
s -> LMADDim num
dim {ldShape :: num
ldShape = num
s}) (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad) Shape num
shp}
substituteInLMAD ::
Ord a =>
M.Map a (PrimExp a) ->
LMAD (PrimExp a) ->
LMAD (PrimExp a)
substituteInLMAD :: forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab (LMAD PrimExp a
offset [LMADDim (PrimExp a)]
dims) =
let offset' :: PrimExp a
offset' = Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
offset
dims' :: [LMADDim (PrimExp a)]
dims' =
(LMADDim (PrimExp a) -> LMADDim (PrimExp a))
-> [LMADDim (PrimExp a)] -> [LMADDim (PrimExp a)]
forall a b. (a -> b) -> [a] -> [b]
map
( \(LMADDim PrimExp a
s PrimExp a
r PrimExp a
n Int
p Monotonicity
m) ->
PrimExp a
-> PrimExp a
-> PrimExp a
-> Int
-> Monotonicity
-> LMADDim (PrimExp a)
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim
(Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
s)
(Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
r)
(Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab PrimExp a
n)
Int
p
Monotonicity
m
)
[LMADDim (PrimExp a)]
dims
in PrimExp a -> [LMADDim (PrimExp a)] -> LMAD (PrimExp a)
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp a
offset' [LMADDim (PrimExp a)]
dims'
substituteInIxFun ::
Ord a =>
M.Map a (TPrimExp t a) ->
IxFun (TPrimExp t a) ->
IxFun (TPrimExp t a)
substituteInIxFun :: forall a t.
Ord a =>
Map a (TPrimExp t a)
-> IxFun (TPrimExp t a) -> IxFun (TPrimExp t a)
substituteInIxFun Map a (TPrimExp t a)
tab (IxFun NonEmpty (LMAD (TPrimExp t a))
lmads Shape (TPrimExp t a)
oshp Bool
cg) =
NonEmpty (LMAD (TPrimExp t a))
-> Shape (TPrimExp t a) -> Bool -> IxFun (TPrimExp t a)
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun
((LMAD (TPrimExp t a) -> LMAD (TPrimExp t a))
-> NonEmpty (LMAD (TPrimExp t a)) -> NonEmpty (LMAD (TPrimExp t a))
forall a b. (a -> b) -> NonEmpty a -> NonEmpty b
NE.map ((PrimExp a -> TPrimExp t a)
-> LMAD (PrimExp a) -> LMAD (TPrimExp t a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap PrimExp a -> TPrimExp t a
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (LMAD (PrimExp a) -> LMAD (TPrimExp t a))
-> (LMAD (TPrimExp t a) -> LMAD (PrimExp a))
-> LMAD (TPrimExp t a)
-> LMAD (TPrimExp t a)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
forall a.
Ord a =>
Map a (PrimExp a) -> LMAD (PrimExp a) -> LMAD (PrimExp a)
substituteInLMAD Map a (PrimExp a)
tab' (LMAD (PrimExp a) -> LMAD (PrimExp a))
-> (LMAD (TPrimExp t a) -> LMAD (PrimExp a))
-> LMAD (TPrimExp t a)
-> LMAD (PrimExp a)
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (TPrimExp t a -> PrimExp a)
-> LMAD (TPrimExp t a) -> LMAD (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped) NonEmpty (LMAD (TPrimExp t a))
lmads)
((TPrimExp t a -> TPrimExp t a)
-> Shape (TPrimExp t a) -> Shape (TPrimExp t a)
forall a b. (a -> b) -> [a] -> [b]
map (PrimExp a -> TPrimExp t a
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp a -> TPrimExp t a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> TPrimExp t a
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Map a (PrimExp a) -> PrimExp a -> PrimExp a
forall v. Ord v => Map v (PrimExp v) -> PrimExp v -> PrimExp v
substituteInPrimExp Map a (PrimExp a)
tab' (PrimExp a -> PrimExp a)
-> (TPrimExp t a -> PrimExp a) -> TPrimExp t a -> PrimExp a
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped) Shape (TPrimExp t a)
oshp)
Bool
cg
where
tab' :: Map a (PrimExp a)
tab' = (TPrimExp t a -> PrimExp a)
-> Map a (TPrimExp t a) -> Map a (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap TPrimExp t a -> PrimExp a
forall t v. TPrimExp t v -> PrimExp v
untyped Map a (TPrimExp t a)
tab
isDirect :: (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isDirect ixfun :: IxFun num
ixfun@(IxFun (LMAD num
offset [LMADDim num]
dims :| []) Shape num
oshp Bool
True) =
let strides_expected :: Shape num
strides_expected = Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> Shape num -> Shape num
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 (Shape num -> Shape num
forall a. [a] -> [a]
reverse (Shape num -> Shape num
forall a. [a] -> [a]
tail Shape num
oshp))
in IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
oshp Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
Bool -> Bool -> Bool
&& num
offset num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
Bool -> Bool -> Bool
&& ((LMADDim num, Int, num, num) -> Bool)
-> [(LMADDim num, Int, num, num)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all
( \(LMADDim num
s num
r num
n Int
p Monotonicity
_, Int
m, num
d, num
se) ->
num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
se Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
&& num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
d Bool -> Bool -> Bool
&& Int
p Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
m
)
([LMADDim num]
-> Permutation
-> Shape num
-> Shape num
-> [(LMADDim num, Int, num, num)]
forall a b c d. [a] -> [b] -> [c] -> [d] -> [(a, b, c, d)]
zip4 [LMADDim num]
dims [Int
0 .. [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Shape num
oshp Shape num
strides_expected)
isDirect IxFun num
_ = Bool
False
hasContiguousPerm :: IxFun num -> Bool
hasContiguousPerm :: forall num. IxFun num -> Bool
hasContiguousPerm (IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_) =
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
in Permutation
perm Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Permutation
forall a. Ord a => [a] -> [a]
sort Permutation
perm
hasContiguousPerm IxFun num
_ = Bool
False
shape :: (Eq num, IntegralExp num) => IxFun num -> Shape num
shape :: forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape (IxFun (LMAD num
lmad :| [LMAD num]
_) Shape num
_ Bool
_) = LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad
lmadShape :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad = Permutation -> [num] -> [num]
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ LMAD num -> [num]
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad
lmadShapeBase :: (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase :: forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase = (LMADDim num -> num) -> [LMADDim num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> num
forall num. LMADDim num -> num
ldShape ([LMADDim num] -> [num])
-> (LMAD num -> [LMADDim num]) -> LMAD num -> [num]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
index ::
(IntegralExp num, Eq num) =>
IxFun num ->
Indices num ->
num
index :: forall num.
(IntegralExp num, Eq num) =>
IxFun num -> Indices num -> num
index = NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (NonEmpty (LMAD num) -> Indices num -> num)
-> (IxFun num -> NonEmpty (LMAD num))
-> IxFun num
-> Indices num
-> num
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs
where
indexFromLMADs ::
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) ->
Indices num ->
num
indexFromLMADs :: forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad :| []) Indices num
inds = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad Indices num
inds
indexFromLMADs (LMAD num
lmad1 :| LMAD num
lmad2 : [LMAD num]
lmads) Indices num
inds =
let i_flat :: num
i_flat = LMAD num -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD LMAD num
lmad1 Indices num
inds
new_inds :: Indices num
new_inds = Indices num -> num -> Indices num
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex (Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteFwd (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad2) (Indices num -> Indices num) -> Indices num -> Indices num
forall a b. (a -> b) -> a -> b
$ LMAD num -> Indices num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad2) num
i_flat
in NonEmpty (LMAD num) -> Indices num -> num
forall num.
(IntegralExp num, Eq num) =>
NonEmpty (LMAD num) -> Indices num -> num
indexFromLMADs (LMAD num
lmad2 LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Indices num
new_inds
indexLMAD ::
(IntegralExp num, Eq num) =>
LMAD num ->
Indices num ->
num
indexLMAD :: forall num.
(IntegralExp num, Eq num) =>
LMAD num -> Indices num -> num
indexLMAD lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) Indices num
inds =
let prod :: num
prod =
Indices num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Indices num -> num) -> Indices num -> num
forall a b. (a -> b) -> a -> b
$
((num, num, num) -> num -> num)
-> [(num, num, num)] -> Indices num -> Indices num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim
((LMADDim num -> (num, num, num))
-> [LMADDim num] -> [(num, num, num)]
forall a b. (a -> b) -> [a] -> [b]
map (\(LMADDim num
s num
r num
n Int
_ Monotonicity
_) -> (num
s, num
r, num
n)) [LMADDim num]
dims)
(Permutation -> Indices num -> Indices num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Indices num
inds)
in num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
prod
iotaOffset :: IntegralExp num => num -> Shape num -> IxFun num
iotaOffset :: forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
o Shape num
ns =
let rs :: Shape num
rs = Int -> num -> Shape num
forall a. Int -> a -> [a]
replicate (Shape num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape num
ns) num
0
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
Inc num
o (Shape num -> Shape num -> [(num, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Shape num
rs Shape num
ns) LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
ns Bool
True
iota :: IntegralExp num => Shape num -> IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota = num -> Shape num -> IxFun num
forall num. IntegralExp num => num -> Shape num -> IxFun num
iotaOffset num
0
permute ::
IntegralExp num =>
IxFun num ->
Permutation ->
IxFun num
permute :: forall num.
IntegralExp num =>
IxFun num -> Permutation -> IxFun num
permute (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Permutation
perm_new =
let perm_cur :: Permutation
perm_cur = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm :: Permutation
perm = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm_cur Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_new
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
rotate ::
(Eq num, IntegralExp num) =>
IxFun num ->
Indices num ->
IxFun num
rotate :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> IxFun num
rotate (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Shape num
offs =
let dims' :: [LMADDim num]
dims' =
(LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s num
r num
n Int
p Monotonicity
f) num
o ->
if num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0
then num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
p Monotonicity
Unknown
else num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s (num
r num -> num -> num
forall a. Num a => a -> a -> a
+ num
o) num
n Int
p Monotonicity
f
)
[LMADDim num]
dims
(Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteInv (LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad) Shape num
offs)
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg
sliceOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
Maybe (IxFun num)
sliceOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
ldims) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
is = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
is' :: Slice num
is' = Permutation -> Slice num -> Slice num
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm Slice num
is
cg' :: Bool
cg' = Bool
cg Bool -> Bool -> Bool
&& LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous LMAD num
lmad Slice num
is'
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ LMAD num -> Slice num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
harmlessRotation LMAD num
lmad Slice num
is'
let lmad' :: LMAD num
lmad' = (LMAD num -> (DimIndex num, LMADDim num) -> LMAD num)
-> LMAD num -> [(DimIndex num, LMADDim num)] -> LMAD num
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad) []) ([(DimIndex num, LMADDim num)] -> LMAD num)
-> [(DimIndex num, LMADDim num)] -> LMAD num
forall a b. (a -> b) -> a -> b
$ Slice num -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Slice num
is' [LMADDim num]
ldims
perm' :: Permutation
perm' =
Permutation -> Permutation -> Permutation
forall {t :: * -> *} {a} {t :: * -> *}.
(Foldable t, Foldable t, Num a, Ord a) =>
t a -> t a -> [a]
updatePerm Permutation
perm (Permutation -> Permutation) -> Permutation -> Permutation
forall a b. (a -> b) -> a -> b
$
((Int, DimIndex num) -> Int)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Int, DimIndex num) -> Int
forall a b. (a, b) -> a
fst ([(Int, DimIndex num)] -> Permutation)
-> [(Int, DimIndex num)] -> Permutation
forall a b. (a -> b) -> a -> b
$
((Int, DimIndex num) -> Bool)
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter (Maybe num -> Bool
forall a. Maybe a -> Bool
isJust (Maybe num -> Bool)
-> ((Int, DimIndex num) -> Maybe num)
-> (Int, DimIndex num)
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. DimIndex num -> Maybe num
forall d. DimIndex d -> Maybe d
dimFix (DimIndex num -> Maybe num)
-> ((Int, DimIndex num) -> DimIndex num)
-> (Int, DimIndex num)
-> Maybe num
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (Int, DimIndex num) -> DimIndex num
forall a b. (a, b) -> b
snd) ([(Int, DimIndex num)] -> [(Int, DimIndex num)])
-> [(Int, DimIndex num)] -> [(Int, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
Permutation -> Slice num -> [(Int, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. Slice num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Slice num
is' Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] Slice num
is'
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oshp Bool
cg'
where
updatePerm :: t a -> t a -> [a]
updatePerm t a
ps t a
inds = ([a] -> a -> [a]) -> [a] -> t a -> [a]
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\[a]
acc a
p -> [a]
acc [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ a -> [a]
decrease a
p) [] t a
ps
where
decrease :: a -> [a]
decrease a
p =
let d :: a
d =
(a -> a -> a) -> a -> t a -> a
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \a
n a
i ->
if a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p
then -a
1
else
if a
i a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
p
then a
n
else
if a
n a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1
then a
n a -> a -> a
forall a. Num a => a -> a -> a
+ a
1
else a
n
)
a
0
t a
inds
in [a
p a -> a -> a
forall a. Num a => a -> a -> a
- a
d | a
d a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= -a
1]
harmlessRotation' ::
(Eq num, IntegralExp num) =>
LMADDim num ->
DimIndex num ->
Bool
harmlessRotation' :: forall num.
(Eq num, IntegralExp num) =>
LMADDim num -> DimIndex num -> Bool
harmlessRotation' LMADDim num
_ (DimFix num
_) = Bool
True
harmlessRotation' (LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) DimIndex num
_ = Bool
True
harmlessRotation' (LMADDim num
_ num
0 num
_ Int
_ Monotonicity
_) DimIndex num
_ = Bool
True
harmlessRotation' (LMADDim num
_ num
_ num
n Int
_ Monotonicity
_) DimIndex num
dslc
| DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1)
Bool -> Bool -> Bool
|| DimIndex num
dslc DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n =
Bool
True
harmlessRotation' LMADDim num
_ DimIndex num
_ = Bool
False
harmlessRotation ::
(Eq num, IntegralExp num) =>
LMAD num ->
Slice num ->
Bool
harmlessRotation :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
harmlessRotation (LMAD num
_ [LMADDim num]
dims) Slice num
iss =
[Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (LMADDim num -> DimIndex num -> Bool)
-> [LMADDim num] -> Slice num -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith LMADDim num -> DimIndex num -> Bool
forall num.
(Eq num, IntegralExp num) =>
LMADDim num -> DimIndex num -> Bool
harmlessRotation' [LMADDim num]
dims Slice num
iss
sliceOne ::
(Eq num, IntegralExp num) =>
LMAD num ->
(DimIndex num, LMADDim num) ->
LMAD num
sliceOne :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> (DimIndex num, LMADDim num) -> LMAD num
sliceOne (LMAD num
off [LMADDim num]
dims) (DimFix num
i, LMADDim num
s num
r num
n Int
_ Monotonicity
_) =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i) [LMADDim num]
dims
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
_ num
ne num
_, LMADDim num
0 num
_ num
_ Int
p Monotonicity
_) =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, dim :: LMADDim num
dim@(LMADDim num
_ num
_ num
n Int
_ Monotonicity
_))
| DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0 num
n = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [LMADDim num
dim])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimIndex num
dmind, LMADDim num
s num
r num
n Int
p Monotonicity
m)
| DimIndex num
dmind DimIndex num -> DimIndex num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1) num
n (-num
1) =
let r' :: num
r' = if num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
0 else num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
r
off' :: num
off' = num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
0, num
n) (num
n num -> num -> num
forall a. Num a => a -> a -> a
- num
1)
in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
s num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) num
r' num
n Int
p (Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m)])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
b num
ne num
0, LMADDim num
s num
r num
n Int
p Monotonicity
_) =
num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ (num, num, num) -> num -> num
forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
b) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
ne Int
p Monotonicity
Unknown])
sliceOne (LMAD num
off [LMADDim num]
dims) (DimSlice num
bs num
ns num
ss, LMADDim num
s num
0 num
_ Int
p Monotonicity
m) =
let m' :: Monotonicity
m' = case num -> Maybe Int
forall e. IntegralExp e => e -> Maybe Int
sgn num
ss of
Just Int
1 -> Monotonicity
m
Just (-1) -> Monotonicity -> Monotonicity
invertMonotonicity Monotonicity
m
Maybe Int
_ -> Monotonicity
Unknown
in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD (num
off num -> num -> num
forall a. Num a => a -> a -> a
+ num
s num -> num -> num
forall a. Num a => a -> a -> a
* num
bs) ([LMADDim num]
dims [LMADDim num] -> [LMADDim num] -> [LMADDim num]
forall a. [a] -> [a] -> [a]
++ [num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim (num
ss num -> num -> num
forall a. Num a => a -> a -> a
* num
s) num
0 num
ns Int
p Monotonicity
m'])
sliceOne LMAD num
_ (DimIndex num, LMADDim num)
_ = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"
slicePreservesContiguous ::
(Eq num, IntegralExp num) =>
LMAD num ->
Slice num ->
Bool
slicePreservesContiguous :: forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> Bool
slicePreservesContiguous (LMAD num
_ [LMADDim num]
dims) Slice num
slc =
let ([LMADDim num]
dims', Slice num
slc') =
[(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num))
-> [(LMADDim num, DimIndex num)] -> ([LMADDim num], Slice num)
forall a b. (a -> b) -> a -> b
$
((LMADDim num, DimIndex num) -> Bool)
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0) (num -> Bool)
-> ((LMADDim num, DimIndex num) -> num)
-> (LMADDim num, DimIndex num)
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> num
forall num. LMADDim num -> num
ldStride (LMADDim num -> num)
-> ((LMADDim num, DimIndex num) -> LMADDim num)
-> (LMADDim num, DimIndex num)
-> num
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (LMADDim num, DimIndex num) -> LMADDim num
forall a b. (a, b) -> a
fst) ([(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)])
-> [(LMADDim num, DimIndex num)] -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$
[LMADDim num] -> Slice num -> [(LMADDim num, DimIndex num)]
forall a b. [a] -> [b] -> [(a, b)]
zip [LMADDim num]
dims (Slice num -> [(LMADDim num, DimIndex num)])
-> Slice num -> [(LMADDim num, DimIndex num)]
forall a b. (a -> b) -> a -> b
$ (DimIndex num -> DimIndex num) -> Slice num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map DimIndex num -> DimIndex num
forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex Slice num
slc
(Bool
_, Bool
success) =
((Bool, Bool) -> (DimIndex num, LMADDim num) -> (Bool, Bool))
-> (Bool, Bool) -> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \(Bool
found, Bool
res) (DimIndex num
slcdim, LMADDim num
_ num
r num
n Int
_ Monotonicity
_) ->
case (DimIndex num
slcdim, Bool
found) of
(DimFix {}, Bool
True) -> (Bool
found, Bool
False)
(DimFix {}, Bool
False) -> (Bool
found, Bool
res)
(DimSlice num
_ num
ne num
ds, Bool
False) ->
let res' :: Bool
res' = (num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
True, Bool
res Bool -> Bool -> Bool
&& Bool
res')
(DimSlice num
_ num
ne num
ds, Bool
True) ->
let res' :: Bool
res' = (num
n num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
ne) Bool -> Bool -> Bool
&& (num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1 Bool -> Bool -> Bool
|| num
ds num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== -num
1)
in (Bool
found, Bool
res Bool -> Bool -> Bool
&& Bool
res')
)
(Bool
False, Bool
True)
([(DimIndex num, LMADDim num)] -> (Bool, Bool))
-> [(DimIndex num, LMADDim num)] -> (Bool, Bool)
forall a b. (a -> b) -> a -> b
$ Slice num -> [LMADDim num] -> [(DimIndex num, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Slice num
slc' [LMADDim num]
dims'
in Bool
success
normIndex ::
(Eq num, IntegralExp num) =>
DimIndex num ->
DimIndex num
normIndex :: forall num.
(Eq num, IntegralExp num) =>
DimIndex num -> DimIndex num
normIndex (DimSlice num
b num
1 num
_) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
normIndex (DimSlice num
b num
_ num
0) = num -> DimIndex num
forall d. d -> DimIndex d
DimFix num
b
normIndex DimIndex num
d = DimIndex num
d
slice ::
(Eq num, IntegralExp num) =>
IxFun num ->
Slice num ->
IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice IxFun num
_ [] = String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: empty slice"
slice ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
_ [LMADDim num]
_) :| [LMAD num]
lmads) Shape num
oshp Bool
cg) Slice num
dim_slices
| Slice num
dim_slices Slice num -> Slice num -> Bool
forall a. Eq a => a -> a -> Bool
== (num -> DimIndex num) -> Shape num -> Slice num
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) (IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
ixfun) = IxFun num
ixfun
| Just IxFun num
ixfun' <- IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD IxFun num
ixfun Slice num
dim_slices = IxFun num
ixfun'
| Bool
otherwise =
case IxFun num -> Slice num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> Maybe (IxFun num)
sliceOneLMAD (Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)) Slice num
dim_slices of
Just (IxFun (LMAD num
lmad' :| []) Shape num
_ Bool
cg') ->
NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmads) Shape num
oshp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg')
Maybe (IxFun num)
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"slice: reached impossible case"
reshapeCoercion ::
(Eq num, IntegralExp num) =>
IxFun num ->
ShapeChange num ->
Maybe (IxFun num)
reshapeCoercion :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion (IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) ShapeChange num
newshape = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
(ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
dims' :: [LMADDim num]
dims' = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims'
num_rshps :: Int
num_rshps = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| (Int
num_rshps Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
mid_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1))
let dims'' :: [LMADDim num]
dims'' =
Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteInv Permutation
perm ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> num -> LMADDim num)
-> [LMADDim num] -> Shape num -> [LMADDim num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
(\LMADDim num
ld num
n -> LMADDim num
ld {ldShape :: num
ldShape = num
n})
[LMADDim num]
dims'
(ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
newshape)
lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off [LMADDim num]
dims''
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
reshapeOneLMAD ::
(Eq num, IntegralExp num) =>
IxFun num ->
ShapeChange num ->
Maybe (IxFun num)
reshapeOneLMAD :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD ixfun :: IxFun num
ixfun@(IxFun (lmad :: LMAD num
lmad@(LMAD num
off [LMADDim num]
dims) :| [LMAD num]
lmads) Shape num
oldbase Bool
cg) ShapeChange num
newshape = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
(ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions) <- ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape
let hd_len :: Int
hd_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
head_coercions
num_coercions :: Int
num_coercions = Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
tail_coercions
dims_perm :: [LMADDim num]
dims_perm = Permutation -> [LMADDim num] -> [LMADDim num]
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm [LMADDim num]
dims
mid_dims :: [LMADDim num]
mid_dims = Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
take ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_coercions) ([LMADDim num] -> [LMADDim num]) -> [LMADDim num] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$ Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
hd_len [LMADDim num]
dims_perm
mon :: Monotonicity
mon = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
True IxFun num
ixfun
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (\(LMADDim num
s num
r num
_ Int
_ Monotonicity
_) -> num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
/= num
0 Bool -> Bool -> Bool
&& num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) [LMADDim num]
mid_dims
Bool -> Bool -> Bool
&&
Int -> Permutation -> Bool
forall {a}. (Eq a, Num a, Enum a) => a -> [a] -> Bool
consecutive Int
hd_len ((LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm [LMADDim num]
mid_dims)
Bool -> Bool -> Bool
&&
IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& (Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec)
let rsh_len :: Int
rsh_len = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
reshapes
diff :: Int
diff = ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
iota_shape :: Permutation
iota_shape = [Int
0 .. ShapeChange num -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeChange num
newshape Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
perm' :: Permutation
perm' =
(Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map
( \Int
i ->
let ind :: Int
ind =
if Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
then Int
i
else Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff
in if (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len) Bool -> Bool -> Bool
&& (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len)
then Int
i
else
let p :: Int
p = LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim num]
dims [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
ind)
in if Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len
then Int
p
else Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
diff
)
Permutation
iota_shape
([(Int, (num, num))]
support_inds, [(Int, num)]
repeat_inds) =
(([(Int, (num, num))], [(Int, num)])
-> (Int, DimChange num, Int)
-> ([(Int, (num, num))], [(Int, num)]))
-> ([(Int, (num, num))], [(Int, num)])
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \([(Int, (num, num))]
sup, [(Int, num)]
rpt) (Int
i, DimChange num
shpdim, Int
ip) ->
case (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
hd_len, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
hd_len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
rsh_len, DimChange num
shpdim) of
(Bool
True, Bool
_, DimCoercion num
n) ->
case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! Int
i of
(LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ([(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt)
(LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ((Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
(Bool
_, Bool
True, DimCoercion num
n) ->
case [LMADDim num]
dims_perm [LMADDim num] -> Int -> LMADDim num
forall a. [a] -> Int -> a
!! (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
diff) of
(LMADDim num
0 num
_ num
_ Int
_ Monotonicity
_) -> ([(Int, (num, num))]
sup, (Int
ip, num
n) (Int, num) -> [(Int, num)] -> [(Int, num)]
forall a. a -> [a] -> [a]
: [(Int, num)]
rpt)
(LMADDim num
_ num
r num
_ Int
_ Monotonicity
_) -> ((Int
ip, (num
r, num
n)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
(Bool
False, Bool
False, DimChange num
_) ->
((Int
ip, (num
0, DimChange num -> num
forall d. DimChange d -> d
newDim DimChange num
shpdim)) (Int, (num, num)) -> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. a -> [a] -> [a]
: [(Int, (num, num))]
sup, [(Int, num)]
rpt)
(Bool, Bool, DimChange num)
_ -> String -> ([(Int, (num, num))], [(Int, num)])
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
)
([], [])
([(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)]))
-> [(Int, DimChange num, Int)]
-> ([(Int, (num, num))], [(Int, num)])
forall a b. (a -> b) -> a -> b
$ [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a. [a] -> [a]
reverse ([(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)])
-> [(Int, DimChange num, Int)] -> [(Int, DimChange num, Int)]
forall a b. (a -> b) -> a -> b
$ Permutation
-> ShapeChange num -> Permutation -> [(Int, DimChange num, Int)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 Permutation
iota_shape ShapeChange num
newshape Permutation
perm'
(Permutation
sup_inds, [(num, num)]
support) = [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Int, (num, num))] -> (Permutation, [(num, num)]))
-> [(Int, (num, num))] -> (Permutation, [(num, num)])
forall a b. (a -> b) -> a -> b
$ ((Int, (num, num)) -> (Int, (num, num)) -> Ordering)
-> [(Int, (num, num))] -> [(Int, (num, num))]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, (num, num)) -> Int)
-> (Int, (num, num))
-> (Int, (num, num))
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, (num, num)) -> Int
forall a b. (a, b) -> a
fst) [(Int, (num, num))]
support_inds
(Permutation
rpt_inds, Shape num
repeats) = [(Int, num)] -> (Permutation, Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip [(Int, num)]
repeat_inds
LMAD num
off' [LMADDim num]
dims_sup = Monotonicity -> num -> [(num, num)] -> LMAD num
forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
repeats' :: [LMADDim num]
repeats' = (num -> LMADDim num) -> Shape num -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (\num
n -> num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
0 num
0 num
n Int
0 Monotonicity
Unknown) Shape num
repeats
dims' :: [LMADDim num]
dims' =
((Int, LMADDim num) -> LMADDim num)
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> [a] -> [b]
map (Int, LMADDim num) -> LMADDim num
forall a b. (a, b) -> b
snd ([(Int, LMADDim num)] -> [LMADDim num])
-> [(Int, LMADDim num)] -> [LMADDim num]
forall a b. (a -> b) -> a -> b
$
((Int, LMADDim num) -> (Int, LMADDim num) -> Ordering)
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, LMADDim num) -> Int)
-> (Int, LMADDim num)
-> (Int, LMADDim num)
-> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, LMADDim num) -> Int
forall a b. (a, b) -> a
fst) ([(Int, LMADDim num)] -> [(Int, LMADDim num)])
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a b. (a -> b) -> a -> b
$
Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
sup_inds [LMADDim num]
dims_sup [(Int, LMADDim num)]
-> [(Int, LMADDim num)] -> [(Int, LMADDim num)]
forall a. [a] -> [a] -> [a]
++ Permutation -> [LMADDim num] -> [(Int, LMADDim num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
rpt_inds [LMADDim num]
repeats'
lmad' :: LMAD num
lmad' = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off' [LMADDim num]
dims'
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm' LMAD num
lmad' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads) Shape num
oldbase Bool
cg
where
consecutive :: a -> [a] -> Bool
consecutive a
_ [] = Bool
True
consecutive a
i [a
p] = a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
p
consecutive a
i [a]
ps = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (a -> a -> Bool) -> [a] -> [a] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith a -> a -> Bool
forall a. Eq a => a -> a -> Bool
(==) [a]
ps [a
i, a
i a -> a -> a
forall a. Num a => a -> a -> a
+ a
1 ..]
splitCoercions ::
(Eq num, IntegralExp num) =>
ShapeChange num ->
Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions :: forall num.
(Eq num, IntegralExp num) =>
ShapeChange num
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
splitCoercions ShapeChange num
newshape' = do
let (ShapeChange num
head_coercions, ShapeChange num
newshape'') = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
span DimChange num -> Bool
forall {d}. DimChange d -> Bool
isCoercion ShapeChange num
newshape'
(ShapeChange num
reshapes, ShapeChange num
tail_coercions) = (DimChange num -> Bool)
-> ShapeChange num -> (ShapeChange num, ShapeChange num)
forall a. (a -> Bool) -> [a] -> ([a], [a])
break DimChange num -> Bool
forall {d}. DimChange d -> Bool
isCoercion ShapeChange num
newshape''
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((DimChange num -> Bool) -> ShapeChange num -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimChange num -> Bool
forall {d}. DimChange d -> Bool
isCoercion ShapeChange num
tail_coercions)
(ShapeChange num, ShapeChange num, ShapeChange num)
-> Maybe (ShapeChange num, ShapeChange num, ShapeChange num)
forall (m :: * -> *) a. Monad m => a -> m a
return (ShapeChange num
head_coercions, ShapeChange num
reshapes, ShapeChange num
tail_coercions)
where
isCoercion :: DimChange d -> Bool
isCoercion DimCoercion {} = Bool
True
isCoercion DimChange d
_ = Bool
False
reshape ::
(Eq num, IntegralExp num) =>
IxFun num ->
ShapeChange num ->
IxFun num
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
ixfun ShapeChange num
new_shape
| Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeCoercion IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
| Just IxFun num
ixfun' <- IxFun num -> ShapeChange num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> Maybe (IxFun num)
reshapeOneLMAD IxFun num
ixfun ShapeChange num
new_shape = IxFun num
ixfun'
reshape (IxFun (LMAD num
lmad0 :| [LMAD num]
lmad0s) Shape num
oshp Bool
cg) ShapeChange num
new_shape =
case Shape num -> IxFun num
forall num. IntegralExp num => Shape num -> IxFun num
iota (ShapeChange num -> Shape num
forall d. ShapeChange d -> [d]
newDims ShapeChange num
new_shape) of
IxFun (LMAD num
lmad :| []) Shape num
_ Bool
_ -> NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| LMAD num
lmad0 LMAD num -> [LMAD num] -> [LMAD num]
forall a. a -> [a] -> [a]
: [LMAD num]
lmad0s) Shape num
oshp Bool
cg
IxFun num
_ -> String -> IxFun num
forall a. HasCallStack => String -> a
error String
"reshape: reached impossible case"
rank ::
IntegralExp num =>
IxFun num ->
Int
rank :: forall num. IntegralExp num => IxFun num -> Int
rank (IxFun (LMAD num
_ [LMADDim num]
sss :| [LMAD num]
_) Shape num
_ Bool
_) = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
sss
rebaseNice ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
Maybe (IxFun num)
rebaseNice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice
new_base :: IxFun num
new_base@(IxFun (LMAD num
lmad_base :| [LMAD num]
lmads_base) Shape num
_ Bool
cg_base)
ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg) = do
let (LMAD num
lmad :| [LMAD num]
lmads') = NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a
NE.reverse NonEmpty (LMAD num)
lmads
dims :: [LMADDim num]
dims = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad
perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_base :: Permutation
perm_base = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad_base
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
IxFun num -> Shape num
forall num. IxFun num -> Shape num
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
Bool -> Bool -> Bool
&& Bool
cg
Bool -> Bool -> Bool
&& (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
/= Monotonicity
Unknown) (Monotonicity -> Bool)
-> (LMADDim num -> Monotonicity) -> LMADDim num -> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon) [LMADDim num]
dims
Bool -> Bool -> Bool
&& (IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
|| IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
new_base)
Bool -> Bool -> Bool
&& (Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm_base Bool -> Bool -> Bool
|| IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun)
Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
( (num -> LMADDim num -> Bool -> Bool)
-> Shape num -> [LMADDim num] -> [Bool] -> [Bool]
forall a b c d. (a -> b -> c -> d) -> [a] -> [b] -> [c] -> [d]
zipWith3
(\num
sn LMADDim num
ld Bool
inner -> num
sn num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== LMADDim num -> num
forall num. LMADDim num -> num
ldShape LMADDim num
ld Bool -> Bool -> Bool
|| (Bool
inner Bool -> Bool -> Bool
&& LMADDim num -> num
forall num. LMADDim num -> num
ldStride LMADDim num
ld num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
1))
Shape num
shp
[LMADDim num]
dims
(Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
replicate ([LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Bool
False [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ [Bool
True])
)
let perm_base' :: Permutation
perm_base' =
if IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun
then Permutation
perm_base
else (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map (Permutation
perm Permutation -> Int -> Int
forall a. [a] -> Int -> a
!!) Permutation
perm_base
lmad_base' :: LMAD num
lmad_base' = Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_base' LMAD num
lmad_base
dims_base :: [LMADDim num]
dims_base = LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad_base'
n_fewer_dims :: Int
n_fewer_dims = [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims_base Int -> Int -> Int
forall a. Num a => a -> a -> a
- [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [LMADDim num]
dims
([LMADDim num]
dims_base', Shape num
offs_contrib) =
[(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. [(a, b)] -> ([a], [b])
unzip ([(LMADDim num, num)] -> ([LMADDim num], Shape num))
-> [(LMADDim num, num)] -> ([LMADDim num], Shape num)
forall a b. (a -> b) -> a -> b
$
(LMADDim num -> LMADDim num -> (LMADDim num, num))
-> [LMADDim num] -> [LMADDim num] -> [(LMADDim num, num)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
( \(LMADDim num
s1 num
r1 num
n1 Int
p1 Monotonicity
_) (LMADDim num
_ num
r2 num
_ Int
_ Monotonicity
m2) ->
let (num
s', num
off')
| Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = (num
s1, num
0)
| Bool
otherwise = (num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1), num
s1 num -> num -> num
forall a. Num a => a -> a -> a
* (num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
1))
r' :: num
r'
| Monotonicity
m2 Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc = if num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 then num
r1 else num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
| num
r1 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
r2
| num
r2 num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1
| Bool
otherwise = num
n1 num -> num -> num
forall a. Num a => a -> a -> a
- num
r1 num -> num -> num
forall a. Num a => a -> a -> a
+ num
r2
in (num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim num
s' num
r' num
n1 (Int
p1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n_fewer_dims) Monotonicity
Inc, num
off')
)
(Int -> [LMADDim num] -> [LMADDim num]
forall a. Int -> [a] -> [a]
drop Int
n_fewer_dims [LMADDim num]
dims_base)
[LMADDim num]
dims
off_base :: num
off_base = LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad_base' num -> num -> num
forall a. Num a => a -> a -> a
+ Shape num -> num
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum Shape num
offs_contrib
lmad_base'' :: LMAD num
lmad_base''
| LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off_base [LMADDim num]
dims_base'
| Bool
otherwise =
Shape num -> LMAD num -> LMAD num
forall num. Shape num -> LMAD num -> LMAD num
setLMADShape
(LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD num
lmad)
( num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD
(num
off_base num -> num -> num
forall a. Num a => a -> a -> a
+ LMADDim num -> num
forall num. LMADDim num -> num
ldStride ([LMADDim num] -> LMADDim num
forall a. [a] -> a
last [LMADDim num]
dims_base) num -> num -> num
forall a. Num a => a -> a -> a
* LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad)
[LMADDim num]
dims_base'
)
new_base' :: IxFun num
new_base' = NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD num
lmad_base'' LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| [LMAD num]
lmads_base) Shape num
shp Bool
cg_base
IxFun NonEmpty (LMAD num)
lmads_base' Shape num
_ Bool
_ = IxFun num
new_base'
lmads'' :: NonEmpty (LMAD num)
lmads'' = [LMAD num]
lmads' [LMAD num] -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. [a] -> NonEmpty a -> NonEmpty a
++@ NonEmpty (LMAD num)
lmads_base'
IxFun num -> Maybe (IxFun num)
forall (m :: * -> *) a. Monad m => a -> m a
return (IxFun num -> Maybe (IxFun num)) -> IxFun num -> Maybe (IxFun num)
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun NonEmpty (LMAD num)
lmads'' Shape num
shp (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
rebase ::
(Eq num, IntegralExp num) =>
IxFun num ->
IxFun num ->
IxFun num
rebase :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> IxFun num
rebase new_base :: IxFun num
new_base@(IxFun NonEmpty (LMAD num)
lmads_base Shape num
shp_base Bool
cg_base) ixfun :: IxFun num
ixfun@(IxFun NonEmpty (LMAD num)
lmads Shape num
shp Bool
cg)
| Just IxFun num
ixfun' <- IxFun num -> IxFun num -> Maybe (IxFun num)
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> IxFun num -> Maybe (IxFun num)
rebaseNice IxFun num
new_base IxFun num
ixfun = IxFun num
ixfun'
| Bool
otherwise =
let (NonEmpty (LMAD num)
lmads_base', Shape num
shp_base') =
if IxFun num -> Shape num
forall num. IxFun num -> Shape num
base IxFun num
ixfun Shape num -> Shape num -> Bool
forall a. Eq a => a -> a -> Bool
== IxFun num -> Shape num
forall num. (Eq num, IntegralExp num) => IxFun num -> Shape num
shape IxFun num
new_base
then (NonEmpty (LMAD num)
lmads_base, Shape num
shp_base)
else
let IxFun NonEmpty (LMAD num)
lmads' Shape num
shp_base'' Bool
_ = IxFun num -> ShapeChange num -> IxFun num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> ShapeChange num -> IxFun num
reshape IxFun num
new_base (ShapeChange num -> IxFun num) -> ShapeChange num -> IxFun num
forall a b. (a -> b) -> a -> b
$ (num -> DimChange num) -> Shape num -> ShapeChange num
forall a b. (a -> b) -> [a] -> [b]
map num -> DimChange num
forall d. d -> DimChange d
DimCoercion Shape num
shp
in (NonEmpty (LMAD num)
lmads', Shape num
shp_base'')
in NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (NonEmpty (LMAD num)
lmads NonEmpty (LMAD num) -> NonEmpty (LMAD num) -> NonEmpty (LMAD num)
forall a. NonEmpty a -> NonEmpty a -> NonEmpty a
@++@ NonEmpty (LMAD num)
lmads_base') Shape num
shp_base' (Bool
cg Bool -> Bool -> Bool
&& Bool
cg_base)
ixfunMonotonicity :: (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity :: forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity = Bool -> IxFun num -> Monotonicity
forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
False
linearWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe num
linearWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset ixfun :: IxFun num
ixfun@(IxFun (LMAD num
lmad :| []) Shape num
_ Bool
cg) num
elem_size
| IxFun num -> Bool
forall num. IxFun num -> Bool
hasContiguousPerm IxFun num
ixfun Bool -> Bool -> Bool
&& Bool
cg Bool -> Bool -> Bool
&& IxFun num -> Monotonicity
forall num. (Eq num, IntegralExp num) => IxFun num -> Monotonicity
ixfunMonotonicity IxFun num
ixfun Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc =
num -> Maybe num
forall a. a -> Maybe a
Just (num -> Maybe num) -> num -> Maybe num
forall a b. (a -> b) -> a -> b
$ LMAD num -> num
forall num. LMAD num -> num
lmadOffset LMAD num
lmad num -> num -> num
forall a. Num a => a -> a -> a
* num
elem_size
linearWithOffset IxFun num
_ num
_ = Maybe num
forall a. Maybe a
Nothing
rearrangeWithOffset ::
(Eq num, IntegralExp num) =>
IxFun num ->
num ->
Maybe (num, [(Int, num)])
rearrangeWithOffset :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe (num, [(Int, num)])
rearrangeWithOffset (IxFun (LMAD num
lmad :| []) Shape num
oshp Bool
cg) num
elem_size = do
let perm :: Permutation
perm = LMAD num -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD num
lmad
perm_contig :: Permutation
perm_contig = [Int
0 .. Permutation -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Permutation
perm Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
num
offset <-
IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset
(NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (Permutation -> LMAD num -> LMAD num
forall num. Permutation -> LMAD num -> LMAD num
setLMADPermutation Permutation
perm_contig LMAD num
lmad LMAD num -> [LMAD num] -> NonEmpty (LMAD num)
forall a. a -> [a] -> NonEmpty a
:| []) Shape num
oshp Bool
cg)
num
elem_size
(num, [(Int, num)]) -> Maybe (num, [(Int, num)])
forall (m :: * -> *) a. Monad m => a -> m a
return (num
offset, Permutation -> Shape num -> [(Int, num)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
perm (Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
perm (LMAD num -> Shape num
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShapeBase LMAD num
lmad)))
rearrangeWithOffset IxFun num
_ num
_ = Maybe (num, [(Int, num)])
forall a. Maybe a
Nothing
isLinear :: (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear :: forall num. (Eq num, IntegralExp num) => IxFun num -> Bool
isLinear = (Maybe num -> Maybe num -> Bool
forall a. Eq a => a -> a -> Bool
== num -> Maybe num
forall a. a -> Maybe a
Just num
0) (Maybe num -> Bool)
-> (IxFun num -> Maybe num) -> IxFun num -> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. (IxFun num -> num -> Maybe num) -> num -> IxFun num -> Maybe num
forall a b c. (a -> b -> c) -> b -> a -> c
flip IxFun num -> num -> Maybe num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> num -> Maybe num
linearWithOffset num
1
permuteFwd :: Permutation -> [a] -> [a]
permuteFwd :: forall a. Permutation -> [a] -> [a]
permuteFwd Permutation
ps [a]
elems = (Int -> a) -> Permutation -> [a]
forall a b. (a -> b) -> [a] -> [b]
map ([a]
elems [a] -> Int -> a
forall a. [a] -> Int -> a
!!) Permutation
ps
permuteInv :: Permutation -> [a] -> [a]
permuteInv :: forall a. Permutation -> [a] -> [a]
permuteInv Permutation
ps [a]
elems = ((Int, a) -> a) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Int, a) -> a
forall a b. (a, b) -> b
snd ([(Int, a)] -> [a]) -> [(Int, a)] -> [a]
forall a b. (a -> b) -> a -> b
$ ((Int, a) -> (Int, a) -> Ordering) -> [(Int, a)] -> [(Int, a)]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy (Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
compare (Int -> Int -> Ordering)
-> ((Int, a) -> Int) -> (Int, a) -> (Int, a) -> Ordering
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` (Int, a) -> Int
forall a b. (a, b) -> a
fst) ([(Int, a)] -> [(Int, a)]) -> [(Int, a)] -> [(Int, a)]
forall a b. (a -> b) -> a -> b
$ Permutation -> [a] -> [(Int, a)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
ps [a]
elems
flatOneDim ::
(Eq num, IntegralExp num) =>
(num, num, num) ->
num ->
num
flatOneDim :: forall num.
(Eq num, IntegralExp num) =>
(num, num, num) -> num -> num
flatOneDim (num
s, num
r, num
n) num
i
| num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
0
| num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 = num
i num -> num -> num
forall a. Num a => a -> a -> a
* num
s
| Bool
otherwise = ((num
i num -> num -> num
forall a. Num a => a -> a -> a
+ num
r) num -> num -> num
forall e. IntegralExp e => e -> e -> e
`mod` num
n) num -> num -> num
forall a. Num a => a -> a -> a
* num
s
makeRotIota ::
IntegralExp num =>
Monotonicity ->
num ->
[(num, num)] ->
LMAD num
makeRotIota :: forall num.
IntegralExp num =>
Monotonicity -> num -> [(num, num)] -> LMAD num
makeRotIota Monotonicity
mon num
off [(num, num)]
support
| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc Bool -> Bool -> Bool
|| Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Dec =
let rk :: Int
rk = [(num, num)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(num, num)]
support
([num]
rs, [num]
ns) = [(num, num)] -> ([num], [num])
forall a b. [(a, b)] -> ([a], [b])
unzip [(num, num)]
support
ss0 :: [num]
ss0 = [num] -> [num]
forall a. [a] -> [a]
reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
take Int
rk ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
reverse [num]
ns
ss :: [num]
ss =
if Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
Inc
then [num]
ss0
else (num -> num) -> [num] -> [num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> num
forall a. Num a => a -> a -> a
* (-num
1)) [num]
ss0
ps :: Permutation
ps = (Int -> Int) -> Permutation -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral [Int
0 .. Int
rk Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]
fi :: [Monotonicity]
fi = Int -> Monotonicity -> [Monotonicity]
forall a. Int -> a -> [a]
replicate Int
rk Monotonicity
mon
in num -> [LMADDim num] -> LMAD num
forall num. num -> [LMADDim num] -> LMAD num
LMAD num
off ([LMADDim num] -> LMAD num) -> [LMADDim num] -> LMAD num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num -> Int -> Monotonicity -> LMADDim num)
-> [num]
-> [num]
-> [num]
-> Permutation
-> [Monotonicity]
-> [LMADDim num]
forall a b c d e f.
(a -> b -> c -> d -> e -> f)
-> [a] -> [b] -> [c] -> [d] -> [e] -> [f]
zipWith5 num -> num -> num -> Int -> Monotonicity -> LMADDim num
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim [num]
ss [num]
rs [num]
ns Permutation
ps [Monotonicity]
fi
| Bool
otherwise = String -> LMAD num
forall a. HasCallStack => String -> a
error String
"makeRotIota: requires Inc or Dec"
ixfunMonotonicityRots ::
(Eq num, IntegralExp num) =>
Bool ->
IxFun num ->
Monotonicity
ixfunMonotonicityRots :: forall num.
(Eq num, IntegralExp num) =>
Bool -> IxFun num -> Monotonicity
ixfunMonotonicityRots Bool
ignore_rots (IxFun (LMAD num
lmad :| [LMAD num]
lmads) Shape num
_ Bool
_) =
let mon0 :: Monotonicity
mon0 = LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots LMAD num
lmad
in if (LMAD num -> Bool) -> [LMAD num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
mon0) (Monotonicity -> Bool)
-> (LMAD num -> Monotonicity) -> LMAD num -> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> Monotonicity
forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots) [LMAD num]
lmads
then Monotonicity
mon0
else Monotonicity
Unknown
where
lmadMonotonicityRots ::
(Eq num, IntegralExp num) =>
LMAD num ->
Monotonicity
lmadMonotonicityRots :: forall num. (Eq num, IntegralExp num) => LMAD num -> Monotonicity
lmadMonotonicityRots (LMAD num
_ [LMADDim num]
dims)
| (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Inc) [LMADDim num]
dims = Monotonicity
Inc
| (LMADDim num -> Bool) -> [LMADDim num] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Monotonicity -> LMADDim num -> Bool
forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
Dec) [LMADDim num]
dims = Monotonicity
Dec
| Bool
otherwise = Monotonicity
Unknown
isMonDim ::
(Eq num, IntegralExp num) =>
Monotonicity ->
LMADDim num ->
Bool
isMonDim :: forall num.
(Eq num, IntegralExp num) =>
Monotonicity -> LMADDim num -> Bool
isMonDim Monotonicity
mon (LMADDim num
s num
r num
_ Int
_ Monotonicity
ldmon) =
num
s num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0 Bool -> Bool -> Bool
|| ((Bool
ignore_rots Bool -> Bool -> Bool
|| num
r num -> num -> Bool
forall a. Eq a => a -> a -> Bool
== num
0) Bool -> Bool -> Bool
&& Monotonicity
mon Monotonicity -> Monotonicity -> Bool
forall a. Eq a => a -> a -> Bool
== Monotonicity
ldmon)
leastGeneralGeneralization ::
Eq v =>
IxFun (PrimExp v) ->
IxFun (PrimExp v) ->
Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization :: forall v.
Eq v =>
IxFun (PrimExp v)
-> IxFun (PrimExp v)
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
leastGeneralGeneralization (IxFun (LMAD (PrimExp v)
lmad1 :| []) Shape (PrimExp v)
oshp1 Bool
ctg1) (IxFun (LMAD (PrimExp v)
lmad2 :| []) Shape (PrimExp v)
oshp2 Bool
ctg2) = do
Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$
Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Shape (PrimExp v) -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Shape (PrimExp v)
oshp2
Bool -> Bool -> Bool
&& Bool
ctg1 Bool -> Bool -> Bool
forall a. Eq a => a -> a -> Bool
== Bool
ctg2
Bool -> Bool -> Bool
&& (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad1) Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim (PrimExp v) -> Int)
-> [LMADDim (PrimExp v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (PrimExp v) -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD (PrimExp v) -> [LMADDim (PrimExp v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (PrimExp v)
lmad2)
Bool -> Bool -> Bool
&& LMAD (PrimExp v) -> [Monotonicity]
forall {num}. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1 [Monotonicity] -> [Monotonicity] -> Bool
forall a. Eq a => a -> a -> Bool
== LMAD (PrimExp v) -> [Monotonicity]
forall {num}. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad2
let (Bool
ctg, Permutation
dperm, [Monotonicity]
dmon) = (Bool
ctg1, LMAD (PrimExp v) -> Permutation
forall num. LMAD num -> Permutation
lmadPermutation LMAD (PrimExp v)
lmad1, LMAD (PrimExp v) -> [Monotonicity]
forall {num}. LMAD num -> [Monotonicity]
lmadDMon LMAD (PrimExp v)
lmad1)
([PrimExp (Ext v)]
dshp, [(PrimExp v, PrimExp v)]
m1) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall {m :: * -> *} {v}.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [] (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDShp LMAD (PrimExp v)
lmad2)
([PrimExp (Ext v)]
oshp, [(PrimExp v, PrimExp v)]
m2) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall {m :: * -> *} {v}.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m1 Shape (PrimExp v)
oshp1 Shape (PrimExp v)
oshp2
([PrimExp (Ext v)]
dstd, [(PrimExp v, PrimExp v)]
m3) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall {m :: * -> *} {v}.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m2 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDSrd LMAD (PrimExp v)
lmad2)
([PrimExp (Ext v)]
drot, [(PrimExp v, PrimExp v)]
m4) <- [(PrimExp v, PrimExp v)]
-> Shape (PrimExp v)
-> Shape (PrimExp v)
-> Maybe ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall {m :: * -> *} {v}.
(Monad m, Eq v) =>
[(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m3 (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> Shape (PrimExp v)
forall a. LMAD a -> [a]
lmadDRot LMAD (PrimExp v)
lmad2)
let (PrimExp (Ext v)
offt, [(PrimExp v, PrimExp v)]
m5) = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m4 (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad1) (LMAD (PrimExp v) -> PrimExp v
forall num. LMAD num -> num
lmadOffset LMAD (PrimExp v)
lmad2)
let lmad_dims :: [LMADDim (PrimExp (Ext v))]
lmad_dims =
((PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)
-> LMADDim (PrimExp (Ext v)))
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> [a] -> [b]
map (\(PrimExp (Ext v)
a, PrimExp (Ext v)
b, PrimExp (Ext v)
c, Int
d, Monotonicity
e) -> PrimExp (Ext v)
-> PrimExp (Ext v)
-> PrimExp (Ext v)
-> Int
-> Monotonicity
-> LMADDim (PrimExp (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim PrimExp (Ext v)
a PrimExp (Ext v)
b PrimExp (Ext v)
c Int
d Monotonicity
e) ([(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
-> [LMADDim (PrimExp (Ext v))])
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
-> [LMADDim (PrimExp (Ext v))]
forall a b. (a -> b) -> a -> b
$
[PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> [PrimExp (Ext v)]
-> Permutation
-> [Monotonicity]
-> [(PrimExp (Ext v), PrimExp (Ext v), PrimExp (Ext v), Int,
Monotonicity)]
forall a b c d e.
[a] -> [b] -> [c] -> [d] -> [e] -> [(a, b, c, d, e)]
zip5 [PrimExp (Ext v)]
dstd [PrimExp (Ext v)]
drot [PrimExp (Ext v)]
dshp Permutation
dperm [Monotonicity]
dmon
lmad :: LMAD (PrimExp (Ext v))
lmad = PrimExp (Ext v)
-> [LMADDim (PrimExp (Ext v))] -> LMAD (PrimExp (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD PrimExp (Ext v)
offt [LMADDim (PrimExp (Ext v))]
lmad_dims
(IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
-> Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return (NonEmpty (LMAD (PrimExp (Ext v)))
-> [PrimExp (Ext v)] -> Bool -> IxFun (PrimExp (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (PrimExp (Ext v))
lmad LMAD (PrimExp (Ext v))
-> [LMAD (PrimExp (Ext v))] -> NonEmpty (LMAD (PrimExp (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [PrimExp (Ext v)]
oshp Bool
ctg, [(PrimExp v, PrimExp v)]
m5)
where
lmadDMon :: LMAD num -> [Monotonicity]
lmadDMon = (LMADDim num -> Monotonicity) -> [LMADDim num] -> [Monotonicity]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Monotonicity
forall num. LMADDim num -> Monotonicity
ldMon ([LMADDim num] -> [Monotonicity])
-> (LMAD num -> [LMADDim num]) -> LMAD num -> [Monotonicity]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims
lmadDSrd :: LMAD b -> [b]
lmadDSrd = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldStride ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
lmadDShp :: LMAD b -> [b]
lmadDShp = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldShape ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
lmadDRot :: LMAD b -> [b]
lmadDRot = (LMADDim b -> b) -> [LMADDim b] -> [b]
forall a b. (a -> b) -> [a] -> [b]
map LMADDim b -> b
forall num. LMADDim num -> num
ldRotate ([LMADDim b] -> [b]) -> (LMAD b -> [LMADDim b]) -> LMAD b -> [b]
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMAD b -> [LMADDim b]
forall num. LMAD num -> [LMADDim num]
lmadDims
generalize :: [(PrimExp v, PrimExp v)]
-> [PrimExp v]
-> [PrimExp v]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
generalize [(PrimExp v, PrimExp v)]
m [PrimExp v]
l1 [PrimExp v]
l2 =
(([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> (PrimExp v, PrimExp v)
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)]))
-> ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> [(PrimExp v, PrimExp v)]
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM
( \([PrimExp (Ext v)]
l_acc, [(PrimExp v, PrimExp v)]
m') (PrimExp v
pe1, PrimExp v
pe2) -> do
let (PrimExp (Ext v)
e, [(PrimExp v, PrimExp v)]
m'') = [(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
forall v.
Eq v =>
[(PrimExp v, PrimExp v)]
-> PrimExp v
-> PrimExp v
-> (PrimExp (Ext v), [(PrimExp v, PrimExp v)])
PEG.leastGeneralGeneralization [(PrimExp v, PrimExp v)]
m' PrimExp v
pe1 PrimExp v
pe2
([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
-> m ([PrimExp (Ext v)], [(PrimExp v, PrimExp v)])
forall (m :: * -> *) a. Monad m => a -> m a
return ([PrimExp (Ext v)]
l_acc [PrimExp (Ext v)] -> [PrimExp (Ext v)] -> [PrimExp (Ext v)]
forall a. [a] -> [a] -> [a]
++ [PrimExp (Ext v)
e], [(PrimExp v, PrimExp v)]
m'')
)
([], [(PrimExp v, PrimExp v)]
m)
([PrimExp v] -> [PrimExp v] -> [(PrimExp v, PrimExp v)]
forall a b. [a] -> [b] -> [(a, b)]
zip [PrimExp v]
l1 [PrimExp v]
l2)
leastGeneralGeneralization IxFun (PrimExp v)
_ IxFun (PrimExp v)
_ = Maybe (IxFun (PrimExp (Ext v)), [(PrimExp v, PrimExp v)])
forall a. Maybe a
Nothing
isSequential :: [Int] -> Bool
isSequential :: Permutation -> Bool
isSequential Permutation
xs =
((Int, Int) -> Bool) -> [(Int, Int)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((Int -> Int -> Bool) -> (Int, Int) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
(==)) ([(Int, Int)] -> Bool) -> [(Int, Int)] -> Bool
forall a b. (a -> b) -> a -> b
$ Permutation -> Permutation -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip Permutation
xs [Int
0 ..]
existentializeExp :: TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp :: forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
e = do
Int
i <- ([TPrimExp t v] -> Int) -> StateT [TPrimExp t v] Identity Int
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length
([TPrimExp t v] -> [TPrimExp t v])
-> StateT [TPrimExp t v] Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ([TPrimExp t v] -> [TPrimExp t v] -> [TPrimExp t v]
forall a. [a] -> [a] -> [a]
++ [TPrimExp t v
e])
let t :: PrimType
t = PrimExp v -> PrimType
forall v. PrimExp v -> PrimType
primExpType (PrimExp v -> PrimType) -> PrimExp v -> PrimType
forall a b. (a -> b) -> a -> b
$ TPrimExp t v -> PrimExp v
forall t v. TPrimExp t v -> PrimExp v
untyped TPrimExp t v
e
TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall (m :: * -> *) a. Monad m => a -> m a
return (TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v)))
-> TPrimExp t (Ext v) -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall a b. (a -> b) -> a -> b
$ PrimExp (Ext v) -> TPrimExp t (Ext v)
forall t v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp (Ext v) -> TPrimExp t (Ext v))
-> PrimExp (Ext v) -> TPrimExp t (Ext v)
forall a b. (a -> b) -> a -> b
$ Ext v -> PrimType -> PrimExp (Ext v)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext v
forall a. Int -> Ext a
Ext Int
i) PrimType
t
existentialize ::
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v) ->
State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
existentialize :: forall t v.
(IntExp t, Eq v, Pretty v) =>
IxFun (TPrimExp t v)
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
existentialize (IxFun (LMAD (TPrimExp t v)
lmad :| []) [TPrimExp t v]
oshp Bool
True)
| (LMADDim (TPrimExp t v) -> Bool)
-> [LMADDim (TPrimExp t v)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all ((TPrimExp t v -> TPrimExp t v -> Bool
forall a. Eq a => a -> a -> Bool
== TPrimExp t v
0) (TPrimExp t v -> Bool)
-> (LMADDim (TPrimExp t v) -> TPrimExp t v)
-> LMADDim (TPrimExp t v)
-> Bool
forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. LMADDim (TPrimExp t v) -> TPrimExp t v
forall num. LMADDim num -> num
ldRotate) (LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad),
[TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD (TPrimExp t v) -> [TPrimExp t v]
forall num. (Eq num, IntegralExp num) => LMAD num -> Shape num
lmadShape LMAD (TPrimExp t v)
lmad) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp t v] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TPrimExp t v]
oshp,
Permutation -> Bool
isSequential ((LMADDim (TPrimExp t v) -> Int)
-> [LMADDim (TPrimExp t v)] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim (TPrimExp t v) -> Int
forall num. LMADDim num -> Int
ldPerm ([LMADDim (TPrimExp t v)] -> Permutation)
-> [LMADDim (TPrimExp t v)] -> Permutation
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad) = do
[TPrimExp t (Ext v)]
oshp' <- (TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v)))
-> [TPrimExp t v]
-> StateT [TPrimExp t v] Identity [TPrimExp t (Ext v)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TPrimExp t v -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp [TPrimExp t v]
oshp
TPrimExp t (Ext v)
lmadOffset' <- TPrimExp t v -> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp (TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v)))
-> TPrimExp t v
-> StateT [TPrimExp t v] Identity (TPrimExp t (Ext v))
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> TPrimExp t v
forall num. LMAD num -> num
lmadOffset LMAD (TPrimExp t v)
lmad
[LMADDim (TPrimExp t (Ext v))]
lmadDims' <- (LMADDim (TPrimExp t v)
-> StateT [TPrimExp t v] Identity (LMADDim (TPrimExp t (Ext v))))
-> [LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM LMADDim (TPrimExp t v)
-> StateT [TPrimExp t v] Identity (LMADDim (TPrimExp t (Ext v)))
forall t v.
LMADDim (TPrimExp t v)
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim ([LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))])
-> [LMADDim (TPrimExp t v)]
-> StateT [TPrimExp t v] Identity [LMADDim (TPrimExp t (Ext v))]
forall a b. (a -> b) -> a -> b
$ LMAD (TPrimExp t v) -> [LMADDim (TPrimExp t v)]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD (TPrimExp t v)
lmad
let lmad' :: LMAD (TPrimExp t (Ext v))
lmad' = TPrimExp t (Ext v)
-> [LMADDim (TPrimExp t (Ext v))] -> LMAD (TPrimExp t (Ext v))
forall num. num -> [LMADDim num] -> LMAD num
LMAD TPrimExp t (Ext v)
lmadOffset' [LMADDim (TPrimExp t (Ext v))]
lmadDims'
Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v)))))
-> Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall a b. (a -> b) -> a -> b
$ IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v)))
forall a. a -> Maybe a
Just (IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v))))
-> IxFun (TPrimExp t (Ext v)) -> Maybe (IxFun (TPrimExp t (Ext v)))
forall a b. (a -> b) -> a -> b
$ NonEmpty (LMAD (TPrimExp t (Ext v)))
-> [TPrimExp t (Ext v)] -> Bool -> IxFun (TPrimExp t (Ext v))
forall num. NonEmpty (LMAD num) -> Shape num -> Bool -> IxFun num
IxFun (LMAD (TPrimExp t (Ext v))
lmad' LMAD (TPrimExp t (Ext v))
-> [LMAD (TPrimExp t (Ext v))]
-> NonEmpty (LMAD (TPrimExp t (Ext v)))
forall a. a -> [a] -> NonEmpty a
:| []) [TPrimExp t (Ext v)]
oshp' Bool
True
where
existentializeLMADDim ::
LMADDim (TPrimExp t v) ->
State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim :: forall t v.
LMADDim (TPrimExp t v)
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
existentializeLMADDim (LMADDim TPrimExp t v
str TPrimExp t v
rot TPrimExp t v
shp Int
perm Monotonicity
mon) = do
TPrimExp t (Ext v)
stride' <- TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
str
TPrimExp t (Ext v)
shape' <- TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
forall t v.
TPrimExp t v -> State [TPrimExp t v] (TPrimExp t (Ext v))
existentializeExp TPrimExp t v
shp
LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
forall (m :: * -> *) a. Monad m => a -> m a
return (LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v))))
-> LMADDim (TPrimExp t (Ext v))
-> State [TPrimExp t v] (LMADDim (TPrimExp t (Ext v)))
forall a b. (a -> b) -> a -> b
$ TPrimExp t (Ext v)
-> TPrimExp t (Ext v)
-> TPrimExp t (Ext v)
-> Int
-> Monotonicity
-> LMADDim (TPrimExp t (Ext v))
forall num. num -> num -> num -> Int -> Monotonicity -> LMADDim num
LMADDim TPrimExp t (Ext v)
stride' ((v -> Ext v) -> TPrimExp t v -> TPrimExp t (Ext v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap v -> Ext v
forall a. a -> Ext a
Free TPrimExp t v
rot) TPrimExp t (Ext v)
shape' Int
perm Monotonicity
mon
existentialize IxFun (TPrimExp t v)
_ = Maybe (IxFun (TPrimExp t (Ext v)))
-> State [TPrimExp t v] (Maybe (IxFun (TPrimExp t (Ext v))))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (IxFun (TPrimExp t (Ext v)))
forall a. Maybe a
Nothing
closeEnough :: IxFun num -> IxFun num -> Bool
closeEnough :: forall num. IxFun num -> IxFun num -> Bool
closeEnough IxFun num
ixf1 IxFun num
ixf2 =
([num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (IxFun num -> [num]
forall num. IxFun num -> Shape num
base IxFun num
ixf2))
Bool -> Bool -> Bool
&& (NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== NonEmpty (LMAD num) -> Int
forall a. NonEmpty a -> Int
NE.length (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
Bool -> Bool -> Bool
&& ((LMAD num, LMAD num) -> Bool)
-> NonEmpty (LMAD num, LMAD num) -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (LMAD num, LMAD num) -> Bool
forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (NonEmpty (LMAD num)
-> NonEmpty (LMAD num) -> NonEmpty (LMAD num, LMAD num)
forall a b. NonEmpty a -> NonEmpty b -> NonEmpty (a, b)
NE.zip (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf1) (IxFun num -> NonEmpty (LMAD num)
forall num. IxFun num -> NonEmpty (LMAD num)
ixfunLMADs IxFun num
ixf2))
where
closeEnoughLMADs :: (LMAD num, LMAD num) -> Bool
closeEnoughLMADs :: forall num. (LMAD num, LMAD num) -> Bool
closeEnoughLMADs (LMAD num
lmad1, LMAD num
lmad2) =
[LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [LMADDim num] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)
Bool -> Bool -> Bool
&& (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad1)
Permutation -> Permutation -> Bool
forall a. Eq a => a -> a -> Bool
== (LMADDim num -> Int) -> [LMADDim num] -> Permutation
forall a b. (a -> b) -> [a] -> [b]
map LMADDim num -> Int
forall num. LMADDim num -> Int
ldPerm (LMAD num -> [LMADDim num]
forall num. LMAD num -> [LMADDim num]
lmadDims LMAD num
lmad2)