-- |
-- Module : Data.IntervalSet
-- Copyright : (c) Christoph Breitkopf 2015 - 2017
-- License : BSD-style
-- Maintainer : chbreitkopf@gmail.com
-- Stability : experimental
-- Portability : non-portable (MPTC with FD)
--
-- An implementation of sets of intervals. The intervals may
-- overlap, and the implementation contains efficient search functions
-- for all intervals containing a point or overlapping a given interval.
-- Closed, open, and half-open intervals can be contained in the same set.
--
-- It is an error to insert an empty interval into a set. This precondition is not
-- checked by the various construction functions.
--
-- Since many function names (but not the type name) clash with
-- /Prelude/ names, this module is usually imported @qualified@, e.g.
--
-- > import Data.IntervalSet.Strict (IntervalSet)
-- > import qualified Data.IntervalSet.Strict as IS
--
-- It offers most of the same functions as 'Data.Set', but the member type must be an
-- instance of 'Interval'. The 'findMin' and 'findMax' functions deviate from their
-- set counterparts in being total and returning a 'Maybe' value.
-- Some functions differ in asymptotic performance (for example 'size') or have not
-- been tuned for efficiency as much as their equivalents in 'Data.Set'.
--
-- In addition, there are functions specific to sets of intervals, for example to search
-- for all intervals containing a given point or contained in a given interval.
--
-- The implementation is a red-black tree augmented with the maximum upper bound
-- of all keys.
--
-- Parts of this implementation are based on code from the 'Data.Map' implementation,
-- (c) Daan Leijen 2002, (c) Andriy Palamarchuk 2008.
-- The red-black tree deletion is based on code from llrbtree by Kazu Yamamoto.
-- Of course, any errors are mine.
--
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.IntervalSet (
-- * re-export
Interval(..)
-- * Set type
, IntervalSet(..) -- instance Eq,Show,Read
-- * Operators
, (\\)
-- * Query
, null
, size
, member
, notMember
, lookupLT
, lookupGT
, lookupLE
, lookupGE
-- ** Interval query
, containing
, intersecting
, within
-- * Construction
, empty
, singleton
-- ** Insertion
, insert
-- ** Delete\/Update
, delete
-- * Combine
, union
, unions
, difference
, intersection
-- * Traversal
-- ** Map
, map
, mapMonotonic
-- ** Fold
, foldr, foldl
, foldl', foldr'
-- * Flatten
, flattenWith, flattenWithMonotonic
-- * Conversion
, elems
-- ** Lists
, toList
, fromList
-- ** Ordered lists
, toAscList
, toDescList
, fromAscList
, fromDistinctAscList
-- * Filter
, filter
, partition
, split
, splitMember
, splitAt
, splitIntersecting
-- * Subset
, isSubsetOf, isProperSubsetOf
-- * Min\/Max
, findMin
, findMax
, findLast
, deleteMin
, deleteMax
, deleteFindMin
, deleteFindMax
, minView
, maxView
-- * Debugging
, valid
) where
import Prelude hiding (null, map, filter, foldr, foldl, splitAt)
import Data.Bits (shiftR, (.&.))
import qualified Data.Semigroup as Sem
import Data.Monoid (Monoid(..))
import qualified Data.Foldable as Foldable
import qualified Data.List as L
import Control.DeepSeq
import Control.Applicative ((<|>))
import Data.IntervalMap.Generic.Interval
{--------------------------------------------------------------------
Operators
--------------------------------------------------------------------}
infixl 9 \\ --
-- | Same as 'difference'.
(\\) :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
m1 \\ m2 = difference m1 m2
-- | The Color of a tree node.
data Color = R | B deriving (Eq)
-- | A set of intervals of type @k@.
data IntervalSet k = Nil
| Node !Color
!k -- key
!k -- interval with maximum upper in tree
!(IntervalSet k) -- left subtree
!(IntervalSet k) -- right subtree
instance (Eq k) => Eq (IntervalSet k) where
a == b = toAscList a == toAscList b
instance (Ord k) => Ord (IntervalSet k) where
compare a b = compare (toAscList a) (toAscList b)
instance (Interval i k, Ord i) => Sem.Semigroup (IntervalSet i) where
(<>) = union
instance (Interval i k, Ord i) => Monoid (IntervalSet i) where
mempty = empty
mappend = union
mconcat = unions
instance Foldable.Foldable IntervalSet where
fold t = go t
where go Nil = mempty
go (Node _ k _ l r) = go l `mappend` (k `mappend` go r)
foldr = foldr
foldl = foldl
foldMap f t = go t
where go Nil = mempty
go (Node _ k _ l r) = go l `mappend` (f k `mappend` go r)
instance (NFData k) => NFData (IntervalSet k) where
rnf Nil = ()
rnf (Node _ kx _ l r) = kx `deepseq` l `deepseq` r `deepseq` ()
instance (Interval i k, Ord i, Read i) => Read (IntervalSet i) where
readsPrec p = readParen (p > 10) $ \ r -> do
("fromList",s) <- lex r
(xs,t) <- reads s
return (fromList xs,t)
instance (Show k) => Show (IntervalSet k) where
showsPrec d m = showParen (d > 10) $
showString "fromList " . shows (toList m)
isRed :: IntervalSet k -> Bool
isRed (Node R _ _ _ _) = True
isRed _ = False
turnBlack :: IntervalSet k -> IntervalSet k
turnBlack (Node R k m l r) = Node B k m l r
turnBlack t = t
turnRed :: IntervalSet k -> IntervalSet k
turnRed Nil = error "turnRed: Leaf"
turnRed (Node B k m l r) = Node R k m l r
turnRed t = t
-- construct node, recomputing the upper key bound.
mNode :: (Interval k e) => Color -> k -> IntervalSet k -> IntervalSet k -> IntervalSet k
mNode c k l r = Node c k (maxUpper k l r) l r
maxUpper :: (Interval i k) => i -> IntervalSet i -> IntervalSet i -> i
maxUpper k Nil Nil = k
maxUpper k Nil (Node _ _ m _ _) = maxByUpper k m
maxUpper k (Node _ _ m _ _) Nil = maxByUpper k m
maxUpper k (Node _ _ l _ _) (Node _ _ r _ _) = maxByUpper k (maxByUpper l r)
-- interval with the greatest upper bound. The lower bound is ignored!
maxByUpper :: (Interval i e) => i -> i -> i
maxByUpper a b = a `seq` b `seq`
case compareUpperBounds a b of
LT -> b
_ -> a
-- ---------------------------------------------------------
-- | /O(1)/. The empty set.
empty :: IntervalSet k
empty = Nil
-- | /O(1)/. A set with one entry.
singleton :: k -> IntervalSet k
singleton k = Node B k k Nil Nil
-- | /O(1)/. Is the set empty?
null :: IntervalSet k -> Bool
null Nil = True
null _ = False
-- | /O(n)/. Number of keys in the set.
--
-- Caution: unlike 'Data.Set.size', this takes linear time!
size :: IntervalSet k -> Int
size t = h 0 t
where
h n s = n `seq` case s of
Nil -> n
Node _ _ _ l r -> h (h n l + 1) r
-- | /O(log n)/. Does the set contain the given value? See also 'notMember'.
member :: (Ord k) => k -> IntervalSet k -> Bool
member k Nil = k `seq` False
member k (Node _ key _ l r) = case compare k key of
LT -> member k l
GT -> member k r
EQ -> True
-- | /O(log n)/. Does the set not contain the given value? See also 'member'.
notMember :: (Ord k) => k -> IntervalSet k -> Bool
notMember key tree = not (member key tree)
-- | /O(log n)/. Find the largest key smaller than the given one.
lookupLT :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupLT k m = go m
where
go Nil = Nothing
go (Node _ key _ l r) | k <= key = go l
| otherwise = go1 key r
go1 rk Nil = Just rk
go1 rk (Node _ key _ l r) | k <= key = go1 rk l
| otherwise = go1 key r
-- | /O(log n)/. Find the smallest key larger than the given one.
lookupGT :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupGT k m = go m
where
go Nil = Nothing
go (Node _ key _ l r) | k >= key = go r
| otherwise = go1 key l
go1 rk Nil = Just rk
go1 rk (Node _ key _ l r) | k >= key = go1 rk r
| otherwise = go1 key l
-- | /O(log n)/. Find the largest key equal to or smaller than the given one.
lookupLE :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupLE k m = go m
where
go Nil = Nothing
go (Node _ key _ l r) = case compare k key of
LT -> go l
EQ -> Just key
GT -> go1 key r
go1 rk Nil = Just rk
go1 rk (Node _ key _ l r) = case compare k key of
LT -> go1 rk l
EQ -> Just key
GT -> go1 key r
-- | /O(log n)/. Find the smallest key equal to or larger than the given one.
lookupGE :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupGE k m = go m
where
go Nil = Nothing
go (Node _ key _ l r) = case compare k key of
LT -> go1 key l
EQ -> Just key
GT -> go r
go1 rk Nil = Just rk
go1 rk (Node _ key _ l r) = case compare k key of
LT -> go1 key l
EQ -> Just key
GT -> go1 rk r
-- | Return the set of all intervals containing the given point.
-- This is the second element of the value of 'splitAt':
--
-- > set `containing` p == let (_,s,_) = set `splitAt` p in s
--
-- /O(n)/, since potentially all intervals could contain the point.
-- /O(log n)/ average case. This is also the worst case for sets containing no overlapping intervals.
containing :: (Interval k e) => IntervalSet k -> e -> IntervalSet k
t `containing` p = p `seq` fromDistinctAscList (go [] t)
where
go xs Nil = xs
go xs (Node _ k m l r)
| p `above` m = xs -- above all intervals in the tree: no result
| p `below` k = go xs l -- to the left of the lower bound: can't be in right subtree
| p `inside` k = go (k : go xs r) l
| otherwise = go (go xs r) l
-- | Return the set of all intervals overlapping (intersecting) the given interval.
-- This is the second element of the result of 'splitIntersecting':
--
-- > set `intersecting` i == let (_,s,_) = set `splitIntersecting` i in s
--
-- /O(n)/, since potentially all values could intersect the interval.
-- /O(log n)/ average case, if few values intersect the interval.
intersecting :: (Interval k e) => IntervalSet k -> k -> IntervalSet k
t `intersecting` i = i `seq` fromDistinctAscList (go [] t)
where
go xs Nil = xs
go xs (Node _ k m l r)
| i `after` m = xs
| i `before` k = go xs l
| i `overlaps` k = go (k : go xs r) l
| otherwise = go (go xs r) l
-- | Return the set of all intervals which are completely inside the given interval.
--
-- /O(n)/, since potentially all values could be inside the interval.
-- /O(log n)/ average case, if few keys are inside the interval.
within :: (Interval k e) => IntervalSet k -> k -> IntervalSet k
t `within` i = i `seq` fromDistinctAscList (go [] t)
where
go xs Nil = xs
go xs (Node _ k m l r)
| i `after` m = xs
| i `before` k = go xs l
| i `subsumes` k = go (k : go xs r) l
| otherwise = go (go xs r) l
-- | /O(log n)/. Insert a new value. If the set already contains an element equal to the value,
-- it is replaced by the new value.
insert :: (Interval k e, Ord k) => k -> IntervalSet k -> IntervalSet k
insert v s = v `seq` turnBlack (ins s)
where
singletonR k = Node R k k Nil Nil
ins Nil = singletonR v
ins (Node color k m l r) =
case compare v k of
LT -> balanceL color k (ins l) r
GT -> balanceR color k l (ins r)
EQ -> Node color v m l r
balanceL :: (Interval k e) => Color -> k -> IntervalSet k -> IntervalSet k -> IntervalSet k
balanceL B zk (Node R yk _ (Node R xk _ a b) c) d =
mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceL B zk (Node R xk _ a (Node R yk _ b c)) d =
mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceL c xk l r = mNode c xk l r
balanceR :: (Interval k e) => Color -> k -> IntervalSet k -> IntervalSet k -> IntervalSet k
balanceR B xk a (Node R yk _ b (Node R zk _ c d)) =
mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceR B xk a (Node R zk _ (Node R yk _ b c) d) =
mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceR c xk l r = mNode c xk l r
-- min/max
-- | /O(log n)/. Returns the minimal value in the set.
findMin :: IntervalSet k -> Maybe k
findMin (Node _ k _ Nil _) = Just k
findMin (Node _ _ _ l _) = findMin l
findMin Nil = Nothing
-- | /O(log n)/. Returns the maximal value in the set.
findMax :: IntervalSet k -> Maybe k
findMax (Node _ k _ _ Nil) = Just k
findMax (Node _ _ _ _ r) = findMax r
findMax Nil = Nothing
-- | Returns the interval with the largest endpoint.
-- If there is more than one interval with that endpoint,
-- return the rightmost.
--
-- /O(n)/, since all intervals could have the same endpoint.
-- /O(log n)/ average case.
findLast :: (Interval k e) => IntervalSet k -> Maybe k
findLast Nil = Nothing
findLast t@(Node _ _ mx _ _) = go t
where
go (Node _ k m l r) | sameU m mx = if sameU k m then go r <|> Just k
else go r <|> go l
| otherwise = Nothing
go Nil = Nothing
sameU a b = compareUpperBounds a b == EQ
-- Type to indicate whether the number of black nodes changed or stayed the same.
data DeleteResult k = U !(IntervalSet k) -- Unchanged
| S !(IntervalSet k) -- Shrunk
unwrap :: DeleteResult k -> IntervalSet k
unwrap (U m) = m
unwrap (S m) = m
-- DeleteResult with value
data DeleteResult' k a = U' !(IntervalSet k) a
| S' !(IntervalSet k) a
unwrap' :: DeleteResult' k a -> IntervalSet k
unwrap' (U' m _) = m
unwrap' (S' m _) = m
-- annotate DeleteResult with value
annotate :: DeleteResult k -> a -> DeleteResult' k a
annotate (U m) x = U' m x
annotate (S m) x = S' m x
-- | /O(log n)/. Remove the smallest element from the set. Return the empty set if the set is empty.
deleteMin :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k
deleteMin Nil = Nil
deleteMin m = turnBlack (unwrap' (deleteMin' m))
deleteMin' :: (Interval k e, Ord k) => IntervalSet k -> DeleteResult' k k
deleteMin' Nil = error "deleteMin': Nil"
deleteMin' (Node B k _ Nil Nil) = S' Nil k
deleteMin' (Node B k _ Nil r@(Node R _ _ _ _)) = U' (turnBlack r) k
deleteMin' (Node R k _ Nil r) = U' r k
deleteMin' (Node c k _ l r) =
case deleteMin' l of
(U' l' kv) -> U' (mNode c k l' r) kv
(S' l' kv) -> annotate (unbalancedR c k l' r) kv
deleteMax' :: (Interval k e, Ord k) => IntervalSet k -> DeleteResult' k k
deleteMax' Nil = error "deleteMax': Nil"
deleteMax' (Node B k _ Nil Nil) = S' Nil k
deleteMax' (Node B k _ l@(Node R _ _ _ _) Nil) = U' (turnBlack l) k
deleteMax' (Node R k _ l Nil) = U' l k
deleteMax' (Node c k _ l r) =
case deleteMax' r of
(U' r' kv) -> U' (mNode c k l r') kv
(S' r' kv) -> annotate (unbalancedL c k l r') kv
-- The left tree lacks one Black node
unbalancedR :: (Interval k e, Ord k) => Color -> k -> IntervalSet k -> IntervalSet k -> DeleteResult k
-- Decreasing one Black node in the right
unbalancedR B k l r@(Node B _ _ _ _) = S (balanceR B k l (turnRed r))
unbalancedR R k l r@(Node B _ _ _ _) = U (balanceR B k l (turnRed r))
-- Taking one Red node from the right and adding it to the right as Black
unbalancedR B k l (Node R rk _ rl@(Node B _ _ _ _) rr)
= U (mNode B rk (balanceR B k l (turnRed rl)) rr)
unbalancedR _ _ _ _ = error "unbalancedR"
unbalancedL :: (Interval k e, Ord k) => Color -> k -> IntervalSet k -> IntervalSet k -> DeleteResult k
unbalancedL R k l@(Node B _ _ _ _) r = U (balanceL B k (turnRed l) r)
unbalancedL B k l@(Node B _ _ _ _) r = S (balanceL B k (turnRed l) r)
unbalancedL B k (Node R lk _ ll lr@(Node B _ _ _ _)) r
= U (mNode B lk ll (balanceL B k (turnRed lr) r))
unbalancedL _ _ _ _ = error "unbalancedL"
-- | /O(log n)/. Remove the largest element from the set. Return the empty set if the set is empty.
deleteMax :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k
deleteMax Nil = Nil
deleteMax m = turnBlack (unwrap' (deleteMax' m))
-- | /O(log n)/. Delete and return the smallest element.
deleteFindMin :: (Interval k e, Ord k) => IntervalSet k -> (k, IntervalSet k)
deleteFindMin mp = case deleteMin' mp of
(U' r v) -> (v, turnBlack r)
(S' r v) -> (v, turnBlack r)
-- | /O(log n)/. Delete and return the largest element.
deleteFindMax :: (Interval k e, Ord k) => IntervalSet k -> (k, IntervalSet k)
deleteFindMax mp = case deleteMax' mp of
(U' r v) -> (v, turnBlack r)
(S' r v) -> (v, turnBlack r)
-- | /O(log n)/. Retrieves the minimal element of the set, and
-- the set stripped of that element, or 'Nothing' if passed an empty set.
minView :: (Interval k e, Ord k) => IntervalSet k -> Maybe (k, IntervalSet k)
minView Nil = Nothing
minView x = Just (deleteFindMin x)
-- | /O(log n)/. Retrieves the maximal element of the set, and
-- the set stripped of that element, or 'Nothing' if passed an empty set.
maxView :: (Interval k e, Ord k) => IntervalSet k -> Maybe (k, IntervalSet k)
maxView Nil = Nothing
maxView x = Just (deleteFindMax x)
-- folding
-- | /O(n)/. Fold the values in the set using the given right-associative
-- binary operator, such that @'foldr' f z == 'Prelude.foldr' f z . 'elems'@.
foldr :: (k -> b -> b) -> b -> IntervalSet k -> b
foldr _ z Nil = z
foldr f z (Node _ k _ l r) = foldr f (f k (foldr f z r)) l
-- | /O(n)/. A strict version of 'foldr'. Each application of the operator is
-- evaluated before using the result in the next application. This
-- function is strict in the starting value.
foldr' :: (k -> b -> b) -> b -> IntervalSet k -> b
foldr' f z s = z `seq` case s of
Nil -> z
Node _ k _ l r -> foldr' f (f k (foldr' f z r)) l
-- | /O(n)/. Fold the values in the set using the given left-associative
-- binary operator, such that @'foldl' f z == 'Prelude.foldl' f z . 'elems'@.
foldl :: (b -> k -> b) -> b -> IntervalSet k -> b
foldl _ z Nil = z
foldl f z (Node _ k _ l r) = foldl f (f (foldl f z l) k) r
-- | /O(n)/. A strict version of 'foldl'. Each application of the operator is
-- evaluated before using the result in the next application. This
-- function is strict in the starting value.
foldl' :: (b -> k -> b) -> b -> IntervalSet k -> b
foldl' f z s = z `seq` case s of
Nil -> z
Node _ k _ l r -> foldl' f (f (foldl' f z l) k) r
-- delete
-- | /O(log n)/. Delete an element from the set. If the set does not contain the value,
-- it is returned unchanged.
delete :: (Interval k e, Ord k) => k -> IntervalSet k -> IntervalSet k
delete key mp = turnBlack (unwrap (delete' key mp))
delete' :: (Interval k e, Ord k) => k -> IntervalSet k -> DeleteResult k
delete' x Nil = x `seq` U Nil
delete' x (Node c k _ l r) =
case compare x k of
LT -> case delete' x l of
(U l') -> U (mNode c k l' r)
(S l') -> unbalancedR c k l' r
GT -> case delete' x r of
(U r') -> U (mNode c k l r')
(S r') -> unbalancedL c k l r'
EQ -> case r of
Nil -> if c == B then blackify l else U l
_ -> case deleteMin' r of
(U' r' rk) -> U (mNode c rk l r')
(S' r' rk) -> unbalancedL c rk l r'
blackify :: IntervalSet k -> DeleteResult k
blackify (Node R k m l r) = U (Node B k m l r)
blackify s = S s
-- | /O(n+m)/. The expression (@'union' t1 t2@) takes the left-biased union of @t1@ and @t2@.
-- It prefers @t1@ when duplicate elements are encountered.
union :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
union m1 m2 = fromDistinctAscList (ascListUnion (toAscList m1) (toAscList m2))
-- | The union of a list of sets:
-- (@'unions' == 'Prelude.foldl' 'union' 'empty'@).
unions :: (Interval k e, Ord k) => [IntervalSet k] -> IntervalSet k
unions [] = empty
unions [s] = s
unions iss = fromDistinctAscList (head (go (L.map toAscList iss)))
where
go [] = []
go xs@[_] = xs
go (x:y:xs) = go (ascListUnion x y : go xs)
-- | /O(n+m)/. Difference of two sets.
-- Return elements of the first set not existing in the second set.
difference :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
difference m1 m2 = fromDistinctAscList (ascListDifference (toAscList m1) (toAscList m2))
-- | /O(n+m)/. Intersection of two sets.
-- Return elements in the first set also existing in the second set.
intersection :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
intersection m1 m2 = fromDistinctAscList (ascListIntersection (toAscList m1) (toAscList m2))
ascListUnion :: Ord k => [k] -> [k] -> [k]
ascListUnion [] [] = []
ascListUnion [] ys = ys
ascListUnion xs [] = xs
ascListUnion xs@(x:xs') ys@(y:ys') =
case compare x y of
LT -> x : ascListUnion xs' ys
GT -> y : ascListUnion xs ys'
EQ -> x : ascListUnion xs' ys'
ascListDifference :: Ord k => [k] -> [k] -> [k]
ascListDifference [] _ = []
ascListDifference xs [] = xs
ascListDifference xs@(xk:xs') ys@(yk:ys') =
case compare xk yk of
LT -> xk : ascListDifference xs' ys
GT -> ascListDifference xs ys'
EQ -> ascListDifference xs' ys'
ascListIntersection :: Ord k => [k] -> [k] -> [k]
ascListIntersection [] _ = []
ascListIntersection _ [] = []
ascListIntersection xs@(xk:xs') ys@(yk:ys') =
case compare xk yk of
LT -> ascListIntersection xs' ys
GT -> ascListIntersection xs ys'
EQ -> xk : ascListIntersection xs' ys'
-- --- Conversion ---
-- | /O(n)/. The list of all values contained in the set, in ascending order.
toAscList :: IntervalSet k -> [k]
toAscList set = toAscList' set []
toAscList' :: IntervalSet k -> [k] -> [k]
toAscList' m xs = foldr (:) xs m
-- | /O(n)/. The list of all values in the set, in no particular order.
toList :: IntervalSet k -> [k]
toList s = go s []
where
go Nil xs = xs
go (Node _ k _ l r) xs = k : go l (go r xs)
-- | /O(n)/. The list of all values in the set, in descending order.
toDescList :: IntervalSet k -> [k]
toDescList m = foldl (flip (:)) [] m
-- | /O(n log n)/. Build a set from a list of elements. See also 'fromAscList'.
-- If the list contains duplicate values, the last value is retained.
fromList :: (Interval k e, Ord k) => [k] -> IntervalSet k
fromList xs = L.foldl' (flip insert) empty xs
-- | /O(n)/. Build a set from an ascending list in linear time.
-- /The precondition (input list is ascending) is not checked./
fromAscList :: (Interval k e, Eq k) => [k] -> IntervalSet k
fromAscList xs = fromDistinctAscList (uniq xs)
uniq :: Eq k => [k] -> [k]
uniq [] = []
uniq (x:xs) = go x xs
where
go v [] = [v]
go v (y:ys) | v == y = go v ys
| otherwise = v : go y ys
-- Strict tuple
data T2 a b = T2 !a !b
-- | /O(n)/. Build a set from an ascending list of distinct elements in linear time.
-- /The precondition is not checked./
fromDistinctAscList :: (Interval k e) => [k] -> IntervalSet k
-- exactly 2^n-1 items have height n. They can be all black
-- from 2^n - 2^n-2 items have height n+1. The lowest "row" should be red.
fromDistinctAscList lyst = case h (length lyst) lyst of
(T2 result []) -> result
_ -> error "fromDistinctAscList: list not fully consumed"
where
h n xs | n == 0 = T2 Nil xs
| isPerfect n = buildB n xs
| otherwise = buildR n (log2 n) xs
buildB n xs | xs `seq` n <= 0 = error "fromDictinctAscList: buildB 0"
| n == 1 = case xs of (k:xs') -> T2 (Node B k k Nil Nil) xs'
_ -> error "fromDictinctAscList: buildB 1"
| otherwise =
case n `quot` 2 of { n' ->
case buildB n' xs of { (T2 _ []) -> error "fromDictinctAscList: buildB n";
(T2 l (k:xs')) ->
case buildB n' xs' of { (T2 r xs'') ->
T2 (mNode B k l r) xs'' }}}
buildR n d xs | d `seq` xs `seq` n == 0 = T2 Nil xs
| n == 1 = case xs of (k:xs') -> T2 (Node (if d==0 then R else B) k k Nil Nil) xs'
_ -> error "fromDistinctAscList: buildR 1"
| otherwise =
case n `quot` 2 of { n' ->
case buildR n' (d-1) xs of { (T2 _ []) -> error "fromDistinctAscList: buildR n";
(T2 l (k:xs')) ->
case buildR (n - (n' + 1)) (d-1) xs' of { (T2 r xs'') ->
T2 (mNode B k l r) xs'' }}}
-- is n a perfect binary tree size (2^m-1)?
isPerfect :: Int -> Bool
isPerfect n = (n .&. (n + 1)) == 0
log2 :: Int -> Int
log2 m = h (-1) m
where
h r n | r `seq` n <= 0 = r
| otherwise = h (r + 1) (n `shiftR` 1)
-- | /O(n)/. List of all values in the set, in ascending order.
elems :: IntervalSet k -> [k]
elems s = toAscList s
-- --- Mapping ---
-- | /O(n log n)/. Map a function over all values in the set.
--
-- The size of the result may be smaller if @f@ maps two or more distinct
-- elements to the same value.
map :: (Interval b e2, Ord b) => (a -> b) -> IntervalSet a -> IntervalSet b
map f s = fromList [f x | x <- toList s]
-- | /O(n)/. @'mapMonotonic' f s == 'map' f s@, but works only when @f@
-- is strictly monotonic.
-- That is, for any values @x@ and @y@, if @x@ < @y@ then @f x@ < @f y@.
-- /The precondition is not checked./
mapMonotonic :: (Interval k2 e, Ord k2) => (k1 -> k2) -> IntervalSet k1 -> IntervalSet k2
mapMonotonic _ Nil = Nil
mapMonotonic f (Node c k _ l r) =
mNode c (f k) (mapMonotonic f l) (mapMonotonic f r)
-- | /O(n)/. Filter values satisfying a predicate.
filter :: (Interval k e) => (k -> Bool) -> IntervalSet k -> IntervalSet k
filter p s = fromDistinctAscList (L.filter p (toAscList s))
-- | /O(n)/. Partition the set according to a predicate. The first
-- set contains all elements that satisfy the predicate, the second all
-- elements that fail the predicate. See also 'split'.
partition :: (Interval k e) => (k -> Bool) -> IntervalSet k -> (IntervalSet k, IntervalSet k)
partition p s = let (xs,ys) = L.partition p (toAscList s)
in (fromDistinctAscList xs, fromDistinctAscList ys)
-- | /O(n)/. The expression (@'split' k set@) is a pair @(set1,set2)@ where
-- the elements in @set1@ are smaller than @k@ and the elements in @set2@ larger than @k@.
-- Any key equal to @k@ is found in neither @set1@ nor @set2@.
split :: (Interval i k, Ord i) => i -> IntervalSet i -> (IntervalSet i, IntervalSet i)
split x m = (l, r)
where (l, _, r) = splitMember x m
-- | /O(n)/. The expression (@'splitMember' k set@) splits a set just
-- like 'split' but also returns @'member' k set@.
splitMember :: (Interval i k, Ord i) => i -> IntervalSet i -> (IntervalSet i, Bool, IntervalSet i)
splitMember x s = case span (< x) (toAscList s) of
([], []) -> (empty, False, empty)
([], y:_) | y == x -> (empty, True, deleteMin s)
| otherwise -> (empty, False, s)
(_, []) -> (s, False, empty)
(lt, ge@(y:gt)) | y == x -> (fromDistinctAscList lt, True, fromDistinctAscList gt)
| otherwise -> (fromDistinctAscList lt, False, fromDistinctAscList ge)
-- Helper for building sets from distinct ascending values and subsets
data Union k = UEmpty | Union !(Union k) !(Union k)
| UCons !k !(Union k)
| UAppend !(IntervalSet k) !(Union k)
mkUnion :: Union a -> Union a -> Union a
mkUnion UEmpty u = u
mkUnion u UEmpty = u
mkUnion u1 u2 = Union u1 u2
fromUnion :: Interval k e => Union k -> IntervalSet k
fromUnion UEmpty = empty
fromUnion (UCons key UEmpty) = singleton key
fromUnion (UAppend set UEmpty) = turnBlack set
fromUnion x = fromDistinctAscList (unfold x [])
where
unfold UEmpty r = r
unfold (Union a b) r = unfold a (unfold b r)
unfold (UCons k u) r = k : unfold u r
unfold (UAppend s u) r = toAscList' s (unfold u r)
-- | /O(n)/. Split around a point.
-- Splits the set into three subsets: intervals below the point,
-- intervals containing the point, and intervals above the point.
splitAt :: (Interval i k) => IntervalSet i -> k -> (IntervalSet i, IntervalSet i, IntervalSet i)
splitAt set p = (fromUnion (lower set), set `containing` p, fromUnion (higher set))
where
lower Nil = UEmpty
lower s@(Node _ k m l r)
| p `above` m = UAppend s UEmpty
| p `below` k = lower l
| p `inside` k = mkUnion (lower l) (lower r)
| otherwise = mkUnion (lower l) (UCons k (lower r))
higher Nil = UEmpty
higher (Node _ k m l r)
| p `above` m = UEmpty
| p `below` k = mkUnion (higher l) (UCons k (UAppend r UEmpty))
| otherwise = higher r
-- | /O(n)/. Split around an interval.
-- Splits the set into three subsets: intervals below the given interval,
-- intervals intersecting the given interval, and intervals above the
-- given interval.
splitIntersecting :: (Interval i k, Ord i) => IntervalSet i -> i -> (IntervalSet i, IntervalSet i, IntervalSet i)
splitIntersecting set i = (fromUnion (lower set), set `intersecting` i, fromUnion (higher set))
where
lower Nil = UEmpty
lower s@(Node _ k m l r)
-- whole set lower: all
| i `after` m = UAppend s UEmpty
-- interval before key: only from left subtree
| i <= k = lower l
-- interval intersects key to the right: both subtrees could contain lower intervals
| i `overlaps` k = mkUnion (lower l) (lower r)
-- interval to the right of the key: key and both subtrees
| otherwise = mkUnion (lower l) (UCons k (lower r))
higher Nil = UEmpty
higher (Node _ k m l r)
-- whole set lower: nothing
| i `after` m = UEmpty
-- interval before key: node and complete right subtree + maybe part of the left subtree
| i `before` k = mkUnion (higher l) (UCons k (UAppend r UEmpty))
-- interval overlaps or to the right of key: only from right subtree
| otherwise = higher r
-- subsets
-- | /O(n+m)/. Is the first set a subset of the second set?
-- This is always true for equal sets.
isSubsetOf :: (Ord k) => IntervalSet k -> IntervalSet k -> Bool
isSubsetOf set1 set2 = ascListSubset (toAscList set1) (toAscList set2)
ascListSubset :: (Ord a) => [a] -> [a] -> Bool
ascListSubset [] _ = True
ascListSubset (_:_) [] = False
ascListSubset s1@(k1:r1) (k2:r2) =
case compare k1 k2 of
GT -> ascListSubset s1 r2
EQ -> ascListSubset r1 r2
LT -> False
-- | /O(n+m)/. Is the first set a proper subset of the second set?
-- (i.e. a subset but not equal).
isProperSubsetOf :: (Ord k) => IntervalSet k -> IntervalSet k -> Bool
isProperSubsetOf set1 set2 = go (toAscList set1) (toAscList set2)
where
go [] (_:_) = True
go _ [] = False
go s1@(k1:r1) (k2:r2) =
case compare k1 k2 of
GT -> ascListSubset s1 r2
EQ -> go r1 r2
LT -> False
-- | /O(n log n)/. Build a new set by combining successive values.
flattenWith :: (Ord a, Interval a e) => (a -> a -> Maybe a) -> IntervalSet a -> IntervalSet a
flattenWith combine set = fromList (combineSuccessive combine set)
-- | /O(n)/. Build a new set by combining successive values.
-- Same as 'flattenWith', but works only when the combining functions returns
-- strictly monotonic values.
flattenWithMonotonic :: (Interval a e) => (a -> a -> Maybe a) -> IntervalSet a -> IntervalSet a
flattenWithMonotonic combine set = fromDistinctAscList (combineSuccessive combine set)
combineSuccessive :: (a -> a -> Maybe a) -> IntervalSet a -> [a]
combineSuccessive combine set = go (toAscList set)
where
go (x : xs@(_:_)) = go1 x xs
go xs = xs
go1 x (y:ys) = case combine x y of
Nothing -> x : go1 y ys
Just x' -> go1 x' ys
go1 x [] = [x]
-- debugging
-- | The height of the tree. For testing/debugging only.
height :: IntervalSet k -> Int
height Nil = 0
height (Node _ _ _ l r) = 1 + max (height l) (height r)
-- | The maximum height of a red-black tree with the given number of nodes.
-- For testing/debugging only.
maxHeight :: Int -> Int
maxHeight nodes = 2 * log2 (nodes + 1)
-- | Check red-black-tree and interval search augmentation invariants.
-- For testing/debugging only.
valid :: (Interval i k, Ord i) => IntervalSet i -> Bool
valid mp = test mp && height mp <= maxHeight (size mp) && validColor mp
where
test Nil = True
test n@(Node _ _ _ l r) = validOrder n && validMax n && test l && test r
validMax (Node _ k m lo hi) = m == maxUpper k lo hi
validMax Nil = True
validOrder (Node _ _ _ Nil Nil) = True
validOrder (Node _ k1 _ Nil (Node _ k2 _ _ _)) = k1 < k2
validOrder (Node _ k2 _ (Node _ k1 _ _ _) Nil) = k1 < k2
validOrder (Node _ k2 _ (Node _ k1 _ _ _) (Node _ k3 _ _ _)) = k1 < k2 && k2 < k3
validOrder Nil = True
-- validColor parentColor blackCount tree
validColor n = blackDepth n >= 0
-- return -1 if subtrees have diffrent black depths or two consecutive red nodes are encountered
blackDepth :: IntervalSet k -> Int
blackDepth Nil = 0
blackDepth (Node c _ _ l r) = case blackDepth l of
ld -> if ld < 0 then ld
else
case blackDepth r of
rd | rd < 0 -> rd
| rd /= ld || (c == R && (isRed l || isRed r)) -> -1
| c == B -> rd + 1
| otherwise -> rd