module Prometheus.Metric.Counter (
    Counter
,   counter
,   incCounter
,   addCounter
,   unsafeAddCounter
,   addDurationToCounter
,   getCounter
,   countExceptions
) where

import Prometheus.Info
import Prometheus.Metric
import Prometheus.Metric.Observer (timeAction)
import Prometheus.MonadMonitor

import Control.DeepSeq
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad (unless)
import qualified Data.Atomics as Atomics
import qualified Data.ByteString.UTF8 as BS
import qualified Data.IORef as IORef


newtype Counter = MkCounter (IORef.IORef Double)

instance NFData Counter where
  rnf :: Counter -> ()
rnf (MkCounter IORef Double
ioref) = seq :: forall a b. a -> b -> b
seq IORef Double
ioref ()

-- | Creates a new counter metric with a given name and help string.
counter :: Info -> Metric Counter
counter :: Info -> Metric Counter
counter Info
info = forall s. IO (s, IO [SampleGroup]) -> Metric s
Metric forall a b. (a -> b) -> a -> b
$ do
    IORef Double
ioref <- forall a. a -> IO (IORef a)
IORef.newIORef Double
0
    forall (m :: * -> *) a. Monad m => a -> m a
return (IORef Double -> Counter
MkCounter IORef Double
ioref, Info -> IORef Double -> IO [SampleGroup]
collectCounter Info
info IORef Double
ioref)

withCounter :: MonadMonitor m
          => Counter
          -> (Double -> Double)
          -> m ()
withCounter :: forall (m :: * -> *).
MonadMonitor m =>
Counter -> (Double -> Double) -> m ()
withCounter (MkCounter IORef Double
ioref) Double -> Double
f =
    forall (m :: * -> *). MonadMonitor m => IO () -> m ()
doIO forall a b. (a -> b) -> a -> b
$ forall t. IORef t -> (t -> t) -> IO ()
Atomics.atomicModifyIORefCAS_ IORef Double
ioref Double -> Double
f

-- | Increments the value of a counter metric by 1.
incCounter :: MonadMonitor m => Counter -> m ()
incCounter :: forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter Counter
c = forall (m :: * -> *).
MonadMonitor m =>
Counter -> (Double -> Double) -> m ()
withCounter Counter
c (forall a. Num a => a -> a -> a
+ Double
1)

-- | Add the given value to the counter, if it is zero or more.
addCounter :: MonadMonitor m => Counter -> Double -> m Bool
addCounter :: forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m Bool
addCounter Counter
c Double
x
  | Double
x forall a. Ord a => a -> a -> Bool
< Double
0 = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  | Bool
otherwise = do
      forall (m :: * -> *).
MonadMonitor m =>
Counter -> (Double -> Double) -> m ()
withCounter Counter
c Double -> Double
add
      forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  where add :: Double -> Double
add Double
i = Double
i seq :: forall a b. a -> b -> b
`seq` Double
x seq :: forall a b. a -> b -> b
`seq` Double
i forall a. Num a => a -> a -> a
+ Double
x

-- | Add the given value to the counter. Panic if it is less than zero.
unsafeAddCounter :: MonadMonitor m => Counter -> Double -> m ()
unsafeAddCounter :: forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m ()
unsafeAddCounter Counter
c Double
x = do
  Bool
added <- forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m Bool
addCounter Counter
c Double
x
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
added forall a b. (a -> b) -> a -> b
$
    forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"Tried to add negative value to counter: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> [Char]
show Double
x

-- | Add the duration of an IO action (in seconds) to a counter.
--
-- If the IO action throws, no duration is added.
addDurationToCounter :: (MonadIO m, MonadMonitor m) => Counter -> m a -> m a
addDurationToCounter :: forall (m :: * -> *) a.
(MonadIO m, MonadMonitor m) =>
Counter -> m a -> m a
addDurationToCounter Counter
metric m a
io = do
    (a
result, Double
duration) <- forall (m :: * -> *) a. MonadIO m => m a -> m (a, Double)
timeAction m a
io
    Bool
_ <- forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m Bool
addCounter Counter
metric Double
duration 
    forall (m :: * -> *) a. Monad m => a -> m a
return a
result

-- | Retrieves the current value of a counter metric.
getCounter :: MonadIO m => Counter -> m Double
getCounter :: forall (m :: * -> *). MonadIO m => Counter -> m Double
getCounter (MkCounter IORef Double
ioref) = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. IORef a -> IO a
IORef.readIORef IORef Double
ioref

collectCounter :: Info -> IORef.IORef Double -> IO [SampleGroup]
collectCounter :: Info -> IORef Double -> IO [SampleGroup]
collectCounter Info
info IORef Double
c = do
    Double
value <- forall a. IORef a -> IO a
IORef.readIORef IORef Double
c
    let sample :: Sample
sample = Text -> LabelPairs -> ByteString -> Sample
Sample (Info -> Text
metricName Info
info) [] ([Char] -> ByteString
BS.fromString forall a b. (a -> b) -> a -> b
$ forall a. Show a => a -> [Char]
show Double
value)
    forall (m :: * -> *) a. Monad m => a -> m a
return [Info -> SampleType -> [Sample] -> SampleGroup
SampleGroup Info
info SampleType
CounterType [Sample
sample]]

-- | Count the amount of times an action throws any synchronous exception.
--
-- >>> exceptions <- register $ counter (Info "exceptions_total" "Total amount of exceptions thrown")
-- >>> countExceptions exceptions $ return ()
-- >>> getCounter exceptions
-- 0.0
-- >>> countExceptions exceptions (error "Oh no!") `catch` (\SomeException{} -> return ())
-- >>> getCounter exceptions
-- 1.0
--
-- It's important to note that this will count *all* synchronous exceptions. If
-- you want more granular counting of exceptions, you will need to write custom
-- code using 'incCounter'.
countExceptions :: (MonadCatch m, MonadMonitor m) => Counter -> m a -> m a
countExceptions :: forall (m :: * -> *) a.
(MonadCatch m, MonadMonitor m) =>
Counter -> m a -> m a
countExceptions Counter
m m a
io = m a
io forall (m :: * -> *) a b. MonadCatch m => m a -> m b -> m a
`onException` forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter Counter
m

-- $setup
-- >>> :module +Prometheus
-- >>> :set -XOverloadedStrings