{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_HADDOCK not-home #-}
module Data.Set.NonEmpty.Internal (
NESet(..)
, nonEmptySet
, withNonEmpty
, toSet
, singleton
, fromList
, toList
, size
, union
, unions
, foldr
, foldl
, foldr'
, foldl'
, MergeNESet(..)
, merge
, valid
, insertMinSet
, insertMaxSet
, disjointSet
, powerSetSet
, disjointUnionSet
, cartesianProductSet
) where
import Control.DeepSeq
import Data.Data
import Data.Function
import Data.Functor.Classes
import Data.List.NonEmpty (NonEmpty(..))
import Data.Semigroup
import Data.Semigroup.Foldable (Foldable1)
import Data.Set.Internal (Set(..))
import Prelude hiding (foldr, foldr1, foldl, foldl1)
import Text.Read
import qualified Data.Foldable as F
import qualified Data.Semigroup.Foldable as F1
import qualified Data.Set as S
import qualified Data.Set.Internal as S
#if !MIN_VERSION_containers(0,5,11)
import Utils.Containers.Internal.StrictPair
#endif
data NESet a =
NESet { nesV0 :: !a
, nesSet :: !(Set a)
}
deriving (Typeable)
instance Eq a => Eq (NESet a) where
t1 == t2 = S.size (nesSet t1) == S.size (nesSet t2)
&& toList t1 == toList t2
instance Ord a => Ord (NESet a) where
compare = compare `on` toList
(<) = (<) `on` toList
(>) = (>) `on` toList
(<=) = (<=) `on` toList
(>=) = (>=) `on` toList
instance Show a => Show (NESet a) where
showsPrec p xs = showParen (p > 10) $
showString "fromList (" . shows (toList xs) . showString ")"
instance (Read a, Ord a) => Read (NESet a) where
readPrec = parens $ prec 10 $ do
Ident "fromList" <- lexP
xs <- parens . prec 10 $ readPrec
return (fromList xs)
readListPrec = readListPrecDefault
instance Eq1 NESet where
liftEq eq m n =
size m == size n && liftEq eq (toList m) (toList n)
instance Ord1 NESet where
liftCompare cmp m n =
liftCompare cmp (toList m) (toList n)
instance Show1 NESet where
liftShowsPrec sp sl d m =
showsUnaryWith (liftShowsPrec sp sl) "fromList" d (toList m)
instance NFData a => NFData (NESet a) where
rnf (NESet x s) = rnf x `seq` rnf s
instance (Data a, Ord a) => Data (NESet a) where
gfoldl f z set = z fromList `f` toList set
toConstr _ = fromListConstr
gunfold k z c = case constrIndex c of
1 -> k (z fromList)
_ -> error "gunfold"
dataTypeOf _ = setDataType
dataCast1 = gcast1
fromListConstr :: Constr
fromListConstr = mkConstr setDataType "fromList" [] Prefix
setDataType :: DataType
setDataType = mkDataType "Data.Set.NonEmpty.Internal.NESet" [fromListConstr]
nonEmptySet :: Set a -> Maybe (NESet a)
nonEmptySet = (fmap . uncurry) NESet . S.minView
{-# INLINE nonEmptySet #-}
withNonEmpty
:: r
-> (NESet a -> r)
-> Set a
-> r
withNonEmpty def f = maybe def f . nonEmptySet
{-# INLINE withNonEmpty #-}
toSet :: NESet a -> Set a
toSet (NESet x s) = insertMinSet x s
{-# INLINE toSet #-}
singleton :: a -> NESet a
singleton x = NESet x S.empty
{-# INLINE singleton #-}
fromList :: Ord a => NonEmpty a -> NESet a
fromList (x :| s) = withNonEmpty (singleton x) (<> singleton x)
. S.fromList
$ s
{-# INLINE fromList #-}
toList :: NESet a -> NonEmpty a
toList (NESet x s) = x :| S.toList s
{-# INLINE toList #-}
size :: NESet a -> Int
size (NESet _ s) = 1 + S.size s
{-# INLINE size #-}
foldr :: (a -> b -> b) -> b -> NESet a -> b
foldr f z (NESet x s) = x `f` S.foldr f z s
{-# INLINE foldr #-}
foldr' :: (a -> b -> b) -> b -> NESet a -> b
foldr' f z (NESet x s) = x `f` y
where
!y = S.foldr' f z s
{-# INLINE foldr' #-}
foldr1 :: (a -> a -> a) -> NESet a -> a
foldr1 f (NESet x s) = maybe x (f x . uncurry (S.foldr f))
. S.maxView
$ s
{-# INLINE foldr1 #-}
foldl :: (a -> b -> a) -> a -> NESet b -> a
foldl f z (NESet x s) = S.foldl f (f z x) s
{-# INLINE foldl #-}
foldl' :: (a -> b -> a) -> a -> NESet b -> a
foldl' f z (NESet x s) = S.foldl' f y s
where
!y = f z x
{-# INLINE foldl' #-}
foldl1 :: (a -> a -> a) -> NESet a -> a
foldl1 f (NESet x s) = S.foldl f x s
{-# INLINE foldl1 #-}
union
:: Ord a
=> NESet a
-> NESet a
-> NESet a
union n1@(NESet x1 s1) n2@(NESet x2 s2) = case compare x1 x2 of
LT -> NESet x1 . S.union s1 . toSet $ n2
EQ -> NESet x1 . S.union s1 $ s2
GT -> NESet x2 . S.union (toSet n1) $ s2
{-# INLINE union #-}
unions
:: (Foldable1 f, Ord a)
=> f (NESet a)
-> NESet a
unions (F1.toNonEmpty->(s :| ss)) = F.foldl' union s ss
{-# INLINE unions #-}
instance Ord a => Semigroup (NESet a) where
(<>) = union
{-# INLINE (<>) #-}
sconcat = unions
{-# INLINE sconcat #-}
instance Foldable NESet where
#if MIN_VERSION_base(4,11,0)
fold (NESet x s) = x <> F.fold s
{-# INLINE fold #-}
foldMap f (NESet x s) = f x <> foldMap f s
{-# INLINE foldMap #-}
#else
fold (NESet x s) = x `mappend` F.fold s
{-# INLINE fold #-}
foldMap f (NESet x s) = f x `mappend` foldMap f s
{-# INLINE foldMap #-}
#endif
foldr = foldr
{-# INLINE foldr #-}
foldr' = foldr'
{-# INLINE foldr' #-}
foldr1 = foldr1
{-# INLINE foldr1 #-}
foldl = foldl
{-# INLINE foldl #-}
foldl' = foldl'
{-# INLINE foldl' #-}
foldl1 = foldl1
{-# INLINE foldl1 #-}
null _ = False
{-# INLINE null #-}
length = size
{-# INLINE length #-}
elem x (NESet x0 s) = F.elem x s
|| x == x0
{-# INLINE elem #-}
minimum (NESet x _) = x
{-# INLINE minimum #-}
maximum (NESet x s) = maybe x fst . S.maxView $ s
{-# INLINE maximum #-}
toList = F.toList . toList
{-# INLINE toList #-}
instance Foldable1 NESet where
fold1 (NESet x s) = option x (x <>)
. F.foldMap (Option . Just)
$ s
{-# INLINE fold1 #-}
foldMap1 f (NESet x s) = option (f x) (f x <>)
. F.foldMap (Option . Just . f)
$ s
{-# INLINE foldMap1 #-}
toNonEmpty = toList
{-# INLINE toNonEmpty #-}
newtype MergeNESet a = MergeNESet { getMergeNESet :: NESet a }
instance Semigroup (MergeNESet a) where
MergeNESet n1 <> MergeNESet n2 = MergeNESet (merge n1 n2)
{-# INLINE (<>) #-}
merge :: NESet a -> NESet a -> NESet a
merge (NESet x1 s1) n2 = NESet x1 $ s1 `S.merge` toSet n2
valid :: Ord a => NESet a -> Bool
valid (NESet x s) = S.valid s
&& all ((x <) . fst) (S.minView s)
insertMinSet :: a -> Set a -> Set a
insertMinSet x = \case
Tip -> S.singleton x
Bin _ y l r -> balanceL y (insertMinSet x l) r
{-# INLINABLE insertMinSet #-}
insertMaxSet :: a -> Set a -> Set a
insertMaxSet x = \case
Tip -> S.singleton x
Bin _ y l r -> balanceR y l (insertMaxSet x r)
{-# INLINABLE insertMaxSet #-}
disjointSet :: Ord a => Set a -> Set a -> Bool
#if MIN_VERSION_containers(0,5,11)
disjointSet = S.disjoint
#else
disjointSet xs = S.null . S.intersection xs
#endif
{-# INLINE disjointSet #-}
powerSetSet :: Set a -> Set (Set a)
#if MIN_VERSION_containers(0,5,11)
powerSetSet = S.powerSet
{-# INLINE powerSetSet #-}
#else
powerSetSet xs0 = insertMinSet S.empty (S.foldr' step' Tip xs0) where
step' x pxs = insertMinSet (S.singleton x) (insertMinSet x `S.mapMonotonic` pxs) `glue` pxs
{-# INLINABLE powerSetSet #-}
minViewSure :: a -> Set a -> Set a -> StrictPair a (Set a)
minViewSure = go
where
go x Tip r = x :*: r
go x (Bin _ xl ll lr) r =
case go xl ll lr of
xm :*: l' -> xm :*: balanceR x l' r
maxViewSure :: a -> Set a -> Set a -> StrictPair a (Set a)
maxViewSure = go
where
go x l Tip = x :*: l
go x l (Bin _ xr rl rr) =
case go xr rl rr of
xm :*: r' -> xm :*: balanceL x l r'
glue :: Set a -> Set a -> Set a
glue Tip r = r
glue l Tip = l
glue l@(Bin sl xl ll lr) r@(Bin sr xr rl rr)
| sl > sr = let !(m :*: l') = maxViewSure xl ll lr in balanceR m l' r
| otherwise = let !(m :*: r') = minViewSure xr rl rr in balanceL m l r'
#endif
disjointUnionSet :: Set a -> Set b -> Set (Either a b)
#if MIN_VERSION_containers(0,5,11)
disjointUnionSet = S.disjointUnion
#else
disjointUnionSet as bs = S.merge (S.mapMonotonic Left as) (S.mapMonotonic Right bs)
#endif
{-# INLINE disjointUnionSet #-}
cartesianProductSet :: Set a -> Set b -> Set (a, b)
#if MIN_VERSION_containers(0,5,11)
cartesianProductSet = S.cartesianProduct
#else
cartesianProductSet as bs =
getMergeSet $ foldMap (\a -> MergeSet $ S.mapMonotonic ((,) a) bs) as
newtype MergeSet a = MergeSet { getMergeSet :: Set a }
instance Semigroup (MergeSet a) where
MergeSet xs <> MergeSet ys = MergeSet (S.merge xs ys)
instance Monoid (MergeSet a) where
mempty = MergeSet S.empty
mappend = (<>)
#endif
{-# INLINE cartesianProductSet #-}
balanceR :: a -> Set a -> Set a -> Set a
balanceR x l r = case l of
Tip -> case r of
Tip -> Bin 1 x Tip Tip
Bin _ _ Tip Tip -> Bin 2 x Tip r
Bin _ rx Tip rr@Bin{} -> Bin 3 rx (Bin 1 x Tip Tip) rr
Bin _ rx (Bin _ rlx _ _) Tip -> Bin 3 rlx (Bin 1 x Tip Tip) (Bin 1 rx Tip Tip)
Bin rs rx rl@(Bin rls rlx rll rlr) rr@(Bin rrs _ _ _)
| rls < ratio*rrs -> Bin (1+rs) rx (Bin (1+rls) x Tip rl) rr
| otherwise -> Bin (1+rs) rlx (Bin (1+S.size rll) x Tip rll) (Bin (1+rrs+S.size rlr) rx rlr rr)
Bin ls _ _ _ -> case r of
Tip -> Bin (1+ls) x l Tip
Bin rs rx rl rr
| rs > delta*ls -> case (rl, rr) of
(Bin rls rlx rll rlr, Bin rrs _ _ _)
| rls < ratio*rrs -> Bin (1+ls+rs) rx (Bin (1+ls+rls) x l rl) rr
| otherwise -> Bin (1+ls+rs) rlx (Bin (1+ls+S.size rll) x l rll) (Bin (1+rrs+S.size rlr) rx rlr rr)
(_, _) -> error "Failure in Data.Map.balanceR"
| otherwise -> Bin (1+ls+rs) x l r
{-# NOINLINE balanceR #-}
balanceL :: a -> Set a -> Set a -> Set a
balanceL x l r = case r of
Tip -> case l of
Tip -> Bin 1 x Tip Tip
Bin _ _ Tip Tip -> Bin 2 x l Tip
Bin _ lx Tip (Bin _ lrx _ _) -> Bin 3 lrx (Bin 1 lx Tip Tip) (Bin 1 x Tip Tip)
Bin _ lx ll@Bin{} Tip -> Bin 3 lx ll (Bin 1 x Tip Tip)
Bin ls lx ll@(Bin lls _ _ _) lr@(Bin lrs lrx lrl lrr)
| lrs < ratio*lls -> Bin (1+ls) lx ll (Bin (1+lrs) x lr Tip)
| otherwise -> Bin (1+ls) lrx (Bin (1+lls+S.size lrl) lx ll lrl) (Bin (1+S.size lrr) x lrr Tip)
Bin rs _ _ _ -> case l of
Tip -> Bin (1+rs) x Tip r
Bin ls lx ll lr
| ls > delta*rs -> case (ll, lr) of
(Bin lls _ _ _, Bin lrs lrx lrl lrr)
| lrs < ratio*lls -> Bin (1+ls+rs) lx ll (Bin (1+rs+lrs) x lr r)
| otherwise -> Bin (1+ls+rs) lrx (Bin (1+lls+S.size lrl) lx ll lrl) (Bin (1+rs+S.size lrr) x lrr r)
(_, _) -> error "Failure in Data.Set.NonEmpty.Internal.balanceL"
| otherwise -> Bin (1+ls+rs) x l r
{-# NOINLINE balanceL #-}
delta,ratio :: Int
delta = 3
ratio = 2