{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
module Data.HashMap.Mutable.Basic
( MHashMap
, new
, newSized
, delete
, lookup
, insert
, mapM_
, foldM
, computeOverhead
) where
import Control.Exception (assert)
import Control.Monad hiding (foldM, mapM_)
import Control.Monad.Primitive (PrimMonad, PrimState, unsafePrimToPrim)
import Control.Monad.ST (ST)
import Data.Bits
import Data.Hashable (Hashable)
import Data.Maybe
import Data.Monoid
import Data.Primitive.MutVar (MutVar, newMutVar, readMutVar, writeMutVar)
import Data.STRef
import Data.Semigroup (Semigroup)
import GHC.Exts
import Prelude hiding (lookup, mapM_, read)
import qualified Data.Hashable as H
import qualified Data.Primitive.ByteArray as A
import qualified Data.Primitive.ByteArray as A
import Data.HashMap.Mutable.Internal.Array
import Data.HashMap.Mutable.Internal.CacheLine
import Data.HashMap.Mutable.Internal.IntArray (Elem)
import qualified Data.HashMap.Mutable.Internal.IntArray as U
import Data.HashMap.Mutable.Internal.Utils
import qualified Data.Semigroup as SG
newtype MHashMap s k v = HT (MutVar s (HashTable_ s k v))
type SizeRefs s = A.MutableByteArray s
intSz :: Int
intSz = (finiteBitSize (0::Int) `div` 8)
readLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int
readLoad = flip A.readByteArray 0
writeLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m ()
writeLoad = flip A.writeByteArray 0
readDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int
readDelLoad = flip A.readByteArray 1
writeDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m ()
writeDelLoad = flip A.writeByteArray 1
newSizeRefs :: PrimMonad m => m (SizeRefs (PrimState m))
newSizeRefs = do
let asz = 2 * intSz
a <- A.newAlignedPinnedByteArray asz intSz
A.fillByteArray a 0 asz 0
return a
data HashTable_ s k v = MHashMap
{ _size :: {-# UNPACK #-} !Int
, _load :: !(SizeRefs s)
, _hashes :: !(U.IntArray s)
, _keys :: {-# UNPACK #-} !(MutableArray s k)
, _values :: {-# UNPACK #-} !(MutableArray s v)
}
instance Show (MHashMap s k v) where
show _ = "<MHashMap>"
new :: PrimMonad m => m (MHashMap (PrimState m) k v)
new = newSized 1
{-# INLINE new #-}
newSized :: PrimMonad m => Int -> m (MHashMap (PrimState m) k v)
newSized n = do
debug $ "entering: newSized " ++ show n
let m = nextBestPrime $ ceiling (fromIntegral n / maxLoad)
ht <- newSizedReal m
newRef ht
{-# INLINE newSized #-}
newSizedReal :: PrimMonad m => Int -> m (HashTable_ (PrimState m) k v)
newSizedReal m = do
let m' = ((m + numElemsInCacheLine - 1) `div` numElemsInCacheLine)
* numElemsInCacheLine
h <- U.newArray m'
k <- newArray m undefined
v <- newArray m undefined
ld <- newSizeRefs
return $! MHashMap m ld h k v
delete :: (PrimMonad m, Hashable k, Eq k) =>
(MHashMap (PrimState m) k v)
-> k
-> m ()
delete htRef k = do
debug $ "entered: delete: hash=" ++ show h
ht <- readRef htRef
_ <- delete' ht True k h
return ()
where
!h = hash k
{-# INLINE delete #-}
lookup :: (PrimMonad m, Eq k, Hashable k) => (MHashMap (PrimState m) k v) -> k -> m (Maybe v)
lookup htRef !k = do
ht <- readRef htRef
lookup' ht
where
lookup' (MHashMap sz _ hashes keys values) = do
let !b = whichBucket h sz
debug $ "lookup h=" ++ show h ++ " sz=" ++ show sz ++ " b=" ++ show b
go b 0 sz
where
!h = hash k
!he = hashToElem h
go !b !start !end = {-# SCC "lookup/go" #-} do
debug $ concat [ "lookup'/go: "
, show b
, "/"
, show start
, "/"
, show end
]
idx <- forwardSearch2 hashes b end he emptyMarker
debug $ "forwardSearch2 returned " ++ show idx
if (idx < 0 || idx < start || idx >= end)
then return Nothing
else do
h0 <- U.readArray hashes idx
debug $ "h0 was " ++ show h0
if recordIsEmpty h0
then do
debug $ "record empty, returning Nothing"
return Nothing
else do
k' <- readArray keys idx
if k == k'
then do
debug $ "value found at " ++ show idx
v <- readArray values idx
return $! Just v
else do
debug $ "value not found, recursing"
if idx < b
then go (idx + 1) (idx + 1) b
else go (idx + 1) start end
{-# INLINE lookup #-}
insert :: (PrimMonad m, Eq k, Hashable k) =>
(MHashMap (PrimState m) k v)
-> k
-> v
-> m ()
insert htRef !k !v = do
ht <- readRef htRef
!ht' <- insert' ht
writeRef htRef ht'
where
insert' ht = do
debug "insert': calling delete'"
b <- delete' ht False k h
debug $ concat [ "insert': writing h="
, show h
, " he="
, show he
, " b="
, show b
]
U.writeArray hashes b he
writeArray keys b k
writeArray values b v
checkOverflow ht
where
!h = hash k
!he = hashToElem h
hashes = _hashes ht
keys = _keys ht
values = _values ht
{-# INLINE insert #-}
foldM :: PrimMonad m => (a -> k -> v -> m a) -> a -> MHashMap (PrimState m) k v -> m a
foldM f seed0 htRef = readRef htRef >>= work
where
work (MHashMap sz _ hashes keys values) = go 0 seed0
where
go !i !seed | i >= sz = return seed
| otherwise = do
h <- U.readArray hashes i
if recordIsEmpty h || recordIsDeleted h
then go (i+1) seed
else do
k <- readArray keys i
v <- readArray values i
!seed' <- f seed k v
go (i+1) seed'
mapM_ :: PrimMonad m => (k -> v -> m b) -> MHashMap (PrimState m) k v -> m ()
mapM_ f htRef = readRef htRef >>= work
where
work (MHashMap sz _ hashes keys values) = go 0
where
go !i | i >= sz = return ()
| otherwise = do
h <- U.readArray hashes i
if recordIsEmpty h || recordIsDeleted h
then go (i+1)
else do
k <- readArray keys i
v <- readArray values i
_ <- f k v
go (i+1)
computeOverhead :: PrimMonad m => MHashMap (PrimState m) k v -> m Double
computeOverhead htRef = readRef htRef >>= work
where
work (MHashMap sz' loadRef _ _ _) = do
!ld <- readLoad loadRef
let k = fromIntegral ld / sz
return $ constOverhead/sz + (2 + 2*ws*(1-k)) / (k * ws)
where
ws = fromIntegral $! finiteBitSize (0::Int) `div` 8
sz = fromIntegral sz'
constOverhead = 14
{-# INLINE insertRecord #-}
insertRecord :: PrimMonad m
=> Int
-> U.IntArray (PrimState m)
-> MutableArray (PrimState m) k
-> MutableArray (PrimState m) v
-> Int
-> k
-> v
-> m ()
insertRecord !sz !hashes !keys !values !h !key !value = do
let !b = whichBucket h sz
debug $ "insertRecord sz=" ++ show sz ++ " h=" ++ show h ++ " b=" ++ show b
probe b
where
he = hashToElem h
probe !i = {-# SCC "insertRecord/probe" #-} do
!idx <- forwardSearch2 hashes i sz emptyMarker deletedMarker
debug $ "forwardSearch2 returned " ++ show idx
assert (idx >= 0) $ do
U.writeArray hashes idx he
writeArray keys idx key
writeArray values idx value
checkOverflow :: (PrimMonad m, Eq k, Hashable k) =>
(HashTable_ (PrimState m) k v)
-> m (HashTable_ (PrimState m) k v)
checkOverflow ht@(MHashMap sz ldRef _ _ _) = do
!ld <- readLoad ldRef
let !ld' = ld + 1
writeLoad ldRef ld'
!dl <- readDelLoad ldRef
debug $ concat [ "checkOverflow: sz="
, show sz
, " entries="
, show ld
, " deleted="
, show dl ]
if fromIntegral (ld + dl) / fromIntegral sz > maxLoad
then if dl > ld `div` 2
then rehashAll ht sz
else growTable ht
else return ht
rehashAll :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> Int -> m (HashTable_ (PrimState m) k v)
rehashAll (MHashMap sz loadRef hashes keys values) sz' = do
debug $ "rehashing: old size " ++ show sz ++ ", new size " ++ show sz'
ht' <- newSizedReal sz'
let (MHashMap _ loadRef' newHashes newKeys newValues) = ht'
readLoad loadRef >>= writeLoad loadRef'
rehash newHashes newKeys newValues
return ht'
where
rehash newHashes newKeys newValues = go 0
where
go !i | i >= sz = return ()
| otherwise = {-# SCC "growTable/rehash" #-} do
h0 <- U.readArray hashes i
when (not (recordIsEmpty h0 || recordIsDeleted h0)) $ do
k <- readArray keys i
v <- readArray values i
insertRecord sz' newHashes newKeys newValues
(hash k) k v
go $ i+1
growTable :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> m (HashTable_ (PrimState m) k v)
growTable ht@(MHashMap sz _ _ _ _) = do
let !sz' = bumpSize maxLoad sz
rehashAll ht sz'
data Slot = Slot {
_slot :: {-# UNPACK #-} !Int
, _wasDeleted :: {-# UNPACK #-} !Int
}
deriving (Show)
instance Semigroup Slot where
Slot x1 b1 <> Slot x2 b2 = if x1 == maxBound then Slot x2 b2 else Slot x1 b1
instance Monoid Slot where
mempty = Slot maxBound 0
mappend = (SG.<>)
delete' :: (PrimMonad m, Hashable k, Eq k) =>
(HashTable_ (PrimState m) k v)
-> Bool
-> k
-> Int
-> m Int
delete' (MHashMap sz loadRef hashes keys values) clearOut k h = do
debug $ "delete': h=" ++ show h ++ " he=" ++ show he
++ " sz=" ++ show sz ++ " b0=" ++ show b0
pair@(found, slot) <- go mempty b0 False
debug $ "go returned " ++ show pair
let !b' = _slot slot
when found $ bump loadRef (-1)
when (not clearOut && _wasDeleted slot == 1) $ bumpDel loadRef (-1)
return b'
where
he = hashToElem h
bump ref i = do
!ld <- readLoad ref
writeLoad ref $! ld + i
bumpDel ref i = do
!ld <- readDelLoad ref
writeDelLoad ref $! ld + i
!b0 = whichBucket h sz
haveWrapped !(Slot fp _) !b = if fp == maxBound
then False
else b <= fp
go !fp !b !wrap = do
debug $ concat [ "go: fp="
, show fp
, " b="
, show b
, ", wrap="
, show wrap
, ", he="
, show he
, ", emptyMarker="
, show emptyMarker
, ", deletedMarker="
, show deletedMarker ]
!idx <- forwardSearch3 hashes b sz he emptyMarker deletedMarker
debug $ "forwardSearch3 returned " ++ show idx ++ " with sz=" ++ show sz ++ ", b=" ++ show b
if wrap && idx >= b0
then return $!
(False, fp `mappend` (Slot (error "impossible") 0))
else do
assert (idx >= 0) $ return ()
h0 <- U.readArray hashes idx
debug $ "h0 was " ++ show h0
if recordIsEmpty h0
then do
let pl = fp `mappend` (Slot idx 0)
debug $ "empty, returning " ++ show pl
return (False, pl)
else do
let !wrap' = haveWrapped fp idx
if recordIsDeleted h0
then do
let pl = fp `mappend` (Slot idx 1)
debug $ "deleted, cont with pl=" ++ show pl
go pl (idx + 1) wrap'
else
if he == h0
then do
debug $ "found he == h0 == " ++ show h0
k' <- readArray keys idx
if k == k'
then do
let samePlace = _slot fp == idx
debug $ "found at " ++ show idx
debug $ "clearout=" ++ show clearOut
debug $ "sp? " ++ show samePlace
when (clearOut || not samePlace) $ do
bumpDel loadRef 1
U.writeArray hashes idx deletedMarker
writeArray keys idx undefined
writeArray values idx undefined
return (True, fp `mappend` (Slot idx 0))
else go fp (idx + 1) wrap'
else go fp (idx + 1) wrap'
maxLoad :: Double
maxLoad = 0.82
emptyMarker :: Elem
emptyMarker = 0
deletedMarker :: Elem
deletedMarker = 1
{-# INLINE recordIsEmpty #-}
recordIsEmpty :: Elem -> Bool
recordIsEmpty = (== emptyMarker)
{-# INLINE recordIsDeleted #-}
recordIsDeleted :: Elem -> Bool
recordIsDeleted = (== deletedMarker)
{-# INLINE hash #-}
hash :: (Hashable k) => k -> Int
hash = H.hash
{-# INLINE hashToElem #-}
hashToElem :: Int -> Elem
hashToElem !h = out
where
!(I# lo#) = h .&. U.elemMask
!m# = maskw# lo# 0# `or#` maskw# lo# 1#
!nm# = not# m#
!r# = ((int2Word# 2#) `and#` m#) `or#` (int2Word# lo# `and#` nm#)
!out = U.primWordToElem r#
newRef :: PrimMonad m => HashTable_ (PrimState m) k v -> m (MHashMap (PrimState m) k v)
newRef = liftM HT . newMutVar
{-# INLINE newRef #-}
writeRef :: PrimMonad m => MHashMap (PrimState m) k v -> HashTable_ (PrimState m) k v -> m ()
writeRef (HT ref) ht = writeMutVar ref ht
{-# INLINE writeRef #-}
readRef :: PrimMonad m => MHashMap (PrimState m) k v -> m (HashTable_ (PrimState m) k v)
readRef (HT ref) = readMutVar ref
{-# INLINE readRef #-}
{-# INLINE debug #-}
debug :: PrimMonad m => String -> m ()
#ifdef DEBUG
debug s = unsafePrimToPrim (putStrLn s)
#else
debug _ = return ()
#endif