module Data.EnumMapMap.Base(
(:&)(..), K(..), N(..), Z(..),
d1, d2, d3, d4, d5, d6, d7, d8, d9, d10,
IsSplit(..),
Plus,
EMM(..),
IsEmm(..),
EnumMapMap(..),
mergeWithKey',
mapWithKey_,
foldrWithKey_,
foldlStrict,
Key,
bin,
tip,
nomatch,
match,
join,
zero
) where
import Prelude hiding (lookup,
map,
filter,
foldr, foldl,
null, init,
head, tail)
import Control.DeepSeq (NFData(rnf))
import Data.Bits
import Data.Monoid (Monoid(..))
import GHC.Exts (Word(..), Int(..), shiftRL#)
data EMM k v = Bin !Prefix !Mask
!(EMM k v) !(EMM k v)
| Tip !Int v
| Nil
deriving (Show)
type Nat = Word
type Key = Int
type Prefix = Int
type Mask = Int
infixr 3 :&
data k :& t = !k :& !t
deriving (Show, Eq)
data K k = K !k
deriving (Show, Eq)
data Z = Z
data N n = N n
d1 :: Z
d1 = Z
d2 :: N(Z)
d2 = N d1
d3 :: N(N(Z))
d3 = N d2
d4 :: N(N(N(Z)))
d4 = N d3
d5 :: N(N(N(N(Z))))
d5 = N d4
d6 :: N(N(N(N(N(Z)))))
d6 = N d5
d7 :: N(N(N(N(N(N(Z))))))
d7 = N d6
d8 :: N(N(N(N(N(N(N(Z)))))))
d8 = N d7
d9 :: N(N(N(N(N(N(N(N(Z))))))))
d9 = N d8
d10 :: N(N(N(N(N(N(N(N(N(Z)))))))))
d10 = N d9
class IsSplit k z where
type Head k z :: *
type Tail k z :: *
splitKey :: z -> EnumMapMap k v
-> EnumMapMap (Head k z) (EnumMapMap (Tail k z) v)
instance (IsSplit t n, Enum k) => IsSplit (k :& t) (N n) where
type Head (k :& t) (N n) = k :& (Head t n)
type Tail (k :& t) (N n) = Tail t n
splitKey (N n) (KCC emm) = KCC $ mapWithKey_ (\_ -> splitKey n) emm
type family Plus k1 k2 :: *
type instance Plus (k1 :& t) k2 = k1 :& (Plus t k2)
class IsEmm k where
data EnumMapMap k :: * -> *
emptySubTrees :: EnumMapMap k v -> Bool
emptySubTrees_ :: EnumMapMap k v -> Bool
removeEmpties :: EnumMapMap k v -> EnumMapMap k v
joinKey :: (IsEmm (Plus k k2)) =>
EnumMapMap k (EnumMapMap k2 v)
-> EnumMapMap (Plus k k2) v
joinKey = removeEmpties . unsafeJoinKey
unsafeJoinKey :: EnumMapMap k (EnumMapMap k2 v)
-> EnumMapMap (Plus k k2) v
empty :: EnumMapMap k v
null :: EnumMapMap k v -> Bool
size :: EnumMapMap k v -> Int
member :: k -> EnumMapMap k v -> Bool
singleton :: k -> v -> EnumMapMap k v
lookup :: k -> EnumMapMap k v -> Maybe v
insert :: k -> v -> EnumMapMap k v -> EnumMapMap k v
insertWith :: (v -> v -> v)
-> k -> v -> EnumMapMap k v -> EnumMapMap k v
insertWith f = insertWithKey (\_ -> f)
insertWithKey :: (k -> v -> v -> v)
-> k -> v -> EnumMapMap k v -> EnumMapMap k v
delete :: k -> EnumMapMap k v -> EnumMapMap k v
alter :: (Maybe v -> Maybe v) -> k -> EnumMapMap k v -> EnumMapMap k v
map :: (v -> t) -> EnumMapMap k v -> EnumMapMap k t
map f = mapWithKey (\_ -> f)
mapWithKey :: (k -> v -> t) -> EnumMapMap k v -> EnumMapMap k t
foldrWithKey :: (k -> v -> t -> t) -> t -> EnumMapMap k v -> t
toList :: EnumMapMap k v -> [(k, v)]
toList = foldrWithKey (\k x xs -> (k, x):xs) []
fromList :: [(k, v)] -> EnumMapMap k v
fromList = foldlStrict (\t (k, x) -> insert k x t) empty
union :: EnumMapMap k v -> EnumMapMap k v -> EnumMapMap k v
unions :: [EnumMapMap k v] -> EnumMapMap k v
unions = foldlStrict union empty
unionWith :: (v -> v -> v)
-> EnumMapMap k v -> EnumMapMap k v -> EnumMapMap k v
unionWith f = unionWithKey (\_ -> f)
unionWithKey :: (k -> v -> v -> v)
-> EnumMapMap k v -> EnumMapMap k v -> EnumMapMap k v
difference :: EnumMapMap k v1 -> EnumMapMap k v2 -> EnumMapMap k v1
differenceWith :: (v1 -> v2 -> Maybe v1)
-> EnumMapMap k v1
-> EnumMapMap k v2
-> EnumMapMap k v1
differenceWith f = differenceWithKey (\_ -> f)
differenceWithKey :: (k -> v1 -> v2 -> Maybe v1)
-> EnumMapMap k v1
-> EnumMapMap k v2
-> EnumMapMap k v1
intersection :: EnumMapMap k v1
-> EnumMapMap k v2
-> EnumMapMap k v1
intersectionWith :: (v1 -> v2 -> v3)
-> EnumMapMap k v1
-> EnumMapMap k v2
-> EnumMapMap k v3
intersectionWith f = intersectionWithKey (\_ -> f)
intersectionWithKey :: (k -> v1 -> v2 -> v3)
-> EnumMapMap k v1
-> EnumMapMap k v2
-> EnumMapMap k v3
equal :: Eq v => EnumMapMap k v -> EnumMapMap k v -> Bool
nequal :: Eq v => EnumMapMap k v -> EnumMapMap k v -> Bool
instance (Enum k, IsEmm t) => IsEmm (k :& t) where
data EnumMapMap (k :& t) v = KCC (EMM k (EnumMapMap t v))
emptySubTrees e@(KCC emm) =
case emm of
Nil -> False
_ -> emptySubTrees_ e
emptySubTrees_ (KCC emm) = go emm
where
go t = case t of
Bin _ _ l r -> go l || go r
Tip _ v -> emptySubTrees_ v
Nil -> True
removeEmpties (KCC emm) = KCC $ go emm
where
go t = case t of
Bin p m l r -> bin p m (go l) (go r)
Tip k v -> tip k (removeEmpties v)
Nil -> Nil
unsafeJoinKey (KCC emm) = KCC $ mapWithKey_ (\_ -> unsafeJoinKey) emm
empty = KCC Nil
null (KCC t) =
case t of
Nil -> True
_ -> False
size (KCC t) = go t
where
go (Bin _ _ l r) = go l + go r
go (Tip _ y) = size y
go Nil = 0
member !(key' :& nxt) (KCC emm) = go emm
where
go t = case t of
Bin _ m l r -> case zero key m of
True -> go l
False -> go r
Tip kx x -> case key == kx of
True -> member nxt x
False -> False
Nil -> False
key = fromEnum key'
singleton (key :& nxt) = KCC . Tip (fromEnum key) . singleton nxt
lookup (key :& nxt) (KCC emm) = go emm
where
go (Bin _ m l r)
| zero (fromEnum key) m = go l
| otherwise = go r
go (Tip kx x)
= case kx == (fromEnum key) of
True -> lookup nxt x
False -> Nothing
go Nil = Nothing
insert (key :& nxt) val (KCC emm)
= KCC $ insertWith_ (insert nxt val) key (singleton nxt val) emm
insertWithKey f k@(key :& nxt) val (KCC emm) =
KCC $ insertWith_ go key (singleton nxt val) emm
where
go = insertWithKey (\_ -> f k) nxt val
delete !(key :& nxt) (KCC emm) =
KCC $ alter_ (delete nxt) (fromEnum key) emm
alter f !(key :& nxt) (KCC emm) =
KCC $ alter_ (alter f nxt) (fromEnum key) emm
mapWithKey f (KCC emm) = KCC $ mapWithKey_ go emm
where
go k = mapWithKey (\nxt -> f $ k :& nxt)
foldrWithKey f init (KCC emm) = foldrWithKey_ go init emm
where
go k val z = foldrWithKey (\nxt -> f $ k :& nxt) z val
union (KCC emm1) (KCC emm2) = KCC $ mergeWithKey' binD go id id emm1 emm2
where
go = \(Tip k1 x1) (Tip _ x2) -> tip k1 $ union x1 x2
unionWithKey f (KCC emm1) (KCC emm2) =
KCC $ mergeWithKey' binD go id id emm1 emm2
where
go = \(Tip k1 x1) (Tip _ x2) ->
Tip k1 $ unionWithKey (g k1) x1 x2
g k1 nxt = f $ (toEnum k1) :& nxt
difference (KCC emm1) (KCC emm2) =
KCC $ mergeWithKey' binD go id (const Nil) emm1 emm2
where
go = \(Tip k1 x1) (Tip _ x2) ->
tip k1 (difference x1 x2)
differenceWithKey f (KCC emm1) (KCC emm2) =
KCC $ mergeWithKey' binD go id (const Nil) emm1 emm2
where
go = \(Tip k1 x1) (Tip _ x2) ->
tip k1 $ differenceWithKey (\nxt ->
f $ (toEnum k1) :& nxt) x1 x2
intersection (KCC emm1) (KCC emm2) =
KCC $ mergeWithKey' binD go (const Nil) (const Nil) emm1 emm2
where
go = \(Tip k1 x1) (Tip _ x2) ->
tip k1 $ intersection x1 x2
intersectionWithKey f (KCC emm1) (KCC emm2) =
KCC $ mergeWithKey' binD go (const Nil) (const Nil) emm1 emm2
where
go = \(Tip k1 x1) (Tip _ x2) ->
tip k1 $ intersectionWithKey (\nxt ->
f $ (toEnum k1) :& nxt) x1 x2
equal (KCC emm1) (KCC emm2) = emm1 == emm2
nequal (KCC emm1) (KCC emm2) = emm1 /= emm2
insertWith_ :: Enum k => (v -> v) -> k -> v -> EMM k v -> EMM k v
insertWith_ f !key' val emm = key `seq` go emm
where
go t =
case t of
Bin p m l r
| nomatch key p m -> join key (Tip key val) p t
| zero key m -> Bin p m (go l) r
| otherwise -> Bin p m l (go r)
Tip ky y
| key == ky -> Tip key (f y)
| otherwise -> join key (Tip key val) ky t
Nil -> Tip key val
key = fromEnum key'
alter_ :: (IsEmm b) =>
(EnumMapMap b v -> EnumMapMap b v)
-> Key
-> EMM a (EnumMapMap b v)
-> EMM a (EnumMapMap b v)
alter_ f k = go
where
go t =
case t of
Bin p m l r | nomatch k p m -> joinD k (tip k $ f empty) p t
| zero k m -> binD p m (go l) r
| otherwise -> binD p m l (go r)
Tip ky y | k == ky -> tip k $ f y
| otherwise -> joinD k (tip k $ f empty) ky t
Nil -> tip k $ f empty
mapWithKey_ :: Enum k => (k -> v -> t) -> EMM k v -> EMM k t
mapWithKey_ f = go
where
go (Bin p m l r) = Bin p m (go l) (go r)
go (Tip k x) = Tip k (f (toEnum k) x)
go Nil = Nil
foldrWithKey_ :: (Enum k) => (k -> v -> t -> t) -> t -> EMM k v -> t
foldrWithKey_ f z = \emm ->
case emm of Bin _ m l r | m < 0 -> go (go z l) r
| otherwise -> go (go z r) l
_ -> go z emm
where
go z' Nil = z'
go z' (Tip kx tx) = f (toEnum kx) tx z'
go z' (Bin _ _ l r) = go (go z' r) l
mergeWithKey' :: (Enum a) =>
(Prefix -> Mask -> EMM a v3 -> EMM a v3 -> EMM a v3)
-> (EMM a v1 -> EMM a v2 -> EMM a v3)
-> (EMM a v1 -> EMM a v3)
-> (EMM a v2 -> EMM a v3)
-> EMM a v1 -> EMM a v2 -> EMM a v3
mergeWithKey' bin' f g1 g2 = go
where
go t1@(Bin p1 m1 l1 r1) t2@(Bin p2 m2 l2 r2)
| shorter m1 m2 = merge1
| shorter m2 m1 = merge2
| p1 == p2 = bin' p1 m1 (go l1 l2) (go r1 r2)
| otherwise = maybe_join p1 (g1 t1) p2 (g2 t2)
where
merge1 | nomatch p2 p1 m1 = maybe_join p1 (g1 t1) p2 (g2 t2)
| zero p2 m1 = bin' p1 m1 (go l1 t2) (g1 r1)
| otherwise = bin' p1 m1 (g1 l1) (go r1 t2)
merge2 | nomatch p1 p2 m2 = maybe_join p1 (g1 t1) p2 (g2 t2)
| zero p1 m2 = bin' p2 m2 (go t1 l2) (g2 r2)
| otherwise = bin' p2 m2 (g2 l2) (go t1 r2)
go t1'@(Bin _ _ _ _) t2'@(Tip k2' _) = merge t2' k2' t1'
where merge t2 k2 t1@(Bin p1 m1 l1 r1)
| nomatch k2 p1 m1 = maybe_join p1 (g1 t1) k2 (g2 t2)
| zero k2 m1 = bin' p1 m1 (merge t2 k2 l1) (g1 r1)
| otherwise = bin' p1 m1 (g1 l1) (merge t2 k2 r1)
merge t2 k2 t1@(Tip k1 _)
| k1 == k2 = f t1 t2
| otherwise = maybe_join k1 (g1 t1) k2 (g2 t2)
merge t2 _ Nil = g2 t2
go t1@(Bin _ _ _ _) Nil = g1 t1
go t1'@(Tip k1' _) t2' = merge t1' k1' t2'
where merge t1 k1 t2@(Bin p2 m2 l2 r2)
| nomatch k1 p2 m2 = maybe_join k1 (g1 t1) p2 (g2 t2)
| zero k1 m2 = bin' p2 m2 (merge t1 k1 l2) (g2 r2)
| otherwise = bin' p2 m2 (g2 l2) (merge t1 k1 r2)
merge t1 k1 t2@(Tip k2 _)
| k1 == k2 = f t1 t2
| otherwise = maybe_join k1 (g1 t1) k2 (g2 t2)
merge t1 _ Nil = g1 t1
go Nil t2 = g2 t2
maybe_join _ Nil _ t2 = t2
maybe_join _ t1 _ Nil = t1
maybe_join p1 t1 p2 t2 = join p1 t1 p2 t2
instance (Eq v, IsEmm k) => Eq (EnumMapMap k v) where
t1 == t2 = equal t1 t2
t1 /= t2 = nequal t1 t2
instance Eq v => Eq (EMM k v) where
t1 == t2 = equalE t1 t2
t1 /= t2 = nequalE t1 t2
equalE :: Eq v => EMM k v -> EMM k v -> Bool
equalE (Bin p1 m1 l1 r1) (Bin p2 m2 l2 r2)
= (m1 == m2) && (p1 == p2) && (equalE l1 l2) && (equalE r1 r2)
equalE (Tip kx x) (Tip ky y)
= (kx == ky) && (x==y)
equalE Nil Nil = True
equalE _ _ = False
nequalE :: Eq v => EMM k v -> EMM k v -> Bool
nequalE (Bin p1 m1 l1 r1) (Bin p2 m2 l2 r2)
= (m1 /= m2) || (p1 /= p2) || (nequalE l1 l2) || (nequalE r1 r2)
nequalE (Tip kx x) (Tip ky y)
= (kx /= ky) || (x/=y)
nequalE Nil Nil = False
nequalE _ _ = True
instance (IsEmm k) => Functor (EnumMapMap k)
where
fmap = map
instance (IsEmm k) => Monoid (EnumMapMap k v) where
mempty = empty
mappend = union
mconcat = unions
instance (Show v, Show (EnumMapMap t v)) => Show (EnumMapMap (k :& t) v) where
show (KCC emm) = show emm
instance (NFData v, NFData (EnumMapMap t v)) => NFData (EnumMapMap (k :& t) v)
where
rnf (KCC emm) = go emm
where
go Nil = ()
go (Tip _ v) = rnf v
go (Bin _ _ l r) = go l `seq` go r
natFromInt :: Int -> Nat
natFromInt = fromIntegral
intFromNat :: Nat -> Int
intFromNat = fromIntegral
shiftRL :: Nat -> Int -> Nat
shiftRL (W# x) (I# i)
= W# (shiftRL# x i)
join :: Prefix -> EMM a v -> Prefix -> EMM a v -> EMM a v
join p1 t1 p2 t2
| zero p1 m = Bin p m t1 t2
| otherwise = Bin p m t2 t1
where
m = branchMask p1 p2
p = mask p1 m
joinD :: (IsEmm b) =>
Prefix -> EMM a (EnumMapMap b v)
-> Prefix -> EMM a (EnumMapMap b v)
-> EMM a (EnumMapMap b v)
joinD p1 t1 p2 t2
| zero p1 m = binD p m t1 t2
| otherwise = binD p m t2 t1
where
m = branchMask p1 p2
p = mask p1 m
bin :: Prefix -> Mask -> EMM k v -> EMM k v -> EMM k v
bin _ _ l Nil = l
bin _ _ Nil r = r
bin p m l r = Bin p m l r
binD :: (IsEmm b) =>
Prefix -> Mask
-> EMM a (EnumMapMap b v)
-> EMM a (EnumMapMap b v)
-> EMM a (EnumMapMap b v)
binD _ _ l Nil = l
binD _ _ Nil r = r
binD p m l r@(Tip _ y)
| null y = l
| otherwise = Bin p m l r
binD p m l@(Tip _ y) r
| null y = r
| otherwise = Bin p m l r
binD p m l r = Bin p m l r
tip :: (IsEmm b) => Key -> EnumMapMap b v -> EMM a (EnumMapMap b v)
tip k val
| null val = Nil
| otherwise = Tip k val
zero :: Key -> Mask -> Bool
zero i m
= (natFromInt i) .&. (natFromInt m) == 0
nomatch,match :: Key -> Prefix -> Mask -> Bool
nomatch i p m
= (mask i m) /= p
match i p m
= (mask i m) == p
mask :: Key -> Mask -> Prefix
mask i m
= maskW (natFromInt i) (natFromInt m)
maskW :: Nat -> Nat -> Prefix
maskW i m
= intFromNat (i .&. (complement (m1) `xor` m))
shorter :: Mask -> Mask -> Bool
shorter m1 m2
= (natFromInt m1) > (natFromInt m2)
branchMask :: Prefix -> Prefix -> Mask
branchMask p1 p2
= intFromNat (highestBitMask (natFromInt p1 `xor` natFromInt p2))
highestBitMask :: Nat -> Nat
highestBitMask x0
= case (x0 .|. shiftRL x0 1) of
x1 -> case (x1 .|. shiftRL x1 2) of
x2 -> case (x2 .|. shiftRL x2 4) of
x3 -> case (x3 .|. shiftRL x3 8) of
x4 -> case (x4 .|. shiftRL x4 16) of
x5 -> case (x5 .|. shiftRL x5 32) of
x6 -> (x6 `xor` (shiftRL x6 1))
foldlStrict :: (a -> b -> a) -> a -> [b] -> a
foldlStrict f = go
where
go z [] = z
go z (x:xs) = let z' = f z x in z' `seq` go z' xs