{-# LANGUAGE GADTs, RecordWildCards #-}
module Data.TTLHashTable (
TTLHashTable,
TTLHashTableError(..),
Settings(..),
insert,
insert_,
insertWithTTL,
insertWithTTL_,
delete,
find,
lookup,
new,
newWithSettings,
removeExpired,
size) where
import Prelude hiding (lookup)
import Control.Exception (Exception)
import Control.Monad (void)
import Control.Monad.Failable (Failable, failure)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT)
import Data.Default (Default, def)
import Data.Hashable (Hashable)
import Data.IntMap.Strict (IntMap)
import Data.IORef (IORef,
atomicModifyIORef',
modifyIORef',
newIORef,
readIORef)
import Data.Typeable (Typeable)
import System.Clock (Clock(Monotonic), TimeSpec(..), getTime)
import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO as H
import qualified Data.IntMap.Strict as M
data TTLHashTable h k v where
TTLHashTable :: (C.HashTable h)
=> { hashTable_ :: H.IOHashTable h k (Value v),
maxSize_ :: Int,
numEntriesRef_ :: IORef Int,
timeStampsRef_ :: IORef (IntMap k),
renewUponRead_ :: Bool,
defaultTTL_ :: Int,
gcMaxEntries_ :: Int }
-> TTLHashTable h k v
data Value v = Value { expiresAt :: Int,
ttl :: Int,
value :: v }
data Settings = Settings {
maxSize :: Int,
renewUponRead :: Bool,
defaultTTL :: Int,
gcMaxEntries :: Int }
data TTLHashTableError = NotFound
| ExpiredEntry
| HashTableFull
deriving (Eq, Typeable, Show)
instance Exception TTLHashTableError
instance Default Settings where
def = Settings { maxSize = maxBound,
renewUponRead = False,
defaultTTL = 365 * 24 * 60 * 60 * 1000,
gcMaxEntries = maxBound
}
new :: (C.HashTable h, MonadIO m) => m (TTLHashTable h k v)
new = newWithSettings def
newWithSettings :: (C.HashTable h, MonadIO m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} =
liftIO $ do
table <- newHT
sRef <- newIORef 0
tRef <- newIORef M.empty
return TTLHashTable { hashTable_ = table,
maxSize_ = maxSize,
numEntriesRef_ = sRef,
timeStampsRef_ = tRef,
renewUponRead_ = renewUponRead,
defaultTTL_ = defaultTTL,
gcMaxEntries_ = gcMaxEntries
}
where newHT | maxSize == maxBound = H.new
| otherwise = H.newSized maxSize
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert ht@TTLHashTable {..} = insertWithTTL ht defaultTTL_
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert_ h k = void . runMaybeT . insert h k
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
numEntries <- liftIO $ readIORef numEntriesRef_
now <- getTimeStamp
let expiresAt = now + ttl
if numEntries < maxSize_
then insert' expiresAt
else do
madeSpace <- checkOldest ht now
maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
where insert' expiresAt = do
let value = Value expiresAt ttl v
liftIO $ do
H.insert hashTable_ k value
modifyIORef' numEntriesRef_ (+1)
modifyIORef' timeStampsRef_ $ M.insert expiresAt k
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k
checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> Int -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
liftIO . runMaybeT $ do
(timeStamp, k) <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
case M.minViewWithKey timeStamps of
Nothing -> (timeStamps, Nothing)
Just ((timeStamp, k), timeStamps') ->
if timeStamp <= now
then (timeStamps', Just (timeStamp, k))
else (timeStamps, Nothing)
MaybeT $ mutateWith (deleteExpired timeStamp) ht k
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup ht@TTLHashTable {..} k
| not renewUponRead_ = do
now <- getTimeStamp
mValue <- liftIO $ H.lookup hashTable_ k
Value {..} <- checkLookedUp ht k mValue now
return value
| otherwise = do
now <- getTimeStamp
(mExpire, mValue) <- mutateWith (refreshEntry now) ht k
removeTimeStamp mExpire
Value {..} <- checkLookedUp ht k mValue now
liftIO $ modifyIORef' timeStampsRef_ $ M.insert expiresAt k
return value
where refreshEntry _ Nothing =
((Nothing, 0), (Nothing, Nothing))
refreshEntry now (Just v@Value{..}) =
if expiresAt > now then
let v' = Value { expiresAt = now + ttl, ttl = ttl, value = value }
in ((Just v', 0), (Just expiresAt, Just v'))
else
((Nothing, 1), (Nothing, Just v))
removeTimeStamp Nothing =
return ()
removeTimeStamp (Just timeStamp) =
liftIO . modifyIORef' timeStampsRef_ $ M.delete timeStamp
checkLookedUp :: (Eq k, Hashable k, MonadIO m, C.HashTable h, Failable m)
=> TTLHashTable h k v
-> k
-> Maybe (Value v)
-> Int
-> m (Value v)
checkLookedUp _ _ Nothing _ = failure NotFound
checkLookedUp ht k (Just v@Value {..}) now =
if expiresAt < now
then do
delete ht k
failure ExpiredEntry
else
return v
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k = do
runMaybeT $ lookup ht k
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m ()
delete = mutateWith simpleDeletion
where simpleDeletion Nothing = ((Nothing, 0), ())
simpleDeletion (Just Value {..}) = ((Nothing, 1), ())
deleteExpired :: Int -> Maybe (Value v) -> ((Maybe (Value v), Int), Maybe ())
deleteExpired _ Nothing =
((Nothing, 0), Nothing)
deleteExpired timeStamp (Just v@Value {..}) =
if expiresAt == timeStamp
then ((Nothing, 1), Just ())
else ((Just v, 0), Nothing)
mutateWith :: (Eq k, Hashable k, MonadIO m)
=> (Maybe (Value v) -> ((Maybe (Value v), Int), a))
-> TTLHashTable h k v
-> k
-> m a
mutateWith mutator TTLHashTable {..} k =
liftIO $ do
(n, result) <- H.mutate hashTable_ k mutate'
modifyIORef' numEntriesRef_ $ flip (-) n
return result
where mutate' mValue =
let ((mValue', n), result) = mutator mValue
in (mValue', (n, result))
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m Int
removeExpired ht@TTLHashTable {..} =
liftIO $ do
now <- getTimeStamp
(n, expired) <- atomicModifyIORef' timeStampsRef_ $ selectedEntries now
mapM_ remove expired
return n
where remove (timeStamp, k) = mutateWith (deleteExpired timeStamp) ht k
selectedEntries now m =
let (old, active) = M.split now m
(selected, notYet) = splitAt gcMaxEntries_ $ M.toList old
(n, toReinsert) = foldl countFromList (0, M.empty) notYet
in (M.union active toReinsert, (n, selected))
countFromList (n, acc) (k, v) = (n + 1, M.insert k v acc)
getTimeStamp :: (MonadIO m) => m Int
getTimeStamp = do
(TimeSpec secs ns) <- liftIO $ getTime Monotonic
return . fromIntegral $ (secs * 1000000000 + ns) `div` 1000000