module Futhark.IR.Prop.Rearrange
( rearrangeShape,
rearrangeInverse,
rearrangeReach,
rearrangeCompose,
isPermutationOf,
transposeIndex,
isMapTranspose,
)
where
import Data.List (sortOn, tails)
import Futhark.Util
rearrangeShape :: [Int] -> [a] -> [a]
rearrangeShape :: forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm [a]
l = (Int -> a) -> [Int] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map Int -> a
pick [Int]
perm
where
pick :: Int -> a
pick Int
i
| Int
0 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i, Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n = [a]
l [a] -> Int -> a
forall a. [a] -> Int -> a
!! Int
i
| Bool
otherwise =
[Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Int] -> [Char]
forall a. Show a => a -> [Char]
show [Int]
perm [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" is not a valid permutation for input."
n :: Int
n = [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse [Int]
perm = ((Int, Int) -> Int) -> [(Int, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Int) -> Int
forall a b. (a, b) -> b
snd ([(Int, Int)] -> [Int]) -> [(Int, Int)] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((Int, Int) -> Int) -> [(Int, Int)] -> [(Int, Int)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, Int) -> Int
forall a b. (a, b) -> a
fst ([(Int, Int)] -> [(Int, Int)]) -> [(Int, Int)] -> [(Int, Int)]
forall a b. (a -> b) -> a -> b
$ [Int] -> [Int] -> [(Int, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int]
perm [Int
0 ..]
rearrangeReach :: [Int] -> Int
rearrangeReach :: [Int] -> Int
rearrangeReach [Int]
perm = case (([Int], [Int]) -> Bool) -> [([Int], [Int])] -> [([Int], [Int])]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (([Int] -> [Int] -> Bool) -> ([Int], [Int]) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
(/=)) ([([Int], [Int])] -> [([Int], [Int])])
-> [([Int], [Int])] -> [([Int], [Int])]
forall a b. (a -> b) -> a -> b
$ [[Int]] -> [[Int]] -> [([Int], [Int])]
forall a b. [a] -> [b] -> [(a, b)]
zip ([Int] -> [[Int]]
forall a. [a] -> [[a]]
tails [Int]
perm) ([Int] -> [[Int]]
forall a. [a] -> [[a]]
tails [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]) of
[] -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
([Int]
perm', [Int]
_) : [([Int], [Int])]
_ -> Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm'
where
n :: Int
n = [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose = [Int] -> [Int] -> [Int]
forall a. [Int] -> [a] -> [a]
rearrangeShape
isPermutationOf :: Eq a => [a] -> [a] -> Maybe [Int]
isPermutationOf :: forall a. Eq a => [a] -> [a] -> Maybe [Int]
isPermutationOf [a]
l1 [a]
l2 =
case ([Maybe a] -> a -> Maybe ([Maybe a], Int))
-> [Maybe a] -> [a] -> Maybe ([Maybe a], [Int])
forall (m :: * -> *) acc x y.
Monad m =>
(acc -> x -> m (acc, y)) -> acc -> [x] -> m (acc, [y])
mapAccumLM (Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick Int
0) ((a -> Maybe a) -> [a] -> [Maybe a]
forall a b. (a -> b) -> [a] -> [b]
map a -> Maybe a
forall a. a -> Maybe a
Just [a]
l2) [a]
l1 of
Just ([Maybe a]
l2', [Int]
perm)
| (Maybe a -> Bool) -> [Maybe a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Maybe a -> Maybe a -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe a
forall a. Maybe a
Nothing) [Maybe a]
l2' -> [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm
Maybe ([Maybe a], [Int])
_ -> Maybe [Int]
forall a. Maybe a
Nothing
where
pick :: Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick :: forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick Int
_ [] a
_ = Maybe ([Maybe a], Int)
forall a. Maybe a
Nothing
pick Int
i (Maybe a
x : [Maybe a]
xs) a
y
| a -> Maybe a
forall a. a -> Maybe a
Just a
y Maybe a -> Maybe a -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe a
x = ([Maybe a], Int) -> Maybe ([Maybe a], Int)
forall a. a -> Maybe a
Just (Maybe a
forall a. Maybe a
Nothing Maybe a -> [Maybe a] -> [Maybe a]
forall a. a -> [a] -> [a]
: [Maybe a]
xs, Int
i)
| Bool
otherwise = do
([Maybe a]
xs', Int
v) <- Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
forall a. Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [Maybe a]
xs a
y
([Maybe a], Int) -> Maybe ([Maybe a], Int)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a
x Maybe a -> [Maybe a] -> [Maybe a]
forall a. a -> [a] -> [a]
: [Maybe a]
xs', Int
v)
transposeIndex :: Int -> Int -> [a] -> [a]
transposeIndex :: forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n [a]
l
| Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l =
let n' :: Int
n' = ((Int
k Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
l) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k
in Int -> Int -> [a] -> [a]
forall a. Int -> Int -> [a] -> [a]
transposeIndex Int
k Int
n' [a]
l
| Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0,
([a]
pre, a
needle : [a]
end) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
k [a]
l,
([a]
beg, [a]
mid) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt ([a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
pre Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n) [a]
pre =
[a]
beg [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
needle] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
mid [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
end
| ([a]
beg, a
needle : [a]
post) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
k [a]
l,
([a]
mid, [a]
end) <- Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
splitAt Int
n [a]
post =
[a]
beg [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
mid [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
needle] [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a]
end
| Bool
otherwise = [a]
l
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm
| [Int]
posttrans [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped .. [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped Int -> Int -> Int
forall a. Num a => a -> a -> a
+ [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
posttrans Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1],
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
pretrans,
Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [Int] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Int]
posttrans =
(Int, Int, Int) -> Maybe (Int, Int, Int)
forall a. a -> Maybe a
Just ([Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
mapped, [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
pretrans, [Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
posttrans)
| Bool
otherwise =
Maybe (Int, Int, Int)
forall a. Maybe a
Nothing
where
([Int]
mapped, [Int]
notmapped) = Int -> [Int] -> ([Int], [Int])
forall {a}. (Eq a, Num a) => a -> [a] -> ([a], [a])
findIncreasingFrom Int
0 [Int]
perm
([Int]
pretrans, [Int]
posttrans) = [Int] -> ([Int], [Int])
forall {a}. (Eq a, Num a) => [a] -> ([a], [a])
findTransposed [Int]
notmapped
findIncreasingFrom :: a -> [a] -> ([a], [a])
findIncreasingFrom a
x (a
i : [a]
is)
| a
i a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
x =
let ([a]
js, [a]
ps) = a -> [a] -> ([a], [a])
findIncreasingFrom (a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
1) [a]
is
in (a
i a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
js, [a]
ps)
findIncreasingFrom a
_ [a]
is =
([], [a]
is)
findTransposed :: [a] -> ([a], [a])
findTransposed [] =
([], [])
findTransposed (a
i : [a]
is) =
a -> [a] -> ([a], [a])
forall {a}. (Eq a, Num a) => a -> [a] -> ([a], [a])
findIncreasingFrom a
i (a
i a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [a]
is)