module Futhark.Representation.AST.Attributes.Rearrange
( rearrangeShape
, rearrangeInverse
, rearrangeReach
, rearrangeCompose
, isPermutationOf
, transposeIndex
, isMapTranspose
) where
import Data.List
import Futhark.Util
rearrangeShape :: [Int] -> [a] -> [a]
rearrangeShape perm l = map pick perm
where pick i
| 0 <= i, i < n = l!!i
| otherwise =
error $ show perm ++ " is not a valid permutation for input."
n = length l
rearrangeInverse :: [Int] -> [Int]
rearrangeInverse perm = map snd $ sortOn fst $ zip perm [0..]
rearrangeReach :: [Int] -> Int
rearrangeReach perm = case dropWhile (uncurry (/=)) $ zip (tails perm) (tails [0..n-1]) of
[] -> n + 1
(perm',_):_ -> n - length perm'
where n = length perm
rearrangeCompose :: [Int] -> [Int] -> [Int]
rearrangeCompose = rearrangeShape
isPermutationOf :: Eq a => [a] -> [a] -> Maybe [Int]
isPermutationOf l1 l2 =
case mapAccumLM (pick 0) (map Just l2) l1 of
Just (l2', perm)
| all (==Nothing) l2' -> Just perm
_ -> Nothing
where pick :: Eq a => Int -> [Maybe a] -> a -> Maybe ([Maybe a], Int)
pick _ [] _ = Nothing
pick i (x:xs) y
| Just y == x = Just (Nothing : xs, i)
| otherwise = do
(xs', v) <- pick (i+1) xs y
return (x : xs', v)
transposeIndex :: Int -> Int -> [a] -> [a]
transposeIndex k n l
| k + n >= length l =
let n' = ((k + n) `mod` length l)-k
in transposeIndex k n' l
| n < 0,
(pre,needle:end) <- splitAt k l,
(beg,mid) <- splitAt (length pre+n) pre =
beg ++ [needle] ++ mid ++ end
| (beg,needle:post) <- splitAt k l,
(mid,end) <- splitAt n post =
beg ++ mid ++ [needle] ++ end
| otherwise = l
isMapTranspose :: [Int] -> Maybe (Int, Int, Int)
isMapTranspose perm
| posttrans == [length mapped..length mapped+length posttrans-1],
not $ null pretrans, not $ null posttrans =
Just (length mapped, length pretrans, length posttrans)
| otherwise =
Nothing
where (mapped, notmapped) = findIncreasingFrom 0 perm
(pretrans, posttrans) = findTransposed notmapped
findIncreasingFrom x (i:is)
| i == x =
let (js, ps) = findIncreasingFrom (x+1) is
in (i : js, ps)
findIncreasingFrom _ is =
([], is)
findTransposed [] =
([], [])
findTransposed (i:is) =
findIncreasingFrom i (i:is)