{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP          #-}
{-# LANGUAGE MagicHash    #-}

module Data.HashMap.Mutable.Basic
  ( HashTable
  , new
  , newSized
  , delete
  , lookup
  , insert
  , mapM_
  , foldM
  , computeOverhead
  ) where


------------------------------------------------------------------------------
import           Control.Exception                 (assert)
import           Control.Monad                     hiding (foldM, mapM_)
import           Control.Monad.ST                  (ST)
import           Control.Monad.Primitive           (PrimMonad,PrimState,unsafePrimToPrim)
import           Data.Bits
import           Data.Hashable                     (Hashable)
import qualified Data.Hashable                     as H
import           Data.Maybe
import           Data.Monoid
import qualified Data.Primitive.ByteArray          as A
import           Data.Primitive.MutVar             (MutVar,readMutVar,writeMutVar,newMutVar)
import           Data.STRef
import           GHC.Exts
import           Prelude                           hiding (lookup, mapM_, read)
------------------------------------------------------------------------------
import           Data.HashMap.Mutable.Internal.Array
import           Data.HashMap.Mutable.Internal.CacheLine
import           Data.HashMap.Mutable.Internal.IntArray  (Elem)
import qualified Data.HashMap.Mutable.Internal.IntArray  as U
import           Data.HashMap.Mutable.Internal.Utils


------------------------------------------------------------------------------
-- | An open addressing hash table using linear probing.
newtype HashTable s k v = HT (MutVar s (HashTable_ s k v))

type SizeRefs s = A.MutableByteArray s

intSz :: Int
intSz = (finiteBitSize (0::Int) `div` 8)

readLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int
readLoad = flip A.readByteArray 0

writeLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m ()
writeLoad = flip A.writeByteArray 0

readDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> m Int
readDelLoad = flip A.readByteArray 1

writeDelLoad :: PrimMonad m => SizeRefs (PrimState m) -> Int -> m ()
writeDelLoad = flip A.writeByteArray 1

newSizeRefs :: PrimMonad m => m (SizeRefs (PrimState m))
newSizeRefs = do
    let asz = 2 * intSz
    a <- A.newAlignedPinnedByteArray asz intSz
    A.fillByteArray a 0 asz 0
    return a


data HashTable_ s k v = HashTable
    { _size   :: {-# UNPACK #-} !Int
    , _load   :: !(SizeRefs s)   -- ^ 2-element array, stores how many entries
                                  -- and deleted entries are in the table.
    , _hashes :: !(U.IntArray s)
    , _keys   :: {-# UNPACK #-} !(MutableArray s k)
    , _values :: {-# UNPACK #-} !(MutableArray s v)
    }


------------------------------------------------------------------------------
instance Show (HashTable s k v) where
    show _ = "<HashTable>"


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:new".
new :: PrimMonad m => m (HashTable (PrimState m) k v)
new = newSized 1
{-# INLINE new #-}


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:newSized".
newSized :: PrimMonad m => Int -> m (HashTable (PrimState m) k v)
newSized n = do
    debug $ "entering: newSized " ++ show n
    let m = nextBestPrime $ ceiling (fromIntegral n / maxLoad)
    ht <- newSizedReal m
    newRef ht
{-# INLINE newSized #-}


------------------------------------------------------------------------------
newSizedReal :: PrimMonad m => Int -> m (HashTable_ (PrimState m) k v)
newSizedReal m = do
    -- make sure the hash array is a multiple of cache-line sized so we can
    -- always search a whole cache line at once
    let m' = ((m + numElemsInCacheLine - 1) `div` numElemsInCacheLine)
             * numElemsInCacheLine
    h  <- U.newArray m'
    k  <- newArray m undefined
    v  <- newArray m undefined
    ld <- newSizeRefs
    return $! HashTable m ld h k v


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:delete".
delete :: (PrimMonad m, Hashable k, Eq k) =>
          (HashTable (PrimState m) k v)
       -> k
       -> m ()
delete htRef k = do
    debug $ "entered: delete: hash=" ++ show h
    ht <- readRef htRef
    _  <- delete' ht True k h
    return ()
  where
    !h = hash k
{-# INLINE delete #-}


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:lookup".
lookup :: (PrimMonad m, Eq k, Hashable k) => (HashTable (PrimState m) k v) -> k -> m (Maybe v)
lookup htRef !k = do
    ht <- readRef htRef
    lookup' ht
  where
    lookup' (HashTable sz _ hashes keys values) = do
        let !b = whichBucket h sz
        debug $ "lookup h=" ++ show h ++ " sz=" ++ show sz ++ " b=" ++ show b
        go b 0 sz

      where
        !h  = hash k
        !he = hashToElem h

        go !b !start !end = {-# SCC "lookup/go" #-} do
            debug $ concat [ "lookup'/go: "
                           , show b
                           , "/"
                           , show start
                           , "/"
                           , show end
                           ]
            idx <- forwardSearch2 hashes b end he emptyMarker
            debug $ "forwardSearch2 returned " ++ show idx
            if (idx < 0 || idx < start || idx >= end)
               then return Nothing
               else do
                 h0  <- U.readArray hashes idx
                 debug $ "h0 was " ++ show h0

                 if recordIsEmpty h0
                   then do
                       debug $ "record empty, returning Nothing"
                       return Nothing
                   else do
                     k' <- readArray keys idx
                     if k == k'
                       then do
                         debug $ "value found at " ++ show idx
                         v <- readArray values idx
                         return $! Just v
                       else do
                         debug $ "value not found, recursing"
                         if idx < b
                           then go (idx + 1) (idx + 1) b
                           else go (idx + 1) start end
{-# INLINE lookup #-}


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:insert".
insert :: (PrimMonad m, Eq k, Hashable k) =>
          (HashTable (PrimState m) k v)
       -> k
       -> v
       -> m ()
insert htRef !k !v = do
    ht <- readRef htRef
    !ht' <- insert' ht
    writeRef htRef ht'

  where
    insert' ht = do
        debug "insert': calling delete'"
        b <- delete' ht False k h

        debug $ concat [ "insert': writing h="
                       , show h
                       , " he="
                       , show he
                       , " b="
                       , show b
                       ]
        U.writeArray hashes b he
        writeArray keys b k
        writeArray values b v

        checkOverflow ht

      where
        !h     = hash k
        !he    = hashToElem h
        hashes = _hashes ht
        keys   = _keys ht
        values = _values ht
{-# INLINE insert #-}


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:foldM".
foldM :: PrimMonad m => (a -> (k,v) -> m a) -> a -> HashTable (PrimState m) k v -> m a
foldM f seed0 htRef = readRef htRef >>= work
  where
    work (HashTable sz _ hashes keys values) = go 0 seed0
      where
        go !i !seed | i >= sz = return seed
                    | otherwise = do
            h <- U.readArray hashes i
            if recordIsEmpty h || recordIsDeleted h
              then go (i+1) seed
              else do
                k <- readArray keys i
                v <- readArray values i
                !seed' <- f seed (k, v)
                go (i+1) seed'


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:mapM_".
mapM_ :: PrimMonad m => (k -> v -> m b) -> HashTable (PrimState m) k v -> m ()
mapM_ f htRef = readRef htRef >>= work
  where
    work (HashTable sz _ hashes keys values) = go 0
      where
        go !i | i >= sz = return ()
              | otherwise = do
            h <- U.readArray hashes i
            if recordIsEmpty h || recordIsDeleted h
              then go (i+1)
              else do
                k <- readArray keys i
                v <- readArray values i
                _ <- f k v
                go (i+1)


------------------------------------------------------------------------------
-- | See the documentation for this function in
-- "Data.HashTable.Class#v:computeOverhead".
computeOverhead :: PrimMonad m => HashTable (PrimState m) k v -> m Double
computeOverhead htRef = readRef htRef >>= work
  where
    work (HashTable sz' loadRef _ _ _) = do
        !ld <- readLoad loadRef
        let k = fromIntegral ld / sz
        return $ constOverhead/sz + (2 + 2*ws*(1-k)) / (k * ws)
      where
        ws = fromIntegral $! finiteBitSize (0::Int) `div` 8
        sz = fromIntegral sz'
        -- Change these if you change the representation
        constOverhead = 14


------------------------------
-- Private functions follow --
------------------------------


------------------------------------------------------------------------------
{-# INLINE insertRecord #-}
insertRecord :: PrimMonad m
             => Int
             -> U.IntArray (PrimState m)
             -> MutableArray (PrimState m) k
             -> MutableArray (PrimState m) v
             -> Int
             -> k
             -> v
             -> m ()
insertRecord !sz !hashes !keys !values !h !key !value = do
    let !b = whichBucket h sz
    debug $ "insertRecord sz=" ++ show sz ++ " h=" ++ show h ++ " b=" ++ show b
    probe b

  where
    he = hashToElem h

    probe !i = {-# SCC "insertRecord/probe" #-} do
        !idx <- forwardSearch2 hashes i sz emptyMarker deletedMarker
        debug $ "forwardSearch2 returned " ++ show idx
        assert (idx >= 0) $ do
            U.writeArray hashes idx he
            writeArray keys idx key
            writeArray values idx value


------------------------------------------------------------------------------
checkOverflow :: (PrimMonad m, Eq k, Hashable k) =>
                 (HashTable_ (PrimState m) k v)
              -> m (HashTable_ (PrimState m) k v)
checkOverflow ht@(HashTable sz ldRef _ _ _) = do
    !ld <- readLoad ldRef
    let !ld' = ld + 1
    writeLoad ldRef ld'
    !dl <- readDelLoad ldRef

    debug $ concat [ "checkOverflow: sz="
                   , show sz
                   , " entries="
                   , show ld
                   , " deleted="
                   , show dl ]

    if fromIntegral (ld + dl) / fromIntegral sz > maxLoad
      then if dl > ld `div` 2
             then rehashAll ht sz
             else growTable ht
      else return ht


------------------------------------------------------------------------------
rehashAll :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> Int -> m (HashTable_ (PrimState m) k v)
rehashAll (HashTable sz loadRef hashes keys values) sz' = do
    debug $ "rehashing: old size " ++ show sz ++ ", new size " ++ show sz'
    ht' <- newSizedReal sz'
    let (HashTable _ loadRef' newHashes newKeys newValues) = ht'
    readLoad loadRef >>= writeLoad loadRef'
    rehash newHashes newKeys newValues
    return ht'

  where
    rehash newHashes newKeys newValues = go 0
      where
        go !i | i >= sz   = return ()
              | otherwise = {-# SCC "growTable/rehash" #-} do
                    h0 <- U.readArray hashes i
                    when (not (recordIsEmpty h0 || recordIsDeleted h0)) $ do
                        k <- readArray keys i
                        v <- readArray values i
                        insertRecord sz' newHashes newKeys newValues
                                     (hash k) k v
                    go $ i+1


------------------------------------------------------------------------------
growTable :: (Hashable k, PrimMonad m) => HashTable_ (PrimState m) k v -> m (HashTable_ (PrimState m) k v)
growTable ht@(HashTable sz _ _ _ _) = do
    let !sz' = bumpSize maxLoad sz
    rehashAll ht sz'


------------------------------------------------------------------------------
-- Helper data structure for delete'
data Slot = Slot {
      _slot       :: {-# UNPACK #-} !Int
    , _wasDeleted :: {-# UNPACK #-} !Int  -- we use Int because Bool won't
                                          -- unpack
    }
  deriving (Show)


------------------------------------------------------------------------------
instance Monoid Slot where
    mempty = Slot maxBound 0
    (Slot x1 b1) `mappend` (Slot x2 b2) =
        if x1 == maxBound then Slot x2 b2 else Slot x1 b1


------------------------------------------------------------------------------
-- Returns the slot in the array where it would be safe to write the given key.
delete' :: (PrimMonad m, Hashable k, Eq k) =>
           (HashTable_ (PrimState m) k v)
        -> Bool
        -> k
        -> Int
        -> m Int
delete' (HashTable sz loadRef hashes keys values) clearOut k h = do
    debug $ "delete': h=" ++ show h ++ " he=" ++ show he
            ++ " sz=" ++ show sz ++ " b0=" ++ show b0
    pair@(found, slot) <- go mempty b0 False
    debug $ "go returned " ++ show pair

    let !b' = _slot slot

    when found $ bump loadRef (-1)

    -- bump the delRef lower if we're writing over a deleted marker
    when (not clearOut && _wasDeleted slot == 1) $ bumpDel loadRef (-1)
    return b'

  where
    he = hashToElem h
    bump ref i = do
        !ld <- readLoad ref
        writeLoad ref $! ld + i
    bumpDel ref i = do
        !ld <- readDelLoad ref
        writeDelLoad ref $! ld + i

    !b0 = whichBucket h sz

    haveWrapped !(Slot fp _) !b = if fp == maxBound
                                    then False
                                    else b <= fp

    -- arguments:

    --   * fp    maintains the slot in the array where it would be safe to
    --           write the given key
    --   * b     search the buckets array starting at this index.
    --   * wrap  True if we've wrapped around, False otherwise

    go !fp !b !wrap = do
        debug $ concat [ "go: fp="
                       , show fp
                       , " b="
                       , show b
                       , ", wrap="
                       , show wrap
                       , ", he="
                       , show he
                       , ", emptyMarker="
                       , show emptyMarker
                       , ", deletedMarker="
                       , show deletedMarker ]

        !idx <- forwardSearch3 hashes b sz he emptyMarker deletedMarker
        debug $ "forwardSearch3 returned " ++ show idx ++ " with sz=" ++ show sz ++ ", b=" ++ show b

        if wrap && idx >= b0
          -- we wrapped around in the search and didn't find our hash code;
          -- this means that the table is full of deleted elements. Just return
          -- the first place we'd be allowed to insert.
          --
          -- TODO: if we get in this situation we should probably just rehash
          -- the table, because every insert is going to be O(n).
          then return $!
                   (False, fp `mappend` (Slot (error "impossible") 0))
          else do
            -- because the table isn't full, we know that there must be either
            -- an empty or a deleted marker somewhere in the table. Assert this
            -- here.
            assert (idx >= 0) $ return ()
            h0 <- U.readArray hashes idx
            debug $ "h0 was " ++ show h0

            if recordIsEmpty h0
              then do
                  let pl = fp `mappend` (Slot idx 0)
                  debug $ "empty, returning " ++ show pl
                  return (False, pl)
              else do
                let !wrap' = haveWrapped fp idx
                if recordIsDeleted h0
                  then do
                      let pl = fp `mappend` (Slot idx 1)
                      debug $ "deleted, cont with pl=" ++ show pl
                      go pl (idx + 1) wrap'
                  else
                    if he == h0
                      then do
                        debug $ "found he == h0 == " ++ show h0
                        k' <- readArray keys idx
                        if k == k'
                          then do
                            let samePlace = _slot fp == idx
                            debug $ "found at " ++ show idx
                            debug $ "clearout=" ++ show clearOut
                            debug $ "sp? " ++ show samePlace
                            -- "clearOut" is set if we intend to write a new
                            -- element into the slot. If we're doing an update
                            -- and we found the old key, instead of writing
                            -- "deleted" and then re-writing the new element
                            -- there, we can just write the new element. This
                            -- only works if we were planning on writing the
                            -- new element here.
                            when (clearOut || not samePlace) $ do
                                bumpDel loadRef 1
                                U.writeArray hashes idx deletedMarker
                                writeArray keys idx undefined
                                writeArray values idx undefined
                            return (True, fp `mappend` (Slot idx 0))
                          else go fp (idx + 1) wrap'
                      else go fp (idx + 1) wrap'

------------------------------------------------------------------------------
maxLoad :: Double
maxLoad = 0.82


------------------------------------------------------------------------------
emptyMarker :: Elem
emptyMarker = 0

------------------------------------------------------------------------------
deletedMarker :: Elem
deletedMarker = 1


------------------------------------------------------------------------------
{-# INLINE recordIsEmpty #-}
recordIsEmpty :: Elem -> Bool
recordIsEmpty = (== emptyMarker)


------------------------------------------------------------------------------
{-# INLINE recordIsDeleted #-}
recordIsDeleted :: Elem -> Bool
recordIsDeleted = (== deletedMarker)


------------------------------------------------------------------------------
{-# INLINE hash #-}
hash :: (Hashable k) => k -> Int
hash = H.hash


------------------------------------------------------------------------------
{-# INLINE hashToElem #-}
hashToElem :: Int -> Elem
hashToElem !h = out
  where
    !(I# lo#) = h .&. U.elemMask

    !m#  = maskw# lo# 0# `or#` maskw# lo# 1#
    !nm# = not# m#

    !r#  = ((int2Word# 2#) `and#` m#) `or#` (int2Word# lo# `and#` nm#)
    !out = U.primWordToElem r#


------------------------------------------------------------------------------
newRef :: PrimMonad m => HashTable_ (PrimState m) k v -> m (HashTable (PrimState m) k v)
newRef = liftM HT . newMutVar
{-# INLINE newRef #-}

writeRef :: PrimMonad m => HashTable (PrimState m) k v -> HashTable_ (PrimState m) k v -> m ()
writeRef (HT ref) ht = writeMutVar ref ht
{-# INLINE writeRef #-}

readRef :: PrimMonad m => HashTable (PrimState m) k v -> m (HashTable_ (PrimState m) k v)
readRef (HT ref) = readMutVar ref
{-# INLINE readRef #-}


------------------------------------------------------------------------------
{-# INLINE debug #-}
debug :: PrimMonad m => String -> m ()
#ifdef DEBUG
debug s = unsafePrimToPrim (putStrLn s)
#else
debug _ = return ()
#endif