module Data.HashTable.ST.Linear
  ( HashTable
  , new
  , newSized
  , delete
  , lookup
  , insert
  , mapM_
  , foldM
  , computeOverhead
  ) where
import           Control.Monad                         hiding (foldM, mapM_)
import           Control.Monad.ST
import           Data.Bits
import           Data.Hashable
import           Data.STRef
import           Prelude                               hiding (lookup, mapM_)
import qualified Data.HashTable.Class                  as C
import           Data.HashTable.Internal.Array
import           Data.HashTable.Internal.Linear.Bucket (Bucket)
import qualified Data.HashTable.Internal.Linear.Bucket as Bucket
import           Data.HashTable.Internal.Utils
#ifdef DEBUG
import           System.IO
#endif
newtype HashTable s k v = HT (STRef s (HashTable_ s k v))
data HashTable_ s k v = HashTable
    { _level    ::  !Int
    , _splitptr ::  !Int
    , _buckets  ::  !(MutableArray s (Bucket s k v))
    }
instance C.HashTable HashTable where
    new             = new
    newSized        = newSized
    insert          = insert
    delete          = delete
    lookup          = lookup
    foldM           = foldM
    mapM_           = mapM_
    computeOverhead = computeOverhead
instance Show (HashTable s k v) where
    show _ = "<HashTable>"
new :: ST s (HashTable s k v)
new = do
    v <- Bucket.newBucketArray 2
    newRef $ HashTable 1 0 v
newSized :: Int -> ST s (HashTable s k v)
newSized n = do
    v <- Bucket.newBucketArray sz
    newRef $ HashTable lvl 0 v
  where
    k   = ceiling (fromIntegral n * fillFactor / fromIntegral bucketSplitSize)
    lvl = max 1 (fromEnum $ log2 k)
    sz  = power2 lvl
delete :: (Hashable k, Eq k) =>
          (HashTable s k v)
       -> k
       -> ST s ()
delete htRef !k = readRef htRef >>= work
  where
    work (HashTable lvl splitptr buckets) = do
        let !h0 = hashKey lvl splitptr k
        debug $ "delete: size=" ++ show (power2 lvl) ++ ", h0=" ++ show h0
                  ++ "splitptr: " ++ show splitptr
        delete' buckets h0 k
lookup :: (Eq k, Hashable k) => (HashTable s k v) -> k -> ST s (Maybe v)
lookup htRef !k = readRef htRef >>= work
  where
    work (HashTable lvl splitptr buckets) = do
        let h0 = hashKey lvl splitptr k
        bucket <- readArray buckets h0
        Bucket.lookup bucket k
insert :: (Eq k, Hashable k) =>
          (HashTable s k v)
       -> k
       -> v
       -> ST s ()
insert htRef k v = do
    ht' <- readRef htRef >>= work
    writeRef htRef ht'
  where
    work ht@(HashTable lvl splitptr buckets) = do
        let !h0 = hashKey lvl splitptr k
        delete' buckets h0 k
        bsz <- primitiveInsert' buckets h0 k v
        if checkOverflow bsz
          then do
            debug $ "insert: splitting"
            h <- split ht
            debug $ "insert: done splitting"
            return h
          else do
            debug $ "insert: done"
            return ht
mapM_ :: ((k,v) -> ST s b) -> HashTable s k v -> ST s ()
mapM_ f htRef = readRef htRef >>= work
  where
    work (HashTable lvl _ buckets) = go 0
      where
        !sz = power2 lvl
        go !i | i >= sz = return ()
              | otherwise = do
            b <- readArray buckets i
            Bucket.mapM_ f b
            go $ i+1
foldM :: (a -> (k,v) -> ST s a)
      -> a -> HashTable s k v
      -> ST s a
foldM f seed0 htRef = readRef htRef >>= work
  where
    work (HashTable lvl _ buckets) = go seed0 0
      where
        !sz = power2 lvl
        go !seed !i | i >= sz   = return seed
                    | otherwise = do
            b <- readArray buckets i
            !seed' <- Bucket.foldM f seed b
            go seed' $ i+1
computeOverhead :: HashTable s k v -> ST s Double
computeOverhead htRef = readRef htRef >>= work
  where
    work (HashTable lvl _ buckets) = do
        (totElems, overhead) <- go 0 0 0
        let n = fromIntegral totElems
        let o = fromIntegral overhead
        return $ (fromIntegral sz + constOverhead + o) / n
      where
        constOverhead = 5.0
        !sz = power2 lvl
        go !nelems !overhead !i | i >= sz = return (nelems, overhead)
                                | otherwise = do
            b <- readArray buckets i
            (!n,!o) <- Bucket.nelemsAndOverheadInWords b
            let !n' = n + nelems
            let !o' = o + overhead
            go n' o' (i+1)
delete' :: Eq k =>
           MutableArray s (Bucket s k v)
        -> Int
        -> k
        -> ST s ()
delete' buckets h0 k = do
    bucket <- readArray buckets h0
    _ <- Bucket.delete bucket k
    return ()
split :: (Hashable k) =>
         (HashTable_ s k v)
      -> ST s (HashTable_ s k v)
split ht@(HashTable lvl splitptr buckets) = do
    debug $ "split: start: nbuck=" ++ show (power2 lvl)
              ++ ", splitptr=" ++ show splitptr
    
    oldBucket <- readArray buckets splitptr
    nelems <- Bucket.size oldBucket
    let !bsz = max Bucket.newBucketSize $
                   ceiling $ (0.625 :: Double) * fromIntegral nelems
    
    dbucket1 <- Bucket.emptyWithSize bsz
    writeArray buckets splitptr dbucket1
    
    let lvl2 = power2 lvl
    let lvl1 = power2 $ lvl1
    (!buckets',!lvl',!sp') <-
        if splitptr+1 >= lvl1
          then do
            debug $ "split: resizing bucket array"
            let lvl3 = 2*lvl2
            b <- Bucket.expandBucketArray lvl3 lvl2 buckets
            debug $ "split: resizing bucket array: done"
            return (b,lvl+1,0)
          else return (buckets,lvl,splitptr+1)
    let ht' = HashTable lvl' sp' buckets'
    
    let splitOffs = splitptr + lvl1
    db2   <- readArray buckets' splitOffs
    db2sz <- Bucket.size db2
    let db2sz' = db2sz + bsz
    db2'  <- Bucket.growBucketTo db2sz' db2
    debug $ "growing bucket at " ++ show splitOffs ++ " to size "
              ++ show db2sz'
    writeArray buckets' splitOffs db2'
    
    debug $ "split: rehashing bucket"
    let f = uncurry $ primitiveInsert ht'
    forceSameType f (uncurry $ primitiveInsert ht)
    Bucket.mapM_ f oldBucket
    debug $ "split: done"
    return ht'
checkOverflow :: Int -> Bool
checkOverflow sz = sz > bucketSplitSize
primitiveInsert :: (Hashable k) =>
                   (HashTable_ s k v)
                -> k
                -> v
                -> ST s Int
primitiveInsert (HashTable lvl splitptr buckets) k v = do
    debug $ "primitiveInsert start: nbuckets=" ++ show (power2 lvl)
    let h0 = hashKey lvl splitptr k
    primitiveInsert' buckets h0 k v
primitiveInsert' :: MutableArray s (Bucket s k v)
                 -> Int
                 -> k
                 -> v
                 -> ST s Int
primitiveInsert' buckets !h0 !k !v = do
    debug $ "primitiveInsert': bucket number=" ++ show h0
    bucket <- readArray buckets h0
    debug $ "primitiveInsert': snoccing bucket"
    (!hw,m) <- Bucket.snoc bucket k v
    debug $ "primitiveInsert': bucket snoc'd"
    maybe (return ())
          (writeArray buckets h0)
          m
    return hw
fillFactor :: Double
fillFactor = 1.3
bucketSplitSize :: Int
bucketSplitSize = Bucket.bucketSplitSize
power2 :: Int -> Int
power2 i = 1 `iShiftL` i
hashKey :: (Hashable k) => Int -> Int -> k -> Int
hashKey !lvl !splitptr !k = h1
  where
    !h0 = hashAtLvl (lvl1) k
    !h1 = if (h0 < splitptr)
            then hashAtLvl lvl k
            else h0
hashAtLvl :: (Hashable k) => Int -> k -> Int
hashAtLvl !lvl !k = h
  where
    !h        = hashcode .&. mask
    !hashcode = hash k
    !mask     = power2 lvl  1
newRef :: HashTable_ s k v -> ST s (HashTable s k v)
newRef = liftM HT . newSTRef
writeRef :: HashTable s k v -> HashTable_ s k v -> ST s ()
writeRef (HT ref) ht = writeSTRef ref ht
readRef :: HashTable s k v -> ST s (HashTable_ s k v)
readRef (HT ref) = readSTRef ref
debug :: String -> ST s ()
#ifdef DEBUG
debug s = unsafeIOToST $ do
              putStrLn s
              hFlush stdout
#else
#ifdef TESTSUITE
debug !s = do
    let !_ = length s
    return $! ()
#else
debug _ = return ()
#endif
#endif