{-| Module : What4.Utils.LeqMap Copyright : (c) Galois, Inc 2015-2020 License : BSD3 Maintainer : Joe Hendrix This module defines a strict map. It is similiar to Data.Map.Strict, but provides some additional operations including splitEntry, splitLeq, fromDistinctDescList. -} {-# LANGUAGE CPP #-} {-# LANGUAGE ScopedTypeVariables #-} module What4.Utils.LeqMap ( LeqMap , toList , findMin , findMax , null , empty , mapKeysMonotonic , union , fromDistinctAscList , fromDistinctDescList , toDescList , deleteFindMin , deleteFindMax , minViewWithKey , filterGt , filterLt , insert , lookupLE , lookupLT , lookupGE , lookupGT , keys , mergeWithKey , singleton , foldlWithKey' , size , splitEntry , splitLeq ) where import Control.Applicative hiding (empty) import Prelude hiding (lookup, null) import Data.Traversable (foldMapDefault) data MaybeS a = NothingS | JustS !a type Size = Int data LeqMap k p = Bin {-# UNPACK #-} !Size !k !p !(LeqMap k p) !(LeqMap k p) | Tip bin :: k -> p -> LeqMap k p -> LeqMap k p -> LeqMap k p bin k x l r = Bin (size l + size r + 1) k x l r balanceL :: k -> p -> LeqMap k p -> LeqMap k p -> LeqMap k p balanceL k x l r = case l of Bin ls lk lx ll lr | ls > max 1 (delta*size r) -> case lr of Bin lrs lrk lrx lrl lrr | lrs >= ratio* size ll -> bin lrk lrx (bin lk lx ll lrl) (bin k x lrr r) _ -> bin lk lx ll (bin k x lr r) _ -> bin k x l r -- balanceR is called when right subtree might have been inserted to or when -- left subtree might have been deleted from. balanceR :: k -> p -> LeqMap k p -> LeqMap k p -> LeqMap k p balanceR k x l r = case l of Tip -> case r of Tip -> Bin 1 k x Tip Tip (Bin _ _ _ Tip Tip) -> Bin 2 k x Tip r (Bin _ rk rx Tip rr@(Bin{})) -> Bin 3 rk rx (Bin 1 k x Tip Tip) rr (Bin _ rk rx (Bin _ rlk rlx _ _) Tip) -> Bin 3 rlk rlx (Bin 1 k x Tip Tip) (Bin 1 rk rx Tip Tip) (Bin rs rk rx rl@(Bin rls rlk rlx rll rlr) rr@(Bin rrs _ _ _ _)) | rls < ratio*rrs -> Bin (1+rs) rk rx (Bin (1+rls) k x Tip rl) rr | otherwise -> Bin (1+rs) rlk rlx (Bin (1+size rll) k x Tip rll) (Bin (1+rrs+size rlr) rk rx rlr rr) (Bin ls _ _ _ _) -> case r of Tip -> Bin (1+ls) k x l Tip (Bin rs rk rx rl rr) | rs > delta*ls -> case (rl, rr) of (Bin rls rlk rlx rll rlr, Bin rrs _ _ _ _) | rls < ratio*rrs -> Bin (1+ls+rs) rk rx (Bin (1+ls+rls) k x l rl) rr | otherwise -> Bin (1+ls+rs) rlk rlx (Bin (1+ls+size rll) k x l rll) (Bin (1+rrs+size rlr) rk rx rlr rr) (_, _) -> error "Failure in Data.Map.balanceR" | otherwise -> Bin (1+ls+rs) k x l r delta,ratio :: Int delta = 3 ratio = 2 insertMax :: k -> p -> LeqMap k p -> LeqMap k p insertMax kx x t = case t of Tip -> singleton kx x Bin _ ky y l r -> balanceR ky y l (insertMax kx x r) insertMin :: k -> p -> LeqMap k p -> LeqMap k p insertMin kx x t = case t of Tip -> singleton kx x Bin _ ky y l r -> balanceL ky y (insertMin kx x l) r link :: k -> p -> LeqMap k p -> LeqMap k p -> LeqMap k p link kx x Tip r = insertMin kx x r link kx x l Tip = insertMax kx x l link kx x l@(Bin sizeL ky y ly ry) r@(Bin sizeR kz z lz rz) | delta*sizeL < sizeR = balanceL kz z (link kx x l lz) rz | delta*sizeR < sizeL = balanceR ky y ly (link kx x ry r) | otherwise = bin kx x l r instance (Ord k, Eq p) => Eq (LeqMap k p) where x == y = size x == size y && toList x == toList y instance Functor (LeqMap k) where fmap _ Tip = Tip fmap f (Bin s k a l r) = Bin s k (f a) (fmap f l) (fmap f r) instance Foldable (LeqMap k) where foldMap = foldMapDefault instance Traversable (LeqMap k) where traverse _ Tip = pure Tip traverse f (Bin s k a l r) = Bin s k <$> f a <*> traverse f l <*> traverse f r -- | Return the empty map empty :: LeqMap k p empty = Tip singleton :: k -> p -> LeqMap k p singleton k a = Bin 1 k a Tip Tip size :: LeqMap k p -> Int size Tip = 0 size (Bin s _ _ _ _) = s null :: LeqMap k p -> Bool null Tip = True null Bin{} = False findMax :: LeqMap k p -> (k,p) findMax Tip = error "findMax of empty map." findMax (Bin _ k0 a0 _ r0) = go k0 a0 r0 where go :: k -> p -> LeqMap k p -> (k,p) go _ _ (Bin _ k a _ r) = go k a r go k a Tip = (k, a) findMin :: LeqMap k p -> (k,p) findMin Tip = error "findMin of empty map." findMin (Bin _ k0 a0 l0 _) = go k0 a0 l0 where go :: k -> p -> LeqMap k p -> (k,p) go _ _ (Bin _ k a l _) = go k a l go k a Tip = (k, a) toList :: LeqMap k p -> [(k,p)] toList Tip = [] toList (Bin _ k a l r) = toList l ++ ((k,a):toList r) mapKeysMonotonic :: (k1 -> k2) -> LeqMap k1 p -> LeqMap k2 p mapKeysMonotonic _ Tip = Tip mapKeysMonotonic f (Bin s k a l r) = Bin s (f k) a (mapKeysMonotonic f l) (mapKeysMonotonic f r) splitLeq :: Ord k => k -> LeqMap k p -> (LeqMap k p, LeqMap k p) splitLeq k m = seq k $ case m of Tip -> (Tip, Tip) Bin _ kx x l r -> case compare k kx of LT -> let (ll, lr) = splitLeq k l r' = link kx x lr r in seq r' (ll, r') GT -> let (rl, rr) = splitLeq k r l' = link kx x l rl in seq l' (l', rr) EQ -> let l' = insertMax kx x l in seq l' (l', r) {-# INLINABLE splitLeq #-} splitEntry :: LeqMap k p -> Maybe (LeqMap k p, (k, p), LeqMap k p) splitEntry Tip = Nothing splitEntry (Bin _ k a l r) = Just (l, (k, a), r) insert :: Ord k => k -> p -> LeqMap k p -> LeqMap k p insert = go where go :: Ord k => k -> p -> LeqMap k p -> LeqMap k p go kx x _ | seq kx $ seq x $ False = error "insert bad" go kx x Tip = singleton kx x go kx x (Bin sz ky y l r) = case compare kx ky of LT -> balanceL ky y (go kx x l) r GT -> balanceR ky y l (go kx x r) EQ -> Bin sz kx x l r lookupLE_Just :: Ord k => k -> k -> p -> LeqMap k p -> (k, p) lookupLE_Just _ ky y Tip = (ky,y) lookupLE_Just k ky y (Bin _ kx x l r) = case compare kx k of LT -> lookupLE_Just k kx x r GT -> lookupLE_Just k ky y l EQ -> (kx, x) {-# INLINABLE lookupLE_Just #-} lookupGE_Just :: Ord k => k -> k -> p -> LeqMap k p -> (k, p) lookupGE_Just _ ky y Tip = (ky,y) lookupGE_Just k ky y (Bin _ kx x l r) = case compare kx k of LT -> lookupGE_Just k ky y r GT -> lookupGE_Just k kx x l EQ -> (kx, x) {-# INLINABLE lookupGE_Just #-} lookupLT_Just :: Ord k => k -> k -> p -> LeqMap k p -> (k, p) lookupLT_Just _ ky y Tip = (ky,y) lookupLT_Just k ky y (Bin _ kx x l r) = case kx < k of True -> lookupLT_Just k kx x r False -> lookupLT_Just k ky y l {-# INLINABLE lookupLT_Just #-} lookupGT_Just :: Ord k => k -> k -> p -> LeqMap k p -> (k, p) lookupGT_Just _ ky y Tip = (ky,y) lookupGT_Just k ky y (Bin _ kx x l r) = case kx > k of True -> lookupGT_Just k kx x l False -> lookupGT_Just k ky y r {-# INLINABLE lookupGT_Just #-} -- | Find largest element that is less than or equal to key (if any). lookupLE :: Ord k => k -> LeqMap k p -> Maybe (k,p) lookupLE k0 m0 = seq k0 (goNothing k0 m0) where goNothing :: Ord k => k -> LeqMap k p -> Maybe (k,p) goNothing _ Tip = Nothing goNothing k (Bin _ kx x l r) = case compare kx k of LT -> Just $ lookupLE_Just k kx x r GT -> goNothing k l EQ -> Just (kx, x) {-# INLINABLE lookupLE #-} -- | Find largest element that is at least key (if any). lookupGE :: Ord k => k -> LeqMap k p -> Maybe (k,p) lookupGE k0 m0 = seq k0 (goNothing k0 m0) where goNothing :: Ord k => k -> LeqMap k p -> Maybe (k,p) goNothing _ Tip = Nothing goNothing k (Bin _ kx x l r) = case compare kx k of LT -> goNothing k r GT -> Just $ lookupGE_Just k kx x l EQ -> Just (kx, x) {-# INLINABLE lookupGE #-} -- | Find less than element that is less than key (if any). lookupLT :: Ord k => k -> LeqMap k p -> Maybe (k,p) lookupLT k0 m0 = seq k0 (goNothing k0 m0) where goNothing :: Ord k => k -> LeqMap k p -> Maybe (k,p) goNothing _ Tip = Nothing goNothing k (Bin _ kx x l r) = case kx < k of True -> Just $ lookupLT_Just k kx x r False -> goNothing k l {-# INLINABLE lookupLT #-} -- | Find less than element that is less than key (if any). lookupGT :: Ord k => k -> LeqMap k p -> Maybe (k,p) lookupGT k0 m0 = seq k0 (goNothing k0 m0) where goNothing :: Ord k => k -> LeqMap k p -> Maybe (k,p) goNothing _ Tip = Nothing goNothing k (Bin _ kx x l r) = case kx > k of True -> Just $ lookupGT_Just k kx x l False -> goNothing k r {-# INLINABLE lookupGT #-} filterMGt :: Ord k => MaybeS k -> LeqMap k p -> LeqMap k p filterMGt NothingS t = t filterMGt (JustS b0) t = filterGt b0 t {-# INLINABLE filterMGt #-} filterGt :: Ord k => k -> LeqMap k p -> LeqMap k p filterGt b t = seq b $ do case t of Tip -> Tip Bin _ kx x l r -> case compare b kx of LT -> link kx x (filterGt b l) r GT -> filterGt b r EQ -> r {-# INLINABLE filterGt #-} filterMLt :: Ord k => MaybeS k -> LeqMap k p -> LeqMap k p filterMLt NothingS t = t filterMLt (JustS b) t = filterLt b t {-# INLINABLE filterMLt #-} filterLt :: Ord k => k -> LeqMap k p -> LeqMap k p filterLt b t = seq b $ do case t of Tip -> Tip Bin _ kx x l r -> case compare kx b of LT -> link kx x l (filterLt b r) EQ -> l GT -> filterLt b l {-# INLINABLE filterLt #-} trim :: Ord k => MaybeS k -> MaybeS k -> LeqMap k p -> LeqMap k p trim NothingS NothingS t = t trim (JustS lk) NothingS t = greater lk t trim NothingS (JustS hk) t = lesser hk t trim (JustS lk) (JustS hk) t = middle lk hk t {-# INLINABLE trim #-} -- | @lesser hi m@ returns all entries in @m@ less than @hi@. lesser :: Ord k => k -> LeqMap k p -> LeqMap k p lesser hi (Bin _ k _ l _) | hi <= k = lesser hi l lesser _ t' = t' {-# INLINABLE lesser #-} mgt :: Ord k => k -> MaybeS k -> Bool mgt _ NothingS = True mgt k (JustS y) = k > y middle :: Ord k => k -> k -> LeqMap k p -> LeqMap k p middle lo hi (Bin _ k _ _ r) | k <= lo = middle lo hi r middle lo hi (Bin _ k _ l _) | k >= hi = middle lo hi l middle _ _ t' = t' {-# INLINABLE middle #-} greater :: Ord k => k -> LeqMap k p -> LeqMap k p greater lo (Bin _ k _ _ r) | k <= lo = greater lo r greater _ t' = t' union :: Ord k => LeqMap k p -> LeqMap k p -> LeqMap k p union Tip t2 = t2 union t1 Tip = t1 union t1 t2 = hedgeUnion NothingS NothingS t1 t2 {-# INLINABLE union #-} insertR :: Ord k => k -> p -> LeqMap k p -> LeqMap k p insertR = go where go :: Ord k => k -> p -> LeqMap k p -> LeqMap k p go kx x _ | seq kx $ seq x $ False = error "insert bad" go kx x Tip = singleton kx x go kx x t@(Bin _ ky y l r) = case compare kx ky of LT -> balanceL ky y (go kx x l) r GT -> balanceR ky y l (go kx x r) EQ -> t {-# INLINABLE insertR #-} -- left-biased hedge union hedgeUnion :: Ord k => MaybeS k -> MaybeS k -> LeqMap k p -> LeqMap k p -> LeqMap k p hedgeUnion _ _ t1 Tip = t1 hedgeUnion blo bhi Tip (Bin _ kx x l r) = link kx x (filterMGt blo l) (filterMLt bhi r) hedgeUnion _ _ t1 (Bin _ kx x Tip Tip) = insertR kx x t1 -- According to benchmarks, this special case increases -- performance up to 30%. It does not help in difference or intersection. hedgeUnion blo bhi (Bin _ kx x l r) t2 = link kx x (hedgeUnion blo bmi l (trim blo bmi t2)) (hedgeUnion bmi bhi r (trim bmi bhi t2)) where bmi = JustS kx {-# INLINABLE hedgeUnion #-} foldlWithKey' :: (a -> k -> b -> a) -> a -> LeqMap k b -> a foldlWithKey' _ z Tip = z foldlWithKey' f z (Bin _ kx x l r) = foldlWithKey' f (f (foldlWithKey' f z l) kx x) r keys :: LeqMap k p -> [k] keys Tip = [] keys (Bin _ kx _ l r) = keys l ++ (kx:keys r) minViewWithKey :: LeqMap k p -> Maybe ((k,p), LeqMap k p) minViewWithKey Tip = Nothing minViewWithKey t@Bin{} = Just (deleteFindMin t) deleteFindMin :: LeqMap k p -> ((k,p),LeqMap k p) deleteFindMin t = case t of Bin _ k x Tip r -> ((k,x),r) Bin _ k x l r -> let (km,l') = deleteFindMin l in (km,balanceR k x l' r) Tip -> (error "LeqMap.deleteFindMin: can not return the minimal element of an empty map", Tip) deleteFindMax :: LeqMap k p -> ((k,p),LeqMap k p) deleteFindMax t = case t of Bin _ k x l Tip -> ((k,x),l) Bin _ k x l r -> let (km,r') = deleteFindMax r in (km,balanceL k x l r') Tip -> (error "LeqMap.deleteFindMax: can not return the maximal element of an empty map", Tip) mergeWithKey :: forall a b c . (a -> b -> IO c) -> (a -> IO c) -> (b -> IO c) -> LeqMap Integer a -> LeqMap Integer b -> IO (LeqMap Integer c) mergeWithKey f0 g1 g2 = go where go Tip t2 = traverse g2 t2 go t1 Tip = traverse g1 t1 go t1 t2 | size t1 <= size t2 = hedgeMerge NothingS NothingS NothingS t1 NothingS t2 | otherwise = mergeWithKey (flip f0) g2 g1 t2 t1 hedgeMerge :: MaybeS Integer -> MaybeS Integer -> MaybeS a -> LeqMap Integer a -> MaybeS b -> LeqMap Integer b -> IO (LeqMap Integer c) hedgeMerge mlo mhi a _ b _ | seq mlo $ seq mhi $ seq a $ seq b $ False = error "hedgeMerge" hedgeMerge _ _ _ t1 mb Tip = do case mb of NothingS -> traverse g1 t1 JustS b -> traverse (`f0` b) t1 hedgeMerge blo bhi ma Tip _ (Bin _ kx x l r) = do case ma of NothingS -> link kx <$> g2 x <*> traverse g2 (filterMGt blo l) <*> traverse g2 (filterMLt bhi r) JustS a -> link kx <$> f0 a x <*> traverse (f0 a) (filterMGt blo l) <*> traverse (f0 a) (filterMLt bhi r) hedgeMerge blo bhi a (Bin _ kx x l r) mb t2 = do let bmi = JustS kx case lookupLE kx t2 of Just (ky,y) | ky `mgt` blo -> do l' <- hedgeMerge blo bmi a l mb (trim blo bmi t2) x' <- f0 x y r' <- hedgeMerge bmi bhi (JustS x) r (JustS y) (trim bmi bhi t2) return $! link kx x' l' r' _ -> do case mb of NothingS -> do l' <- traverse g1 l x' <- g1 x r' <- hedgeMerge bmi bhi (JustS x) r mb (trim bmi bhi t2) return $! link kx x' l' r' JustS b -> do l' <- traverse (`f0` b) l x' <- f0 x b r' <- hedgeMerge bmi bhi (JustS x) r mb (trim bmi bhi t2) return $! link kx x' l' r' {-# INLINE mergeWithKey #-} foldlWithKey :: (a -> k -> b -> a) -> a -> LeqMap k b -> a foldlWithKey f z = go z where go z' Tip = z' go z' (Bin _ kx x l r) = go (f (go z' l) kx x) r {-# INLINE foldlWithKey #-} toDescList :: LeqMap k p -> [(k,p)] toDescList = foldlWithKey (\xs k x -> (k,x):xs) [] fromDistinctAscList :: [(k,p)] -> LeqMap k p fromDistinctAscList [] = Tip fromDistinctAscList ((kx0, x0) : xs0) = x0 `seq` go 0 (Bin 1 kx0 x0 Tip Tip) xs0 where go :: Int -> LeqMap k p -> [(k,p)] -> LeqMap k p go _ t [] = t go s l ((kx, x) : xs) = case create s xs of (r, ys) -> x `seq` go (s + 1) (link kx x l r) ys -- @create k l@ extracts at most @2^k@ elements from @l@ and creates a map. -- The remaining elements (if any) are returned as well. create :: Int -> [(k, p)] -> (LeqMap k p, [(k,p)]) -- Reached end of list. create _ [] = (Tip, []) -- Extract single element create 0 ((kx,x) : xs') = x `seq` (Bin 1 kx x Tip Tip, xs') create s xs | otherwise = case create (s - 1) xs of res@(_, []) -> res (l, (ky, y):ys) -> case create (s - 1) ys of (r, zs) -> y `seq` (link ky y l r, zs) -- | Create a map from a list of keys in descending order. fromDistinctDescList :: [(k,p)] -> LeqMap k p fromDistinctDescList [] = Tip fromDistinctDescList ((kx0, x0) : xs0) = x0 `seq` go 0 (Bin 1 kx0 x0 Tip Tip) xs0 where go :: Int -> LeqMap k p -> [(k,p)] -> LeqMap k p go _ t [] = t go s r ((kx, x) : xs) = case create s xs of (l, ys) -> x `seq` go (s + 1) (link kx x l r) ys -- @create k l@ extracts at most @2^k@ elements from @l@ and creates a map. -- The remaining elements (if any) are returned as well. create :: Int -> [(k, p)] -> (LeqMap k p, [(k,p)]) -- Reached end of list. create _ [] = (Tip, []) -- Extract single element create 0 ((kx,x) : xs') = x `seq` (Bin 1 kx x Tip Tip, xs') create s xs | otherwise = case create (s - 1) xs of res@(_, []) -> res (r, (ky, y):ys) -> case create (s - 1) ys of (l, zs) -> y `seq` (link ky y l r, zs)