-- |
-- Module      :  Data.IntervalSet
-- Copyright   :  (c) Christoph Breitkopf 2015 - 2017
-- License     :  BSD-style
-- Maintainer  :  chbreitkopf@gmail.com
-- Stability   :  experimental
-- Portability :  non-portable (MPTC with FD)
--
-- An implementation of sets of intervals. The intervals may
-- overlap, and the implementation contains efficient search functions
-- for all intervals containing a point or overlapping a given interval.
-- Closed, open, and half-open intervals can be contained in the same set.
--
-- It is an error to insert an empty interval into a set. This precondition is not
-- checked by the various construction functions.
--
-- Since many function names (but not the type name) clash with
-- /Prelude/ names, this module is usually imported @qualified@, e.g.
--
-- >  import Data.IntervalSet.Strict (IntervalSet)
-- >  import qualified Data.IntervalSet.Strict as IS
--
-- It offers most of the same functions as 'Data.Set', but the member type must be an
-- instance of 'Interval'. The 'findMin' and 'findMax' functions deviate from their
-- set counterparts in being total and returning a 'Maybe' value.
-- Some functions differ in asymptotic performance (for example 'size') or have not
-- been tuned for efficiency as much as their equivalents in 'Data.Set'.
--
-- In addition, there are functions specific to sets of intervals, for example to search
-- for all intervals containing a given point or contained in a given interval.
--
-- The implementation is a red-black tree augmented with the maximum upper bound
-- of all keys.
--
-- Parts of this implementation are based on code from the 'Data.Map' implementation,
-- (c) Daan Leijen 2002, (c) Andriy Palamarchuk 2008.
-- The red-black tree deletion is based on code from llrbtree by Kazu Yamamoto.
-- Of course, any errors are mine.
--
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.IntervalSet (
            -- * re-export
            Interval(..)
            -- * Set type
            , IntervalSet(..)      -- instance Eq,Show,Read

            -- * Operators
            , (\\)

            -- * Query
            , null
            , size
            , member
            , notMember
            , lookupLT
            , lookupGT
            , lookupLE
            , lookupGE

            -- ** Interval query
            , containing
            , intersecting
            , within
            
            -- * Construction
            , empty
            , singleton

            -- ** Insertion
            , insert
            
            -- ** Delete\/Update
            , delete

            -- * Combine
            , union
            , unions
            , difference
            , intersection

            -- * Traversal
            -- ** Map
            , map
            , mapMonotonic

            -- ** Fold
            , foldr, foldl
            , foldl', foldr'

            -- * Flatten
            , flattenWith, flattenWithMonotonic

            -- * Conversion
            , elems

            -- ** Lists
            , toList
            , fromList

            -- ** Ordered lists
            , toAscList
            , toDescList
            , fromAscList
            , fromDistinctAscList

            -- * Filter
            , filter
            , partition

            , split
            , splitMember
            , splitAt
            , splitIntersecting

            -- * Subset
            , isSubsetOf, isProperSubsetOf

            -- * Min\/Max
            , findMin
            , findMax
            , findLast
            , deleteMin
            , deleteMax
            , deleteFindMin
            , deleteFindMax
            , minView
            , maxView

            -- * Debugging
            , valid

            ) where

import Prelude hiding (null, map, filter, foldr, foldl, splitAt)
import Data.Bits (shiftR, (.&.))
import qualified Data.Semigroup as Sem
import Data.Monoid (Monoid(..))
import qualified Data.Foldable as Foldable
import qualified Data.List as L
import Control.DeepSeq
import Control.Applicative ((<|>))

import Data.IntervalMap.Generic.Interval

{--------------------------------------------------------------------
  Operators
--------------------------------------------------------------------}
infixl 9 \\ --

-- | Same as 'difference'.
(\\) :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
m1 \\ m2 = difference m1 m2


-- | The Color of a tree node.
data Color = R | B deriving (Eq)

-- | A set of intervals of type @k@.
data IntervalSet k = Nil
                   | Node !Color
                          !k -- key
                          !k -- interval with maximum upper in tree
                          !(IntervalSet k) -- left subtree
                          !(IntervalSet k) -- right subtree

instance (Eq k) => Eq (IntervalSet k) where
  a == b = toAscList a == toAscList b

instance (Ord k) => Ord (IntervalSet k) where
  compare a b = compare (toAscList a) (toAscList b)

instance (Interval i k, Ord i) => Sem.Semigroup (IntervalSet i) where
  (<>) = union

instance (Interval i k, Ord i) => Monoid (IntervalSet i) where
    mempty  = empty
    mappend = union
    mconcat = unions
              
instance Foldable.Foldable IntervalSet where
    fold t = go t
      where go Nil = mempty
            go (Node _ k _ l r) = go l `mappend` (k `mappend` go r)
    foldr = foldr
    foldl = foldl
    foldMap f t = go t
      where go Nil = mempty
            go (Node _ k _ l r) = go l `mappend` (f k `mappend` go r)

instance (NFData k) => NFData (IntervalSet k) where
    rnf Nil = ()
    rnf (Node _ kx _ l r) = kx `deepseq` l `deepseq` r `deepseq` ()

instance (Interval i k, Ord i, Read i) => Read (IntervalSet i) where
  readsPrec p = readParen (p > 10) $ \ r -> do
    ("fromList",s) <- lex r
    (xs,t) <- reads s
    return (fromList xs,t)

instance (Show k) => Show (IntervalSet k) where
  showsPrec d m  = showParen (d > 10) $
    showString "fromList " . shows (toList m)


isRed :: IntervalSet k -> Bool
isRed (Node R _ _ _ _) = True
isRed _ = False

turnBlack :: IntervalSet k -> IntervalSet k
turnBlack (Node R k m l r) = Node B k m l r
turnBlack t = t

turnRed :: IntervalSet k -> IntervalSet k
turnRed Nil = error "turnRed: Leaf"
turnRed (Node B k m l r) = Node R k m l r
turnRed t = t

-- construct node, recomputing the upper key bound.
mNode :: (Interval k e) => Color -> k -> IntervalSet k -> IntervalSet k -> IntervalSet k
mNode c k l r = Node c k (maxUpper k l r) l r

maxUpper :: (Interval i k) => i -> IntervalSet i -> IntervalSet i -> i
maxUpper k Nil              Nil              = k
maxUpper k Nil              (Node _ _ m _ _) = maxByUpper k m
maxUpper k (Node _ _ m _ _) Nil              = maxByUpper k m
maxUpper k (Node _ _ l _ _) (Node _ _ r _ _) = maxByUpper k (maxByUpper l r)

-- interval with the greatest upper bound. The lower bound is ignored!
maxByUpper :: (Interval i e) => i -> i -> i
maxByUpper a b = a `seq` b `seq`
                 case compareUpperBounds a b of
                   LT -> b
                   _  -> a

-- ---------------------------------------------------------

-- | /O(1)/. The empty set.
empty :: IntervalSet k
empty =  Nil

-- | /O(1)/. A set with one entry.
singleton :: k -> IntervalSet k
singleton k = Node B k k Nil Nil


-- | /O(1)/. Is the set empty?
null :: IntervalSet k -> Bool
null Nil = True
null _   = False

-- | /O(n)/. Number of keys in the set.
--
-- Caution: unlike 'Data.Set.size', this takes linear time!
size :: IntervalSet k -> Int
size t = h 0 t
  where
    h n s = n `seq` case s of
                      Nil -> n
                      Node _ _ _ l r -> h (h n l + 1) r

-- | /O(log n)/. Does the set contain the given value? See also 'notMember'.
member :: (Ord k) => k -> IntervalSet k -> Bool
member k Nil = k `seq` False
member k (Node _ key _ l r) = case compare k key of
                                LT -> member k l
                                GT -> member k r
                                EQ -> True

-- | /O(log n)/. Does the set not contain the given value? See also 'member'.
notMember :: (Ord k) => k -> IntervalSet k -> Bool
notMember key tree = not (member key tree)


-- | /O(log n)/. Find the largest key smaller than the given one.
lookupLT :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupLT k m = go m
  where
    go Nil = Nothing
    go (Node _ key _ l r) | k <= key  = go l
                          | otherwise = go1 key r
    go1 rk Nil = Just rk
    go1 rk (Node _ key _ l r) | k <= key  = go1 rk l
                              | otherwise = go1 key r

-- | /O(log n)/. Find the smallest key larger than the given one.
lookupGT :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupGT k m = go m
  where
    go Nil = Nothing
    go (Node _ key _ l r) | k >= key  = go r
                          | otherwise = go1 key l
    go1 rk Nil = Just rk
    go1 rk (Node _ key _ l r) | k >= key  = go1 rk r
                              | otherwise = go1 key l

-- | /O(log n)/. Find the largest key equal to or smaller than the given one.
lookupLE :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupLE k m = go m
  where
    go Nil = Nothing
    go (Node _ key _ l r) = case compare k key of
                              LT -> go l
                              EQ -> Just key
                              GT -> go1 key r
    go1 rk Nil = Just rk
    go1 rk (Node _ key _ l r) = case compare k key of
                                  LT -> go1 rk l
                                  EQ -> Just key
                                  GT -> go1 key r

-- | /O(log n)/. Find the smallest key equal to or larger than the given one.
lookupGE :: (Ord k) => k -> IntervalSet k -> Maybe k
lookupGE k m = go m
  where
    go Nil = Nothing
    go (Node _ key _ l r) = case compare k key of
                              LT -> go1 key l
                              EQ -> Just key
                              GT -> go r
    go1 rk Nil = Just rk
    go1 rk (Node _ key _ l r) = case compare k key of
                                  LT -> go1 key l
                                  EQ -> Just key
                                  GT -> go1 rk r

-- | Return the set of all intervals containing the given point.
-- This is the second element of the value of 'splitAt':
--
-- > set `containing` p == let (_,s,_) = set `splitAt` p in s
--
-- /O(n)/, since potentially all intervals could contain the point.
-- /O(log n)/ average case. This is also the worst case for sets containing no overlapping intervals.
containing :: (Interval k e) => IntervalSet k -> e -> IntervalSet k
t `containing` p = p `seq` fromDistinctAscList (go [] t)
  where
    go xs Nil = xs
    go xs (Node _ k m l r)
       | p `above` m  =  xs         -- above all intervals in the tree: no result
       | p `below` k  =  go xs l    -- to the left of the lower bound: can't be in right subtree
       | p `inside` k =  go (k : go xs r) l
       | otherwise    =  go (go xs r) l

-- | Return the set of all intervals overlapping (intersecting) the given interval.
-- This is the second element of the result of 'splitIntersecting':
--
-- > set `intersecting` i == let (_,s,_) = set `splitIntersecting` i in s
--
-- /O(n)/, since potentially all values could intersect the interval.
-- /O(log n)/ average case, if few values intersect the interval.
intersecting :: (Interval k e) => IntervalSet k -> k -> IntervalSet k
t `intersecting` i = i `seq` fromDistinctAscList (go [] t)
  where
    go xs Nil = xs
    go xs (Node _ k m l r)
       | i `after` m     =  xs
       | i `before` k    =  go xs l
       | i `overlaps` k  =  go (k : go xs r) l
       | otherwise       =  go (go xs r) l

-- | Return the set of all intervals which are completely inside the given interval.
--
-- /O(n)/, since potentially all values could be inside the interval.
-- /O(log n)/ average case, if few keys are inside the interval.
within :: (Interval k e) => IntervalSet k -> k -> IntervalSet k
t `within` i = i `seq` fromDistinctAscList (go [] t)
  where
    go xs Nil = xs
    go xs (Node _ k m l r)
       | i `after` m     =  xs
       | i `before` k    =  go xs l
       | i `subsumes` k  =  go (k : go xs r) l
       | otherwise       =  go (go xs r) l


-- | /O(log n)/. Insert a new value. If the set already contains an element equal to the value,
-- it is replaced by the new value.
insert :: (Interval k e, Ord k) => k -> IntervalSet k -> IntervalSet k
insert v s = v `seq` turnBlack (ins s)
  where
    singletonR k = Node R k k Nil Nil
    ins Nil = singletonR v
    ins (Node color k m l r) =
      case compare v k of
        LT -> balanceL color k (ins l) r
        GT -> balanceR color k l (ins r)
        EQ -> Node color v m l r

balanceL :: (Interval k e) => Color -> k -> IntervalSet k -> IntervalSet k -> IntervalSet k
balanceL B zk (Node R yk _ (Node R xk _ a b) c) d =
    mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceL B zk (Node R xk _ a (Node R yk _ b c)) d =
    mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceL c xk l r = mNode c xk l r

balanceR :: (Interval k e) => Color -> k -> IntervalSet k -> IntervalSet k -> IntervalSet k
balanceR B xk a (Node R yk _ b (Node R zk _ c d)) =
    mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceR B xk a (Node R zk _ (Node R yk _ b c) d) =
    mNode R yk (mNode B xk a b) (mNode B zk c d)
balanceR c xk l r = mNode c xk l r


-- min/max

-- | /O(log n)/. Returns the minimal value in the set.
findMin :: IntervalSet k -> Maybe k
findMin (Node _ k _ Nil _) = Just k
findMin (Node _ _ _ l _) = findMin l
findMin Nil = Nothing

-- | /O(log n)/. Returns the maximal value in the set.
findMax :: IntervalSet k -> Maybe k
findMax (Node _ k _ _ Nil) = Just k
findMax (Node _ _ _ _ r) = findMax r
findMax Nil = Nothing

-- | Returns the interval with the largest endpoint.
-- If there is more than one interval with that endpoint,
-- return the rightmost.
--
-- /O(n)/, since all intervals could have the same endpoint.
-- /O(log n)/ average case.
findLast :: (Interval k e) => IntervalSet k -> Maybe k
findLast Nil = Nothing
findLast t@(Node _ _ mx _ _) = go t
  where
    go (Node _ k m l r) | sameU m mx = if sameU k m then go r <|> Just k
                                                    else go r <|> go l
                        | otherwise  = Nothing
    go Nil = Nothing
    sameU a b = compareUpperBounds a b == EQ


-- Type to indicate whether the number of black nodes changed or stayed the same.
data DeleteResult k = U !(IntervalSet k)   -- Unchanged
                    | S !(IntervalSet k)   -- Shrunk

unwrap :: DeleteResult k -> IntervalSet k
unwrap (U m) = m
unwrap (S m) = m

-- DeleteResult with value
data DeleteResult' k a = U' !(IntervalSet k) a
                       | S' !(IntervalSet k) a

unwrap' :: DeleteResult' k a -> IntervalSet k
unwrap' (U' m _) = m
unwrap' (S' m _) = m

-- annotate DeleteResult with value
annotate :: DeleteResult k -> a -> DeleteResult' k a
annotate (U m) x = U' m x
annotate (S m) x = S' m x


-- | /O(log n)/. Remove the smallest element from the set. Return the empty set if the set is empty.
deleteMin :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k
deleteMin Nil = Nil
deleteMin m   = turnBlack (unwrap' (deleteMin' m))

deleteMin' :: (Interval k e, Ord k) => IntervalSet k -> DeleteResult' k k
deleteMin' Nil = error "deleteMin': Nil"
deleteMin' (Node B k _ Nil Nil) = S' Nil k
deleteMin' (Node B k _ Nil r@(Node R _ _ _ _)) = U' (turnBlack r) k
deleteMin' (Node R k _ Nil r) = U' r k
deleteMin' (Node c k _ l r) =
  case deleteMin' l of
    (U' l' kv) -> U' (mNode c k l' r) kv
    (S' l' kv) -> annotate (unbalancedR c k l' r) kv

deleteMax' :: (Interval k e, Ord k) => IntervalSet k -> DeleteResult' k k
deleteMax' Nil = error "deleteMax': Nil"
deleteMax' (Node B k _ Nil Nil) = S' Nil k
deleteMax' (Node B k _ l@(Node R _ _ _ _) Nil) = U' (turnBlack l) k
deleteMax' (Node R k _ l Nil) = U' l k
deleteMax' (Node c k _ l r) =
  case deleteMax' r of
    (U' r' kv) -> U' (mNode c k l r') kv
    (S' r' kv) -> annotate (unbalancedL c k l r') kv

-- The left tree lacks one Black node
unbalancedR :: (Interval k e, Ord k) => Color -> k -> IntervalSet k -> IntervalSet k -> DeleteResult k
-- Decreasing one Black node in the right
unbalancedR B k l r@(Node B _ _ _ _) = S (balanceR B k l (turnRed r))
unbalancedR R k l r@(Node B _ _ _ _) = U (balanceR B k l (turnRed r))
-- Taking one Red node from the right and adding it to the right as Black
unbalancedR B k l (Node R rk _ rl@(Node B _ _ _ _) rr)
  = U (mNode B rk (balanceR B k l (turnRed rl)) rr)
unbalancedR _ _ _ _ = error "unbalancedR"

unbalancedL :: (Interval k e, Ord k) => Color -> k -> IntervalSet k -> IntervalSet k -> DeleteResult k
unbalancedL R k l@(Node B _ _ _ _) r = U (balanceL B k (turnRed l) r)
unbalancedL B k l@(Node B _ _ _ _) r = S (balanceL B k (turnRed l) r)
unbalancedL B k (Node R lk _ ll lr@(Node B _ _ _ _)) r
  = U (mNode B lk ll (balanceL B k (turnRed lr) r))
unbalancedL _ _ _ _ = error "unbalancedL"


-- | /O(log n)/. Remove the largest element from the set. Return the empty set if the set is empty.
deleteMax :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k
deleteMax Nil = Nil
deleteMax m   = turnBlack (unwrap' (deleteMax' m))

-- | /O(log n)/. Delete and return the smallest element.
deleteFindMin :: (Interval k e, Ord k) => IntervalSet k -> (k, IntervalSet k)
deleteFindMin mp = case deleteMin' mp of
                     (U' r v) -> (v, turnBlack r)
                     (S' r v) -> (v, turnBlack r)

-- | /O(log n)/. Delete and return the largest element.
deleteFindMax :: (Interval k e, Ord k) => IntervalSet k -> (k, IntervalSet k)
deleteFindMax mp = case deleteMax' mp of
                     (U' r v) -> (v, turnBlack r)
                     (S' r v) -> (v, turnBlack r)

-- | /O(log n)/. Retrieves the minimal element of the set, and
-- the set stripped of that element, or 'Nothing' if passed an empty set.
minView :: (Interval k e, Ord k) => IntervalSet k -> Maybe (k, IntervalSet k)
minView Nil = Nothing
minView x   = Just (deleteFindMin x)

-- | /O(log n)/. Retrieves the maximal element of the set, and
-- the set stripped of that element, or 'Nothing' if passed an empty set.
maxView :: (Interval k e, Ord k) => IntervalSet k -> Maybe (k, IntervalSet k)
maxView Nil = Nothing
maxView x   = Just (deleteFindMax x)


-- folding

-- | /O(n)/. Fold the values in the set using the given right-associative
-- binary operator, such that @'foldr' f z == 'Prelude.foldr' f z . 'elems'@.
foldr :: (k -> b -> b) -> b -> IntervalSet k -> b
foldr _ z Nil = z
foldr f z (Node _ k _ l r) = foldr f (f k (foldr f z r)) l

-- | /O(n)/. A strict version of 'foldr'. Each application of the operator is
-- evaluated before using the result in the next application. This
-- function is strict in the starting value.
foldr' :: (k -> b -> b) -> b -> IntervalSet k -> b
foldr' f z s = z `seq` case s of
                         Nil -> z
                         Node _ k _ l r -> foldr' f (f k (foldr' f z r)) l

-- | /O(n)/. Fold the values in the set using the given left-associative
-- binary operator, such that @'foldl' f z == 'Prelude.foldl' f z . 'elems'@.
foldl :: (b -> k -> b) -> b -> IntervalSet k -> b
foldl _ z Nil = z
foldl f z (Node _ k _ l r) = foldl f (f (foldl f z l) k) r

-- | /O(n)/. A strict version of 'foldl'. Each application of the operator is
-- evaluated before using the result in the next application. This
-- function is strict in the starting value.
foldl' :: (b -> k -> b) -> b -> IntervalSet k -> b
foldl' f z s = z `seq` case s of
                         Nil -> z
                         Node _ k _ l r -> foldl' f (f (foldl' f z l) k) r

-- delete

-- | /O(log n)/. Delete an element from the set. If the set does not contain the value,
-- it is returned unchanged.
delete :: (Interval k e, Ord k) => k -> IntervalSet k -> IntervalSet k
delete key mp = turnBlack (unwrap (delete' key mp))

delete' :: (Interval k e, Ord k) => k -> IntervalSet k -> DeleteResult k
delete' x Nil = x `seq` U Nil
delete' x (Node c k _ l r) =
  case compare x k of
    LT -> case delete' x l of
            (U l') -> U (mNode c k l' r)
            (S l')    -> unbalancedR c k l' r
    GT -> case delete' x r of
            (U r') -> U (mNode c k l r')
            (S r')    -> unbalancedL c k l r'
    EQ -> case r of
            Nil -> if c == B then blackify l else U l
            _ -> case deleteMin' r of
                   (U' r' rk) -> U (mNode c rk l r')
                   (S' r' rk) -> unbalancedL c rk l r'

blackify :: IntervalSet k -> DeleteResult k
blackify (Node R k m l r) = U (Node B k m l r)
blackify s                = S s


-- | /O(n+m)/. The expression (@'union' t1 t2@) takes the left-biased union of @t1@ and @t2@. 
-- It prefers @t1@ when duplicate elements are encountered.
union :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
union m1 m2 = fromDistinctAscList (ascListUnion (toAscList m1) (toAscList m2))

-- | The union of a list of sets:
--   (@'unions' == 'Prelude.foldl' 'union' 'empty'@).
unions :: (Interval k e, Ord k) => [IntervalSet k] -> IntervalSet k
unions []  = empty
unions [s] = s
unions iss = fromDistinctAscList (head (go (L.map toAscList iss)))
  where
    go []       = []
    go xs@[_]   = xs
    go (x:y:xs) = go (ascListUnion x y : go xs)

-- | /O(n+m)/. Difference of two sets.
-- Return elements of the first set not existing in the second set.
difference :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
difference m1 m2 = fromDistinctAscList (ascListDifference (toAscList m1) (toAscList m2))

-- | /O(n+m)/. Intersection of two sets.
-- Return elements in the first set also existing in the second set.
intersection :: (Interval k e, Ord k) => IntervalSet k -> IntervalSet k -> IntervalSet k
intersection m1 m2 = fromDistinctAscList (ascListIntersection (toAscList m1) (toAscList m2))

ascListUnion :: Ord k => [k] -> [k] -> [k]
ascListUnion [] [] = []
ascListUnion [] ys = ys
ascListUnion xs [] = xs
ascListUnion xs@(x:xs') ys@(y:ys') =
  case compare x y of
    LT -> x : ascListUnion xs' ys
    GT -> y : ascListUnion xs ys'
    EQ -> x : ascListUnion xs' ys'

ascListDifference :: Ord k => [k] -> [k] -> [k]
ascListDifference [] _  = []
ascListDifference xs [] = xs
ascListDifference xs@(xk:xs') ys@(yk:ys') =
  case compare xk yk of
    LT -> xk : ascListDifference xs' ys
    GT -> ascListDifference xs ys'
    EQ -> ascListDifference xs' ys'

ascListIntersection :: Ord k => [k] -> [k] -> [k]
ascListIntersection [] _ = []
ascListIntersection _ [] = []
ascListIntersection xs@(xk:xs') ys@(yk:ys') =
  case compare xk yk of
    LT -> ascListIntersection xs' ys
    GT -> ascListIntersection xs ys'
    EQ -> xk : ascListIntersection xs' ys'


-- --- Conversion ---

-- | /O(n)/. The list of all values contained in the set, in ascending order.
toAscList :: IntervalSet k -> [k]
toAscList set = toAscList' set []

toAscList' :: IntervalSet k -> [k] -> [k]
toAscList' m xs = foldr (:) xs m



-- | /O(n)/. The list of all values in the set, in no particular order.
toList :: IntervalSet k -> [k]
toList s = go s []
  where
    go Nil              xs = xs
    go (Node _ k _ l r) xs = k : go l (go r xs)

-- | /O(n)/. The list of all values in the set, in descending order.
toDescList :: IntervalSet k -> [k]
toDescList m = foldl (flip (:)) [] m

-- | /O(n log n)/. Build a set from a list of elements. See also 'fromAscList'.
-- If the list contains duplicate values, the last value is retained.
fromList :: (Interval k e, Ord k) => [k] -> IntervalSet k
fromList xs = L.foldl' (flip insert) empty xs

-- | /O(n)/. Build a set from an ascending list in linear time.
-- /The precondition (input list is ascending) is not checked./
fromAscList :: (Interval k e, Eq k) => [k] -> IntervalSet k
fromAscList xs = fromDistinctAscList (uniq xs)

uniq :: Eq k => [k] -> [k]
uniq [] = []
uniq (x:xs) = go x xs
  where
    go v [] = [v]
    go v (y:ys) | v == y    = go v ys
                | otherwise = v : go y ys

-- Strict tuple
data T2 a b = T2 !a !b


-- | /O(n)/. Build a set from an ascending list of distinct elements in linear time.
-- /The precondition is not checked./
fromDistinctAscList :: (Interval k e) => [k] -> IntervalSet k
-- exactly 2^n-1 items have height n. They can be all black
-- from 2^n - 2^n-2 items have height n+1. The lowest "row" should be red.
fromDistinctAscList lyst = case h (length lyst) lyst of
                             (T2 result []) -> result
                             _ -> error "fromDistinctAscList: list not fully consumed"
  where
    h n xs | n == 0      = T2 Nil xs
           | isPerfect n = buildB n xs
           | otherwise   = buildR n (log2 n) xs

    buildB n xs | xs `seq` n <= 0 = error "fromDictinctAscList: buildB 0"
                | n == 1     = case xs of (k:xs') -> T2 (Node B k k Nil Nil) xs'
                                          _ -> error "fromDictinctAscList: buildB 1"
                | otherwise  =
                     case n `quot` 2 of { n' ->
                     case buildB n' xs of { (T2 _ []) -> error "fromDictinctAscList: buildB n";
                                            (T2 l (k:xs')) ->
                     case buildB n' xs' of { (T2 r xs'') ->
                     T2 (mNode B k l r) xs'' }}}

    buildR n d xs | d `seq` xs `seq` n == 0 = T2 Nil xs
                  | n == 1    = case xs of (k:xs') -> T2 (Node (if d==0 then R else B) k k Nil Nil) xs'
                                           _ -> error "fromDistinctAscList: buildR 1"
                  | otherwise =
                      case n `quot` 2 of { n' ->
                      case buildR n' (d-1) xs of { (T2 _ []) -> error "fromDistinctAscList: buildR n";
                                                   (T2 l (k:xs')) ->
                      case buildR (n - (n' + 1)) (d-1) xs' of { (T2 r xs'') ->
                      T2 (mNode B k l r) xs'' }}}


-- is n a perfect binary tree size (2^m-1)?
isPerfect :: Int -> Bool
isPerfect n = (n .&. (n + 1)) == 0

log2 :: Int -> Int
log2 m = h (-1) m
  where
    h r n | r `seq` n <= 0 = r
          | otherwise      = h (r + 1) (n `shiftR` 1)


-- | /O(n)/. List of all values in the set, in ascending order.
elems :: IntervalSet k -> [k]
elems s = toAscList s

-- --- Mapping ---

-- | /O(n log n)/. Map a function over all values in the set.
--
-- The size of the result may be smaller if @f@ maps two or more distinct
-- elements to the same value.
map :: (Interval b e2, Ord b) => (a -> b) -> IntervalSet a -> IntervalSet b
map f s = fromList [f x | x <- toList s]

-- | /O(n)/. @'mapMonotonic' f s == 'map' f s@, but works only when @f@
-- is strictly monotonic.
-- That is, for any values @x@ and @y@, if @x@ < @y@ then @f x@ < @f y@.
-- /The precondition is not checked./
mapMonotonic :: (Interval k2 e, Ord k2) => (k1 -> k2) -> IntervalSet k1 -> IntervalSet k2
mapMonotonic _ Nil = Nil
mapMonotonic f (Node c k _ l r) =
    mNode c (f k) (mapMonotonic f l) (mapMonotonic f r)

-- | /O(n)/. Filter values satisfying a predicate.
filter :: (Interval k e) => (k -> Bool) -> IntervalSet k -> IntervalSet k
filter p s = fromDistinctAscList (L.filter p (toAscList s))

-- | /O(n)/. Partition the set according to a predicate. The first
-- set contains all elements that satisfy the predicate, the second all
-- elements that fail the predicate. See also 'split'.
partition :: (Interval k e) => (k -> Bool) -> IntervalSet k -> (IntervalSet k, IntervalSet k)
partition p s = let (xs,ys) = L.partition p (toAscList s)
                in (fromDistinctAscList xs, fromDistinctAscList ys)

-- | /O(n)/. The expression (@'split' k set@) is a pair @(set1,set2)@ where
-- the elements in @set1@ are smaller than @k@ and the elements in @set2@ larger than @k@.
-- Any key equal to @k@ is found in neither @set1@ nor @set2@.
split :: (Interval i k, Ord i) => i -> IntervalSet i -> (IntervalSet i, IntervalSet i)
split x m = (l, r)
  where (l, _, r) = splitMember x m
     
-- | /O(n)/. The expression (@'splitMember' k set@) splits a set just
-- like 'split' but also returns @'member' k set@.
splitMember :: (Interval i k, Ord i) => i -> IntervalSet i -> (IntervalSet i, Bool, IntervalSet i)
splitMember x s = case span (< x) (toAscList s) of
                    ([], [])                    -> (empty, False, empty)
                    ([], y:_)       | y == x    -> (empty, True, deleteMin s)
                                    | otherwise -> (empty, False, s)
                    (_, [])                     -> (s, False, empty)
                    (lt, ge@(y:gt)) | y == x    -> (fromDistinctAscList lt, True, fromDistinctAscList gt)
                                    | otherwise -> (fromDistinctAscList lt, False, fromDistinctAscList ge)

-- Helper for building sets from distinct ascending values and subsets
data Union k = UEmpty | Union !(Union k) !(Union k)
             | UCons !k !(Union k)
             | UAppend !(IntervalSet k) !(Union k)

mkUnion :: Union a -> Union a -> Union a
mkUnion UEmpty u = u
mkUnion u UEmpty = u
mkUnion u1 u2 = Union u1 u2

fromUnion :: Interval k e => Union k -> IntervalSet k
fromUnion UEmpty               = empty
fromUnion (UCons key UEmpty)   = singleton key
fromUnion (UAppend set UEmpty) = turnBlack set
fromUnion x                    = fromDistinctAscList (unfold x [])
  where
    unfold UEmpty        r = r
    unfold (Union a b)   r = unfold a (unfold b r)
    unfold (UCons k u)   r = k : unfold u r
    unfold (UAppend s u) r = toAscList' s (unfold u r)


-- | /O(n)/. Split around a point.
-- Splits the set into three subsets: intervals below the point,
-- intervals containing the point, and intervals above the point.
splitAt :: (Interval i k) => IntervalSet i -> k -> (IntervalSet i, IntervalSet i, IntervalSet i)
splitAt set p = (fromUnion (lower set), set `containing` p, fromUnion (higher set))
  where
    lower Nil = UEmpty
    lower s@(Node _ k m l r)
      | p `above`  m  =  UAppend s UEmpty
      | p `below`  k  =  lower l
      | p `inside` k  =  mkUnion (lower l) (lower r)
      | otherwise     =  mkUnion (lower l) (UCons k (lower r))
    higher Nil = UEmpty
    higher (Node _ k m l r)
      | p `above`  m  =  UEmpty
      | p `below`  k  =  mkUnion (higher l) (UCons k (UAppend r UEmpty))
      | otherwise     =  higher r

-- | /O(n)/. Split around an interval.
-- Splits the set into three subsets: intervals below the given interval,
-- intervals intersecting the given interval, and intervals above the
-- given interval.
splitIntersecting :: (Interval i k, Ord i) => IntervalSet i -> i -> (IntervalSet i, IntervalSet i, IntervalSet i)
splitIntersecting set i = (fromUnion (lower set), set `intersecting` i, fromUnion (higher set))
  where
    lower Nil = UEmpty
    lower s@(Node _ k m l r)
      -- whole set lower: all
      | i `after`  m   =  UAppend s UEmpty
      -- interval before key: only from left subtree
      | i <= k         =  lower l
      -- interval intersects key to the right: both subtrees could contain lower intervals
      | i `overlaps` k =  mkUnion (lower l) (lower r)
      -- interval to the right of the key: key and both subtrees
      | otherwise      =  mkUnion (lower l) (UCons k (lower r))
    higher Nil = UEmpty
    higher (Node _ k m l r)
      -- whole set lower: nothing
      | i `after` m    =  UEmpty
      -- interval before key: node and complete right subtree + maybe part of the left subtree
      | i `before`  k  =  mkUnion (higher l) (UCons k (UAppend r UEmpty))
      -- interval overlaps or to the right of key: only from right subtree
      | otherwise      =  higher r


-- subsets

-- | /O(n+m)/. Is the first set a subset of the second set?
-- This is always true for equal sets.
isSubsetOf :: (Ord k) => IntervalSet k -> IntervalSet k -> Bool
isSubsetOf set1 set2 = ascListSubset (toAscList set1) (toAscList set2)

ascListSubset :: (Ord a) => [a] -> [a] -> Bool
ascListSubset []    _  =  True
ascListSubset (_:_) [] =  False
ascListSubset s1@(k1:r1) (k2:r2) =
  case compare k1 k2 of
    GT -> ascListSubset s1 r2
    EQ -> ascListSubset r1 r2
    LT -> False

-- | /O(n+m)/. Is the first set a proper subset of the second set?
-- (i.e. a subset but not equal).
isProperSubsetOf :: (Ord k) => IntervalSet k -> IntervalSet k -> Bool
isProperSubsetOf set1 set2 = go (toAscList set1) (toAscList set2)
  where
    go [] (_:_) = True
    go _  []    = False
    go s1@(k1:r1) (k2:r2) =
       case compare k1 k2 of
         GT -> ascListSubset s1 r2
         EQ -> go r1 r2
         LT -> False

-- | /O(n log n)/. Build a new set by combining successive values.
flattenWith :: (Ord a, Interval a e) => (a -> a -> Maybe a) -> IntervalSet a -> IntervalSet a
flattenWith combine set = fromList (combineSuccessive combine set)

-- | /O(n)/. Build a new set by combining successive values.
-- Same as 'flattenWith', but works only when the combining functions returns
-- strictly monotonic values.
flattenWithMonotonic :: (Interval a e) => (a -> a -> Maybe a) -> IntervalSet a -> IntervalSet a
flattenWithMonotonic combine set = fromDistinctAscList (combineSuccessive combine set)

combineSuccessive :: (a -> a -> Maybe a) -> IntervalSet a -> [a]
combineSuccessive combine set = go (toAscList set)
  where
    go (x : xs@(_:_)) = go1 x xs
    go xs             = xs
    go1 x (y:ys) = case combine x y of
                     Nothing -> x : go1 y ys
                     Just x' -> go1 x' ys
    go1 x []     = [x]


-- debugging

-- | The height of the tree. For testing/debugging only.
height :: IntervalSet k -> Int
height Nil = 0
height (Node _ _ _ l r) = 1 + max (height l) (height r)

-- | The maximum height of a red-black tree with the given number of nodes.
-- For testing/debugging only.
maxHeight :: Int -> Int
maxHeight nodes = 2 * log2 (nodes + 1)


-- | Check red-black-tree and interval search augmentation invariants.
-- For testing/debugging only.
valid :: (Interval i k, Ord i) => IntervalSet i -> Bool
valid mp = test mp && height mp <= maxHeight (size mp) && validColor mp
  where
    test Nil = True
    test n@(Node _ _ _ l r) = validOrder n && validMax n && test l && test r
    validMax (Node _ k m lo hi) =  m == maxUpper k lo hi
    validMax Nil = True

    validOrder (Node _ _ _ Nil Nil) = True
    validOrder (Node _ k1 _ Nil (Node _ k2 _ _ _)) = k1 < k2
    validOrder (Node _ k2 _ (Node _ k1 _ _ _) Nil) = k1 < k2
    validOrder (Node _ k2 _ (Node _ k1 _ _ _) (Node _ k3 _ _ _)) = k1 < k2 && k2 < k3
    validOrder Nil = True

    -- validColor parentColor blackCount tree
    validColor n = blackDepth n >= 0

    -- return -1 if subtrees have diffrent black depths or two consecutive red nodes are encountered
    blackDepth :: IntervalSet k -> Int
    blackDepth Nil  = 0
    blackDepth (Node c _ _ l r) = case blackDepth l of
                                      ld -> if ld < 0 then ld
                                            else
                                              case blackDepth r of
                                                rd | rd < 0       -> rd
                                                   | rd /= ld || (c == R && (isRed l || isRed r)) -> -1
                                                   | c == B       -> rd + 1
                                                   | otherwise    -> rd