-- | Example usage:
--
-- > -- Downloads a large payload from an external data store.
-- > downloadData :: IO ByteString
-- >
-- > cachedDownloadData :: IO (Cached IO ByteString)
-- > cachedDownloadData = cachedIO (secondsToNominalDiffTime 600) downloadData
--
-- The first time @cachedDownloadData@ is called, it calls @downloadData@,
-- stores the result, and returns it. If it is called again:
--
-- * before 10 minutes have passed, it returns the stored value.
-- * after 10 minutes have passed, it calls @downloadData@ and stores the
-- result again.
--
module Control.Concurrent.CachedIO (
    Cached(..),
    cachedIO,
    cachedIOWith,
    cachedIO',
    cachedIOWith'
    ) where

import Control.Concurrent.STM (atomically, newTVar, readTVar, writeTVar, retry, TVar)
import Control.Monad (join)
import Control.Monad.Catch (MonadCatch, onException)
import Control.Monad.IO.Class (liftIO, MonadIO)
import Data.Time.Clock (NominalDiffTime, addUTCTime, getCurrentTime, UTCTime)

-- | A cached IO action in some monad @m@. Use 'runCached' to extract the action when you want to query it.
--
-- Note that using 'Control.Monad.join' when the cached action and the outer monad are the same will ignore caching.
newtype Cached m a = Cached {forall (m :: * -> *) a. Cached m a -> m a
runCached :: m a}

data State a  = Uninitialized | Initializing | Updating a | Fresh UTCTime a

-- | Cache an IO action, producing a version of this IO action that is cached
-- for 'interval' seconds. The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh, if the cache is older than 'interval'
-- seconds.
cachedIO :: (MonadIO m, MonadIO t, MonadCatch t)
         => NominalDiffTime -- ^ Number of seconds before refreshing cache
         -> t a             -- ^ IO action to cache
         -> m (Cached t a)
cachedIO :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
NominalDiffTime -> t a -> m (Cached t a)
cachedIO NominalDiffTime
interval = (UTCTime -> UTCTime -> Bool) -> t a -> m (Cached t a)
forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool) -> t a -> m (Cached t a)
cachedIOWith (NominalDiffTime -> UTCTime -> UTCTime -> Bool
secondsPassed NominalDiffTime
interval)

-- | Cache an IO action, producing a version of this IO action that is cached
-- for 'interval' seconds. The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh, if the cache is older than 'interval'
-- seconds.
cachedIO' :: (MonadIO m, MonadIO t, MonadCatch t)
          => NominalDiffTime -- ^ Number of seconds before refreshing cache
          -> (Maybe (UTCTime, a) -> t a) -- ^ action to cache. The stale value and its refresh date
          -- are passed so that the action can perform external staleness checks
          -> m (Cached t a)
cachedIO' :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
NominalDiffTime -> (Maybe (UTCTime, a) -> t a) -> m (Cached t a)
cachedIO' NominalDiffTime
interval = (UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (Cached t a)
forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (Cached t a)
cachedIOWith' (NominalDiffTime -> UTCTime -> UTCTime -> Bool
secondsPassed NominalDiffTime
interval)

-- | Check if @starting time@ + @seconds@ is after @end time@
secondsPassed :: NominalDiffTime  -- ^ Seconds
               -> UTCTime         -- ^ Start time
               -> UTCTime         -- ^ End time
               -> Bool
secondsPassed :: NominalDiffTime -> UTCTime -> UTCTime -> Bool
secondsPassed NominalDiffTime
interval UTCTime
start UTCTime
end = NominalDiffTime -> UTCTime -> UTCTime
addUTCTime NominalDiffTime
interval UTCTime
start UTCTime -> UTCTime -> Bool
forall a. Ord a => a -> a -> Bool
> UTCTime
end

-- | Cache an IO action, The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh
cachedIOWith
    :: (MonadIO m, MonadIO t, MonadCatch t)
    => (UTCTime -> UTCTime -> Bool) -- ^ Test function:
    --   If 'isCacheStillFresh' 'lastUpdated' 'now' returns 'True'
    --   the cache is considered still fresh and returns the cached IO action
    -> t a -- ^ action to cache.
    -> m (Cached t a)
cachedIOWith :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool) -> t a -> m (Cached t a)
cachedIOWith UTCTime -> UTCTime -> Bool
f t a
io = (UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (Cached t a)
forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (Cached t a)
cachedIOWith' UTCTime -> UTCTime -> Bool
f (t a -> Maybe (UTCTime, a) -> t a
forall a b. a -> b -> a
const t a
io)

-- | Cache an IO action, The cache begins uninitialized.
--
-- The outer IO is responsible for setting up the cache. Use the inner one to
-- either get the cached value or refresh
cachedIOWith'
    :: (MonadIO m, MonadIO t, MonadCatch t)
    => (UTCTime -> UTCTime -> Bool) -- ^ Test function:
    --   If 'isCacheStillFresh' 'lastUpdated' 'now' returns 'True'
    --   the cache is considered still fresh and returns the cached IO action
    -> (Maybe (UTCTime, a) -> t a) -- ^ action to cache. The stale value and its refresh date
    -- are passed so that the action can perform external staleness checks
    -> m (Cached t a)
cachedIOWith' :: forall (m :: * -> *) (t :: * -> *) a.
(MonadIO m, MonadIO t, MonadCatch t) =>
(UTCTime -> UTCTime -> Bool)
-> (Maybe (UTCTime, a) -> t a) -> m (Cached t a)
cachedIOWith' UTCTime -> UTCTime -> Bool
isCacheStillFresh Maybe (UTCTime, a) -> t a
io = do
  TVar (State a)
cachedT <- IO (TVar (State a)) -> m (TVar (State a))
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (STM (TVar (State a)) -> IO (TVar (State a))
forall a. STM a -> IO a
atomically (State a -> STM (TVar (State a))
forall a. a -> STM (TVar a)
newTVar State a
forall a. State a
Uninitialized))
  Cached t a -> m (Cached t a)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Cached t a -> m (Cached t a))
-> (t a -> Cached t a) -> t a -> m (Cached t a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t a -> Cached t a
forall (m :: * -> *) a. m a -> Cached m a
Cached (t a -> m (Cached t a)) -> t a -> m (Cached t a)
forall a b. (a -> b) -> a -> b
$ do
    UTCTime
now <- IO UTCTime -> t UTCTime
forall a. IO a -> t a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
    t (t a) -> t a
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (t (t a) -> t a) -> (STM (t a) -> t (t a)) -> STM (t a) -> t a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (t a) -> t (t a)
forall a. IO a -> t a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (t a) -> t (t a))
-> (STM (t a) -> IO (t a)) -> STM (t a) -> t (t a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. STM (t a) -> IO (t a)
forall a. STM a -> IO a
atomically (STM (t a) -> t a) -> STM (t a) -> t a
forall a b. (a -> b) -> a -> b
$ do
      State a
cached <- TVar (State a) -> STM (State a)
forall a. TVar a -> STM a
readTVar TVar (State a)
cachedT
      case State a
cached of
        previousState :: State a
previousState@(Fresh UTCTime
lastUpdated a
value)
        -- There's data in the cache and it's recent. Just return.
          | UTCTime -> UTCTime -> Bool
isCacheStillFresh UTCTime
lastUpdated UTCTime
now -> t a -> STM (t a)
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> t a
forall a. a -> t a
forall (m :: * -> *) a. Monad m => a -> m a
return a
value)
        -- There's data in the cache, but it's stale. Update the cache state
        -- to prevent a second thread from also executing the action. The second
        -- thread will get the stale data instead.
          | Bool
otherwise -> do
            TVar (State a) -> State a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (State a)
cachedT (a -> State a
forall a. a -> State a
Updating a
value)
            t a -> STM (t a)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (State a -> TVar (State a) -> t a
refreshCache State a
previousState TVar (State a)
cachedT)
        -- Another thread is already updating the cache, just return the stale value
        Updating a
value -> t a -> STM (t a)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> t a
forall a. a -> t a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
value)
        -- The cache is uninitialized. Mark the cache as initializing to block other
        -- threads. Initialize and return.
        State a
Uninitialized -> t a -> STM (t a)
forall a. a -> STM a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (State a -> TVar (State a) -> t a
refreshCache State a
forall a. State a
Uninitialized TVar (State a)
cachedT)
        -- The cache is uninitialized and another thread is already attempting to
        -- initialize it. Block.
        State a
Initializing -> STM (t a)
forall a. STM a
retry
  where
    refreshCache :: State a -> TVar (State a) -> t a
refreshCache State a
previousState TVar (State a)
cachedT = do
      let previous :: Maybe (UTCTime, a)
previous = case State a
previousState of
            Fresh UTCTime
lastUpdated a
value -> (UTCTime, a) -> Maybe (UTCTime, a)
forall a. a -> Maybe a
Just (UTCTime
lastUpdated, a
value)
            State a
_                       -> Maybe (UTCTime, a)
forall a. Maybe a
Nothing
      a
newValue <- Maybe (UTCTime, a) -> t a
io Maybe (UTCTime, a)
previous t a -> t () -> t a
forall (m :: * -> *) a b.
(HasCallStack, MonadCatch m) =>
m a -> m b -> m a
`onException` State a -> TVar (State a) -> t ()
forall (m :: * -> *) a.
MonadIO m =>
State a -> TVar (State a) -> m ()
restoreState State a
previousState TVar (State a)
cachedT
      UTCTime
now <- IO UTCTime -> t UTCTime
forall a. IO a -> t a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO UTCTime
getCurrentTime
      IO () -> t ()
forall a. IO a -> t a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (STM () -> IO ()
forall a. STM a -> IO a
atomically (TVar (State a) -> State a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (State a)
cachedT (UTCTime -> a -> State a
forall a. UTCTime -> a -> State a
Fresh UTCTime
now a
newValue)))
      IO a -> t a
forall a. IO a -> t a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
newValue)

restoreState :: (MonadIO m) => State a -> TVar (State a) -> m ()
restoreState :: forall (m :: * -> *) a.
MonadIO m =>
State a -> TVar (State a) -> m ()
restoreState State a
previousState TVar (State a)
cachedT = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (STM () -> IO ()
forall a. STM a -> IO a
atomically (TVar (State a) -> State a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar (State a)
cachedT State a
previousState))