{-# LANGUAGE GADTs, GeneralizedNewtypeDeriving, RecordWildCards #-}
{- |
Module: Data.TTLHashTable
Description: Adds TTL entry expiration to the excellent mutable hash tables from the hashtables package
Copyright: (c) Erick Gonzalez, 2019
License: BSD3
Maintainer: erick@codemonkeylabs.de

This library extends fast mutable hashtables so that entries added can be expired after a given TTL (time to live). This TTL can be specified as a default property of the table or on a per entry basis.

-}
module Data.TTLHashTable (
-- * How to use this module:
-- |
-- Import one of the hash table modules from the hashtables package.. i.e. Basic, Cuckoo, etc
-- and "wrap" them in a TTLHashTable:
--
-- @
-- import Data.HashTable.ST.Basic as Basic
--
-- type HashTable k v = TTLHashTable Basic.HashTable k v
--
-- @
--
-- You can then use the functions in this module with this hashtable type. Note that the
-- functions in this module which can fail offer a flexible error handling strategy by virtue of
-- working in the context of a 'Failable' monad. So for example, if the function is used directly
-- in the IO monad and a failure occurs it would then result in an exception being thrown. However
-- if the context supports the possibiliy of failure like a 'MaybeT' or 'ExceptT'
-- transformer, it would then instead return something like @IO Nothing@ or @Left NotFound@
-- respectively (depending on the actual failure of course).
--
-- None of the functions in this module are thread safe, just as the underlying mutable
-- hash tables in the ST monad aren't as well. If concurrent threads need to operate on the same
-- table, you need to provide external means of synchronization to guarantee exclusive access
-- to the table
                          TimeStamp,
                          TTLHashTable,
                          TTLHashTableError(..),
                          Settings(..),
                          insert,
                          insert_,
                          insertWithTTL,
                          insertWithTTL_,
                          delete,
                          find,
                          foldM,
                          getSettings,
                          getTimeStamp,
                          Data.TTLHashTable.mapM_,
                          lookup,
                          lookupAndRenew,
                          lookupMaybeExpired,
                          mutate,
                          new,
                          newWithSettings,
                          reconfigure,
                          removeExpired,
                          size) where

import Prelude                 hiding (lookup)
import Control.Exception              (Exception)
import Control.Monad                  (void, forM_, when)
import Control.Monad.Except           (runExcept, throwError)
import Control.Monad.Failable         (Failable, failure)
import Control.Monad.IO.Class         (MonadIO, liftIO)
import Control.Monad.Trans.Maybe      (MaybeT(..), runMaybeT)
import Data.Bifunctor                 (first)
import Data.Bits                      (finiteBitSize)
import Data.Default                   (Default, def)
import Data.Hashable                  (Hashable)
import Data.IntMap.Strict             (IntMap)
import Data.IORef                     (IORef,
                                       atomicModifyIORef',
                                       modifyIORef',
                                       newIORef,
                                       readIORef,
                                       writeIORef)
import Data.Typeable                  (Typeable)
import System.Clock                   (Clock(..), TimeSpec(..), getTime)

import qualified Data.HashTable.Class as C
import qualified Data.HashTable.IO    as H
import qualified Data.IntMap.Strict   as M

-- | The TTL hash table type, parameterized on the type of table, key and value.
data TTLHashTable h k v where
    TTLHashTable :: (C.HashTable h)
                    => { hashTable_        :: H.IOHashTable h k (Value v),
                        maxSizeRef_       :: IORef Int,
                        numEntriesRef_    :: IORef Int,
                        timeStampsRef_    :: IORef (IntMap k),
                        defaultTTLRef_    :: IORef Int,
                        gcMaxEntriesRef_  :: IORef Int }
                    -> TTLHashTable h k v

data Value v = Value { expiresAt :: TimeStamp,
                       ttl       :: Int,
                       value     :: v }

-- | A representation of a point in time, used to track entry lifetime.
newtype TimeStamp = TimeStamp Int deriving (Num, Integral, Real, Enum, Eq, Ord)

-- | The 'Settings' type allows for specifying how the hash table should behave.
data Settings = Settings {
                           -- | Maximum size of the hash table. Once reached, insertion of keys
                           -- will fail. Defaults to @maxBound@
                           maxSize       :: Int,
                           -- | Default TTL value in milliseconds to be used for an entry if none
                           -- is specified at insertion time
                           defaultTTL    :: Int,
                           -- | Maximum number of entries that can be garbage collected in one
                           -- single call to removeExpired. This setting is provided so that
                           -- the possibility of long running garbage collection can be managed
                           -- by the user of the library. Default is @maxBound@
                           gcMaxEntries  :: Int }

-- | Exception type used to report failures (depending on calling context)
data TTLHashTableError =
    NotFound      -- ^ The entry was not found in the table
  | ExpiredEntry  -- ^ The entry did exist but is no longer valid
  | HashTableFull -- ^ The maximum size for the table has been reached
  | UnsupportedPlatform String -- ^ The platform is not supported
  | HashTableTooLarge -- ^ The hash table is too large for the provided settings
    deriving (Eq, Typeable, Show)

instance Exception TTLHashTableError

instance Default Settings where
    def = Settings { maxSize       = maxBound,
                     defaultTTL    = 24 * 60 * 60 * 1000, -- 1 day in milliseconds
                     gcMaxEntries  = maxBound
                   }

assertIntSize :: (Failable m) => m ()
assertIntSize =
    when (finiteBitSize maxInt < 64) $
      failure $ UnsupportedPlatform "Int size on this platform is < 64 bits"
          where maxInt   = maxBound :: Int


-- | Creates a new hash table with default settings
new :: (C.HashTable h, MonadIO m, Failable m) => m (TTLHashTable h k v)
new = newWithSettings def

-- | Creates a new hash table with the specified settings. Use the 'Default' instance of 'Settings'
-- and then fine tune parameters as needed. I.e:
-- @
-- newWithSettings def { maxSize = 64 }
-- @
newWithSettings :: (C.HashTable h, MonadIO m, Failable m) => Settings -> m (TTLHashTable h k v)
newWithSettings Settings {..} = do
    assertIntSize
    liftIO $ do
      table   <- newHT
      sRef    <- newIORef 0
      tRef    <- newIORef M.empty
      msRef   <- newIORef maxSize
      dTTLRef <- newIORef defaultTTL
      meRef   <- newIORef gcMaxEntries
      return TTLHashTable { hashTable_        = table,
                            maxSizeRef_       = msRef,
                            numEntriesRef_    = sRef,
                            timeStampsRef_    = tRef,
                            defaultTTLRef_    = dTTLRef,
                            gcMaxEntriesRef_  = meRef
                          }
          where newHT | maxSize == maxBound = H.new
                      | otherwise           = H.newSized maxSize

-- | Insert a new entry into the hash table. Take note of the fact that __this function can fail__
-- for example if table has reached maxSize entries for example. Failure is signaled depending on
-- the calling 'Failable' context. So for example if called in pure IO, it would throw a regular
-- IO exception (of type 'TTLHashTableError'). For this reason,
-- __you probably  want to call this function in a 'MaybeT' or 'ExceptT' monad__
insert :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> k
          -> v
          -> m ()
insert ht@TTLHashTable {..} k v = do
  ttl <- liftIO $ readIORef defaultTTLRef_
  insertWithTTL ht ttl k v

-- | Just like 'insert' but doesn't result in a failure if the insertion doesn't succeed.
-- It just saves you from ignoring the return code returned from 'insert' manually
-- (or catching and ignoring the exception in the case of IO)
insert_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
          => TTLHashTable h k v
          -> k
          -> v
          -> m ()
insert_ h k = void . runMaybeT . insert h k

-- | like 'insert' but an entry specific TTL in milliseconds can be provided.
insertWithTTL :: (Eq k, Hashable k, C.HashTable h, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> Int
          -> k
          -> v
          -> m ()
insertWithTTL ht@TTLHashTable {..} ttl k v = do
  numEntries <- liftIO $ readIORef numEntriesRef_
  now        <- getTimeStamp
  let expiresAt = now + fromIntegral (ttl * 1000000) -- to nanoseconds
  maxSize <- liftIO $ readIORef maxSizeRef_
  if numEntries < maxSize
    then insert' expiresAt
    else do
      madeSpace <- checkOldest ht now
      maybe (failure HashTableFull) (const $ insert' expiresAt) madeSpace
          where insert' expiresAt = do
                  let value = Value expiresAt ttl v
                  liftIO $ do
                    H.insert hashTable_ k value
                    modifyIORef' numEntriesRef_ (+1)
                    modifyIORef' timeStampsRef_ $ M.insert (fromIntegral expiresAt) k

-- | like 'insertWithTTL' but ignores insertion failure
insertWithTTL_ :: (Eq k, Hashable k, C.HashTable h, MonadIO m)
          => TTLHashTable h k v
          -> Int
          -> k
          -> v
          -> m ()
insertWithTTL_ h ttl k = void . runMaybeT . insertWithTTL h ttl k


checkOldest :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> TimeStamp -> m (Maybe ())
checkOldest ht@TTLHashTable {..} now =
    liftIO . runMaybeT $ do
      (timeStamp, k) <- MaybeT . atomicModifyIORef' timeStampsRef_ $ \timeStamps ->
                         case M.minViewWithKey timeStamps of
                           Nothing -> (timeStamps, Nothing)
                           Just ((timeStamp, k), timeStamps') ->
                             if (fromIntegral timeStamp) <= now
                               then (timeStamps', Just (timeStamp, k))
                               else (timeStamps, Nothing)
      MaybeT $ mutateWith (deleteExpired $ fromIntegral timeStamp) ht k

-- | Lookup a key in the hash table. If called straight in the IO monad it would throw a
-- 'NotFound' exception, but if called under @MaybeT IO@ or @ExceptT SomeException IO@ it would
-- return @IO Nothing@ or @IO (Left NotFound)@ respectively. So you probably want to
-- __execute this function in one of these transformer monads__
lookup :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookup = lookup' False

-- | Like lookup but it restarts the lifetime of the entry if found. Note that this is
-- not a read only operation (i.e. the entry must be of course modified to update its timestamp)
lookupAndRenew :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m v
lookupAndRenew = lookup' True

-- | Perform a lookup without doing any actual checks to see if the entry is still alive or
-- if it has expired. Instead, the entry expiration timestamp is returned together with the
-- value (if found). The purpose of this function is to provide a manual way to check expiration,
-- for example when performing batch lookups or sacrificing lifetime resolution for the sake
-- of performance (and thus avoiding an expensive timestamp retrieval every time a lookup is
-- performed)
lookupMaybeExpired :: (Eq k, Hashable k, MonadIO m, Failable m) => TTLHashTable h k v -> k -> m (v, TimeStamp)
lookupMaybeExpired TTLHashTable {..} k =
  maybe (failure NotFound) returnValue =<< liftIO (H.lookup hashTable_ k)
    where returnValue Value {..} = return (value, expiresAt)

lookup' :: (Eq k, Hashable k, MonadIO m, Failable m) => Bool -> TTLHashTable h k v -> k -> m v
lookup' False TTLHashTable {..} k = do
        now        <- getTimeStamp
        mValue     <- liftIO $ H.lookup hashTable_ k
        Value {..} <- checkLookedUp mValue now
        return value
lookup' True ht@TTLHashTable {..} k = do
        now               <- getTimeStamp
        (mExpire, mValue) <- mutateWith (refreshEntry now) ht k
        removeTimeStamp mExpire
        Value {..}        <- checkLookedUp mValue now
        liftIO $ modifyIORef' timeStampsRef_ $ M.insert (fromIntegral expiresAt) k
        return value
    where refreshEntry _ Nothing =
              (Nothing, (Nothing, Nothing))
          refreshEntry now (Just v@Value{..}) =
              if expiresAt > now then
                  let v' = Value { expiresAt = now + (fromIntegral ttl * 1000000), -- nanoseconds
                                   ttl = ttl,
                                   value = value }
                  in (Just v', (Just expiresAt, Just v'))
              else
                  (Nothing, (Nothing, Just v))
          removeTimeStamp Nothing =
              return ()
          removeTimeStamp (Just timeStamp) =
              liftIO . modifyIORef' timeStampsRef_ $ M.delete (fromIntegral timeStamp)

checkLookedUp :: (MonadIO m, Failable m)
                 => Maybe (Value v)
                 -> TimeStamp
                 -> m (Value v)
checkLookedUp Nothing _               = failure NotFound
checkLookedUp (Just v@Value {..}) now =
  if expiresAt < now
    then failure ExpiredEntry
    else return v

-- | A lookup function which simply returns 'Maybe' wrapped in the calling 'MonadIO'
-- context, to accomodate the more conventional users
find :: (Eq k, Hashable k, MonadIO m) => TTLHashTable h k v -> k -> m (Maybe v)
find ht@TTLHashTable {..} k =
  runMaybeT $ lookup ht k

-- | delete an entry from the hash table.
delete :: (C.HashTable h, Eq k, Hashable k, MonadIO m, Failable m) =>
          TTLHashTable h k v -> k -> m ()
delete ht@TTLHashTable {..} k = do
    timeStamp <- deleteWith $ \v ->
      (Nothing, v >>= return . expiresAt)
    forM_ timeStamp $ liftIO . modifyIORef' timeStampsRef_ . M.delete . fromIntegral
      where deleteWith fun = mutateWith fun ht k

deleteExpired :: TimeStamp -> Maybe (Value v) -> (Maybe (Value v), Maybe ())
deleteExpired _ Nothing =
  (Nothing, Nothing)
deleteExpired timeStamp (Just v@Value {..}) =
  if expiresAt == timeStamp
    then (Nothing, Just ())
    else (Just v, Nothing)

mutateWith :: (Eq k, Hashable k, MonadIO m, Failable m)
              => (Maybe (Value v) -> (Maybe (Value v), a))
              -> TTLHashTable h k v
              -> k
              -> m a
mutateWith mutator TTLHashTable {..} k = do
  iResult <- liftIO $ do
              numEntries <- readIORef numEntriesRef_
              maxSize    <- readIORef maxSizeRef_
              H.mutate hashTable_ k $ mutate_ numEntries maxSize
  flip (either failure) iResult $ \(n, result) -> do
    liftIO $ modifyIORef' numEntriesRef_ $ (+) n
    return result
        where mutate_ numEntries maxSize mValue =
                  let (mValue', result) = mutator mValue
                      result' = runExcept $ do
                                  n <- case (mValue, mValue') of
                                        (Nothing, Just _) ->
                                          if numEntries < maxSize
                                            then return 1
                                            else throwError HashTableFull
                                        (Just _, Nothing) -> return (-1)
                                        _                 -> return 0
                                  return (n, result)
                  in either (\_ -> (mValue, result')) (\_ -> (mValue', result')) result'

-- | Report the current number of entries in the table, including those who have expired but
-- haven't been garbage collected yet
size :: (MonadIO m) => TTLHashTable h k v -> m Int
size TTLHashTable {..} = liftIO $ readIORef numEntriesRef_

-- | Run garbage collection of expired entries in the table. It returns the number of expired
-- entries left yet to be removed from the table, if the 'gcMaxEntries' limit was reached before
-- finishing cleaning up all old entries. Note that this function as well as all other operations
-- in a hash table are __not__ thread safe. If concurrent threads need to operate on the table,
-- some concurrency primitive must be used to guarantee exclusive access.
removeExpired :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> m Int
removeExpired ht@TTLHashTable {..} = do
  gcMaxEntries <- liftIO $ readIORef gcMaxEntriesRef_
  removeExpired' ht gcMaxEntries

removeExpired' :: (MonadIO m, Eq k, Hashable k) => TTLHashTable h k v -> Int -> m Int
removeExpired' ht@TTLHashTable {..} gcMaxEntries =
  liftIO $ do
    now          <- getTimeStamp
    (n, expired) <- atomicModifyIORef' timeStampsRef_ $ selectedEntries (fromIntegral now)
    Prelude.mapM_ remove expired
    return n
        where remove (timeStamp, k) = mutateWith (deleteExpired $ fromIntegral timeStamp) ht k
              selectedEntries now m =
                  let (old, active)      = M.split now m
                      (selected, notYet) = splitAt gcMaxEntries $ M.toList old
                      (n, toReinsert)    = foldl countFromList (0, M.empty) notYet
                  in (M.union active toReinsert, (n, selected))
              countFromList (n, acc) (k, v) = (n + 1, M.insert k v acc)

-- | Returns a timestamp for the current moment (__now__). This value can be used when
-- | comparing expiration times manually as returned by 'lookupMaybeExpired' for example,
-- | when it makes sense to do so in the name of performance.
getTimeStamp :: (MonadIO m) => m TimeStamp
getTimeStamp = do
  (TimeSpec secs ns) <- liftIO $ getTime Monotonic
  return . fromIntegral $ secs * 1000000000 + ns

-- | A strict fold in IO over the @(key, value)@ records in a hash table
foldM :: (MonadIO m) => (a -> (k, v) -> IO a) -> a -> TTLHashTable h k v -> m a
foldM f x TTLHashTable {..} = liftIO . H.foldM f' x $ hashTable_
    where f' acc (k, Value {..}) = f acc (k, value)

-- | A side-effecting map over the @(key, value)@ records in a hash table
mapM_ :: (MonadIO m) => ((k, v) -> IO a) -> TTLHashTable h k v -> m ()
mapM_ f TTLHashTable {..} = liftIO $ H.mapM_ f' hashTable_
    where f' (k, Value {..}) = f (k, value)

-- | Provide a new set of settings for a given hash table
reconfigure :: (MonadIO m, Failable m) => TTLHashTable h k v -> Settings -> m ()
reconfigure TTLHashTable {..} Settings {..} = do
  numEntries <- liftIO $ readIORef numEntriesRef_
  when (numEntries > maxSize) $ failure HashTableTooLarge
  liftIO $ do
    writeIORef maxSizeRef_ maxSize
    writeIORef defaultTTLRef_ defaultTTL
    writeIORef gcMaxEntriesRef_ gcMaxEntries

getSettings :: (MonadIO m) => TTLHashTable h k v -> m Settings
getSettings TTLHashTable {..} = liftIO $ do
  maxSize       <- readIORef maxSizeRef_
  defaultTTL    <- readIORef defaultTTLRef_
  gcMaxEntries  <- readIORef gcMaxEntriesRef_
  return Settings { maxSize = maxSize,
                    defaultTTL = defaultTTL,
                    gcMaxEntries = gcMaxEntries }

-- | mutate an entry with the provided modification function. The tuple returned corresponds
-- to the new value mapped to the key and a result to return from the mutate operation. Note
-- that if the new value is @Nothing@ then the entry is deleted if it exists or no change is
-- performed if it didn't. If the value is @Just v@ then the value is replaced or inserted
-- depending on whether it was found or not respectively for that key.
mutate :: (Eq k, Hashable k, MonadIO m, Failable m)
          => TTLHashTable h k v
          -> k
          -> (Maybe v -> (Maybe v, a))
          -> m a
mutate ht@TTLHashTable {..} k f = do
  now <- getTimeStamp
  defaultTTL <- liftIO $ readIORef defaultTTLRef_
  mutateWith (mutate' now defaultTTL f) ht k

mutate' :: TimeStamp -> Int -> (Maybe v -> (Maybe v, a)) -> Maybe (Value v) -> (Maybe (Value v), a)
mutate' now defaultTTL f mV =
    let (expiresAt, ttl, mValue) = metaFrom mV
        vFrom mValue' = do
          value <- mValue'
          return Value {expiresAt = expiresAt, ttl = ttl, value = value }
    in first vFrom $ f mValue
        where metaFrom Nothing =
                (now + fromIntegral (defaultTTL * 1000000), defaultTTL, Nothing)
              metaFrom (Just Value {..})  =
                (now + fromIntegral (ttl * 1000000), ttl, Just value)