{-# LANGUAGE UnboxedTuples, BangPatterns, TypeFamilies, PatternGuards, MagicHash, CPP, NamedFieldPuns, FlexibleInstances #-} {-# OPTIONS -funbox-strict-fields #-} module Data.TrieMap.WordMap (SNode, WHole, TrieMap(WordMap), Hole(Hole), getWordMap, getHole) where import Data.TrieMap.TrieKey import Data.TrieMap.Sized import Control.Exception (assert) import Control.Applicative (Applicative(..), (<$>)) import Control.Monad hiding (join) import Data.Bits import Data.Foldable import Data.Maybe hiding (mapMaybe) import Data.Monoid import Data.TrieMap.Utils import GHC.Exts import Prelude hiding (lookup, null, map, foldl, foldr, foldl1, foldr1) #include "MachDeps.h" #define NIL SNode{node = Nil} #define TIP(args) SNode{node = (Tip args)} #define BIN(args) SNode{node = (Bin args)} type Nat = Word type Prefix = Word type Mask = Word type Key = Word type Size = Int data Path a = Root | LeftBin !Prefix !Mask (Path a) !(SNode a) | RightBin !Prefix !Mask !(SNode a) (Path a) data SNode a = SNode {sz :: !Size, node :: (Node a)} {-# ANN type SNode ForceSpecConstr #-} data Node a = Nil | Tip !Key a | Bin !Prefix !Mask !(SNode a) !(SNode a) {-# ANN type Node ForceSpecConstr #-} instance Sized (SNode a) where getSize# SNode{sz} = unbox sz instance Sized a => Sized (Node a) where getSize# t = unbox $ case t of Nil -> 0 Tip _ a -> getSize a Bin _ _ l r -> getSize l + getSize r {-# INLINE sNode #-} sNode :: Sized a => Node a -> SNode a sNode !n = SNode (getSize n) n data WHole a = WHole !Key (Path a) {-# INLINE hole #-} hole :: Key -> Path a -> Hole Word a hole k path = Hole (WHole k path) #define HOLE(args) (Hole (WHole args)) -- | @'TrieMap' 'Word' a@ is based on "Data.IntMap". instance TrieKey Word where newtype TrieMap Word a = WordMap {getWordMap :: SNode a} newtype Hole Word a = Hole {getHole :: WHole a} emptyM = WordMap nil singletonM k a = WordMap (singleton k a) getSimpleM (WordMap (SNode _ n)) = case n of Nil -> Null Tip _ a -> Singleton a _ -> NonSimple sizeM (WordMap t) = getSize t lookupM k (WordMap m) = lookup k m traverseM f (WordMap m) = WordMap <$> traverse f m fmapM f (WordMap m) = WordMap (map f m) mapMaybeM f (WordMap m) = WordMap (mapMaybe f m) mapEitherM f (WordMap m) = both WordMap WordMap (mapEither f) m unionM f (WordMap m1) (WordMap m2) = WordMap (unionWith f m1 m2) isectM f (WordMap m1) (WordMap m2) = WordMap (intersectionWith f m1 m2) diffM f (WordMap m1) (WordMap m2) = WordMap (differenceWith f m1 m2) isSubmapM (<=) (WordMap m1) (WordMap m2) = isSubmapOfBy (<=) m1 m2 singleHoleM k = hole k Root beforeM HOLE(_ path) = WordMap (before nil path) beforeWithM a HOLE(k path) = WordMap (before (singleton k a) path) afterM HOLE(_ path) = WordMap (after nil path) afterWithM a HOLE(k path) = WordMap (after (singleton k a) path) {-# INLINE searchMC #-} searchMC !k (WordMap t) = mapSearch (hole k) (searchC k t) indexM i (WordMap m) = indexT i m Root where indexT !i TIP(kx x) path = (# i, x, hole kx path #) indexT !i BIN(p m l r) path | i < sl = indexT i l (LeftBin p m path r) | otherwise = indexT (i - sl) r (RightBin p m l path) where !sl = getSize l indexT _ NIL _ = indexFail () extractHoleM (WordMap m) = extractHole Root m where extractHole _ (SNode _ Nil) = mzero extractHole path TIP(kx x) = return (x, hole kx path) extractHole path BIN(p m l r) = extractHole (LeftBin p m path r) l `mplus` extractHole (RightBin p m l path) r clearM HOLE(_ path) = WordMap (assign nil path) {-# INLINE assignM #-} assignM v HOLE(kx path) = WordMap (assign (singleton kx v) path) {-# INLINE unifierM #-} unifierM k' k a = Hole <$> unifier k' k a {-# INLINE searchC #-} searchC :: Key -> SNode a -> SearchCont (Path a) a r searchC !k t notfound found = seek Root t where seek path t@BIN(p m l r) | nomatch k p m = notfound (branchHole k p path t) | zero k m = seek (LeftBin p m path r) l | otherwise = seek (RightBin p m l path) r seek path t@TIP(ky y) | k == ky = found y path | otherwise = notfound (branchHole k ky path t) seek path NIL = notfound path before, after :: SNode a -> Path a -> SNode a before !t Root = t before !t (LeftBin _ _ path _) = before t path before !t (RightBin p m l path) = before (bin p m l t) path after !t Root = t after !t (RightBin _ _ _ path) = after t path after !t (LeftBin p m path r) = after (bin p m t r) path assign :: Sized a => SNode a -> Path a -> SNode a assign NIL Root = nil assign NIL (LeftBin _ _ path r) = assign' r path assign NIL (RightBin _ _ l path) = assign' l path assign t Root = t assign t (LeftBin p m path r) = assign' (bin' p m t r) path assign t (RightBin p m l path) = assign' (bin' p m l t) path assign' :: Sized a => SNode a -> Path a -> SNode a assign' !t Root = t assign' !t (LeftBin p m path r) = assign' (bin' p m t r) path assign' !t (RightBin p m l path) = assign' (bin' p m l t) path branchHole :: Key -> Prefix -> Path a -> SNode a -> Path a branchHole !k !p path t | zero k m = LeftBin p' m path t | otherwise = RightBin p' m t path where m = branchMask k p p' = mask k m lookup :: Key -> SNode a -> Lookup a lookup !k = look where look BIN(_ m l r) = look (if zeroN k m then l else r) look TIP(kx x) | k == kx = some x look _ = none singleton :: Sized a => Key -> a -> SNode a singleton k a = sNode (Tip k a) singletonMaybe :: Sized a => Key -> Maybe a -> SNode a singletonMaybe k = maybe nil (singleton k) traverse :: (Applicative f, Sized b) => (a -> f b) -> SNode a -> f (SNode b) traverse f = trav where trav NIL = pure nil trav TIP(kx x) = singleton kx <$> f x trav BIN(p m l r) = bin' p m <$> trav l <*> trav r instance Foldable SNode where foldMap _ NIL = mempty foldMap f TIP(_ x) = f x foldMap f BIN(_ _ l r) = foldMap f l `mappend` foldMap f r foldr f z BIN(_ _ l r) = foldr f (foldr f z r) l foldr f z TIP(_ x) = f x z foldr _ z NIL = z foldl f z BIN(_ _ l r) = foldl f (foldl f z l) r foldl f z TIP(_ x) = f z x foldl _ z NIL = z foldr1 _ NIL = foldr1Empty foldr1 _ TIP(_ x) = x foldr1 f BIN(_ _ l r) = foldr f (foldr1 f r) l foldl1 _ NIL = foldl1Empty foldl1 _ TIP(_ x) = x foldl1 f BIN(_ _ l r) = foldl f (foldl1 f l) r instance Foldable (TrieMap Word) where foldMap f (WordMap m) = foldMap f m foldr f z (WordMap m) = foldr f z m foldl f z (WordMap m) = foldl f z m foldr1 f (WordMap m) = foldr1 f m foldl1 f (WordMap m) = foldl1 f m map :: Sized b => (a -> b) -> SNode a -> SNode b map f BIN(p m l r) = bin' p m (map f l) (map f r) map f TIP(kx x) = singleton kx (f x) map _ _ = nil mapMaybe :: Sized b => (a -> Maybe b) -> SNode a -> SNode b mapMaybe f BIN(p m l r) = bin p m (mapMaybe f l) (mapMaybe f r) mapMaybe f TIP(kx x) = singletonMaybe kx (f x) mapMaybe _ _ = nil mapEither :: (Sized b, Sized c) => (a -> (# Maybe b, Maybe c #)) -> SNode a -> (# SNode b, SNode c #) mapEither f BIN(p m l r) = both (bin p m lL) (bin p m lR) (mapEither f) r where !(# lL, lR #) = mapEither f l mapEither f TIP(kx x) = both (singletonMaybe kx) (singletonMaybe kx) f x mapEither _ _ = (# nil, nil #) unionWith :: Sized a => (a -> a -> Maybe a) -> SNode a -> SNode a -> SNode a unionWith f n1@(SNode _ t1) n2@(SNode _ t2) = case (t1, t2) of (Nil, _) -> n2 (_, Nil) -> n1 (Tip k x, _) -> alter (maybe (Just x) (f x)) k n2 (_, Tip k x) -> alter (maybe (Just x) (`f` x)) k n1 (Bin p1 m1 l1 r1, Bin p2 m2 l2 r2) | shorter m1 m2 -> union1 | shorter m2 m1 -> union2 | p1 == p2 -> bin p1 m1 (unionWith f l1 l2) (unionWith f r1 r2) | otherwise -> join p1 n1 p2 n2 where union1 | nomatch p2 p1 m1 = join p1 n1 p2 n2 | zero p2 m1 = bin p1 m1 (unionWith f l1 n2) r1 | otherwise = bin p1 m1 l1 (unionWith f r1 n2) union2 | nomatch p1 p2 m2 = join p1 n1 p2 n2 | zero p1 m2 = bin p2 m2 (unionWith f n1 l2) r2 | otherwise = bin p2 m2 l2 (unionWith f n1 r2) {-# INLINE alter #-} alter :: Sized a => (Maybe a -> Maybe a) -> Key -> SNode a -> SNode a alter f k t = getWordMap $ alterM f k (WordMap t) intersectionWith :: Sized c => (a -> b -> Maybe c) -> SNode a -> SNode b -> SNode c intersectionWith f n1@(SNode _ t1) n2@(SNode _ t2) = case (t1, t2) of (Nil, _) -> nil (_, Nil) -> nil (Tip k x, _) -> option (lookup k n2) nil (singletonMaybe k . f x) (_, Tip k y) -> option (lookup k n1) nil (singletonMaybe k . flip f y) (Bin p1 m1 l1 r1, Bin p2 m2 l2 r2) | shorter m1 m2 -> intersection1 | shorter m2 m1 -> intersection2 | p1 == p2 -> bin p1 m1 (intersectionWith f l1 l2) (intersectionWith f r1 r2) | otherwise -> nil where intersection1 | nomatch p2 p1 m1 = nil | zero p2 m1 = intersectionWith f l1 n2 | otherwise = intersectionWith f r1 n2 intersection2 | nomatch p1 p2 m2 = nil | zero p1 m2 = intersectionWith f n1 l2 | otherwise = intersectionWith f n1 r2 differenceWith :: Sized a => (a -> b -> Maybe a) -> SNode a -> SNode b -> SNode a differenceWith f n1@(SNode _ t1) n2@(SNode _ t2) = case (t1, t2) of (Nil, _) -> nil (_, Nil) -> n1 (Tip k x, _) -> option (lookup k n2) n1 (singletonMaybe k . f x) (_, Tip k y) -> alter (>>= flip f y) k n1 (Bin p1 m1 l1 r1, Bin p2 m2 l2 r2) | shorter m1 m2 -> difference1 | shorter m2 m1 -> difference2 | p1 == p2 -> bin p1 m1 (differenceWith f l1 l2) (differenceWith f r1 r2) | otherwise -> n1 where difference1 | nomatch p2 p1 m1 = n1 | zero p2 m1 = bin p1 m1 (differenceWith f l1 n2) r1 | otherwise = bin p1 m1 l1 (differenceWith f r1 n2) difference2 | nomatch p1 p2 m2 = n1 | zero p1 m2 = differenceWith f n1 l2 | otherwise = differenceWith f n1 r2 isSubmapOfBy :: LEq a b -> LEq (SNode a) (SNode b) isSubmapOfBy (<=) t1@BIN(p1 m1 l1 r1) BIN(p2 m2 l2 r2) | shorter m1 m2 = False | shorter m2 m1 = match p1 p2 m2 && (if zero p1 m2 then isSubmapOfBy (<=) t1 l2 else isSubmapOfBy (<=) t1 r2) | otherwise = (p1==p2) && isSubmapOfBy (<=) l1 l2 && isSubmapOfBy (<=) r1 r2 isSubmapOfBy _ BIN(_ _ _ _) _ = False isSubmapOfBy (<=) TIP(k x) t2 = option (lookup k t2) False (x <=) isSubmapOfBy _ NIL _ = True zero :: Key -> Mask -> Bool zero i m = i .&. m == 0 nomatch,match :: Key -> Prefix -> Mask -> Bool nomatch i p m = (mask i m) /= p match i p m = (mask i m) == p zeroN :: Nat -> Nat -> Bool zeroN i m = (i .&. m) == 0 mask :: Nat -> Nat -> Prefix mask i m = i .&. compl ((m-1) .|. m) shorter :: Mask -> Mask -> Bool shorter m1 m2 = m1 > m2 branchMask :: Prefix -> Prefix -> Mask branchMask p1 p2 = highestBitMask (p1 `xor` p2) highestBitMask :: Nat -> Nat highestBitMask x0 = case (x0 .|. shiftR x0 1) of x1 -> case (x1 .|. shiftR x1 2) of x2 -> case (x2 .|. shiftR x2 4) of x3 -> case (x3 .|. shiftR x3 8) of x4 -> case (x4 .|. shiftR x4 16) of x5 -> case (x5 .|. shiftR x5 32) of -- for 64 bit platforms x6 -> (x6 `xor` (shiftR x6 1)) {-# INLINE join #-} join :: Prefix -> SNode a -> Prefix -> SNode a -> SNode a join p1 t1 p2 t2 | zero p1 m = bin' p m t1 t2 | otherwise = bin' p m t2 t1 where m = branchMask p1 p2 p = mask p1 m nil :: SNode a nil = SNode 0 Nil bin :: Prefix -> Mask -> SNode a -> SNode a -> SNode a bin p m l@(SNode sl tl) r@(SNode sr tr) = case (tl, tr) of (Nil, _) -> r (_, Nil) -> l _ -> SNode (sl + sr) (Bin p m l r) bin' :: Prefix -> Mask -> SNode a -> SNode a -> SNode a bin' p m l@SNode{sz=sl} r@SNode{sz=sr} = assert (nonempty l && nonempty r) $ SNode (sl + sr) (Bin p m l r) where nonempty NIL = False nonempty _ = True {-# INLINE unifier #-} unifier :: Sized a => Key -> Key -> a -> Maybe (WHole a) unifier k' k a | k' == k = Nothing | otherwise = Just (WHole k' $ branchHole k' k Root (singleton k a))