{-# LANGUAGE BangPatterns #-} {-# LANGUAGE CPP #-} {-# LANGUAGE MagicHash #-} module Data.HashMap.Mutable.Basic ( HashTable , new , newSized , delete , lookup , insert , mapM_ , foldM , computeOverhead ) where ------------------------------------------------------------------------------ import Control.Exception (assert) import Control.Monad hiding (foldM, mapM_) import Control.Monad.ST (ST) import Control.Monad.Primitive (PrimMonad,PrimState,unsafePrimToPrim) import Data.Bits import Data.Hashable (Hashable) import qualified Data.Hashable as H import Data.Maybe import Data.Monoid import qualified Data.Primitive.ByteArray as A import Data.Primitive.MutVar (MutVar,readMutVar,writeMutVar,newMutVar) import Data.STRef import GHC.Exts import Prelude hiding (lookup, mapM_, read) ------------------------------------------------------------------------------ import Data.HashMap.Mutable.Internal.Array import Data.HashMap.Mutable.Internal.CacheLine import Data.HashMap.Mutable.Internal.IntArray (Elem) import qualified Data.HashMap.Mutable.Internal.IntArray as U import Data.HashMap.Mutable.Internal.Utils ------------------------------------------------------------------------------ -- | An open addressing hash table using linear probing. newtype HashTable s k v = HT (MutVar s (HashTable_ s k v)) type SizeRefs s = A.MutableByteArray s intSz :: Int intSz = (finiteBitSize (0::Int) `div` 8) readLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int readLoad = flip A.readByteArray 0 writeLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m () writeLoad = flip A.writeByteArray 0 readDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int readDelLoad = flip A.readByteArray 1 writeDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m () writeDelLoad = flip A.writeByteArray 1 newSizeRefs :: PrimMonad m => m (SizeRefs (PrimState m)) newSizeRefs = do let asz = 2 * intSz a <- A.newAlignedPinnedByteArray asz intSz A.fillByteArray a 0 asz 0 return a data HashTable_ s k v = HashTable { _size :: {-# UNPACK #-} !Int , _load :: !(SizeRefs s) -- ^ 2-element array, stores how many entries -- and deleted entries are in the table. , _hashes :: !(U.IntArray s) , _keys :: {-# UNPACK #-} !(MutableArray s k) , _values :: {-# UNPACK #-} !(MutableArray s v) } ------------------------------------------------------------------------------ instance Show (HashTable s k v) where show _ = "" ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:new". new :: PrimMonad m => m (HashTable (PrimState m) k v) new = newSized 1 {-# INLINE new #-} ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:newSized". newSized :: PrimMonad m => Int -> m (HashTable (PrimState m) k v) newSized n = do debug $ "entering: newSized " ++ show n let m = nextBestPrime $ ceiling (fromIntegral n / maxLoad) ht <- newSizedReal m newRef ht {-# INLINE newSized #-} ------------------------------------------------------------------------------ newSizedReal :: PrimMonad m => Int -> m (HashTable_ (PrimState m) k v) newSizedReal m = do -- make sure the hash array is a multiple of cache-line sized so we can -- always search a whole cache line at once let m' = ((m + numElemsInCacheLine - 1) `div` numElemsInCacheLine) * numElemsInCacheLine h <- U.newArray m' k <- newArray m undefined v <- newArray m undefined ld <- newSizeRefs return $! HashTable m ld h k v ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:delete". delete :: (PrimMonad m, Hashable k, Eq k) => (HashTable (PrimState m) k v) -> k -> m () delete htRef k = do debug $ "entered: delete: hash=" ++ show h ht <- readRef htRef _ <- delete' ht True k h return () where !h = hash k {-# INLINE delete #-} ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:lookup". lookup :: (PrimMonad m, Eq k, Hashable k) => (HashTable (PrimState m) k v) -> k -> m (Maybe v) lookup htRef !k = do ht <- readRef htRef lookup' ht where lookup' (HashTable sz _ hashes keys values) = do let !b = whichBucket h sz debug $ "lookup h=" ++ show h ++ " sz=" ++ show sz ++ " b=" ++ show b go b 0 sz where !h = hash k !he = hashToElem h go !b !start !end = {-# SCC "lookup/go" #-} do debug $ concat [ "lookup'/go: " , show b , "/" , show start , "/" , show end ] idx <- forwardSearch2 hashes b end he emptyMarker debug $ "forwardSearch2 returned " ++ show idx if (idx < 0 || idx < start || idx >= end) then return Nothing else do h0 <- U.readArray hashes idx debug $ "h0 was " ++ show h0 if recordIsEmpty h0 then do debug $ "record empty, returning Nothing" return Nothing else do k' <- readArray keys idx if k == k' then do debug $ "value found at " ++ show idx v <- readArray values idx return $! Just v else do debug $ "value not found, recursing" if idx < b then go (idx + 1) (idx + 1) b else go (idx + 1) start end {-# INLINE lookup #-} ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:insert". insert :: (PrimMonad m, Eq k, Hashable k) => (HashTable (PrimState m) k v) -> k -> v -> m () insert htRef !k !v = do ht <- readRef htRef !ht' <- insert' ht writeRef htRef ht' where insert' ht = do debug "insert': calling delete'" b <- delete' ht False k h debug $ concat [ "insert': writing h=" , show h , " he=" , show he , " b=" , show b ] U.writeArray hashes b he writeArray keys b k writeArray values b v checkOverflow ht where !h = hash k !he = hashToElem h hashes = _hashes ht keys = _keys ht values = _values ht {-# INLINE insert #-} ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:foldM". foldM :: PrimMonad m => (a -> (k,v) -> m a) -> a -> HashTable (PrimState m) k v -> m a foldM f seed0 htRef = readRef htRef >>= work where work (HashTable sz _ hashes keys values) = go 0 seed0 where go !i !seed | i >= sz = return seed | otherwise = do h <- U.readArray hashes i if recordIsEmpty h || recordIsDeleted h then go (i+1) seed else do k <- readArray keys i v <- readArray values i !seed' <- f seed (k, v) go (i+1) seed' ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:mapM_". mapM_ :: PrimMonad m => (k -> v -> m b) -> HashTable (PrimState m) k v -> m () mapM_ f htRef = readRef htRef >>= work where work (HashTable sz _ hashes keys values) = go 0 where go !i | i >= sz = return () | otherwise = do h <- U.readArray hashes i if recordIsEmpty h || recordIsDeleted h then go (i+1) else do k <- readArray keys i v <- readArray values i _ <- f k v go (i+1) ------------------------------------------------------------------------------ -- | See the documentation for this function in -- "Data.HashTable.Class#v:computeOverhead". computeOverhead :: PrimMonad m => HashTable (PrimState m) k v -> m Double computeOverhead htRef = readRef htRef >>= work where work (HashTable sz' loadRef _ _ _) = do !ld <- readLoad loadRef let k = fromIntegral ld / sz return $ constOverhead/sz + (2 + 2*ws*(1-k)) / (k * ws) where ws = fromIntegral $! finiteBitSize (0::Int) `div` 8 sz = fromIntegral sz' -- Change these if you change the representation constOverhead = 14 ------------------------------ -- Private functions follow -- ------------------------------ ------------------------------------------------------------------------------ {-# INLINE insertRecord #-} insertRecord :: PrimMonad m => Int -> U.IntArray (PrimState m) -> MutableArray (PrimState m) k -> MutableArray (PrimState m) v -> Int -> k -> v -> m () insertRecord !sz !hashes !keys !values !h !key !value = do let !b = whichBucket h sz debug $ "insertRecord sz=" ++ show sz ++ " h=" ++ show h ++ " b=" ++ show b probe b where he = hashToElem h probe !i = {-# SCC "insertRecord/probe" #-} do !idx <- forwardSearch2 hashes i sz emptyMarker deletedMarker debug $ "forwardSearch2 returned " ++ show idx assert (idx >= 0) $ do U.writeArray hashes idx he writeArray keys idx key writeArray values idx value ------------------------------------------------------------------------------ checkOverflow :: (PrimMonad m, Eq k, Hashable k) => (HashTable_ (PrimState m) k v) -> m (HashTable_ (PrimState m) k v) checkOverflow ht@(HashTable sz ldRef _ _ _) = do !ld <- readLoad ldRef let !ld' = ld + 1 writeLoad ldRef ld' !dl <- readDelLoad ldRef debug $ concat [ "checkOverflow: sz=" , show sz , " entries=" , show ld , " deleted=" , show dl ] if fromIntegral (ld + dl) / fromIntegral sz > maxLoad then if dl > ld `div` 2 then rehashAll ht sz else growTable ht else return ht ------------------------------------------------------------------------------ rehashAll :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> Int -> m (HashTable_ (PrimState m) k v) rehashAll (HashTable sz loadRef hashes keys values) sz' = do debug $ "rehashing: old size " ++ show sz ++ ", new size " ++ show sz' ht' <- newSizedReal sz' let (HashTable _ loadRef' newHashes newKeys newValues) = ht' readLoad loadRef >>= writeLoad loadRef' rehash newHashes newKeys newValues return ht' where rehash newHashes newKeys newValues = go 0 where go !i | i >= sz = return () | otherwise = {-# SCC "growTable/rehash" #-} do h0 <- U.readArray hashes i when (not (recordIsEmpty h0 || recordIsDeleted h0)) $ do k <- readArray keys i v <- readArray values i insertRecord sz' newHashes newKeys newValues (hash k) k v go $ i+1 ------------------------------------------------------------------------------ growTable :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> m (HashTable_ (PrimState m) k v) growTable ht@(HashTable sz _ _ _ _) = do let !sz' = bumpSize maxLoad sz rehashAll ht sz' ------------------------------------------------------------------------------ -- Helper data structure for delete' data Slot = Slot { _slot :: {-# UNPACK #-} !Int , _wasDeleted :: {-# UNPACK #-} !Int -- we use Int because Bool won't -- unpack } deriving (Show) ------------------------------------------------------------------------------ instance Monoid Slot where mempty = Slot maxBound 0 (Slot x1 b1) `mappend` (Slot x2 b2) = if x1 == maxBound then Slot x2 b2 else Slot x1 b1 ------------------------------------------------------------------------------ -- Returns the slot in the array where it would be safe to write the given key. delete' :: (PrimMonad m, Hashable k, Eq k) => (HashTable_ (PrimState m) k v) -> Bool -> k -> Int -> m Int delete' (HashTable sz loadRef hashes keys values) clearOut k h = do debug $ "delete': h=" ++ show h ++ " he=" ++ show he ++ " sz=" ++ show sz ++ " b0=" ++ show b0 pair@(found, slot) <- go mempty b0 False debug $ "go returned " ++ show pair let !b' = _slot slot when found $ bump loadRef (-1) -- bump the delRef lower if we're writing over a deleted marker when (not clearOut && _wasDeleted slot == 1) $ bumpDel loadRef (-1) return b' where he = hashToElem h bump ref i = do !ld <- readLoad ref writeLoad ref $! ld + i bumpDel ref i = do !ld <- readDelLoad ref writeDelLoad ref $! ld + i !b0 = whichBucket h sz haveWrapped !(Slot fp _) !b = if fp == maxBound then False else b <= fp -- arguments: -- * fp maintains the slot in the array where it would be safe to -- write the given key -- * b search the buckets array starting at this index. -- * wrap True if we've wrapped around, False otherwise go !fp !b !wrap = do debug $ concat [ "go: fp=" , show fp , " b=" , show b , ", wrap=" , show wrap , ", he=" , show he , ", emptyMarker=" , show emptyMarker , ", deletedMarker=" , show deletedMarker ] !idx <- forwardSearch3 hashes b sz he emptyMarker deletedMarker debug $ "forwardSearch3 returned " ++ show idx ++ " with sz=" ++ show sz ++ ", b=" ++ show b if wrap && idx >= b0 -- we wrapped around in the search and didn't find our hash code; -- this means that the table is full of deleted elements. Just return -- the first place we'd be allowed to insert. -- -- TODO: if we get in this situation we should probably just rehash -- the table, because every insert is going to be O(n). then return $! (False, fp `mappend` (Slot (error "impossible") 0)) else do -- because the table isn't full, we know that there must be either -- an empty or a deleted marker somewhere in the table. Assert this -- here. assert (idx >= 0) $ return () h0 <- U.readArray hashes idx debug $ "h0 was " ++ show h0 if recordIsEmpty h0 then do let pl = fp `mappend` (Slot idx 0) debug $ "empty, returning " ++ show pl return (False, pl) else do let !wrap' = haveWrapped fp idx if recordIsDeleted h0 then do let pl = fp `mappend` (Slot idx 1) debug $ "deleted, cont with pl=" ++ show pl go pl (idx + 1) wrap' else if he == h0 then do debug $ "found he == h0 == " ++ show h0 k' <- readArray keys idx if k == k' then do let samePlace = _slot fp == idx debug $ "found at " ++ show idx debug $ "clearout=" ++ show clearOut debug $ "sp? " ++ show samePlace -- "clearOut" is set if we intend to write a new -- element into the slot. If we're doing an update -- and we found the old key, instead of writing -- "deleted" and then re-writing the new element -- there, we can just write the new element. This -- only works if we were planning on writing the -- new element here. when (clearOut || not samePlace) $ do bumpDel loadRef 1 U.writeArray hashes idx deletedMarker writeArray keys idx undefined writeArray values idx undefined return (True, fp `mappend` (Slot idx 0)) else go fp (idx + 1) wrap' else go fp (idx + 1) wrap' ------------------------------------------------------------------------------ maxLoad :: Double maxLoad = 0.82 ------------------------------------------------------------------------------ emptyMarker :: Elem emptyMarker = 0 ------------------------------------------------------------------------------ deletedMarker :: Elem deletedMarker = 1 ------------------------------------------------------------------------------ {-# INLINE recordIsEmpty #-} recordIsEmpty :: Elem -> Bool recordIsEmpty = (== emptyMarker) ------------------------------------------------------------------------------ {-# INLINE recordIsDeleted #-} recordIsDeleted :: Elem -> Bool recordIsDeleted = (== deletedMarker) ------------------------------------------------------------------------------ {-# INLINE hash #-} hash :: (Hashable k) => k -> Int hash = H.hash ------------------------------------------------------------------------------ {-# INLINE hashToElem #-} hashToElem :: Int -> Elem hashToElem !h = out where !(I# lo#) = h .&. U.elemMask !m# = maskw# lo# 0# `or#` maskw# lo# 1# !nm# = not# m# !r# = ((int2Word# 2#) `and#` m#) `or#` (int2Word# lo# `and#` nm#) !out = U.primWordToElem r# ------------------------------------------------------------------------------ newRef :: PrimMonad m => HashTable_ (PrimState m) k v -> m (HashTable (PrimState m) k v) newRef = liftM HT . newMutVar {-# INLINE newRef #-} writeRef :: PrimMonad m => HashTable (PrimState m) k v -> HashTable_ (PrimState m) k v -> m () writeRef (HT ref) ht = writeMutVar ref ht {-# INLINE writeRef #-} readRef :: PrimMonad m => HashTable (PrimState m) k v -> m (HashTable_ (PrimState m) k v) readRef (HT ref) = readMutVar ref {-# INLINE readRef #-} ------------------------------------------------------------------------------ {-# INLINE debug #-} debug :: PrimMonad m => String -> m () #ifdef DEBUG debug s = unsafePrimToPrim (putStrLn s) #else debug _ = return () #endif