module Camfort.Specification.Stencils.InferenceBackend where
import Prelude hiding (sum)
import Data.Generics.Uniplate.Operations
import Data.List hiding (sum)
import Data.Data
import Control.Arrow ((***))
import Data.Function
import Camfort.Specification.Stencils.Model
import Camfort.Helpers
import Camfort.Helpers.Vec
import Debug.Trace
import Unsafe.Coerce
import Camfort.Specification.Stencils.Syntax
type Span a = (a, a)
mkTrivialSpan a = (a, a)
inferFromIndices :: VecList Int -> Specification
inferFromIndices (VL ixs) =
setLinearity (fromBool mult) (Specification . Left . infer $ ixs')
where
(ixs', mult) = hasDuplicates ixs
infer :: (IsNatural n, Permutable n) => [Vec n Int] -> Result Spatial
infer = simplify . fromRegionsToSpec . inferMinimalVectorRegions
inferFromIndicesWithoutLinearity :: VecList Int -> Specification
inferFromIndicesWithoutLinearity (VL ixs) =
Specification . Left . infer $ ixs
where
infer :: (IsNatural n, Permutable n) => [Vec n Int] -> Result Spatial
infer = simplify . fromRegionsToSpec . inferMinimalVectorRegions
simplify :: Result Spatial -> Result Spatial
simplify = fmap simplifySpatial
simplifySpatial :: Spatial -> Spatial
simplifySpatial (Spatial lin (Sum ps)) = Spatial lin (Sum ps')
where ps' = order (reducor ps normaliseNoSort size)
order = sort . (map (Product . sort . unProd))
size :: [RegionProd] -> Int
size = foldr (+) 0 . map (length . unProd)
reducor :: [a] -> ([a] -> [a]) -> ([a] -> Int) -> [a]
reducor xs f size = reducor' (permutations xs)
where
reducor' [y] = f y
reducor' (y:ys) =
if (size y' < size y)
then reducor' (permutations y')
else reducor' ys
where y' = f y
fromRegionsToSpec :: IsNatural n => [Span (Vec n Int)] -> Result Spatial
fromRegionsToSpec sps = foldr (\x y -> sum (toSpecND x) y) zero sps
toSpecND :: Span (Vec n Int) -> Result Spatial
toSpecND = toSpecPerDim 1
where
toSpecPerDim :: Int -> Span (Vec n Int) -> Result Spatial
toSpecPerDim d (Nil, Nil) = one
toSpecPerDim d (Cons l ls, Cons u us) =
prod (toSpec1D d l u) (toSpecPerDim (d + 1) (ls, us))
toSpec1D :: Dimension -> Int -> Int -> Result Spatial
toSpec1D dim l u
| l == absoluteRep || u == absoluteRep =
Exact $ Spatial NonLinear (Sum [Product []])
| l == 0 && u == 0 =
Exact $ Spatial NonLinear (Sum [Product [Centered 0 dim True]])
| l < 0 && u == 0 =
Exact $ Spatial NonLinear (Sum [Product [Backward (abs l) dim True]])
| l < 0 && u == (1) =
Exact $ Spatial NonLinear (Sum [Product [Backward (abs l) dim False]])
| l == 0 && u > 0 =
Exact $ Spatial NonLinear (Sum [Product [Forward u dim True]])
| l == 1 && u > 0 =
Exact $ Spatial NonLinear (Sum [Product [Forward u dim False]])
| l < 0 && u > 0 && (abs l == u) =
Exact $ Spatial NonLinear (Sum [Product [Centered u dim True]])
| l < 0 && u > 0 && (abs l /= u) =
Exact $ Spatial NonLinear (Sum [Product [Backward (abs l) dim True],
Product [Forward u dim True]])
| otherwise =
upperBound $ Spatial NonLinear (Sum [Product
[if l > 0 then Forward u dim True else Backward (abs l) dim True]])
normaliseSpan :: Span (Vec n Int) -> Span (Vec n Int)
normaliseSpan (Nil, Nil)
= (Nil, Nil)
normaliseSpan (a@(Cons l1 ls1), b@(Cons u1 us1))
| l1 <= u1 = (a, b)
| otherwise = (b, a)
spanBoundingBox :: Span (Vec n Int) -> Span (Vec n Int) -> Span (Vec n Int)
spanBoundingBox a b = boundingBox' (normaliseSpan a) (normaliseSpan b)
where
boundingBox' :: Span (Vec n Int) -> Span (Vec n Int) -> Span (Vec n Int)
boundingBox' (Nil, Nil) (Nil, Nil)
= (Nil, Nil)
boundingBox' (Cons l1 ls1, Cons u1 us1) (Cons l2 ls2, Cons u2 us2)
= let (ls', us') = boundingBox' (ls1, us1) (ls2, us2)
in (Cons (min l1 l2) ls', Cons (max u1 u2) us')
composeConsecutiveSpans :: Span (Vec n Int)
-> Span (Vec n Int) -> [Span (Vec n Int)]
composeConsecutiveSpans (Nil, Nil) (Nil, Nil) = [(Nil, Nil)]
composeConsecutiveSpans (Cons l1 ls1, Cons u1 us1) (Cons l2 ls2, Cons u2 us2)
| (ls1 == ls2) && (us1 == us2) && (u1 + 1 == l2)
= [(Cons l1 ls1, Cons u2 us2)]
| otherwise
= []
inferMinimalVectorRegions :: (Permutable n) => [Vec n Int] -> [Span (Vec n Int)]
inferMinimalVectorRegions = fixCoalesce . map mkTrivialSpan
where fixCoalesce spans =
let spans' = minimaliseRegions . allRegionPermutations $ spans
in if spans' == spans then spans' else fixCoalesce spans'
allRegionPermutations :: (Permutable n)
=> [Span (Vec n Int)] -> [Span (Vec n Int)]
allRegionPermutations =
nub . concat . unpermuteIndices . map (coalesceRegions >< id) . groupByPerm . map permutationss
where
permutationss :: Permutable n
=> Span (Vec n Int)
-> [(Span (Vec n Int), Vec n Int -> Vec n Int)]
permutationss (l, u) = map (\((l', un1), (u', un2)) -> ((l', u'), un1))
$ zip (permutationsV l) (permutationsV u)
sortByFst = sortBy (\(l1, u1) (l2, u2) -> compare l1 l2)
groupByPerm :: [[(Span (Vec n Int), Vec n Int -> Vec n Int)]]
-> [( [Span (Vec n Int)] , Vec n Int -> Vec n Int)]
groupByPerm = map (\ixP -> let unPerm = snd $ head ixP
in (map fst ixP, unPerm)) . transpose
coalesceRegions :: [Span (Vec n Int)] -> [Span (Vec n Int)]
coalesceRegions = nub . foldL composeConsecutiveSpans . sortByFst
unpermuteIndices :: [([Span (Vec n Int)], Vec n Int -> Vec n Int)]
-> [[Span (Vec n Int)]]
unpermuteIndices = nub . map (\(rs, unPerm) -> map (unPerm *** unPerm) rs)
foldL :: (a -> a -> [a]) -> [a] -> [a]
foldL f [] = []
foldL f [a] = [a]
foldL f (a:(b:xs)) = case f a b of
[] -> a : foldL f (b : xs)
cs -> foldL f (cs ++ xs)
minimaliseRegions :: [Span (Vec n Int)] -> [Span (Vec n Int)]
minimaliseRegions [] = []
minimaliseRegions xss = nub . minimalise $ xss
where localMin x ys = (filter' x (\y -> containedWithin x y && (x /= y)) xss) ++ ys
minimalise = foldr localMin []
filter' r f xs = case filter f xs of
[] -> [r]
ys -> ys
containedWithin :: Span (Vec n Int) -> Span (Vec n Int) -> Bool
containedWithin (Nil, Nil) (Nil, Nil)
= True
containedWithin (Cons l1 ls1, Cons u1 us1) (Cons l2 ls2, Cons u2 us2)
= (l2 <= l1 && u1 <= u2) && containedWithin (ls1, us1) (ls2, us2)
class Permutable (n :: Nat) where
selectionsV :: Vec n a -> [Selection n a]
permutationsV :: Vec n a -> [(Vec n a, Vec n a -> Vec n a)]
type family Selection n a where
Selection Z a = a
Selection (S n) a = (a, Vec n a, a -> Vec n a -> Vec (S n) a)
instance Permutable Z where
selectionsV Nil = []
permutationsV Nil = [(Nil, id)]
instance Permutable (S Z) where
selectionsV (Cons x xs)
= [(x, Nil, Cons)]
permutationsV (Cons x Nil)
= [(Cons x Nil, id)]
instance Permutable (S n) => Permutable (S (S n)) where
selectionsV (Cons x xs) =
(x, xs, Cons) : [ (y, Cons x ys, unselect unSel)
| (y, ys, unSel) <- selectionsV xs ]
where
unselect :: (a -> Vec n a -> Vec (S n) a)
-> (a -> Vec (S n) a -> Vec (S (S n)) a)
unselect f y' (Cons x' ys') = Cons x' (f y' ys')
permutationsV xs =
[ (Cons y zs, \(Cons y' zs') -> unSel y' (unPerm zs'))
| (y, ys, unSel) <- selectionsV xs,
(zs, unPerm) <- permutationsV ys ]
data VecList a where VL :: (IsNatural n, Permutable n) => [Vec n a] -> VecList a
data List a where
List :: (IsNatural n, Permutable n) => Vec n a -> List a
lnil :: List a
lnil = List Nil
lcons :: a -> List a -> List a
lcons x (List Nil) = List (Cons x Nil)
lcons x (List (Cons y Nil)) = List (Cons x (Cons y Nil))
lcons x (List (Cons y (Cons z xs))) = List (Cons x (Cons y (Cons z xs)))
fromList :: [a] -> List a
fromList = foldr lcons lnil
fromLists :: [[Int]] -> VecList Int
fromLists [] = VL ([] :: [Vec Z Int])
fromLists (xs:xss) = consList (fromList xs) (fromLists xss)
where
consList :: List Int -> VecList Int -> VecList Int
consList (List vec) (VL []) = VL [vec]
consList (List vec) (VL (x:xs))
= let (vec', x') = zipVec vec x
in
case (preCondition x' xs, preCondition vec' xs) of
(ReflEq, ReflEq) -> VL (vec' : (x' : xs))
where
preCondition :: Vec n a -> [Vec n1 a] -> EqT n n1
preCondition xs x = unsafeCoerce ReflEq
data EqT (a :: k) (b :: k) where
ReflEq :: EqT a a