{-# LANGUAGE CPP #-}
{-# LANGUAGE StrictData #-}

module Control.Concurrent.TokenBucket
  ( Rate (..),
    mkRate,
    infRate,
    TokenBucket,
    newTokenBucket,
    tryAllocateTokens,
  )
where

import Control.Concurrent.MVar qualified as M
import Data.Word (Word64)
import System.Clock qualified as C

-- | Rate represents token bucket parameters.
data Rate = Rate
  { -- | Maximum number of tokens that the token bucket can hold.
    Rate -> Word64
rateBurstAmount :: !Word64,
    -- | Rate at which tokens are to be added to the bucket - expressed as time
    -- in nanoseconds after which a token is added.
    Rate -> Word64
rateNanosPerToken :: !Word64
  }
  deriving stock (Int -> Rate -> ShowS
[Rate] -> ShowS
Rate -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Rate] -> ShowS
$cshowList :: [Rate] -> ShowS
show :: Rate -> String
$cshow :: Rate -> String
showsPrec :: Int -> Rate -> ShowS
$cshowsPrec :: Int -> Rate -> ShowS
Show, Rate -> Rate -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Rate -> Rate -> Bool
$c/= :: Rate -> Rate -> Bool
== :: Rate -> Rate -> Bool
$c== :: Rate -> Rate -> Bool
Eq)

-- | @mkRate@ creates a 'Rate' given the burst amount, and the number of
-- operations (must be > 0) to allow per number of seconds given.
mkRate :: Word64 -> (Word64, Word64) -> Rate
mkRate :: Word64 -> (Word64, Word64) -> Rate
mkRate Word64
burst (Word64
numOperations, Word64
numSeconds) =
  let nanos :: Double
nanos = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Word64
numSeconds forall a. Num a => a -> a -> a
* forall a. Num a => a
C.s2ns :: Double
      perToken :: Word64
perToken = forall a b. (RealFrac a, Integral b) => a -> b
round (Double
nanos forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral Word64
numOperations)
   in Word64 -> Word64 -> Rate
Rate Word64
burst Word64
perToken

-- | @infRate@ creates a 'Rate' whose limit can never be exceeded. Useful to
-- never limit an operation.
infRate :: Rate
infRate :: Rate
infRate = Rate {rateBurstAmount :: Word64
rateBurstAmount = Word64
0, rateNanosPerToken :: Word64
rateNanosPerToken = Word64
0}

data TB = TB
  { TB -> Word64
tbTokens :: !Word64,
    -- lastCheck time is expressed in nanoseconds since some starting point.
    TB -> Word64
tbLastCheck :: !Word64,
    TB -> Rate
tbRate :: !Rate
  }
  deriving stock (Int -> TB -> ShowS
[TB] -> ShowS
TB -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TB] -> ShowS
$cshowList :: [TB] -> ShowS
show :: TB -> String
$cshow :: TB -> String
showsPrec :: Int -> TB -> ShowS
$cshowsPrec :: Int -> TB -> ShowS
Show, TB -> TB -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TB -> TB -> Bool
$c/= :: TB -> TB -> Bool
== :: TB -> TB -> Bool
$c== :: TB -> TB -> Bool
Eq)

newtype TokenBucket = TokenBucket (M.MVar TB)

-- | newTokenBucket creates an initially full token bucket.
newTokenBucket :: Rate -> IO TokenBucket
newTokenBucket :: Rate -> IO TokenBucket
newTokenBucket Rate
r = do
  Word64
now <- IO Word64
getTimeNanos
  MVar TB
mv <- forall a. a -> IO (MVar a)
M.newMVar forall a b. (a -> b) -> a -> b
$ Word64 -> Word64 -> Rate -> TB
TB (Rate -> Word64
rateBurstAmount Rate
r) Word64
now Rate
r
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ MVar TB -> TokenBucket
TokenBucket MVar TB
mv

#if linux_HOST_OS==1
-- On Linux we use MonotonicCoarse for better performance.
getTimeNanos :: IO Word64
getTimeNanos :: IO Word64
getTimeNanos = do
  TimeSpec
t <- Clock -> IO TimeSpec
C.getTime Clock
C.MonotonicCoarse
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ TimeSpec -> Integer
C.toNanoSecs TimeSpec
t
#else
getTimeNanos :: IO Word64
getTimeNanos = do
  t <- C.getTime C.Monotonic
  return $ fromInteger $ C.toNanoSecs t
#endif

-- unsigned arithmetic helpers
minus, plus :: Word64 -> Word64 -> Word64
minus :: Word64 -> Word64 -> Word64
minus Word64
a Word64
b
  | Word64
a forall a. Ord a => a -> a -> Bool
> Word64
b = Word64
a forall a. Num a => a -> a -> a
- Word64
b
  | Bool
otherwise = Word64
0
plus :: Word64 -> Word64 -> Word64
plus Word64
a Word64
b = let s :: Word64
s = Word64
a forall a. Num a => a -> a -> a
+ Word64
b in if Word64
a forall a. Ord a => a -> a -> Bool
<= Word64
s then Word64
s else forall a. Bounded a => a
maxBound

-- | 'tryAllocate tb amount rate' attempts to allocate 'amount' tokens from the
-- given token bucket at the given rate. On success, it returns Nothing, and on
-- failure it returns the minimum amount of time to wait in nanoseconds before
-- the allocation can succeed.
tryAllocateTokens :: TokenBucket -> Word64 -> Rate -> IO (Maybe Word64)
tryAllocateTokens :: TokenBucket -> Word64 -> Rate -> IO (Maybe Word64)
tryAllocateTokens TokenBucket
_ Word64
_ Rate
r | Rate -> Word64
rateNanosPerToken Rate
r forall a. Eq a => a -> a -> Bool
== Word64
0 = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing -- infinite token rate
tryAllocateTokens TokenBucket
_ Word64
amountRequested Rate
r | Word64
amountRequested forall a. Ord a => a -> a -> Bool
> Rate -> Word64
rateBurstAmount Rate
r = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just forall a. Bounded a => a
maxBound
tryAllocateTokens (TokenBucket MVar TB
mv) Word64
amountRequested Rate
r =
  forall a b. MVar a -> (a -> IO (a, b)) -> IO b
M.modifyMVar MVar TB
mv forall a b. (a -> b) -> a -> b
$ \(TB Word64
lvl Word64
ts Rate
_) -> do
    Word64
ct <- IO Word64
getTimeNanos
    let dt :: Word64
dt = Word64
ct Word64 -> Word64 -> Word64
`minus` Word64
ts
        (Word64
dl, Word64
rt) = Word64
dt forall a. Integral a => a -> a -> (a, a)
`quotRem` Rate -> Word64
rateNanosPerToken Rate
r
        lt' :: Word64
lt' = Word64
ct Word64 -> Word64 -> Word64
`minus` Word64
rt
        lvl' :: Word64
lvl'
          | Word64
lvl Word64 -> Word64 -> Word64
`plus` Word64
dl forall a. Ord a => a -> a -> Bool
> Rate -> Word64
rateBurstAmount Rate
r = Rate -> Word64
rateBurstAmount Rate
r
          | Bool
otherwise = Word64
lvl Word64 -> Word64 -> Word64
`plus` Word64
dl
    if Word64
lvl' forall a. Ord a => a -> a -> Bool
>= Word64
amountRequested
      then forall (m :: * -> *) a. Monad m => a -> m a
return (Word64 -> Word64 -> Rate -> TB
TB (Word64
lvl' forall a. Num a => a -> a -> a
- Word64
amountRequested) Word64
lt' Rate
r, forall a. Maybe a
Nothing)
      else do
        let wantTokens :: Word64
wantTokens = Word64
amountRequested Word64 -> Word64 -> Word64
`minus` Word64
lvl'
            wait :: Word64
wait = Word64
wantTokens forall a. Num a => a -> a -> a
* Rate -> Word64
rateNanosPerToken Rate
r Word64 -> Word64 -> Word64
`minus` Word64
rt
        forall (m :: * -> *) a. Monad m => a -> m a
return (Word64 -> Word64 -> Rate -> TB
TB Word64
lvl' Word64
lt' Rate
r, forall a. a -> Maybe a
Just Word64
wait)