{-| Module : What4.Expr.UnaryBV Description : A "unary" bitvector representation Copyright : (c) Galois, Inc 2015-2020 License : BSD3 Maintainer : Joe Hendrix This module defines a data structure for representing a symbolic bitvector using a form of "unary" representation. The idea behind this representation is that we associate a predicate to each possible value of the bitvector that is true if the symbolic value is less than or equal to the possible value. As an example, if we had the unary term 'x' equal to "{ 0 -> false, 1 -> p, 2 -> q, 3 -> t }", then 'x' cannot be '0', has the value '1' if 'p' is true, the value '2' if 'q & not p' is true, and '3' if 'not q' is true. By construction, we should have that 'p => q'. -} {-# LANGUAGE CPP #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE DoAndIfThenElse #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE KindSignatures #-} {-# LANGUAGE PatternGuards #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE TypeOperators #-} #if MIN_VERSION_base(4,9,0) {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} #endif module What4.Expr.UnaryBV ( UnaryBV , width , size , traversePreds , constant , asConstant , unsignedEntries , unsignedRanges , evaluate , sym_evaluate , instantiate , domain -- * Operations , add , neg , mux , eq , slt , ult , uext , sext , trunc ) where import Control.Exception (assert) import Control.Lens import Control.Monad import Data.Bits import Data.Hashable import Data.Parameterized.Classes import Data.Parameterized.NatRepr import qualified GHC.TypeNats as Type import What4.BaseTypes import What4.Interface import What4.Utils.BVDomain (BVDomain) import qualified What4.Utils.BVDomain as BVD import qualified Data.Map.Strict as Map type IntMap = Map.Map Integer -- | @splitLeq k m@ returns a pair @(l,h)@ where @l@ contains -- all the bindings with a key less than or equal to @k@, and -- @h@ contains the ones greater than @k@. splitLeq :: Integer -> IntMap a -> (IntMap a, IntMap a) splitLeq k m = case Map.splitLookup k m of (l, Nothing, h) -> (l,h) (l, Just v, h) -> (Map.insert k v l, h) -- | Split a map into a lower bound, midpoint, and upperbound if non-empty. splitEntry :: IntMap a -> Maybe (IntMap a, (Integer, a), IntMap a) splitEntry m0 = go (Map.splitRoot m0) where go [] = Nothing go [m] = case Map.minViewWithKey m of Nothing -> Nothing Just (p, h) -> Just (Map.empty, p, h) go (l:m:h) = case Map.minViewWithKey m of Nothing -> go (l:h) Just (p, m') | Map.null m' -> Just (l, p, Map.unions h) | otherwise -> Just (l, p, Map.unions (m':h)) -- | This function eliminates entries where the predicate has the same -- value. stripDuplicatePreds :: Eq p => [(Integer,p)] -> [(Integer,p)] stripDuplicatePreds ((l,p):(h,q):r) | p == q = stripDuplicatePreds ((l,p):r) | otherwise = (l,p):stripDuplicatePreds ((h,q):r) stripDuplicatePreds [p] = [p] stripDuplicatePreds [] = [] ------------------------------------------------------------------------ -- UnaryBV -- | A unary bitvector encoding where the map contains predicates -- such as @u^.unaryBVMap^.at i@ holds iff the value represented by @u@ -- is less than or equal to @i@. -- -- The map stored in the representation should always have a single element, -- and the largest integer stored in the map should be associated with a -- predicate representing "true". This means that if the map contains only -- a single binding, then it represents a constant. data UnaryBV p (n::Type.Nat) = UnaryBV { width :: !(NatRepr n) , unaryBVMap :: !(IntMap p) } -- | Returns the number of distinct values that this could be. size :: UnaryBV p n -> Int size x = Map.size (unaryBVMap x) traversePreds :: Traversal (UnaryBV p n) (UnaryBV q n) p q traversePreds f (UnaryBV w m) = UnaryBV w <$> traverse f m instance Eq p => TestEquality (UnaryBV p) where testEquality x y = do Refl <- testEquality (width x) (width y) if unaryBVMap x == unaryBVMap y then Just Refl else Nothing instance Hashable p => Hashable (UnaryBV p n) where hashWithSalt s0 u = Map.foldlWithKey' go s0 (unaryBVMap u) where go s k e = hashWithSalt (hashWithSalt s k) e -- | Create a unary bitvector term from a constant. constant :: IsExprBuilder sym => sym -> NatRepr n -> Integer -> UnaryBV (Pred sym) n constant sym w v = UnaryBV w (Map.singleton v' (truePred sym)) where v' = toUnsigned w v -- | Create a unary bitvector term from a constant. asConstant :: IsExpr p => UnaryBV (p BaseBoolType) w -> Maybe Integer asConstant x | size x == 1, [(v,_)] <- Map.toList (unaryBVMap x) = Just v | otherwise = Nothing -- | @unsignedRanges v@ returns a set of predicates and ranges -- where we know that for each entry @(p,l,h)@ and each value -- @i : l <= i & i <= h@: -- @p@ iff. @v <= i@ unsignedRanges :: UnaryBV p n -> [(p, Integer, Integer)] unsignedRanges v = case Map.toList (unaryBVMap v) of [] -> error "internal: unsignedRanges given illegal UnaryBV" l -> go l where w :: Integer w = maxUnsigned (width v) next :: [(Integer,p)] -> Integer next ((h,_):_) = h-1 next [] = w go :: [(Integer, p)] -> [(p, Integer, Integer)] go [] = [] go ((l,p):rest) = (p,l,next rest) : go rest unsignedEntries :: (1 <= n) => UnaryBV p n -> [(Integer, p)] unsignedEntries b = Map.toList (unaryBVMap b) -- | Evaluate a unary bitvector as an integer given an evaluation function. evaluate :: Monad m => (p -> m Bool) -> UnaryBV p n -> m Integer evaluate f0 u = go f0 (unaryBVMap u) (maxUnsigned (width u)) where go :: Monad m => (p -> m Bool) -> IntMap p -> Integer -> m Integer go f m bnd = case splitEntry m of Nothing -> return bnd Just (l,(k,v),h) -> do b <- f v case b of -- value <= k True -> go f l k -- value > k False -> go f h bnd -- | Evaluate a unary bitvector given an evaluation function. -- -- This function is used to convert a unary bitvector into some other representation -- such as a binary bitvector or vector of bits. -- -- It is polymorphic over the result type 'r', and requires functions for manipulating -- values of type 'r' to construct it. sym_evaluate :: (Applicative m, Monad m) => (Integer -> m r) -- ^ Function for mapping an integer to its bitvector -- representation. -> (p -> r -> r -> m r) -- ^ Function for performing an 'ite' expression on 'r'. -> UnaryBV p n -- ^ Unary bitvector to evaluate. -> m r sym_evaluate cns0 ite0 u = go cns0 ite0 (unaryBVMap u) (maxUnsigned (width u)) where go :: (Applicative m, Monad m) => (Integer -> m r) -> (p -> r -> r -> m r) -> IntMap p -> Integer -> m r go cns ite m bnd = case splitEntry m of Nothing -> cns bnd Just (l,(k,v),h) -> do join $ ite v <$> go cns ite l k <*> go cns ite h bnd -- | This function instantiates the predicates in a unary predicate with new predicates. -- -- The mapping 'f' should be monotonic, that is for all predicates 'p' and 'q, -- such that 'p |- q', 'f' should satisfy the constraint that 'f p |- f q'. instantiate :: (Applicative m, Eq q) => (p -> m q) -> UnaryBV p w -> m (UnaryBV q w) instantiate f u = fin <$> traverse f (unaryBVMap u) where fin m = UnaryBV { width = width u , unaryBVMap = Map.fromDistinctAscList l } where l = stripDuplicatePreds (Map.toList m) -- | Return potential values for abstract domain. domain :: forall p n . (1 <= n) => (p -> Maybe Bool) -> UnaryBV p n -> BVDomain n domain f u = BVD.fromAscEltList (width u) (go (unaryBVMap u)) where go :: IntMap p -> [Integer] go m = case splitEntry m of Nothing -> [] Just (l,(k,v),h) -> do case f v of -- value <= k Just True -> k:go l -- value > k Just False -> go h Nothing -> go l ++ (k:go h) ------------------------------------------------------------------------ -- Operations -- | This merges two maps used for a unary bitvector int a single map that -- combines them. -- -- 'mergeWithKey sym cfn x y' should return a map 'z' such that for all constants -- 'c', 'z = c' iff 'cfn (x = c) (y = c)'. mergeWithKey :: forall sym . IsExprBuilder sym => sym -> (Pred sym -> Pred sym -> IO (Pred sym)) -> IntMap (Pred sym) -> IntMap (Pred sym) -> IO (IntMap (Pred sym)) mergeWithKey sym f x y = go Map.empty (falsePred sym) (Map.toList x) (falsePred sym) (Map.toList y) where go :: IntMap (Pred sym) -> Pred sym -> [(Integer, Pred sym)] -> Pred sym -> [(Integer, Pred sym)] -> IO (IntMap (Pred sym)) -- Force "m" to be evaluated" go m _ _ _ _ | seq m $ False = error "go bad" go m x_prev x_a@((x_k,x_p):x_r) y_prev y_a@((y_k,y_p):y_r) = case compare x_k y_k of LT -> do p <- f x_p y_prev go (Map.insert x_k p m) x_p x_r y_prev y_a GT -> do p <- f x_prev y_p go (Map.insert y_k p m) x_prev x_a y_p y_r EQ -> do p <- f x_p y_p go (Map.insert x_k p m) x_p x_r y_p y_r go m _ [] _ y_a = do go1 m (truePred sym `f`) y_a go m _ x_a _ [] = do go1 m (`f` truePred sym) x_a go1 m fn ((y_k,y_p):y_r) = do p <- fn y_p go1 (Map.insert y_k p m) fn y_r go1 m _ [] = return m -- | @mux sym c x y@ returns value equal to if @c@ then @x@ else @y@. -- The number of entries in the return value is at most @size x@ -- + @size y@. mux :: forall sym n . (1 <= n, IsExprBuilder sym) => sym -> Pred sym -> UnaryBV (Pred sym) n -> UnaryBV (Pred sym) n -> IO (UnaryBV (Pred sym) n) mux sym c x y = fmap (UnaryBV (width x)) $ mergeWithKey sym (itePred sym c) (unaryBVMap x) (unaryBVMap y) -- | Return predicate that holds if bitvectors are equal. eq :: (1 <= n, IsExprBuilder sym) => sym -> UnaryBV (Pred sym) n -> UnaryBV (Pred sym) n -> IO (Pred sym) eq sym0 x0 y0 = let (x_k, x_p) = Map.findMin (unaryBVMap x0) in go sym0 (falsePred sym0) x_k x_p (unaryBVMap x0) (unaryBVMap y0) where go :: IsExprBuilder sym => sym -> Pred sym -> Integer -> Pred sym -> IntMap (Pred sym) -> IntMap (Pred sym) -> IO (Pred sym) go sym r x_k x_p x y | Just (y_k, y_p) <- Map.lookupGE x_k y = case x_k == y_k of False -> do go sym r y_k y_p y x True -> do let x_lt = maybe (falsePred sym) snd (Map.lookupLT x_k x) let y_lt = maybe (falsePred sym) snd (Map.lookupLT x_k y) x_is_eq <- andPred sym x_p =<< notPred sym x_lt y_is_eq <- andPred sym y_p =<< notPred sym y_lt r' <- orPred sym r =<< andPred sym x_is_eq y_is_eq case Map.lookupGE (x_k+1) x of Just (x_k', x_p') -> go sym r' x_k' x_p' x y Nothing -> return r' go _ r _ _ _ _ = return r -- | @compareLt sym x y@ returns predicate that holds -- if for any @k@, @x < k & not (y <= k)@. compareLt :: forall sym . IsExprBuilder sym => sym -> IntMap (Pred sym) -> IntMap (Pred sym) -> IO (Pred sym) compareLt sym x y | Map.null y = return (falsePred sym) | otherwise = go (falsePred sym) 0 where go :: Pred sym -- ^ Return predicate for cases where x is less than minimum. -> Integer -- ^ Minimum value to consider for x. -> IO (Pred sym) go r min_x -- Let x_k0 be min entry in x to consider next. | Just (x_k, _) <- Map.lookupGE min_x x -- Get smallest entry in y that is larger than x_k. , Just (y_k, _) <- Map.lookupGT x_k y -- Lookup largest predicate in x for value that is less then y_k. , Just (x_k_max, x_p) <- Map.lookupLT y_k x = do -- We know the following: -- 1. min_x <= x_k <= x_k_max < y_k. -- 2. y > x_k => y >= y_k -- 3. x < y_k => x_p -- Get predicate asserting x < y_k && not (y <= x_k) -- Get predicate asserting x < y_k && y > x_k x_and_y_lt_x_k <- case Map.lookupLT y_k y of Nothing -> return $ x_p Just (_,y_lt_y_k) -> andPred sym x_p =<< notPred sym y_lt_y_k r' <- orPred sym r x_and_y_lt_x_k go r' (x_k_max+1) go r _ = andPred sym (snd (Map.findMax y)) r -- | Return predicate that holds if first value is less than other. ult :: (1 <= n, IsExprBuilder sym) => sym -> UnaryBV (Pred sym) n -> UnaryBV (Pred sym) n -> IO (Pred sym) ult sym x y = compareLt sym (unaryBVMap x) (unaryBVMap y) -- | Return predicate that holds if first value is less than other. slt :: (1 <= n, IsExprBuilder sym) => sym -> UnaryBV (Pred sym) n -> UnaryBV (Pred sym) n -> IO (Pred sym) slt sym x y = do let mid = maxSigned (width x) -- Split map so that we separate the values that will remain positive -- from the values that will be negative. let (x_pos,x_neg) = splitLeq mid (unaryBVMap x) -- Split map so that we separate the values that will remain positive -- from the values that will be negative. let (y_pos,y_neg) = splitLeq mid (unaryBVMap y) x_is_neg <- if Map.null x_pos then return $ truePred sym else notPred sym (snd (Map.findMax x_pos)) pos_case <- compareLt sym x_pos y_pos neg_case <- andPred sym x_is_neg =<< compareLt sym x_neg y_neg orPred sym pos_case neg_case splitOnAddOverflow :: Integer -> UnaryBV p n -> (IntMap p, IntMap p) splitOnAddOverflow v x = assert (0 <= v && v <= limit) $ splitLeq overflow_limit (unaryBVMap x) where limit = maxUnsigned (width x) overflow_limit = limit - v completeList :: IsExprBuilder sym => sym -> IntMap (Pred sym) -- ^ Map to merge into -> (Integer -> Integer) -- ^ Monotonic function to update keys with -> (Pred sym -> IO (Pred sym)) -- ^ Function on predicate. -> IntMap (Pred sym) -> IO (IntMap (Pred sym)) completeList sym x keyFn predFn m0 = do let m1 = Map.mapKeysMonotonic keyFn m0 m2 <- traverse predFn m1 mergeWithKey sym (orPred sym) x m2 addConstant :: forall sym n . (1 <= n, IsExprBuilder sym) => sym -> IntMap (Pred sym) -> Pred sym -> Integer -> Pred sym -> UnaryBV (Pred sym) n -> IO (IntMap (Pred sym)) addConstant sym m0 x_lt x_val x_leq y = do let w = width y let (y_low, y_high) = splitOnAddOverflow x_val y m1 <- completeList sym m0 (x_val +) (andPred sym x_leq) y_low -- Add entries when we don't overflow. -- If no overflow then continue case Map.null y_high of True -> return m1 False -> do -- See if there are any entries that do not overflow. -- Compute amount of offset to apply to y_val let x_off = x_val-2^natValue w -- Generate predicate asserting that y overflows and x == x_val x_eq <- andPred sym x_leq =<< notPred sym x_lt p <- case Map.null y_low of True -> return $ x_eq False -> andPred sym x_eq =<< notPred sym (snd (Map.findMax y_low)) -- Complete next entries completeList sym m1 (x_off +) (andPred sym p) y_high -- | Add two bitvectors. -- -- The number of integers in the result will be at most the product of the sizes -- of the individual bitvectors. add :: forall sym n . (1 <= n, IsExprBuilder sym) => sym -> UnaryBV (Pred sym) n -> UnaryBV (Pred sym) n -> IO (UnaryBV (Pred sym) n) add sym x y = go_x Map.empty (falsePred sym) (unsignedEntries x) where w = width x go_x :: IntMap (Pred sym) -> Pred sym -> [(Integer, Pred sym)] -> IO (UnaryBV (Pred sym) n) go_x m0 _ [] = do return $! UnaryBV w m0 go_x m0 x_lt ((x_val,x_leq):remaining) = do m2 <- addConstant sym m0 x_lt x_val x_leq y go_x m2 x_leq remaining -- | Negate a bitvector. -- The size of the result will be equal to the size of the input. neg :: forall sym n . (1 <= n, IsExprBuilder sym) => sym -> UnaryBV (Pred sym) n -> IO (UnaryBV (Pred sym) n) neg sym x | Map.null (unaryBVMap x) = error "Illegal unary value" | otherwise = case Map.deleteFindMin (unaryBVMap x) of -- Special case for constant 0. ((0,_), m) | Map.null m -> return x -- Treat 0 case specially, then recurse on remaining elements. ((0,x_p), m) -> go [(0, x_p)] x_p (Map.toDescList m) -- Value can't be 0, so just recurse on all ements. _ -> go [] (falsePred sym) (Map.toDescList (unaryBVMap x)) where w = width x -- Iterate through remaining pairs in descending order. go :: [(Integer, Pred sym)] -- ^ Entries in descending order -> Pred sym -- ^ Predicate for first false. -> [(Integer, Pred sym)] -- ^ Remaining elements in descending order. -> IO (UnaryBV (Pred sym) n) go m p ((x_k,_) : r@((_,y_p):_)) = seq m $ do let z_k = toUnsigned w (negate x_k) q <- orPred sym p =<< notPred sym y_p let pair = (z_k,q) let m' = pair : m seq z_k $ seq pair $ seq m' $ do go m' p r go m _ [(x_k,_)] = seq m $ do let z_k = toUnsigned w (negate x_k) let q = truePred sym return $! UnaryBV w (Map.fromDistinctAscList (reverse ((z_k,q) : m))) go _ _ [] = error "Illegal value return in UnaryBV.neg" -- | Perform a unsigned extension uext :: (1 <= u, u+1 <= r) => UnaryBV p u -> NatRepr r -> UnaryBV p r uext x w' = UnaryBV w' (unaryBVMap x) -- | Perform a signed extension sext :: (1 <= u, u+1 <= r) => UnaryBV p u -> NatRepr r -> UnaryBV p r sext x w' = UnaryBV w' (Map.union neg_entries l) where w = width x mid = maxSigned w (l,h) = splitLeq mid (unaryBVMap x) diff = 2^natValue w' - 2^natValue w neg_entries = Map.mapKeysMonotonic (+ diff) h -- | Perform a struncation. trunc :: forall sym u r . (IsExprBuilder sym, 1 <= u, u <= r) => sym -> UnaryBV (Pred sym) r -> NatRepr u -> IO (UnaryBV (Pred sym) u) trunc sym x w | Just Refl <- testEquality w (width x) = return x | otherwise = go Map.empty (truePred sym) (unaryBVMap x) where go :: IntMap (Pred sym) -> Pred sym -> IntMap (Pred sym) -> IO (UnaryBV (Pred sym) u) go result toRemove remaining | Map.null remaining = return $! UnaryBV w result | otherwise = do let (k,_) = Map.findMin remaining -- Get base offset let base = k `xor` (maxUnsigned w) let next = base + maxUnsigned w let (l,h) = splitLeq next remaining assert (not (Map.null l)) $ do -- Get entries to add. result' <- completeList sym result (toUnsigned w) (andPred sym toRemove) l let (_,p) = Map.findMax l toRemove' <- notPred sym p go result' toRemove' h