{-# language DeriveFoldable #-}
{-# language FlexibleInstances #-}
{-# language FunctionalDependencies #-}
module Data.SplayTree where
import Data.Monoid
import qualified Data.Semigroup as Semigroup
infixr 5 <|
infixl 5 |>
class Monoid v => Measured v a | a -> v where
measure :: a -> v
data SplayTree v a
= Leaf
| Fork (SplayTree v a) a (SplayTree v a) !v
deriving (Eq, Ord, Show, Foldable)
instance Measured v a => Measured v (SplayTree v a) where
{-# INLINE measure #-}
measure Leaf = mempty
measure (Fork _ _ _ v) = v
instance Measured v a => Semigroup.Semigroup (SplayTree v a) where
{-# INLINE (<>) #-}
Leaf <> t = t
t <> Leaf = t
Fork l1 a1 r1 lar1 <> Fork l2 a2 r2 lar2
= Fork l1 a1 (Fork (r1 <> l2) a2 r2 (measure r1 <> lar2)) (lar1 <> lar2)
instance Measured v a => Monoid (SplayTree v a) where
{-# INLINE mempty #-}
mempty = Leaf
{-# INLINE mappend #-}
mappend = (Semigroup.<>)
null :: SplayTree v a -> Bool
null Leaf = True
null Fork {} = False
{-# INLINE singleton #-}
singleton :: Measured v a => a -> SplayTree v a
singleton a = Fork Leaf a Leaf $ measure a
{-# INLINE (<|) #-}
(<|) :: Measured v a => a -> SplayTree v a -> SplayTree v a
(<|) = fork Leaf
{-# INLINE (|>) #-}
(|>) :: Measured v a => SplayTree v a -> a -> SplayTree v a
(|>) t a = fork t a Leaf
{-# INLINE fork #-}
fork :: Measured v a => SplayTree v a -> a -> SplayTree v a -> SplayTree v a
fork l a r = Fork l a r $ measure l <> measure a <> measure r
{-# INLINE uncons #-}
uncons :: Measured v a => SplayTree v a -> Maybe (a, SplayTree v a)
uncons Leaf = Nothing
uncons (Fork left el right _) = Just $ go left el right
where
go Leaf a r = (a, r)
go (Fork l a m _) b r = go l a (fork m b r)
{-# INLINE unsnoc #-}
unsnoc :: Measured v a => SplayTree v a -> Maybe (SplayTree v a, a)
unsnoc Leaf = Nothing
unsnoc (Fork left el right _) = Just $ go left el right
where
go l a Leaf = (l, a)
go l a (Fork m b r _) = go (fork l a m) b r
data SplitResult v a
= Outside
| Inside (SplayTree v a) a (SplayTree v a)
deriving (Eq, Ord, Show)
{-# INLINE split #-}
split :: Measured v a => (v -> Bool) -> SplayTree v a -> SplitResult v a
split = go mempty
where
go _ _ Leaf = Outside
go v f (Fork l a r _)
| f vl = case go v f l of
Outside -> Outside
Inside l' a' m -> Inside l' a' $ fork m a r
| f vla = Inside l a r
| otherwise = case go vla f r of
Outside -> Outside
Inside m a' r' -> Inside (fork l a m) a' r'
where
vl = v <> measure l
vla = vl <> measure a
{-# INLINE map #-}
map
:: (Measured v a, Measured w b)
=> (a -> b)
-> SplayTree v a
-> SplayTree w b
map _ Leaf = Leaf
map f (Fork l a r _) = fork (Data.SplayTree.map f l) (f a) (Data.SplayTree.map f r)
{-# INLINE mapWithPos #-}
mapWithPos
:: (Measured v a, Measured w b)
=> (v -> a -> b)
-> SplayTree v a
-> SplayTree w b
mapWithPos f = go mempty
where
go _ Leaf = Leaf
go i (Fork l a r _) = fork (go i l) (f il a) (go ila r)
where
il = i <> measure l
ila = il <> measure a
{-# INLINE mapWithContext #-}
mapWithContext
:: (Measured v a, Measured w b)
=> (v -> a -> v -> b)
-> SplayTree v a
-> SplayTree w b
mapWithContext f t = go mempty t mempty
where
go _ Leaf _ = Leaf
go i (Fork l a r _) j = fork (go i l arj) (f il a rj) (go ila r j)
where
ma = measure a
il = i <> measure l
ila = il <> ma
rj = measure r <> j
arj = ma
{-# INLINE traverse #-}
traverse
:: (Measured v a, Measured w b, Applicative f)
=> (a -> f b)
-> SplayTree v a
-> f (SplayTree w b)
traverse _ Leaf = pure Leaf
traverse f (Fork l a r _)
= fork
<$> Data.SplayTree.traverse f l
<*> f a
<*> Data.SplayTree.traverse f r
{-# INLINE traverseWithPos #-}
traverseWithPos
:: (Measured v a, Measured w b, Applicative f)
=> (v -> a -> f b)
-> SplayTree v a
-> f (SplayTree w b)
traverseWithPos f = go mempty
where
go _ Leaf = pure Leaf
go i (Fork l a r _)
= fork <$> go i l <*> f il a <*> go ila r
where
il = i <> measure l
ila = il <> measure a
{-# INLINE traverseWithContext #-}
traverseWithContext
:: (Measured v a, Measured w b, Applicative f)
=> (v -> a -> v -> f b)
-> SplayTree v a
-> f (SplayTree w b)
traverseWithContext f t = go mempty t mempty
where
go _ Leaf _ = pure Leaf
go i (Fork l a r _) j
= fork <$> go i l arj <*> f il a rj <*> go ila r j
where
ma = measure a
il = i <> measure l
ila = il <> ma
rj = measure r <> j
arj = ma