{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
module Data.HashTable.ST.Cuckoo
( HashTable
, new
, newSized
, delete
, lookup
, insert
, mapM_
, foldM
) where
import Control.Monad hiding
(foldM,
mapM_)
import Control.Monad.ST (ST)
import Data.Bits
import Data.Hashable hiding
(hash)
import qualified Data.Hashable as H
import Data.Int
import Data.Maybe
import Data.Primitive.Array
import Data.STRef
import GHC.Exts
import Prelude hiding
(lookup,
mapM_,
read)
import qualified Data.HashTable.Class as C
import Data.HashTable.Internal.CacheLine
import Data.HashTable.Internal.CheapPseudoRandomBitStream
import Data.HashTable.Internal.IntArray (Elem)
import qualified Data.HashTable.Internal.IntArray as U
import Data.HashTable.Internal.Utils
#ifdef DEBUG
import System.IO
#endif
newtype HashTable s k v = HT (STRef s (HashTable_ s k v))
data HashTable_ s k v = HashTable
{ _size :: {-# UNPACK #-} !Int
, _rng :: {-# UNPACK #-} !(BitStream s)
, _hashes :: {-# UNPACK #-} !(U.IntArray s)
, _keys :: {-# UNPACK #-} !(MutableArray s k)
, _values :: {-# UNPACK #-} !(MutableArray s v)
, _maxAttempts :: {-# UNPACK #-} !Int
}
instance C.HashTable HashTable where
new = new
newSized = newSized
insert = insert
delete = delete
lookup = lookup
foldM = foldM
mapM_ = mapM_
computeOverhead = computeOverhead
instance Show (HashTable s k v) where
show _ = "<HashTable>"
new :: ST s (HashTable s k v)
new = newSizedReal 2 >>= newRef
{-# INLINE new #-}
newSized :: Int -> ST s (HashTable s k v)
newSized n = do
let n' = (n + numElemsInCacheLine - 1) `div` numElemsInCacheLine
let k = nextBestPrime $ ceiling $ fromIntegral n' / maxLoad
newSizedReal k >>= newRef
{-# INLINE newSized #-}
insert :: (Eq k, Hashable k) => HashTable s k v -> k -> v -> ST s ()
insert ht !k !v = readRef ht >>= \h -> insert' h k v >>= writeRef ht
computeOverhead :: HashTable s k v -> ST s Double
computeOverhead htRef = readRef htRef >>= work
where
work (HashTable sz _ _ _ _ _) = do
nFilled <- foldM f 0 htRef
let oh = (totSz `div` hashCodesPerWord)
+ 2 * (totSz - nFilled)
+ 12
return $! fromIntegral (oh::Int) / fromIntegral nFilled
where
hashCodesPerWord = (bitSize (0 :: Int)) `div` 16
totSz = numElemsInCacheLine * sz
f !a _ = return $! a+1
delete :: (Hashable k, Eq k) =>
HashTable s k v
-> k
-> ST s ()
delete htRef k = readRef htRef >>= go
where
go ht@(HashTable sz _ _ _ _ _) = do
_ <- delete' ht False k b1 b2 h1 h2
return ()
where
h1 = hash1 k
h2 = hash2 k
b1 = whichLine h1 sz
b2 = whichLine h2 sz
lookup :: (Eq k, Hashable k) =>
HashTable s k v
-> k
-> ST s (Maybe v)
lookup htRef k = do
ht <- readRef htRef
lookup' ht k
{-# INLINE lookup #-}
lookup' :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> ST s (Maybe v)
lookup' (HashTable sz _ hashes keys values _) !k = do
idx1 <- searchOne keys hashes k b1 he1
if idx1 >= 0
then do
v <- readArray values idx1
return $! Just v
else do
idx2 <- searchOne keys hashes k b2 he2
if idx2 >= 0
then do
v <- readArray values idx2
return $! Just v
else
return Nothing
where
h1 = hash1 k
h2 = hash2 k
he1 = hashToElem h1
he2 = hashToElem h2
b1 = whichLine h1 sz
b2 = whichLine h2 sz
{-# INLINE lookup' #-}
searchOne :: (Eq k) =>
MutableArray s k
-> U.IntArray s
-> k
-> Int
-> Elem
-> ST s Int
searchOne !keys !hashes !k !b0 !h = go b0
where
go !b = do
debug $ "searchOne: go/" ++ show b ++ "/" ++ show h
idx <- cacheLineSearch hashes b h
debug $ "searchOne: cacheLineSearch returned " ++ show idx
case idx of
-1 -> return (-1)
_ -> do
k' <- readArray keys idx
if k == k'
then return idx
else do
let !idx' = idx + 1
if isCacheLineAligned idx'
then return (-1)
else go idx'
{-# INLINE searchOne #-}
foldM :: (a -> (k,v) -> ST s a)
-> a
-> HashTable s k v
-> ST s a
foldM f seed0 htRef = readRef htRef >>= foldMWork f seed0
{-# INLINE foldM #-}
foldMWork :: (a -> (k,v) -> ST s a)
-> a
-> HashTable_ s k v
-> ST s a
foldMWork f seed0 (HashTable sz _ hashes keys values _) = go 0 seed0
where
totSz = numElemsInCacheLine * sz
go !i !seed | i >= totSz = return seed
| otherwise = do
h <- U.readArray hashes i
if h /= emptyMarker
then do
k <- readArray keys i
v <- readArray values i
!seed' <- f seed (k,v)
go (i+1) seed'
else
go (i+1) seed
{-# INLINE foldMWork #-}
mapM_ :: ((k,v) -> ST s a)
-> HashTable s k v
-> ST s ()
mapM_ f htRef = readRef htRef >>= mapMWork f
{-# INLINE mapM_ #-}
mapMWork :: ((k,v) -> ST s a)
-> HashTable_ s k v
-> ST s ()
mapMWork f (HashTable sz _ hashes keys values _) = go 0
where
totSz = numElemsInCacheLine * sz
go !i | i >= totSz = return ()
| otherwise = do
h <- U.readArray hashes i
if h /= emptyMarker
then do
k <- readArray keys i
v <- readArray values i
_ <- f (k,v)
go (i+1)
else
go (i+1)
{-# INLINE mapMWork #-}
newSizedReal :: Int -> ST s (HashTable_ s k v)
newSizedReal nbuckets = do
let !ntotal = nbuckets * numElemsInCacheLine
let !maxAttempts = 12 + (log2 $ toEnum nbuckets)
debug $ "creating cuckoo hash table with " ++
show nbuckets ++ " buckets having " ++
show ntotal ++ " total slots"
rng <- newBitStream
hashes <- U.newArray ntotal
keys <- newArray ntotal undefined
values <- newArray ntotal undefined
return $! HashTable nbuckets rng hashes keys values maxAttempts
insert' :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> v
-> ST s (HashTable_ s k v)
insert' ht k v = do
debug "insert': begin"
mbX <- updateOrFail ht k v
z <- maybe (return ht)
(\(k',v') -> grow ht k' v')
mbX
debug "insert': end"
return z
{-# INLINE insert #-}
updateOrFail :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> v
-> ST s (Maybe (k,v))
updateOrFail ht@(HashTable sz _ hashes keys values _) k v = do
debug $ "updateOrFail: begin: sz = " ++ show sz
debug $ " h1=" ++ show h1 ++ ", h2=" ++ show h2
++ ", b1=" ++ show b1 ++ ", b2=" ++ show b2
(didx, hashCode) <- delete' ht True k b1 b2 h1 h2
debug $ "delete' returned (" ++ show didx ++ "," ++ show hashCode ++ ")"
if didx >= 0
then do
U.writeArray hashes didx hashCode
writeArray keys didx k
writeArray values didx v
return Nothing
else cuckoo
where
h1 = hash1 k
h2 = hash2 k
b1 = whichLine h1 sz
b2 = whichLine h2 sz
cuckoo = do
debug "cuckoo: calling cuckooOrFail"
result <- cuckooOrFail ht h1 h2 b1 b2 k v
debug $ "cuckoo: cuckooOrFail returned " ++
(if isJust result then "Just _" else "Nothing")
maybe (return Nothing)
(return . Just)
result
{-# INLINE updateOrFail #-}
delete' :: (Hashable k, Eq k) =>
HashTable_ s k v
-> Bool
-> k
-> Int
-> Int
-> Int
-> Int
-> ST s (Int, Elem)
delete' (HashTable _ _ hashes keys values _) !updating !k b1 b2 h1 h2 = do
debug $ "delete' b1=" ++ show b1
++ " b2=" ++ show b2
++ " h1=" ++ show h1
++ " h2=" ++ show h2
prefetchWrite hashes b2
let !he1 = hashToElem h1
let !he2 = hashToElem h2
idx1 <- searchOne keys hashes k b1 he1
if idx1 < 0
then do
idx2 <- searchOne keys hashes k b2 he2
if idx2 < 0
then if updating
then do
debug $ "delete': looking for empty element"
idxE1 <- cacheLineSearch hashes b1 emptyMarker
debug $ "delete': idxE1 was " ++ show idxE1
if idxE1 >= 0
then return (idxE1, he1)
else do
idxE2 <- cacheLineSearch hashes b2 emptyMarker
debug $ "delete': idxE2 was " ++ show idxE1
if idxE2 >= 0
then return (idxE2, he2)
else return (-1, 0)
else return (-1, 0)
else deleteIt idx2 he2
else deleteIt idx1 he1
where
deleteIt !idx !h = do
if not updating
then do
U.writeArray hashes idx emptyMarker
writeArray keys idx undefined
writeArray values idx undefined
else return ()
return $! (idx, h)
{-# INLINE delete' #-}
cuckooOrFail :: (Hashable k, Eq k) =>
HashTable_ s k v
-> Int
-> Int
-> Int
-> Int
-> k
-> v
-> ST s (Maybe (k,v))
cuckooOrFail (HashTable sz rng hashes keys values maxAttempts0)
!h1_0 !h2_0 !b1_0 !b2_0 !k0 !v0 = do
debug $ "cuckooOrFail h1_0=" ++ show h1_0
++ " h2_0=" ++ show h2_0
++ " b1_0=" ++ show b1_0
++ " b2_0=" ++ show b2_0
!lineChoice <- getNextBit rng
debug $ "chose line " ++ show lineChoice
let (!b, !h) = if lineChoice == 0 then (b1_0, h1_0) else (b2_0, h2_0)
go b h k0 v0 maxAttempts0
where
randomIdx !b = do
!z <- getNBits cacheLineIntBits rng
return $! b + fromIntegral z
bumpIdx !idx !h !k !v = do
let !he = hashToElem h
debug $ "bumpIdx idx=" ++ show idx ++ " h=" ++ show h
++ " he=" ++ show he
!he' <- U.readArray hashes idx
debug $ "bumpIdx: he' was " ++ show he'
!k' <- readArray keys idx
v' <- readArray values idx
U.writeArray hashes idx he
writeArray keys idx k
writeArray values idx v
debug $ "bumped key with he'=" ++ show he'
return $! (he', k', v')
otherHash he k = if hashToElem h1 == he then h2 else h1
where
h1 = hash1 k
h2 = hash2 k
tryWrite !b !h k v maxAttempts = do
debug $ "tryWrite b=" ++ show b ++ " h=" ++ show h
idx <- cacheLineSearch hashes b emptyMarker
debug $ "cacheLineSearch returned " ++ show idx
if idx >= 0
then do
U.writeArray hashes idx $! hashToElem h
writeArray keys idx k
writeArray values idx v
return Nothing
else go b h k v $! maxAttempts - 1
go !b !h !k v !maxAttempts | maxAttempts == 0 = return $! Just (k,v)
| otherwise = do
idx <- randomIdx b
(!he0', !k', v') <- bumpIdx idx h k v
let !h' = otherHash he0' k'
let !b' = whichLine h' sz
tryWrite b' h' k' v' maxAttempts
grow :: (Eq k, Hashable k) =>
HashTable_ s k v
-> k
-> v
-> ST s (HashTable_ s k v)
grow (HashTable sz _ hashes keys values _) k0 v0 = do
newHt <- grow' $! bumpSize bumpFactor sz
mbR <- updateOrFail newHt k0 v0
maybe (return newHt)
(\_ -> grow' $ bumpSize bumpFactor $ _size newHt)
mbR
where
grow' newSz = do
debug $ "growing table, oldsz = " ++ show sz ++
", newsz=" ++ show newSz
newHt <- newSizedReal newSz
rehash newSz newHt
rehash !newSz !newHt = go 0
where
totSz = numElemsInCacheLine * sz
go !i | i >= totSz = return newHt
| otherwise = do
h <- U.readArray hashes i
if (h /= emptyMarker)
then do
k <- readArray keys i
v <- readArray values i
mbR <- updateOrFail newHt k v
maybe (go $ i + 1)
(\_ -> grow' $ bumpSize bumpFactor newSz)
mbR
else go $ i + 1
hashPrime :: Int
hashPrime = if wordSize == 32 then hashPrime32 else hashPrime64
where
hashPrime32 = 0xedf2a025
hashPrime64 = 0x3971ca9c8b3722e9
hash1 :: Hashable k => k -> Int
hash1 = H.hash
{-# INLINE hash1 #-}
hash2 :: Hashable k => k -> Int
hash2 = H.hashWithSalt hashPrime
{-# INLINE hash2 #-}
hashToElem :: Int -> Elem
hashToElem !h = out
where
!(I# lo#) = h .&. U.elemMask
!m# = maskw# lo# 0#
!nm# = not# m#
!r# = ((int2Word# 1#) `and#` m#) `or#` (int2Word# lo# `and#` nm#)
!out = U.primWordToElem r#
{-# INLINE hashToElem #-}
emptyMarker :: Elem
emptyMarker = 0
maxLoad :: Double
maxLoad = 0.88
bumpFactor :: Double
bumpFactor = 0.73
debug :: String -> ST s ()
#ifdef DEBUG
debug s = unsafeIOToST (putStrLn s >> hFlush stdout)
#else
debug _ = return ()
#endif
{-# INLINE debug #-}
whichLine :: Int -> Int -> Int
whichLine !h !sz = whichBucket h sz `iShiftL` cacheLineIntBits
{-# INLINE whichLine #-}
newRef :: HashTable_ s k v -> ST s (HashTable s k v)
newRef = liftM HT . newSTRef
{-# INLINE newRef #-}
writeRef :: HashTable s k v -> HashTable_ s k v -> ST s ()
writeRef (HT ref) ht = writeSTRef ref ht
{-# INLINE writeRef #-}
readRef :: HashTable s k v -> ST s (HashTable_ s k v)
readRef (HT ref) = readSTRef ref
{-# INLINE readRef #-}