{- Copyright 2016, Dominic Orchard, Andrew Rice, Mistral Contrastin, Matthew Danish Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -} {- This files gives an executable implementation of the model for abstract stencil specifications. This model is used to drive both the specification checking and program synthesis features. -} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DeriveFunctor #-} {-# LANGUAGE DeriveFoldable #-} {-# LANGUAGE DeriveTraversable #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE MultiWayIf #-} module Camfort.Specification.Stencils.Model ( Interval(..) , Bound(..) , approxVec , Offsets(..) , UnionNF , vecLength , unfCompare , optimise , maximas , Approximation(..) , lowerBound, upperBound , fromExact , Multiplicity(..) , Peelable(..) ) where import Prelude hiding (pred) import qualified Control.Monad as CM import Algebra.Lattice import qualified Data.List.NonEmpty as NE import qualified Data.Set as S import Data.Foldable import Data.SBV import Data.Data import Data.List (sortBy, nub) import Data.Maybe (fromJust) import qualified Camfort.Specification.Stencils.PartialOrd as PO import qualified Camfort.Helpers.Vec as V import System.IO.Unsafe -- Utility container class Container a where type MemberTyp a type CompTyp a member :: MemberTyp a -> a -> Bool compile :: a -> (CompTyp a -> SBool) -------------------------------------------------------------------------------- -- Arbitrary sets representing offsets -------------------------------------------------------------------------------- data Offsets = Offsets (S.Set Int64) | SetOfIntegers deriving Eq instance Ord Offsets where Offsets s `compare` Offsets s' = s `compare` s' Offsets _ `compare` SetOfIntegers = LT SetOfIntegers `compare` Offsets _ = GT SetOfIntegers `compare` SetOfIntegers = EQ instance Container Offsets where type MemberTyp Offsets = Int64 type CompTyp Offsets = SInt64 member i (Offsets s) = i `S.member` s member _ _ = True compile (Offsets s) i = i `sElem` map fromIntegral (S.toList s) compile SetOfIntegers _ = sTrue instance Lattice Offsets where (Offsets s) \/ (Offsets s') = Offsets $ s `S.union` s' _ \/ _ = SetOfIntegers (Offsets s) /\ (Offsets s') = Offsets $ s `S.intersection` s' off@Offsets{} /\ _ = off _ /\ o = o instance BoundedJoinSemiLattice Offsets where bottom = Offsets S.empty instance BoundedMeetSemiLattice Offsets where top = SetOfIntegers -------------------------------------------------------------------------------- -- Interval as defined in the paper -------------------------------------------------------------------------------- data Bound = Arbitrary | Standard -- | Interval data structure assumes the following: -- 1. The first num. param. is less than the second; -- 2. For holed intervals, first num. param. <= 0 <= second num. param.; data Interval a where IntervArbitrary :: Int -> Int -> Interval 'Arbitrary IntervInfiniteArbitrary :: Interval 'Arbitrary IntervHoled :: Int64 -> Int64 -> Bool -> Interval 'Standard IntervInfinite :: Interval 'Standard deriving instance Eq (Interval a) instance Show (Interval 'Standard) where show IntervInfinite = "IntervInfinite" show (IntervHoled lb up p) = "Interv [" ++ show lb ++ "," ++ show up ++ "]^" ++ show p approxInterv :: Interval 'Arbitrary -> Approximation (Interval 'Standard) approxInterv (IntervArbitrary a b) | a > b = error "Interval condition violated: lower bound is bigger than the upper bound." | a <= 0, b >= 0 = Exact $ IntervHoled a' b' True | a <= -1, b == -1 = Exact $ IntervHoled a' 0 False | a == 1, b >= 1 = Exact $ IntervHoled 0 b' False | a > 1, b > 1 = Bound Nothing $ Just $ IntervHoled 0 b' False | a < -1, b < -1 = Bound Nothing $ Just $ IntervHoled a' 0 False | otherwise = error "Impossible: All posibilities are covered." where a' = fromIntegral a b' = fromIntegral b approxInterv IntervInfiniteArbitrary = Exact IntervInfinite approxVec :: forall n . V.Vec n (Interval 'Arbitrary) -> Approximation (V.Vec n (Interval 'Standard)) approxVec v = case findApproxIntervs stdVec of ([],_) -> Exact . fmap fromExact $ stdVec _ -> Bound Nothing (Just $ upperBound <$> stdVec) where stdVec :: V.Vec n (Approximation (Interval 'Standard)) stdVec = fmap approxInterv v findApproxIntervs :: forall n' . V.Vec n' (Approximation (Interval 'Standard)) -> ([ Int ], [ Int ]) findApproxIntervs v' = findApproxIntervs' 0 v' ([],[]) findApproxIntervs' :: forall n' . Int -> V.Vec n' (Approximation (Interval 'Standard)) -> ([ Int ], [ Int ]) -> ([ Int ], [ Int ]) findApproxIntervs' _ V.Nil acc = acc findApproxIntervs' i (V.Cons x xs) (bixs, eixs) = findApproxIntervs' (i+1) xs $ case x of Bound{} -> (i:bixs, eixs) Exact{} -> (bixs, i:eixs) instance Container (Interval 'Standard) where type MemberTyp (Interval 'Standard) = Int64 type CompTyp (Interval 'Standard) = SInt64 member 0 (IntervHoled _ _ b) = b member i (IntervHoled a b _) = i >= a && i <= b member _ _ = True compile (IntervHoled i1 i2 b) i | b = inRange i range | otherwise = inRange i range .&& i ./= 0 where range = (fromIntegral i1, fromIntegral i2) compile IntervInfinite _ = sTrue instance Lattice (Interval 'Standard) where (IntervHoled lb ub noHole) \/ (IntervHoled lb' ub' noHole') = IntervHoled (min lb lb') (max ub ub') (noHole || noHole') _ \/ _ = top (IntervHoled lb ub noHole) /\ (IntervHoled lb' ub' noHole') = IntervHoled (max lb lb') (min ub ub') (noHole && noHole') int@IntervHoled{} /\ _ = int _ /\ int = int instance BoundedJoinSemiLattice (Interval 'Standard) where bottom = IntervHoled 0 0 False instance BoundedMeetSemiLattice (Interval 'Standard) where top = IntervInfinite -- instance BoundedLattice (Interval 'Standard) -------------------------------------------------------------------------------- -- Union of cartesian products normal form -------------------------------------------------------------------------------- type UnionNF n a = NE.NonEmpty (V.Vec n a) vecLength :: UnionNF n a -> V.Natural n vecLength = V.lengthN . NE.head instance Container a => Container (UnionNF n a) where type MemberTyp (UnionNF n a) = V.Vec n (MemberTyp a) type CompTyp (UnionNF n a) = V.Vec n (CompTyp a) member is = any (member' is) where member' is' space = and $ V.zipWith member is' space compile spaces is = foldr1 (.||) $ NE.map (`compile'` is) spaces where compile' space is' = foldr' (\(set, i) -> (.&&) $ compile set i) sTrue $ V.zip space is' instance BoundedMeetSemiLattice a => Lattice (UnionNF n a) where oi \/ oi' = oi <> oi' (/\) = CM.liftM2 (V.zipWith (/\)) unfCompare :: forall a b n . ( Container a, Container b , MemberTyp a ~ Int64, MemberTyp b ~ Int64 , CompTyp a ~ SInt64, CompTyp b ~ SInt64 ) => UnionNF n a -> UnionNF n b -> Ordering unfCompare oi oi' = unsafePerformIO $ do thmRes <- prove pred case thmRes of -- Tell the user if there was a hard proof error (e.g., if -- z3 is not installed/accessible). -- TODO: give more information ThmResult (ProofError _ msgs _) -> fail $ unlines msgs _ -> if modelExists thmRes then do ce <- counterExample thmRes case V.fromList ce of V.VecBox cev -> case V.proveEqSize (NE.head oi) cev of Just V.ReflEq -> -- TODO: The second branch is defensive programming the -- member check is not necessary unless the counter example -- is bogus (it shouldn't be). Delete if it adversely -- effects the performance. if | cev `member` oi -> return GT | cev `member` oi' -> return LT | otherwise -> fail "Impossible: counter example is in \ \neither of the operands" Nothing -> fail "Impossible: Counter example size doesn't \ \match the original vector size." else return EQ where counterExample :: ThmResult -> IO [ Int64 ] counterExample thmRes = case getModelAssignment thmRes of Right (False, ce) -> return ce Right (True, _) -> fail "Returned probable model." Left str -> fail str pred :: Predicate pred = do freeVars <- (mkFreeVars . dimensionality) oi :: Symbolic [ SInt64 ] case V.fromList freeVars of V.VecBox freeVarVec -> case V.proveEqSize (NE.head oi) freeVarVec of Just V.ReflEq -> return $ compile oi freeVarVec .== compile oi' freeVarVec Nothing -> fail $ "Impossible: Free variables size doesn't match that of the " ++ "union parameter." dimensionality = V.length . NE.head -------------------------------------------------------------------------------- -- Optimise unions -------------------------------------------------------------------------------- instance PO.PartialOrd Offsets where (Offsets s) <= (Offsets s') = s <= s' SetOfIntegers <= Offsets{} = False _ <= SetOfIntegers = True instance PO.PartialOrd (Interval 'Standard) where (IntervHoled lb ub p) <= (IntervHoled lb' ub' p') = (p' || not p) && lb >= lb' && ub <= ub' IntervInfinite <= IntervHoled{} = False _ <= IntervInfinite = True instance PO.PartialOrd a => PO.PartialOrd (V.Vec n a) where v <= v' = and $ V.zipWith (PO.<=) v v' optimise :: UnionNF n (Interval 'Standard) -> UnionNF n (Interval 'Standard) optimise = NE.fromList . maximas . fixedPointUnion . NE.toList where fixedPointUnion unf = let unf' = unionLemma . maximas $ unf in if unf' == unf then unf' else fixedPointUnion unf' sensibleGroupBy :: Eq a => (a -> a -> Ordering) -> (a -> a -> Bool) -> [ a ] -> [ [ a ] ] sensibleGroupBy ord p l = nub . map (\el -> sortBy ord . filter (p el) $ l) $ l maximas :: [ V.Vec n (Interval 'Standard) ] -> [ V.Vec n (Interval 'Standard) ] maximas = nub . fmap (head . PO.maxima) . sensibleGroupBy ord (PO.<=) where ord a b = fromJust $ a `PO.compare` b -- | Union lemma says that if we have a product of intervals (as defined in -- the paper) and we union two that agrees in each dimension except one. -- The union is again a product of intervals that agrees with the original -- dimensions in all dimensions except the original differing one. At that -- point it is the union of intervals, which is itself still an interval. unionLemma :: [ V.Vec n (Interval 'Standard) ] -> [ V.Vec n (Interval 'Standard) ] unionLemma = map (foldr1 (V.zipWith (\/))) . sensibleGroupBy (\a b -> if a == b then EQ else LT) agreeButOne where -- This function returns true if two vectors agree at all points but one. -- It also holds if two vectors are identical. agreeButOne :: Eq a => V.Vec n a -> V.Vec n a -> Bool agreeButOne = go False where go :: Eq a => Bool -> V.Vec n a -> V.Vec n a -> Bool go _ V.Nil V.Nil = True go False (V.Cons x xs) (V.Cons y ys) | x == y = go False xs ys | otherwise = go True xs ys go True (V.Cons x xs) (V.Cons y ys) | x == y = go True xs ys | otherwise = False -------------------------------------------------------------------------------- -- Injections for multiplicity and exactness -------------------------------------------------------------------------------- data Approximation a = Exact a | Bound (Maybe a) (Maybe a) deriving (Eq, Show, Functor, Foldable, Traversable, Data, Typeable) fromExact :: Approximation a -> a fromExact (Exact a) = a fromExact _ = error "Can't retrieve from bounded as if it was exact." lowerBound :: Approximation a -> a lowerBound (Bound (Just a) _) = a lowerBound (Bound Nothing _) = error "Approximation doesn't have a lower bound." lowerBound (Exact a) = a upperBound :: Approximation a -> a upperBound (Bound _ (Just a)) = a upperBound (Bound _ Nothing) = error "Approximation doesn't have a upper bound." upperBound (Exact a) = a class Peelable a where type CoreTyp a peel :: a -> CoreTyp a data Multiplicity a = Mult a | Once a deriving (Eq, Show, Functor, Foldable, Traversable, Data, Typeable) instance Peelable (Multiplicity a) where type CoreTyp (Multiplicity a) = a peel (Mult a) = a peel (Once a) = a {- data Approximation a = Exact a | Lower a | Upper a deriving (Eq, Show, Functor, Data, Typeable) instance Peelable Approximation where peel (Exact a) = a peel (Lower a) = a peel (Upper a) = a -}