{-# LANGUAGE GADTs, RecordWildCards #-}
module Data.TTLHashTable (
TTLHashTable,
TTLHashTableError(..),
Settings(..),
insert,
insert_,
insertWithTTL,
insertWithTTL_,
delete,
find,
foldM,
getSettings,
Data.TTLHashTable.mapM_,
lookup,
mutate,
new,
newWithSettings,
reconfigure,
removeExpired,
size) where
import Prelude hiding (lookup)
import Control.Exception (Exception)
import Control.Monad (void, forM_, when)
import Control.Monad.Except (runExcept, throwError)
import Control.Monad.Failable (Failable, failure)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Trans.Maybe (MaybeT(..), runMaybeT)
import Data.Bifunctor (first)
import Data.Bits (finiteBitSize)
import Data.Default (Default, def)
import Data.Hashable (Hashable)
import Data.IntMap.Strict (IntMap)
import Data.IORef (IORef,
atomicModifyIORef',
modifyIORef',
newIORef,
readIORef,
writeIORef)
import Data.Typeable (Typeable)
import System.Clock (Clock(..), TimeSpec(..), getTime)
import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO as H
import qualified Data.IntMap.Strict as M
data TTLHashTable h k v where
TTLHashTable :: (C.HashTable h)
=> { hashTable_ :: H.IOHashTable h k (Value v),
maxSizeRef_ :: IORef Int,
numEntriesRef_ :: IORef Int,
timeStampsRef_ :: IORef (IntMap k),
renewUponReadRef_ :: IORef Bool,
defaultTTLRef_ :: IORef Int,
gcMaxEntriesRef_ :: IORef Int }
-> TTLHashTable h k v
data Value v = Value { expiresAt :: Int,
ttl :: Int,
value :: v }
data Settings = Settings {
maxSize :: Int,
renewUponRead :: Bool,
defaultTTL :: Int,
gcMaxEntries :: Int }
data TTLHashTableError =
NotFound
| ExpiredEntry
| HashTableFull
| UnsupportedPlatform String
| HashTableTooLarge
deriving (Eq, Typeable, Show)
instance Exception TTLHashTableError
instance Default Settings where
def = Settings { maxSize = maxBound,
renewUponRead = False,
defaultTTL = 365 * 24 * 60 * 60 * 1000,
gcMaxEntries = maxBound
}
assertIntSize :: (Failable m) => m ()
assertIntSize =
when (finiteBitSize maxInt < 64) $
failure $ UnsupportedPlatform "Int size on this platform is < 64 bits"
where maxInt = maxBound :: Int
new :: (C.HashTable h, MonadIO m, Failable m) => m (TTLHashTable h k v)
new = newWithSettings def
newWithSettings :: (C.HashTable h, MonadIO m, Failable m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} = do
assertIntSize
liftIO $ do
table <- newHT
sRef <- newIORef 0
tRef <- newIORef M.empty
msRef <- newIORef maxSize
rurRef <- newIORef renewUponRead
dTTLRef <- newIORef defaultTTL
meRef <- newIORef gcMaxEntries
return TTLHashTable { hashTable_ = table,
maxSizeRef_ = msRef,
numEntriesRef_ = sRef,
timeStampsRef_ = tRef,
renewUponReadRef_ = rurRef,
defaultTTLRef_ = dTTLRef,
gcMaxEntriesRef_ = meRef
}
where newHT | maxSize == maxBound = H.new
| otherwise = H.newSized maxSize
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert ht@TTLHashTable {..} k v = do
ttl <- liftIO $ readIORef defaultTTLRef_
insertWithTTL ht ttl k v
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> k
-> v
-> m ()
insert_ h k = void . runMaybeT . insert h k
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
numEntries <- liftIO $ readIORef numEntriesRef_
now <- getTimeStamp
let expiresAt = now + ttl * 1000000
maxSize <- liftIO $ readIORef maxSizeRef_
if numEntries < maxSize
then insert' expiresAt
else do
madeSpace <- checkOldest ht now
maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
where insert' expiresAt = do
let value = Value expiresAt ttl v
liftIO $ do
H.insert hashTable_ k value
modifyIORef' numEntriesRef_ (+1)
modifyIORef' timeStampsRef_ $ M.insert expiresAt k
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
=> TTLHashTable h k v
-> Int
-> k
-> v
-> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k
checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> Int -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
liftIO . runMaybeT $ do
(timeStamp, k) <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
case M.minViewWithKey timeStamps of
Nothing -> (timeStamps, Nothing)
Just ((timeStamp, k), timeStamps') ->
if timeStamp <= now
then (timeStamps', Just (timeStamp, k))
else (timeStamps, Nothing)
MaybeT $ mutateWith (deleteExpired timeStamp) ht k
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup ht@TTLHashTable {..} k = do
renewUponRead <- liftIO $ readIORef renewUponReadRef_
lookup' ht renewUponRead k
lookup' :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> Bool -> k -> m v
lookup' ht@TTLHashTable {..} False k = do
now <- getTimeStamp
mValue <- liftIO $ H.lookup hashTable_ k
Value {..} <- checkLookedUp ht k mValue now
return value
lookup' ht@TTLHashTable {..} True k = do
now <- getTimeStamp
(mExpire, mValue) <- mutateWith (refreshEntry now) ht k
removeTimeStamp mExpire
Value {..} <- checkLookedUp ht k mValue now
liftIO $ modifyIORef' timeStampsRef_ $ M.insert expiresAt k
return value
where refreshEntry _ Nothing =
(Nothing, (Nothing, Nothing))
refreshEntry now (Just v@Value{..}) =
if expiresAt > now then
let v' = Value { expiresAt = now + ttl * 1000000,
ttl = ttl,
value = value }
in (Just v', (Just expiresAt, Just v'))
else
(Nothing, (Nothing, Just v))
removeTimeStamp Nothing =
return ()
removeTimeStamp (Just timeStamp) =
liftIO . modifyIORef' timeStampsRef_ $ M.delete timeStamp
checkLookedUp :: (Eq k, Hashable k, MonadIO m, C.HashTable h, Failable m)
=> TTLHashTable h k v
-> k
-> Maybe (Value v)
-> Int
-> m (Value v)
checkLookedUp _ _ Nothing _ = failure NotFound
checkLookedUp ht k (Just v@Value {..}) now =
if expiresAt < now
then do
delete ht k
failure ExpiredEntry
else
return v
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k = do
runMaybeT $ lookup ht k
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m, Failable m) =>
TTLHashTable h k v -> k -> m ()
delete ht@TTLHashTable {..} k = do
timeStamp <- deleteWith $ \v ->
(Nothing, v >>= return . expiresAt)
forM_ timeStamp $ liftIO . modifyIORef' timeStampsRef_ . M.delete
where deleteWith fun = mutateWith fun ht k
deleteExpired :: Int -> Maybe (Value v) -> (Maybe (Value v), Maybe ())
deleteExpired _ Nothing =
(Nothing, Nothing)
deleteExpired timeStamp (Just v@Value {..}) =
if expiresAt == timeStamp
then (Nothing, Just ())
else (Just v, Nothing)
mutateWith :: (Eq k, Hashable k, MonadIO m, Failable m)
=> (Maybe (Value v) -> (Maybe (Value v), a))
-> TTLHashTable h k v
-> k
-> m a
mutateWith mutator TTLHashTable {..} k = do
iResult <- liftIO $ do
numEntries <- readIORef numEntriesRef_
maxSize <- readIORef maxSizeRef_
H.mutate hashTable_ k $ mutate_ numEntries maxSize
flip (either failure) iResult $ \(n, result) -> do
liftIO $ modifyIORef' numEntriesRef_ $ (+) n
return result
where mutate_ numEntries maxSize mValue =
let (mValue', result) = mutator mValue
result' = runExcept $ do
n <- case (mValue, mValue') of
(Nothing, Just _) -> do
if numEntries < maxSize
then return 1
else throwError HashTableFull
(Just _, Nothing) -> return (-1)
_ -> return 0
return (n, result)
in either (\_ -> (mValue, result')) (\_ -> (mValue', result')) result'
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m Int
removeExpired ht@TTLHashTable {..} = do
gcMaxEntries <- liftIO $ readIORef gcMaxEntriesRef_
removeExpired' ht gcMaxEntries
removeExpired' :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> Int -> m Int
removeExpired' ht@TTLHashTable {..} gcMaxEntries =
liftIO $ do
now <- getTimeStamp
(n, expired) <- atomicModifyIORef' timeStampsRef_ $ selectedEntries now
Prelude.mapM_ remove expired
return n
where remove (timeStamp, k) = mutateWith (deleteExpired timeStamp) ht k
selectedEntries now m =
let (old, active) = M.split now m
(selected, notYet) = splitAt gcMaxEntries $ M.toList old
(n, toReinsert) = foldl countFromList (0, M.empty) notYet
in (M.union active toReinsert, (n, selected))
countFromList (n, acc) (k, v) = (n + 1, M.insert k v acc)
getTimeStamp :: (MonadIO m) => m Int
getTimeStamp = do
(TimeSpec secs ns) <- liftIO $ getTime Monotonic
return . fromIntegral $ (secs * 1000000000 + ns)
foldM :: (MonadIO m) => (a -> (k, v) -> IO a) -> a -> TTLHashTable h k v -> m a
foldM f x TTLHashTable {..} = liftIO . H.foldM f' x $ hashTable_
where f' acc (k, Value {..}) = f acc (k, value)
mapM_ :: (MonadIO m) => ((k, v) -> IO a) -> TTLHashTable h k v -> m ()
mapM_ f TTLHashTable {..} = liftIO $ H.mapM_ f' hashTable_
where f' (k, Value {..}) = f (k, value)
reconfigure :: (MonadIO m, Failable m) => TTLHashTable h k v -> Settings -> m ()
reconfigure TTLHashTable {..} Settings {..} = do
numEntries <- liftIO $ readIORef numEntriesRef_
when (numEntries > maxSize) $ failure HashTableTooLarge
liftIO $ do
writeIORef maxSizeRef_ maxSize
writeIORef renewUponReadRef_ renewUponRead
writeIORef defaultTTLRef_ defaultTTL
writeIORef gcMaxEntriesRef_ gcMaxEntries
getSettings :: (MonadIO m) => TTLHashTable h k v -> m Settings
getSettings TTLHashTable {..} = liftIO $ do
maxSize <- readIORef maxSizeRef_
renewUponRead <- readIORef renewUponReadRef_
defaultTTL <- readIORef defaultTTLRef_
gcMaxEntries <- readIORef gcMaxEntriesRef_
return Settings { maxSize = maxSize,
renewUponRead = renewUponRead,
defaultTTL = defaultTTL,
gcMaxEntries = gcMaxEntries }
mutate :: (Eq k, Hashable k, MonadIO m, Failable m)
=> TTLHashTable h k v
-> k
-> (Maybe v -> (Maybe v, a))
-> m a
mutate ht@TTLHashTable {..} k f = do
now <- getTimeStamp
defaultTTL <- liftIO $ readIORef defaultTTLRef_
mutateWith (mutate' now defaultTTL f) ht k
mutate' :: Int -> Int -> (Maybe v -> (Maybe v, a)) -> Maybe (Value v) -> (Maybe (Value v), a)
mutate' now defaultTTL f mV =
let (expiresAt, ttl, mValue) = metaFrom mV
vFrom mValue' = do
value <- mValue'
return Value {expiresAt = expiresAt, ttl = ttl, value = value }
in first vFrom $ f mValue
where metaFrom Nothing = (now + defaultTTL * 1000000, defaultTTL, Nothing)
metaFrom (Just Value {..}) = (now + ttl * 1000000, ttl, Just value)