{-# LANGUAGE PatternGuards, UnboxedTuples, TypeFamilies, PatternGuards, ViewPatterns #-} {-# OPTIONS -funbox-strict-fields #-} module Data.TrieMap.UnionMap () where import Data.TrieMap.TrieKey import Data.TrieMap.Sized import Control.Applicative union :: (TrieKey k1, TrieKey k2) => Sized a -> TrieMap k1 a -> TrieMap k2 a -> TrieMap (Either k1 k2) a union _ (nullM -> True) (nullM -> True) = Empty union s m1@(sizeM s -> s1) m2@(sizeM s -> s2) = Union (s1 + s2) m1 m2 singletonMaybe :: (TrieKey k1, TrieKey k2) => Sized a -> Either k1 k2 -> Maybe a -> TrieMap (Either k1 k2) a singletonMaybe s k a = maybe Empty (singletonM s k) a singletonL :: (TrieKey k1, TrieKey k2) => Sized a -> k1 -> a -> TrieMap (Either k1 k2) a singletonL s k a = Union (s a) (singletonM s k a) emptyM singletonR :: (TrieKey k1, TrieKey k2) => Sized a -> k2 -> a -> TrieMap (Either k1 k2) a singletonR s k a = Union (s a) emptyM (singletonM s k a) instance (TrieKey k1, TrieKey k2) => TrieKey (Either k1 k2) where data TrieMap (Either k1 k2) a = Empty | Union !Int (TrieMap k1 a) (TrieMap k2 a) emptyM = Empty singletonM s = either (singletonL s) (singletonR s) nullM Empty = True nullM _ = False sizeM _ Empty = 0 sizeM _ (Union s _ _) = s lookupM k (Union _ m1 m2) = either (`lookupM` m1) (`lookupM` m2) k lookupM _ _ = Nothing alterM s f k (Union _ m1 m2) = case k of Left k -> union s (alterM s f k m1) m2 Right k -> union s m1 (alterM s f k m2) alterM s f k _ = singletonMaybe s k (f Nothing) alterLookupM s f k Empty = onUnboxed (singletonMaybe s k) f Nothing alterLookupM s f (Left k) (Union _ m1 m2) = onUnboxed (flip (union s) m2) (alterLookupM s f k) m1 alterLookupM s f (Right k) (Union _ m1 m2) = onUnboxed (union s m1) (alterLookupM s f k) m2 traverseWithKeyM s f (Union _ m1 m2) = union s <$> traverseWithKeyM s (f . Left) m1 <*> traverseWithKeyM s (f . Right) m2 traverseWithKeyM _ _ _ = pure Empty foldWithKeyM f (Union _ m1 m2) = foldWithKeyM (f . Left) m1 . foldWithKeyM (f . Right) m2 foldWithKeyM _ _ = id foldlWithKeyM f (Union _ m1 m2) = foldlWithKeyM (f . Right) m2 . foldlWithKeyM (f . Left) m1 foldlWithKeyM _ _ = id mapMaybeM s f (Union _ m1 m2) = union s (mapMaybeM s (f . Left) m1) (mapMaybeM s (f . Right) m2) mapMaybeM _ _ _ = Empty mapEitherM s1 s2 f (Union _ m1 m2) | (# m1L, m1R #) <- mapEitherM s1 s2 (f . Left) m1, (# m2L, m2R #) <- mapEitherM s1 s2 (f . Right) m2 = (# union s1 m1L m2L, union s2 m1R m2R #) mapEitherM _ _ _ _ = (# Empty, Empty #) extractM s f (Union _ m1 m2) = let (&) = union s in fmap (& m2) <$> extractM s (f . Left) m1 <|> fmap (m1 &) <$> extractM s (f . Right) m2 extractM _ _ _ = empty splitLookupM s f k (Union _ m1 m2) = let (&) = union s in case k of Left k | (# m1L, x, m1R #) <- splitLookupM s f k m1 -> (# m1L & emptyM, x, m1R & m2 #) Right k | (# m2L, x, m2R #) <- splitLookupM s f k m2 -> (# m1 & m2L, x, emptyM & m2R #) splitLookupM _ _ _ _ = (# emptyM, Nothing, emptyM #) unionM s f (Union _ m11 m12) (Union _ m21 m22) = union s (unionM s (f . Left) m11 m21) (unionM s (f . Right) m12 m22) unionM _ _ Empty m2 = m2 unionM _ _ m1 Empty = m1 isectM _ _ Empty _ = Empty isectM _ _ _ Empty = Empty isectM s f (Union _ m11 m12) (Union _ m21 m22) = union s (isectM s (f . Left) m11 m21) (isectM s (f . Right) m12 m22) diffM _ _ Empty _ = Empty diffM _ _ m1 Empty = m1 diffM s f (Union _ m11 m12) (Union _ m21 m22) = union s (diffM s (f . Left) m11 m21) (diffM s (f . Right) m12 m22) isSubmapM _ Empty _ = True isSubmapM (<=) (Union _ m11 m12) (Union _ m21 m22) = isSubmapM (<=) m11 m21 && isSubmapM (<=) m12 m22 isSubmapM _ Union{} Empty = False fromListM s f = onPair (union s) (fromListM s (f . Left)) (fromListM s (f . Right)) . partEithers fromAscListM s f = onPair (union s) (fromAscListM s (f . Left)) (fromAscListM s (f . Right)) . partEithers fromDistAscListM s = onPair (union s) (fromDistAscListM s) (fromDistAscListM s) . partEithers onPair :: (c -> d -> e) -> (a -> c) -> (b -> d) -> (a, b) -> e onPair f g h (a, b) = f (g a) (h b) partEithers :: [(Either a b, x)] -> ([(a, x)], [(b, x)]) partEithers = foldr part ([], []) where part (Left x, z) (xs, ys) = ((x,z):xs, ys) part (Right y, z) (xs, ys) = (xs, (y, z):ys)