{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE MagicHash #-} -- | This whole module is mostly copied from the -- package. -- You can find much better documentation there. Additionally, if you -- problem in the implementation of a function, you should probably open -- up the issue on the -- since Gregory is the one who actually authored (and the one who -- actually understands) this code. The main differences are as follows: -- -- * The type is named @MHashMap@ instead of @HashTable@. -- * All functions are generalized to work in any 'PrimMonad' instead -- of only 'ST'. -- * The functions 'mapM_' and 'foldM' have been curried. They do not -- take tuples here. -- * There is no type class used to overload the hashtable operations. -- module Data.HashMap.Mutable.Basic ( MHashMap , new , newSized , delete , lookup , insert , mapM_ , foldM , computeOverhead ) where ------------------------------------------------------------------------------ import Control.Exception (assert) import Control.Monad hiding (foldM, mapM_) import Control.Monad.Primitive (PrimMonad, PrimState, unsafePrimToPrim) import Control.Monad.ST (ST) import Data.Bits import Data.Hashable (Hashable) import Data.Maybe import Data.Monoid import Data.Primitive.MutVar (MutVar, newMutVar, readMutVar, writeMutVar) import Data.STRef import Data.Semigroup (Semigroup) import GHC.Exts import Prelude hiding (lookup, mapM_, read) import qualified Data.Hashable as H import qualified Data.Primitive.ByteArray as A import qualified Data.Primitive.ByteArray as A ------------------------------------------------------------------------------ import Data.HashMap.Mutable.Internal.Array import Data.HashMap.Mutable.Internal.CacheLine import Data.HashMap.Mutable.Internal.IntArray (Elem) import qualified Data.HashMap.Mutable.Internal.IntArray as U import Data.HashMap.Mutable.Internal.Utils import qualified Data.Semigroup as SG ------------------------------------------------------------------------------ -- | An open addressing hash table using linear probing. newtype MHashMap s k v = HT (MutVar s (HashTable_ s k v)) type SizeRefs s = A.MutableByteArray s intSz :: Int intSz = (finiteBitSize (0::Int) `div` 8) readLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int readLoad = flip A.readByteArray 0 writeLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m () writeLoad = flip A.writeByteArray 0 readDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int readDelLoad = flip A.readByteArray 1 writeDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m () writeDelLoad = flip A.writeByteArray 1 newSizeRefs :: PrimMonad m => m (SizeRefs (PrimState m)) newSizeRefs = do let asz = 2 * intSz a <- A.newAlignedPinnedByteArray asz intSz A.fillByteArray a 0 asz 0 return a data HashTable_ s k v = MHashMap { _size :: {-# UNPACK #-} !Int , _load :: !(SizeRefs s) -- ^ 2-element array, stores how many entries -- and deleted entries are in the table. , _hashes :: !(U.IntArray s) , _keys :: {-# UNPACK #-} !(MutableArray s k) , _values :: {-# UNPACK #-} !(MutableArray s v) } ------------------------------------------------------------------------------ instance Show (MHashMap s k v) where show _ = "" new :: PrimMonad m => m (MHashMap (PrimState m) k v) new = newSized 1 {-# INLINE new #-} newSized :: PrimMonad m => Int -> m (MHashMap (PrimState m) k v) newSized n = do debug $ "entering: newSized " ++ show n let m = nextBestPrime $ ceiling (fromIntegral n / maxLoad) ht <- newSizedReal m newRef ht {-# INLINE newSized #-} newSizedReal :: PrimMonad m => Int -> m (HashTable_ (PrimState m) k v) newSizedReal m = do -- make sure the hash array is a multiple of cache-line sized so we can -- always search a whole cache line at once let m' = ((m + numElemsInCacheLine - 1) `div` numElemsInCacheLine) * numElemsInCacheLine h <- U.newArray m' k <- newArray m undefined v <- newArray m undefined ld <- newSizeRefs return $! MHashMap m ld h k v delete :: (PrimMonad m, Hashable k, Eq k) => (MHashMap (PrimState m) k v) -> k -> m () delete htRef k = do debug $ "entered: delete: hash=" ++ show h ht <- readRef htRef _ <- delete' ht True k h return () where !h = hash k {-# INLINE delete #-} lookup :: (PrimMonad m, Eq k, Hashable k) => (MHashMap (PrimState m) k v) -> k -> m (Maybe v) lookup htRef !k = do ht <- readRef htRef lookup' ht where lookup' (MHashMap sz _ hashes keys values) = do let !b = whichBucket h sz debug $ "lookup h=" ++ show h ++ " sz=" ++ show sz ++ " b=" ++ show b go b 0 sz where !h = hash k !he = hashToElem h go !b !start !end = {-# SCC "lookup/go" #-} do debug $ concat [ "lookup'/go: " , show b , "/" , show start , "/" , show end ] idx <- forwardSearch2 hashes b end he emptyMarker debug $ "forwardSearch2 returned " ++ show idx if (idx < 0 || idx < start || idx >= end) then return Nothing else do h0 <- U.readArray hashes idx debug $ "h0 was " ++ show h0 if recordIsEmpty h0 then do debug $ "record empty, returning Nothing" return Nothing else do k' <- readArray keys idx if k == k' then do debug $ "value found at " ++ show idx v <- readArray values idx return $! Just v else do debug $ "value not found, recursing" if idx < b then go (idx + 1) (idx + 1) b else go (idx + 1) start end {-# INLINE lookup #-} insert :: (PrimMonad m, Eq k, Hashable k) => (MHashMap (PrimState m) k v) -> k -> v -> m () insert htRef !k !v = do ht <- readRef htRef !ht' <- insert' ht writeRef htRef ht' where insert' ht = do debug "insert': calling delete'" b <- delete' ht False k h debug $ concat [ "insert': writing h=" , show h , " he=" , show he , " b=" , show b ] U.writeArray hashes b he writeArray keys b k writeArray values b v checkOverflow ht where !h = hash k !he = hashToElem h hashes = _hashes ht keys = _keys ht values = _values ht {-# INLINE insert #-} foldM :: PrimMonad m => (a -> k -> v -> m a) -> a -> MHashMap (PrimState m) k v -> m a foldM f seed0 htRef = readRef htRef >>= work where work (MHashMap sz _ hashes keys values) = go 0 seed0 where go !i !seed | i >= sz = return seed | otherwise = do h <- U.readArray hashes i if recordIsEmpty h || recordIsDeleted h then go (i+1) seed else do k <- readArray keys i v <- readArray values i !seed' <- f seed k v go (i+1) seed' mapM_ :: PrimMonad m => (k -> v -> m b) -> MHashMap (PrimState m) k v -> m () mapM_ f htRef = readRef htRef >>= work where work (MHashMap sz _ hashes keys values) = go 0 where go !i | i >= sz = return () | otherwise = do h <- U.readArray hashes i if recordIsEmpty h || recordIsDeleted h then go (i+1) else do k <- readArray keys i v <- readArray values i _ <- f k v go (i+1) computeOverhead :: PrimMonad m => MHashMap (PrimState m) k v -> m Double computeOverhead htRef = readRef htRef >>= work where work (MHashMap sz' loadRef _ _ _) = do !ld <- readLoad loadRef let k = fromIntegral ld / sz return $ constOverhead/sz + (2 + 2*ws*(1-k)) / (k * ws) where ws = fromIntegral $! finiteBitSize (0::Int) `div` 8 sz = fromIntegral sz' -- Change these if you change the representation constOverhead = 14 ------------------------------ -- Private functions follow -- ------------------------------ ------------------------------------------------------------------------------ {-# INLINE insertRecord #-} insertRecord :: PrimMonad m => Int -> U.IntArray (PrimState m) -> MutableArray (PrimState m) k -> MutableArray (PrimState m) v -> Int -> k -> v -> m () insertRecord !sz !hashes !keys !values !h !key !value = do let !b = whichBucket h sz debug $ "insertRecord sz=" ++ show sz ++ " h=" ++ show h ++ " b=" ++ show b probe b where he = hashToElem h probe !i = {-# SCC "insertRecord/probe" #-} do !idx <- forwardSearch2 hashes i sz emptyMarker deletedMarker debug $ "forwardSearch2 returned " ++ show idx assert (idx >= 0) $ do U.writeArray hashes idx he writeArray keys idx key writeArray values idx value ------------------------------------------------------------------------------ checkOverflow :: (PrimMonad m, Eq k, Hashable k) => (HashTable_ (PrimState m) k v) -> m (HashTable_ (PrimState m) k v) checkOverflow ht@(MHashMap sz ldRef _ _ _) = do !ld <- readLoad ldRef let !ld' = ld + 1 writeLoad ldRef ld' !dl <- readDelLoad ldRef debug $ concat [ "checkOverflow: sz=" , show sz , " entries=" , show ld , " deleted=" , show dl ] if fromIntegral (ld + dl) / fromIntegral sz > maxLoad then if dl > ld `div` 2 then rehashAll ht sz else growTable ht else return ht ------------------------------------------------------------------------------ rehashAll :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> Int -> m (HashTable_ (PrimState m) k v) rehashAll (MHashMap sz loadRef hashes keys values) sz' = do debug $ "rehashing: old size " ++ show sz ++ ", new size " ++ show sz' ht' <- newSizedReal sz' let (MHashMap _ loadRef' newHashes newKeys newValues) = ht' readLoad loadRef >>= writeLoad loadRef' rehash newHashes newKeys newValues return ht' where rehash newHashes newKeys newValues = go 0 where go !i | i >= sz = return () | otherwise = {-# SCC "growTable/rehash" #-} do h0 <- U.readArray hashes i when (not (recordIsEmpty h0 || recordIsDeleted h0)) $ do k <- readArray keys i v <- readArray values i insertRecord sz' newHashes newKeys newValues (hash k) k v go $ i+1 ------------------------------------------------------------------------------ growTable :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> m (HashTable_ (PrimState m) k v) growTable ht@(MHashMap sz _ _ _ _) = do let !sz' = bumpSize maxLoad sz rehashAll ht sz' ------------------------------------------------------------------------------ -- Helper data structure for delete' data Slot = Slot { _slot :: {-# UNPACK #-} !Int , _wasDeleted :: {-# UNPACK #-} !Int -- we use Int because Bool won't -- unpack } deriving (Show) ------------------------------------------------------------------------------ instance Semigroup Slot where Slot x1 b1 <> Slot x2 b2 = if x1 == maxBound then Slot x2 b2 else Slot x1 b1 instance Monoid Slot where mempty = Slot maxBound 0 mappend = (SG.<>) ------------------------------------------------------------------------------ -- Returns the slot in the array where it would be safe to write the given key. delete' :: (PrimMonad m, Hashable k, Eq k) => (HashTable_ (PrimState m) k v) -> Bool -> k -> Int -> m Int delete' (MHashMap sz loadRef hashes keys values) clearOut k h = do debug $ "delete': h=" ++ show h ++ " he=" ++ show he ++ " sz=" ++ show sz ++ " b0=" ++ show b0 pair@(found, slot) <- go mempty b0 False debug $ "go returned " ++ show pair let !b' = _slot slot when found $ bump loadRef (-1) -- bump the delRef lower if we're writing over a deleted marker when (not clearOut && _wasDeleted slot == 1) $ bumpDel loadRef (-1) return b' where he = hashToElem h bump ref i = do !ld <- readLoad ref writeLoad ref $! ld + i bumpDel ref i = do !ld <- readDelLoad ref writeDelLoad ref $! ld + i !b0 = whichBucket h sz haveWrapped !(Slot fp _) !b = if fp == maxBound then False else b <= fp -- arguments: -- * fp maintains the slot in the array where it would be safe to -- write the given key -- * b search the buckets array starting at this index. -- * wrap True if we've wrapped around, False otherwise go !fp !b !wrap = do debug $ concat [ "go: fp=" , show fp , " b=" , show b , ", wrap=" , show wrap , ", he=" , show he , ", emptyMarker=" , show emptyMarker , ", deletedMarker=" , show deletedMarker ] !idx <- forwardSearch3 hashes b sz he emptyMarker deletedMarker debug $ "forwardSearch3 returned " ++ show idx ++ " with sz=" ++ show sz ++ ", b=" ++ show b if wrap && idx >= b0 -- we wrapped around in the search and didn't find our hash code; -- this means that the table is full of deleted elements. Just return -- the first place we'd be allowed to insert. -- -- TODO: if we get in this situation we should probably just rehash -- the table, because every insert is going to be O(n). then return $! (False, fp `mappend` (Slot (error "impossible") 0)) else do -- because the table isn't full, we know that there must be either -- an empty or a deleted marker somewhere in the table. Assert this -- here. assert (idx >= 0) $ return () h0 <- U.readArray hashes idx debug $ "h0 was " ++ show h0 if recordIsEmpty h0 then do let pl = fp `mappend` (Slot idx 0) debug $ "empty, returning " ++ show pl return (False, pl) else do let !wrap' = haveWrapped fp idx if recordIsDeleted h0 then do let pl = fp `mappend` (Slot idx 1) debug $ "deleted, cont with pl=" ++ show pl go pl (idx + 1) wrap' else if he == h0 then do debug $ "found he == h0 == " ++ show h0 k' <- readArray keys idx if k == k' then do let samePlace = _slot fp == idx debug $ "found at " ++ show idx debug $ "clearout=" ++ show clearOut debug $ "sp? " ++ show samePlace -- "clearOut" is set if we intend to write a new -- element into the slot. If we're doing an update -- and we found the old key, instead of writing -- "deleted" and then re-writing the new element -- there, we can just write the new element. This -- only works if we were planning on writing the -- new element here. when (clearOut || not samePlace) $ do bumpDel loadRef 1 U.writeArray hashes idx deletedMarker writeArray keys idx undefined writeArray values idx undefined return (True, fp `mappend` (Slot idx 0)) else go fp (idx + 1) wrap' else go fp (idx + 1) wrap' ------------------------------------------------------------------------------ maxLoad :: Double maxLoad = 0.82 ------------------------------------------------------------------------------ emptyMarker :: Elem emptyMarker = 0 ------------------------------------------------------------------------------ deletedMarker :: Elem deletedMarker = 1 ------------------------------------------------------------------------------ {-# INLINE recordIsEmpty #-} recordIsEmpty :: Elem -> Bool recordIsEmpty = (== emptyMarker) ------------------------------------------------------------------------------ {-# INLINE recordIsDeleted #-} recordIsDeleted :: Elem -> Bool recordIsDeleted = (== deletedMarker) ------------------------------------------------------------------------------ {-# INLINE hash #-} hash :: (Hashable k) => k -> Int hash = H.hash ------------------------------------------------------------------------------ {-# INLINE hashToElem #-} hashToElem :: Int -> Elem hashToElem !h = out where !(I# lo#) = h .&. U.elemMask !m# = maskw# lo# 0# `or#` maskw# lo# 1# !nm# = not# m# !r# = ((int2Word# 2#) `and#` m#) `or#` (int2Word# lo# `and#` nm#) !out = U.primWordToElem r# ------------------------------------------------------------------------------ newRef :: PrimMonad m => HashTable_ (PrimState m) k v -> m (MHashMap (PrimState m) k v) newRef = liftM HT . newMutVar {-# INLINE newRef #-} writeRef :: PrimMonad m => MHashMap (PrimState m) k v -> HashTable_ (PrimState m) k v -> m () writeRef (HT ref) ht = writeMutVar ref ht {-# INLINE writeRef #-} readRef :: PrimMonad m => MHashMap (PrimState m) k v -> m (HashTable_ (PrimState m) k v) readRef (HT ref) = readMutVar ref {-# INLINE readRef #-} ------------------------------------------------------------------------------ {-# INLINE debug #-} debug :: PrimMonad m => String -> m () #ifdef DEBUG debug s = unsafePrimToPrim (putStrLn s) #else debug _ = return () #endif