{-# LANGUAGE ScopedTypeVariables #-}

module Stamina
  ( -- functions
    retry,
    retryFor,
    -- types
    RetrySettings (..),
    defaults,
    indefiniteDefaults,
    RetryAction (..),
    RetryStatus (..),
    -- raising exceptions
    escalateWith,
    escalate,
    withLeft,
  )
where

import Control.Concurrent (MVar, isEmptyMVar, newEmptyMVar, threadDelay, tryPutMVar)
import Control.Exception (Exception (..), SomeAsyncException (SomeAsyncException), SomeException, throwIO)
import Control.Monad (void)
import Control.Monad.Catch (MonadCatch, throwM, try)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.Maybe (isJust)
import Data.Time.Clock (NominalDiffTime, UTCTime, diffUTCTime, getCurrentTime, nominalDiffTimeToSeconds, secondsToNominalDiffTime)
import System.Random (randomRIO)

-- | Settings for the retry functions.
data RetrySettings = RetrySettings
  { -- | Initial status of the retry, useful to override when resuming a retry
    RetrySettings -> RetryStatus
initialRetryStatus :: RetryStatus,
    -- | Maximum number of attempts. Can be combined with a timeout. Default to 10.
    RetrySettings -> Maybe Int
maxAttempts :: Maybe Int,
    -- | Maximum time for all retries. Can be combined with attempts. Default to 60s.
    RetrySettings -> Maybe NominalDiffTime
maxTime :: Maybe NominalDiffTime,
    -- | Maximum backoff between retries at any time. Default to 60s.
    RetrySettings -> Maybe NominalDiffTime
backoffMaxRetryDelay :: Maybe NominalDiffTime,
    -- | Maximum jitter that is added to retry back-off delays (the actual jitter added is a random number between 0 and backoffJitter). Defaults to 1.0.
    RetrySettings -> Double
backoffJitter :: Double,
    -- | The exponential base used to compute the retry backoff. Defaults to 2.0.
    RetrySettings -> Double
backoffExpBase :: Double
  }

-- | Tracks the status of a retry
--
-- All fields will be zero if no retries have been attempted yet.
data RetryStatus = RetryStatus
  { -- | Number of retry attempts so far.
    RetryStatus -> Int
attempts :: Int,
    -- | Delay before the next retry.
    RetryStatus -> NominalDiffTime
delay :: NominalDiffTime,
    -- | Total delay so far.
    RetryStatus -> NominalDiffTime
totalDelay :: NominalDiffTime,
    -- | Reset the retry status to the initial state.
    RetryStatus -> IO ()
resetInitial :: IO (),
    -- | The last exception that was thrown.
    RetryStatus -> Maybe SomeException
lastException :: Maybe SomeException
  }

defaults :: RetrySettings
defaults :: RetrySettings
defaults =
  RetrySettings
    { initialRetryStatus :: RetryStatus
initialRetryStatus =
        RetryStatus
          { attempts :: Int
attempts = Int
0,
            delay :: NominalDiffTime
delay = NominalDiffTime
0,
            totalDelay :: NominalDiffTime
totalDelay = NominalDiffTime
0,
            resetInitial :: IO ()
resetInitial = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (),
            lastException :: Maybe SomeException
lastException = Maybe SomeException
forall a. Maybe a
Nothing
          },
      maxAttempts :: Maybe Int
maxAttempts = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
10,
      maxTime :: Maybe NominalDiffTime
maxTime = NominalDiffTime -> Maybe NominalDiffTime
forall a. a -> Maybe a
Just (NominalDiffTime -> Maybe NominalDiffTime)
-> NominalDiffTime -> Maybe NominalDiffTime
forall a b. (a -> b) -> a -> b
$ Pico -> NominalDiffTime
secondsToNominalDiffTime Pico
60,
      backoffMaxRetryDelay :: Maybe NominalDiffTime
backoffMaxRetryDelay = NominalDiffTime -> Maybe NominalDiffTime
forall a. a -> Maybe a
Just (NominalDiffTime -> Maybe NominalDiffTime)
-> NominalDiffTime -> Maybe NominalDiffTime
forall a b. (a -> b) -> a -> b
$ Pico -> NominalDiffTime
secondsToNominalDiffTime Pico
60.0,
      backoffJitter :: Double
backoffJitter = Double
1.0,
      backoffExpBase :: Double
backoffExpBase = Double
2.0
    }

indefiniteDefaults :: RetrySettings
indefiniteDefaults :: RetrySettings
indefiniteDefaults =
  RetrySettings
defaults
    { maxTime = Nothing,
      maxAttempts = Nothing
    }

data RetryAction
  = RaiseException -- Propagate the exception.
  | Retry -- Retry with the delay according to the settings.
  | RetryDelay NominalDiffTime -- Retry after the given delay.
  | RetryTime UTCTime -- Retry after the given time.
  deriving (Int -> RetryAction -> ShowS
[RetryAction] -> ShowS
RetryAction -> String
(Int -> RetryAction -> ShowS)
-> (RetryAction -> String)
-> ([RetryAction] -> ShowS)
-> Show RetryAction
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RetryAction -> ShowS
showsPrec :: Int -> RetryAction -> ShowS
$cshow :: RetryAction -> String
show :: RetryAction -> String
$cshowList :: [RetryAction] -> ShowS
showList :: [RetryAction] -> ShowS
Show, RetryAction -> RetryAction -> Bool
(RetryAction -> RetryAction -> Bool)
-> (RetryAction -> RetryAction -> Bool) -> Eq RetryAction
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: RetryAction -> RetryAction -> Bool
== :: RetryAction -> RetryAction -> Bool
$c/= :: RetryAction -> RetryAction -> Bool
/= :: RetryAction -> RetryAction -> Bool
Eq)

-- | Retry on all sync exceptions, async exceptions will still be thrown.
--
-- The backoff delays between retries grow exponentially plus a random jitter.
-- The backoff for retry attempt number _attempt_ is computed as:
--
-- @
--    backoffExpBase ** (attempt - 1) + random(0, backoffJitter)
-- @
--
-- With the default values, the backoff for the first 5 attempts will be:
--
-- @
--    2 ** 0 + random(0, 1) = 1 + random(0, 1)
--    2 ** 1 + random(0, 1) = 2 + random(0, 1)
--    2 ** 2 + random(0, 1) = 4 + random(0, 1)
--    2 ** 3 + random(0, 1) = 8 + random(0, 1)
--    2 ** 4 + random(0, 1) = 16 + random(0, 1)
-- @
--
-- If all retries fail, the last exception is let through.
retry :: forall m a. (MonadCatch m, MonadIO m) => RetrySettings -> (RetryStatus -> m a) -> m a
retry :: forall (m :: * -> *) a.
(MonadCatch m, MonadIO m) =>
RetrySettings -> (RetryStatus -> m a) -> m a
retry RetrySettings
settings = RetrySettings
-> (SomeException -> m RetryAction) -> (RetryStatus -> m a) -> m a
forall (m :: * -> *) exc a.
(Exception exc, MonadIO m, MonadCatch m) =>
RetrySettings
-> (exc -> m RetryAction) -> (RetryStatus -> m a) -> m a
retryFor RetrySettings
settings SomeException -> m RetryAction
skipAsyncExceptions
  where
    skipAsyncExceptions :: SomeException -> m RetryAction
    skipAsyncExceptions :: SomeException -> m RetryAction
skipAsyncExceptions SomeException
exc = case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
exc of
      Just (SomeAsyncException e
_) -> RetryAction -> m RetryAction
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return RetryAction
RaiseException
      Maybe SomeAsyncException
Nothing -> RetryAction -> m RetryAction
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return RetryAction
Retry

-- Same as retry, but only retry the given exceptions.
retryFor ::
  forall m exc a.
  (Exception exc, MonadIO m, MonadCatch m) =>
  RetrySettings ->
  (exc -> m RetryAction) ->
  (RetryStatus -> m a) ->
  m a
retryFor :: forall (m :: * -> *) exc a.
(Exception exc, MonadIO m, MonadCatch m) =>
RetrySettings
-> (exc -> m RetryAction) -> (RetryStatus -> m a) -> m a
retryFor RetrySettings
settings exc -> m RetryAction
handler RetryStatus -> m a
action = m (RetryStatus, MVar ())
initialize m (RetryStatus, MVar ()) -> ((RetryStatus, MVar ()) -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (RetryStatus, MVar ()) -> m a
(MonadCatch m, MonadIO m) => (RetryStatus, MVar ()) -> m a
go
  where
    initialize :: m (RetryStatus, MVar ())
initialize = do
      MVar ()
resetMVar <- IO (MVar ()) -> m (MVar ())
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (MVar ()) -> m (MVar ())) -> IO (MVar ()) -> m (MVar ())
forall a b. (a -> b) -> a -> b
$ IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
      let retryStatus :: RetryStatus
retryStatus = (RetrySettings -> RetryStatus
initialRetryStatus RetrySettings
settings) {resetInitial = void $ tryPutMVar resetMVar ()}
      (RetryStatus, MVar ()) -> m (RetryStatus, MVar ())
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RetryStatus
retryStatus, MVar ()
resetMVar)
    go :: (MonadCatch m, MonadIO m) => (RetryStatus, MVar ()) -> m a
    go :: (MonadCatch m, MonadIO m) => (RetryStatus, MVar ()) -> m a
go (RetryStatus
retryStatus, MVar ()
currentResetMVar) = do
      Either exc a
result <- m a -> m (Either exc a)
forall (m :: * -> *) e a.
(HasCallStack, MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (m a -> m (Either exc a)) -> m a -> m (Either exc a)
forall a b. (a -> b) -> a -> b
$ RetryStatus -> m a
action RetryStatus
retryStatus
      case Either exc a
result of
        Right a
out -> a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
out
        Left exc
exception -> do
          (RetryStatus
newRetryStatus, MVar ()
newResetMVar) <- do
            Bool
isEmpty <- IO Bool -> m Bool
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> m Bool) -> IO Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ MVar () -> IO Bool
forall a. MVar a -> IO Bool
isEmptyMVar MVar ()
currentResetMVar
            if Bool
isEmpty
              then (RetryStatus, MVar ()) -> m (RetryStatus, MVar ())
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (RetryStatus
retryStatus, MVar ()
currentResetMVar)
              else m (RetryStatus, MVar ())
initialize
          RetryAction
exceptionAction <- exc -> m RetryAction
handler exc
exception
          NominalDiffTime
delay_ <- case RetryAction
exceptionAction of
            RetryAction
RaiseException -> exc -> m NominalDiffTime
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM exc
exception
            RetryAction
Retry -> MonadIO m => RetryStatus -> m NominalDiffTime
RetryStatus -> m NominalDiffTime
increaseDelay RetryStatus
newRetryStatus
            RetryDelay NominalDiffTime
delay_ -> NominalDiffTime -> m NominalDiffTime
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return NominalDiffTime
delay_
            RetryTime UTCTime
time -> IO NominalDiffTime -> m NominalDiffTime
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO NominalDiffTime -> m NominalDiffTime)
-> IO NominalDiffTime -> m NominalDiffTime
forall a b. (a -> b) -> a -> b
$ UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
time (UTCTime -> NominalDiffTime) -> IO UTCTime -> IO NominalDiffTime
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime
          let RetrySettings {Maybe NominalDiffTime
maxTime :: RetrySettings -> Maybe NominalDiffTime
maxTime :: Maybe NominalDiffTime
maxTime, Maybe Int
maxAttempts :: RetrySettings -> Maybe Int
maxAttempts :: Maybe Int
maxAttempts} = RetrySettings
settings
          if (Maybe NominalDiffTime -> Bool
forall a. Maybe a -> Bool
isJust Maybe NominalDiffTime
maxTime Bool -> Bool -> Bool
&& NominalDiffTime -> Maybe NominalDiffTime
forall a. a -> Maybe a
Just (RetryStatus -> NominalDiffTime
totalDelay RetryStatus
newRetryStatus NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
+ NominalDiffTime
delay_) Maybe NominalDiffTime -> Maybe NominalDiffTime -> Bool
forall a. Ord a => a -> a -> Bool
> Maybe NominalDiffTime
maxTime)
            Bool -> Bool -> Bool
|| (Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust Maybe Int
maxAttempts Bool -> Bool -> Bool
&& Int -> Maybe Int
forall a. a -> Maybe a
Just (RetryStatus -> Int
attempts RetryStatus
newRetryStatus) Maybe Int -> Maybe Int -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Int
maxAttempts)
            then exc -> m a
forall e a. (HasCallStack, Exception e) => e -> m a
forall (m :: * -> *) e a.
(MonadThrow m, HasCallStack, Exception e) =>
e -> m a
throwM exc
exception
            else do
              IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Pico -> Int
forall b. Integral b => Pico -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Pico -> Int) -> Pico -> Int
forall a b. (a -> b) -> a -> b
$ Pico
1000 Pico -> Pico -> Pico
forall a. Num a => a -> a -> a
* Pico
1000 Pico -> Pico -> Pico
forall a. Num a => a -> a -> a
* (NominalDiffTime -> Pico
nominalDiffTimeToSeconds NominalDiffTime
delay_)
              (RetryStatus, MVar ()) -> m a
(MonadCatch m, MonadIO m) => (RetryStatus, MVar ()) -> m a
go (RetryStatus -> NominalDiffTime -> SomeException -> RetryStatus
updateRetryStatus RetryStatus
newRetryStatus NominalDiffTime
delay_ (SomeException -> RetryStatus) -> SomeException -> RetryStatus
forall a b. (a -> b) -> a -> b
$ exc -> SomeException
forall e. Exception e => e -> SomeException
toException exc
exception, MVar ()
newResetMVar)

    updateRetryStatus :: RetryStatus -> NominalDiffTime -> SomeException -> RetryStatus
    updateRetryStatus :: RetryStatus -> NominalDiffTime -> SomeException -> RetryStatus
updateRetryStatus RetryStatus
status NominalDiffTime
delay_ SomeException
exception =
      RetryStatus
status
        { attempts = attempts status + 1,
          delay = delay_,
          totalDelay = totalDelay status + delay_,
          lastException = Just exception
        }

    increaseDelay :: (MonadIO m) => RetryStatus -> m NominalDiffTime
    increaseDelay :: MonadIO m => RetryStatus -> m NominalDiffTime
increaseDelay RetryStatus
retryStatus = do
      let RetryStatus {Int
attempts :: RetryStatus -> Int
attempts :: Int
attempts} = RetryStatus
retryStatus
      let RetrySettings {Maybe NominalDiffTime
backoffMaxRetryDelay :: RetrySettings -> Maybe NominalDiffTime
backoffMaxRetryDelay :: Maybe NominalDiffTime
backoffMaxRetryDelay, Double
backoffJitter :: RetrySettings -> Double
backoffJitter :: Double
backoffJitter, Double
backoffExpBase :: RetrySettings -> Double
backoffExpBase :: Double
backoffExpBase} = RetrySettings
settings
      Double
jitter <- (Double, Double) -> m Double
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Double
0, Double
backoffJitter)
      let delay :: NominalDiffTime
delay = Pico -> NominalDiffTime
secondsToNominalDiffTime (Pico -> NominalDiffTime) -> Pico -> NominalDiffTime
forall a b. (a -> b) -> a -> b
$ Double -> Pico
forall a b. (Real a, Fractional b) => a -> b
realToFrac (Double -> Pico) -> Double -> Pico
forall a b. (a -> b) -> a -> b
$ Double
backoffExpBase Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
attempts Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
jitter
      NominalDiffTime -> m NominalDiffTime
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (NominalDiffTime -> m NominalDiffTime)
-> NominalDiffTime -> m NominalDiffTime
forall a b. (a -> b) -> a -> b
$ NominalDiffTime
-> (NominalDiffTime -> NominalDiffTime)
-> Maybe NominalDiffTime
-> NominalDiffTime
forall b a. b -> (a -> b) -> Maybe a -> b
maybe NominalDiffTime
delay (NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Ord a => a -> a -> a
min NominalDiffTime
delay) Maybe NominalDiffTime
backoffMaxRetryDelay

-- | Escalate an Either to an exception by converting the Left value to an exception.
escalateWith :: (Exception exc) => (err -> exc) -> Either err a -> IO a
escalateWith :: forall exc err a.
Exception exc =>
(err -> exc) -> Either err a -> IO a
escalateWith err -> exc
f = (err -> IO a) -> (a -> IO a) -> Either err a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (exc -> IO a
forall e a. Exception e => e -> IO a
throwIO (exc -> IO a) -> (err -> exc) -> err -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. err -> exc
f) a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return

-- | Convert a Maybe to an Either.
withLeft :: a -> Maybe b -> Either a b
withLeft :: forall a b. a -> Maybe b -> Either a b
withLeft a
a = Either a b -> (b -> Either a b) -> Maybe b -> Either a b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (a -> Either a b
forall a b. a -> Either a b
Left a
a) b -> Either a b
forall a b. b -> Either a b
Right

-- | Escalate an Either to an exception.
escalate :: (Exception exc) => Either exc a -> IO a
escalate :: forall exc a. Exception exc => Either exc a -> IO a
escalate = (exc -> exc) -> Either exc a -> IO a
forall exc err a.
Exception exc =>
(err -> exc) -> Either err a -> IO a
escalateWith exc -> exc
forall a. a -> a
id