{-# LANGUAGE BangPatterns, UnboxedTuples, TypeFamilies, PatternGuards, MagicHash, CPP, TupleSections #-} module Data.TrieMap.OrdMap () where import Data.TrieMap.TrieKey import Data.TrieMap.Sized import Data.TrieMap.Modifiers import Control.Applicative import Control.Monad hiding (join) import Prelude hiding (lookup) import GHC.Exts #define DELTA 5# #define RATIO 2# type OrdMap k = TrieMap (Ordered k) data Path k a = Root | LeftBin k a !(Path k a) !(OrdMap k a) | RightBin k a !(OrdMap k a) !(Path k a) singletonMaybe :: Sized a => k -> Maybe a -> OrdMap k a singletonMaybe k = maybe Tip (singleton k) instance Ord k => TrieKey (Ordered k) where data TrieMap (Ordered k) a = Tip | Bin Int# k a !(OrdMap k a) !(OrdMap k a) data Hole (Ordered k) a = Empty k !(Path k a) | Full k !(Path k a) !(OrdMap k a) !(OrdMap k a) emptyM = Tip singletonM (Ord k) = singleton k nullM Tip = True nullM _ = False sizeM = size# lookupM (Ord k) = lookup k traverseWithKeyM f = traverseWithKey (f . Ord) foldrWithKeyM f = foldrWithKey (f . Ord) foldlWithKeyM f = foldlWithKey (f . Ord) mapWithKeyM f = mapWithKey (f . Ord) mapMaybeM f = mapMaybe (f . Ord) mapEitherM f = mapEither (f . Ord) isSubmapM = isSubmap fromAscListM f xs = fromAscList (f . Ord) [(k, a) | (Ord k, a) <- xs] fromDistAscListM xs = fromDistinctAscList [(k, a) | (Ord k, a) <- xs] unionM _ Tip m2 = m2 unionM _ m1 Tip = m1 unionM f m1 m2 = hedgeUnionWithKey (f . Ord) (const LT) (const GT) m1 m2 isectM f = isect (f . Ord) diffM _ Tip _ = Tip diffM _ m1 Tip = m1 diffM f m1 m2 = hedgeDiffWithKey (f . Ord) (const LT) (const GT) m1 m2 singleHoleM (Ord k) = Empty k Root keyM (Empty k _) = Ord k keyM (Full k _ _ _) = Ord k beforeM a (Empty k path) = before (singletonMaybe k a) path beforeM a (Full k path l _) = before t path where t = case a of Nothing -> l Just a -> insertMax k a l afterM a (Empty k path) = after (singletonMaybe k a) path afterM a (Full k path _ r) = after t path where t = case a of Nothing -> r Just a -> insertMin k a r searchM (Ord k) = search k Root indexM i# = indexT Root i# where indexT path i# (Bin _ kx x l r) | i# <# sl# = indexT (LeftBin kx x path r) i# l | i# <# sx# = (# i# -# sl#, x, Full kx path l r #) | otherwise = indexT (RightBin kx x l path) (i# -# sx#) r where !sl# = size# l !sx# = getSize# x +# sl# indexT _ _ _ = (# error err, error err, error err #) where err = "Error: empty trie" extractHoleM = extractHole Root where extractHole path (Bin _ kx x l r) = extractHole (LeftBin kx x path r) l `mplus` return (x, Full kx path l r) `mplus` extractHole (RightBin kx x l path) r extractHole _ _ = mzero assignM x (Empty k path) = rebuild (singleton k x) path assignM x (Full k path l r) = rebuild (join k x l r) path clearM (Empty _ path) = rebuild Tip path clearM (Full _ path l r) = rebuild (merge l r) path rebuild :: Sized a => OrdMap k a -> Path k a -> OrdMap k a rebuild t Root = t rebuild t (LeftBin kx x path r) = rebuild (balance kx x t r) path rebuild t (RightBin kx x l path) = rebuild (balance kx x l t) path lookup :: Ord k => k -> OrdMap k a -> Maybe a lookup k (Bin _ k' v l r) = case compare k k' of LT -> lookup k l EQ -> Just v GT -> lookup k r lookup _ _ = Nothing singleton :: Sized a => k -> a -> OrdMap k a singleton k a = Bin (getSize# a) k a Tip Tip traverseWithKey :: (Applicative f, Sized b) => (k -> a -> f b) -> OrdMap k a -> f (OrdMap k b) traverseWithKey _ Tip = pure Tip traverseWithKey f (Bin _ k a l r) = balance k <$> f k a <*> traverseWithKey f l <*> traverseWithKey f r foldrWithKey :: (k -> a -> b -> b) -> OrdMap k a -> b -> b foldrWithKey _ Tip = id foldrWithKey f (Bin _ k a l r) = foldrWithKey f l . f k a . foldrWithKey f r foldlWithKey :: (k -> b -> a -> b) -> OrdMap k a -> b -> b foldlWithKey _ Tip = id foldlWithKey f (Bin _ k a l r) = foldlWithKey f r . flip (f k) a . foldlWithKey f l mapWithKey :: (Ord k, Sized b) => (k -> a -> b) -> OrdMap k a -> OrdMap k b mapWithKey f (Bin _ k a l r) = join k (f k a) (mapWithKey f l) (mapWithKey f r) mapWithKey _ _ = Tip mapMaybe :: (Ord k, Sized b) => (k -> a -> Maybe b) -> OrdMap k a -> OrdMap k b mapMaybe f (Bin _ k a l r) = joinMaybe k (f k a) (mapMaybe f l) (mapMaybe f r) mapMaybe _ _ = Tip mapEither :: (Ord k, Sized b, Sized c) => EitherMap k a b c -> OrdMap k a -> (# OrdMap k b, OrdMap k c #) mapEither f (Bin _ k a l r) | (# aL, aR #) <- f k a, (# lL, lR #) <- mapEither f l, (# rL, rR #) <- mapEither f r = (# joinMaybe k aL lL rL, joinMaybe k aR lR rR #) mapEither _ _ = (# Tip, Tip #) splitLookup :: (Ord k, Sized a) => SplitMap a x -> k -> OrdMap k a -> (# OrdMap k a, Maybe x, OrdMap k a #) splitLookup f k m = case m of Tip -> (# Tip, Nothing, Tip #) Bin _ kx x l r -> case compare k kx of LT -> case splitLookup f k l of (# lL, ans, lR #) -> (# lL, ans, join kx x lR r #) EQ -> case f x of (# xL, ans, xR #) -> (# maybe l (\ xL -> insertMax kx xL l) xL, ans, maybe r (\ xR -> insertMin kx xR r) xR #) GT -> case splitLookup f k r of (# rL, ans, rR #) -> (# join kx x l rL, ans, rR #) isSubmap :: (Ord k, Sized a, Sized b) => LEq a b -> LEq (OrdMap k a) (OrdMap k b) isSubmap _ Tip _ = True isSubmap _ _ Tip = False isSubmap (<=) (Bin _ kx x l r) t = case splitLookup (\ x -> (# Nothing, Just (Elem x), Nothing #)) kx t of (# lt, found, gt #) -> case found of Nothing -> False Just (Elem y) -> x <= y && isSubmap (<=) l lt && isSubmap (<=) r gt fromAscList :: (Eq k, Sized a) => (k -> a -> a -> a) -> [(k, a)] -> OrdMap k a fromAscList f xs = fromDistinctAscList (combineEq xs) where combineEq (x:xs) = combineEq' x xs combineEq [] = [] combineEq' z [] = [z] combineEq' (kz, zz) (x@(kx, xx):xs) | kz == kx = combineEq' (kx, f kx xx zz) xs | otherwise = (kz,zz):combineEq' x xs fromDistinctAscList :: Sized a => [(k, a)] -> OrdMap k a fromDistinctAscList xs = build const (length xs) xs where -- 1) use continutations so that we use heap space instead of stack space. -- 2) special case for n==5 to build bushier trees. build c 0 xs' = c Tip xs' build c 5 xs' = case xs' of ((k1,x1):(k2,x2):(k3,x3):(k4,x4):(k5,x5):xx) -> c (bin k4 x4 (bin k2 x2 (singleton k1 x1) (singleton k3 x3)) (singleton k5 x5)) xx _ -> error "fromDistinctAscList build" build c n xs' = seq nr $ build (buildR nr c) nl xs' where nl = n `div` 2 nr = n - nl - 1 buildR n c l ((k,x):ys) = build (buildB l k x c) n ys buildR _ _ _ [] = error "fromDistinctAscList buildR []" buildB l k x c r zs = c (bin k x l r) zs hedgeUnionWithKey :: (Ord k, Sized a) => (k -> a -> a -> Maybe a) -> (k -> Ordering) -> (k -> Ordering) -> OrdMap k a -> OrdMap k a -> OrdMap k a hedgeUnionWithKey _ _ _ t1 Tip = t1 hedgeUnionWithKey _ cmplo cmphi Tip (Bin _ kx x l r) = join kx x (filterGt cmplo l) (filterLt cmphi r) hedgeUnionWithKey f cmplo cmphi (Bin _ kx x l r) t2 = joinMaybe kx newx (hedgeUnionWithKey f cmplo cmpkx l lt) (hedgeUnionWithKey f cmpkx cmphi r gt) where cmpkx k = compare kx k lt = trim cmplo cmpkx t2 (found,gt) = trimLookupLo kx cmphi t2 newx = case found of Nothing -> Just x Just (_,y) -> f kx x y filterGt :: (Ord k, Sized a) => (k -> Ordering) -> OrdMap k a -> OrdMap k a filterGt _ Tip = Tip filterGt cmp (Bin _ kx x l r) = case cmp kx of LT -> join kx x (filterGt cmp l) r GT -> filterGt cmp r EQ -> r filterLt :: (Ord k, Sized a) => (k -> Ordering) -> OrdMap k a -> OrdMap k a filterLt _ Tip = Tip filterLt cmp (Bin _ kx x l r) = case cmp kx of LT -> filterLt cmp l GT -> join kx x l (filterLt cmp r) EQ -> l trim :: (k -> Ordering) -> (k -> Ordering) -> OrdMap k a -> OrdMap k a trim _ _ Tip = Tip trim cmplo cmphi t@(Bin _ kx _ l r) = case cmplo kx of LT -> case cmphi kx of GT -> t _ -> trim cmplo cmphi l _ -> trim cmplo cmphi r trimLookupLo :: Ord k => k -> (k -> Ordering) -> OrdMap k a -> (Maybe (k,a), OrdMap k a) trimLookupLo _ _ Tip = (Nothing,Tip) trimLookupLo lo cmphi t@(Bin _ kx x l r) = case compare lo kx of LT -> case cmphi kx of GT -> ((lo,) <$> lookup lo t, t) _ -> trimLookupLo lo cmphi l GT -> trimLookupLo lo cmphi r EQ -> (Just (kx,x),trim (compare lo) cmphi r) isect :: (Ord k, Sized a, Sized b, Sized c) => IsectFunc k a b c -> OrdMap k a -> OrdMap k b -> OrdMap k c isect f t1@Bin{} (Bin _ k2 x2 l2 r2) | (# found, hole #) <- search k2 Root t1 = let tl = isect f (beforeM Nothing hole) l2 tr = isect f (afterM Nothing hole) r2 in joinMaybe k2 (found >>= \ x1' -> f k2 x1' x2) tl tr isect _ _ _ = Tip hedgeDiffWithKey :: (Ord k, Sized a) => (k -> a -> b -> Maybe a) -> (k -> Ordering) -> (k -> Ordering) -> OrdMap k a -> OrdMap k b -> OrdMap k a hedgeDiffWithKey _ _ _ Tip _ = Tip hedgeDiffWithKey _ cmplo cmphi (Bin _ kx x l r) Tip = join kx x (filterGt cmplo l) (filterLt cmphi r) hedgeDiffWithKey f cmplo cmphi t (Bin _ kx x l r) = case found of Nothing -> merge tl tr Just (ky,y) -> case f ky y x of Nothing -> merge tl tr Just z -> join ky z tl tr where cmpkx k = compare kx k lt = trim cmplo cmpkx t (found,gt) = trimLookupLo kx cmphi t tl = hedgeDiffWithKey f cmplo cmpkx lt l tr = hedgeDiffWithKey f cmpkx cmphi gt r joinMaybe :: (Ord k, Sized a) => k -> Maybe a -> OrdMap k a -> OrdMap k a -> OrdMap k a joinMaybe kx = maybe merge (join kx) join :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a join kx x Tip r = insertMin kx x r join kx x l Tip = insertMax kx x l join kx x l@(Bin sL# ky y ly ry) r@(Bin sR# kz z lz rz) | DELTA *# sL# <=# sR# = balance kz z (join kx x l lz) rz | DELTA *# sR# <=# sL# = balance ky y ly (join kx x ry r) | otherwise = bin kx x l r -- insertMin and insertMax don't perform potentially expensive comparisons. insertMax,insertMin :: Sized a => k -> a -> OrdMap k a -> OrdMap k a insertMax kx x t = case t of Tip -> singleton kx x Bin _ ky y l r -> balance ky y l (insertMax kx x r) insertMin kx x t = case t of Tip -> singleton kx x Bin _ ky y l r -> balance ky y (insertMin kx x l) r {-------------------------------------------------------------------- [merge l r]: merges two trees. --------------------------------------------------------------------} merge :: Sized a => OrdMap k a -> OrdMap k a -> OrdMap k a merge Tip r = r merge l Tip = l merge l@(Bin sL# kx x lx rx) r@(Bin sR# ky y ly ry) | DELTA *# sL# <=# sR# = balance ky y (merge l ly) ry | DELTA *# sR# <=# sL# = balance kx x lx (merge rx r) | otherwise = glue l r {-------------------------------------------------------------------- [glue l r]: glues two trees together. Assumes that [l] and [r] are already balanced with respect to each other. --------------------------------------------------------------------} glue :: Sized a => OrdMap k a -> OrdMap k a -> OrdMap k a glue Tip r = r glue l Tip = l glue l r | size# l ># size# r = case deleteFindMax (\ k a -> (# balance k a, Nothing #)) l of (# f, l' #) -> f l' r | otherwise = case deleteFindMin (\ k a -> (# balance k a, Nothing #)) r of (# f, r' #) -> f l r' deleteFindMin :: Sized a => (k -> a -> (# x, Maybe a #)) -> OrdMap k a -> (# x, OrdMap k a #) deleteFindMin f t = case t of Bin _ k x Tip r -> onUnboxed (maybe r (\ y' -> bin k y' Tip r)) (f k) x Bin _ k x l r -> onUnboxed (\ l' -> balance k x l' r) (deleteFindMin f) l _ -> (# error "Map.deleteFindMin: can not return the minimal element of an empty map", Tip #) deleteFindMax :: Sized a => (k -> a -> (# x, Maybe a #)) -> OrdMap k a -> (# x, OrdMap k a #) deleteFindMax f t = case t of Bin _ k x l Tip -> onUnboxed (maybe l (\ y -> bin k y l Tip)) (f k) x Bin _ k x l r -> onUnboxed (balance k x l) (deleteFindMax f) r Tip -> (# error "Map.deleteFindMax: can not return the maximal element of an empty map", Tip #) size# :: OrdMap k a -> Int# size# Tip = 0# size# (Bin sz _ _ _ _) = sz balance :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a balance k x l r | sR# >=# (DELTA *# sL#) = rotateL k x l r | sL# >=# (DELTA *# sR#) = rotateR k x l r | otherwise = Bin sX# k x l r where !sL# = size# l !sR# = size# r !sX# = sL# +# sR# +# getSize# x -- rotate rotateL :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a rotateL k x l r@(Bin _ _ _ ly ry) | sL# <# (RATIO *# sR#) = singleL k x l r | otherwise = doubleL k x l r where !sL# = size# ly !sR# = size# ry rotateL _ _ _ Tip = error "rotateL Tip" rotateR :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a rotateR k x l@(Bin _ _ _ ly ry) r | sR# <# (RATIO *# sL#) = singleR k x l r | otherwise = doubleR k x l r where !sL# = size# ly !sR# = size# ry rotateR _ _ _ _ = error "rotateR Tip" -- basic rotations singleL, singleR :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a singleL k1 x1 t1 (Bin _ k2 x2 t2 t3) = bin k2 x2 (bin k1 x1 t1 t2) t3 singleL k1 x1 t1 Tip = bin k1 x1 t1 Tip singleR k1 x1 (Bin _ k2 x2 t1 t2) t3 = bin k2 x2 t1 (bin k1 x1 t2 t3) singleR k1 x1 Tip t2 = bin k1 x1 Tip t2 doubleL, doubleR :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a doubleL k1 x1 t1 (Bin _ k2 x2 (Bin _ k3 x3 t2 t3) t4) = bin k3 x3 (bin k1 x1 t1 t2) (bin k2 x2 t3 t4) doubleL k1 x1 t1 t2 = singleL k1 x1 t1 t2 doubleR k1 x1 (Bin _ k2 x2 t1 (Bin _ k3 x3 t2 t3)) t4 = bin k3 x3 (bin k2 x2 t1 t2) (bin k1 x1 t3 t4) doubleR k1 x1 t1 t2 = singleR k1 x1 t1 t2 bin :: Sized a => k -> a -> OrdMap k a -> OrdMap k a -> OrdMap k a bin k x l r = Bin (size# l +# size# r +# getSize# x) k x l r before :: Sized a => OrdMap k a -> Path k a -> OrdMap k a before t (LeftBin _ _ path _) = before t path before t (RightBin k a l path) = before (join k a l t) path before t _ = t after :: Sized a => OrdMap k a -> Path k a -> OrdMap k a after t (LeftBin k a path r) = after (join k a t r) path after t (RightBin _ _ _ path) = after t path after t _ = t search :: Ord k => k -> Path k a -> OrdMap k a -> (# Maybe a, Hole (Ordered k) a #) search k path Tip = (# Nothing, Empty k path #) search k path (Bin _ kx x l r) = case compare k kx of LT -> search k (LeftBin kx x path r) l EQ -> (# Just x, Full k path l r #) GT -> search k (RightBin kx x l path) r