{-|
Module      : Irc.RateLimit
Description : Rate limit operations for IRC
Copyright   : (c) Eric Mertens, 2016
License     : ISC
Maintainer  : emertens@gmail.com

This module implements a simple rate limiter based on the IRC RFC
to be used to keep an IRC client from getting disconnected for
flooding. It allows one event per duration with a given threshold.

This algorithm keeps track of the time at which the client may
start sending messages. Each message sent advances that time into
the future by the @penalty@. The client is allowed to transmit
up to @threshold@ seconds ahead of this time.

-}
module Irc.RateLimit
  ( RateLimit
  , newRateLimit
  , tickRateLimit
  ) where

import Control.Concurrent
import Control.Monad
import Data.Time

-- | The 'RateLimit' keeps track of rate limit settings as well
-- as the current state of the limit.
data RateLimit = RateLimit
  { RateLimit -> MVar UTCTime
rateStamp     :: !(MVar UTCTime) -- ^ Time that client can send
  , RateLimit -> NominalDiffTime
rateThreshold :: !NominalDiffTime
  , RateLimit -> NominalDiffTime
ratePenalty   :: !NominalDiffTime
  }

-- | Construct a new rate limit with the given penalty and threshold.
newRateLimit ::
  Rational {- ^ penalty seconds -} ->
  Rational {- ^ threshold seconds -} ->
  IO RateLimit
newRateLimit :: Rational -> Rational -> IO RateLimit
newRateLimit Rational
penalty Rational
threshold =
  do UTCTime
now <- IO UTCTime
getCurrentTime
     MVar UTCTime
ref <- forall a. a -> IO (MVar a)
newMVar UTCTime
now

     forall (m :: * -> *) a. Monad m => a -> m a
return RateLimit
        { rateStamp :: MVar UTCTime
rateStamp     = MVar UTCTime
ref
        , rateThreshold :: NominalDiffTime
rateThreshold = forall a b. (Real a, Fractional b) => a -> b
realToFrac (forall a. Ord a => a -> a -> a
max Rational
0 Rational
threshold)
        , ratePenalty :: NominalDiffTime
ratePenalty   = forall a b. (Real a, Fractional b) => a -> b
realToFrac (forall a. Ord a => a -> a -> a
max Rational
0 Rational
penalty)
        }

-- | Account for an event in the context of a 'RateLimit'. This command
-- will block and delay as required to satisfy the current rate. Once
-- it returns it is safe to proceed with the rate limited action.
tickRateLimit :: RateLimit -> IO ()
tickRateLimit :: RateLimit -> IO ()
tickRateLimit RateLimit
r = forall a. MVar a -> (a -> IO a) -> IO ()
modifyMVar_ (RateLimit -> MVar UTCTime
rateStamp RateLimit
r) forall a b. (a -> b) -> a -> b
$ \UTCTime
stamp ->
  do UTCTime
now <- IO UTCTime
getCurrentTime
     let stamp' :: UTCTime
stamp' = RateLimit -> NominalDiffTime
ratePenalty RateLimit
r NominalDiffTime -> UTCTime -> UTCTime
`addUTCTime` forall a. Ord a => a -> a -> a
max UTCTime
stamp UTCTime
now
         diff :: NominalDiffTime
diff   = UTCTime -> UTCTime -> NominalDiffTime
diffUTCTime UTCTime
stamp' UTCTime
now
         excess :: NominalDiffTime
excess = NominalDiffTime
diff forall a. Num a => a -> a -> a
- RateLimit -> NominalDiffTime
rateThreshold RateLimit
r

     forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (NominalDiffTime
excess forall a. Ord a => a -> a -> Bool
> NominalDiffTime
0) (Int -> IO ()
threadDelay (forall a b. (RealFrac a, Integral b) => a -> b
ceiling (NominalDiffTime
1000000 forall a. Num a => a -> a -> a
* NominalDiffTime
excess)))

     forall (m :: * -> *) a. Monad m => a -> m a
return UTCTime
stamp'