{-# LANGUAGE CPP #-}

module Network.Wai.Handler.Warp.Counter (
    Counter,
    newCounter,
    waitForZero,
    increase,
    decrease,
    waitForDecreased,
) where

import Control.Concurrent.STM

import Network.Wai.Handler.Warp.Imports

newtype Counter = Counter (TVar Int)

newCounter :: IO Counter
newCounter :: IO Counter
newCounter = TVar Int -> Counter
Counter (TVar Int -> Counter) -> IO (TVar Int) -> IO Counter
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> IO (TVar Int)
forall a. a -> IO (TVar a)
newTVarIO Int
0

waitForZero :: Counter -> IO ()
waitForZero :: Counter -> IO ()
waitForZero (Counter TVar Int
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    Int
x <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
var
    Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0) STM ()
forall a. STM a
retry

waitForDecreased :: Counter -> IO ()
waitForDecreased :: Counter -> IO ()
waitForDecreased (Counter TVar Int
var) = do
    Int
n0 <- STM Int -> IO Int
forall a. STM a -> IO a
atomically (STM Int -> IO Int) -> STM Int -> IO Int
forall a b. (a -> b) -> a -> b
$ TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
var
    STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Int
n <- TVar Int -> STM Int
forall a. TVar a -> STM a
readTVar TVar Int
var
        Bool -> STM ()
check (Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n0)

increase :: Counter -> IO ()
increase :: Counter -> IO ()
increase (Counter TVar Int
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
var ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

decrease :: Counter -> IO ()
decrease :: Counter -> IO ()
decrease (Counter TVar Int
var) = STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TVar Int -> (Int -> Int) -> STM ()
forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
var ((Int -> Int) -> STM ()) -> (Int -> Int) -> STM ()
forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1