module Prometheus.Metric.Counter (
    Counter
,   counter
,   incCounter
,   addCounter
,   unsafeAddCounter
,   addDurationToCounter
,   getCounter
,   countExceptions
) where

import Prometheus.Info
import Prometheus.Metric
import Prometheus.Metric.Observer (timeAction)
import Prometheus.MonadMonitor

import Control.DeepSeq
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad (unless)
import qualified Data.Atomics as Atomics
import qualified Data.ByteString.UTF8 as BS
import qualified Data.IORef as IORef


newtype Counter = MkCounter (IORef.IORef Double)

instance NFData Counter where
  rnf :: Counter -> ()
rnf (MkCounter IORef Double
ioref) = IORef Double -> () -> ()
forall a b. a -> b -> b
seq IORef Double
ioref ()

-- | Creates a new counter metric with a given name and help string.
counter :: Info -> Metric Counter
counter :: Info -> Metric Counter
counter Info
info = IO (Counter, IO [SampleGroup]) -> Metric Counter
forall s. IO (s, IO [SampleGroup]) -> Metric s
Metric (IO (Counter, IO [SampleGroup]) -> Metric Counter)
-> IO (Counter, IO [SampleGroup]) -> Metric Counter
forall a b. (a -> b) -> a -> b
$ do
    IORef Double
ioref <- Double -> IO (IORef Double)
forall a. a -> IO (IORef a)
IORef.newIORef Double
0
    (Counter, IO [SampleGroup]) -> IO (Counter, IO [SampleGroup])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IORef Double -> Counter
MkCounter IORef Double
ioref, Info -> IORef Double -> IO [SampleGroup]
collectCounter Info
info IORef Double
ioref)

withCounter :: MonadMonitor m
          => Counter
          -> (Double -> Double)
          -> m ()
withCounter :: forall (m :: * -> *).
MonadMonitor m =>
Counter -> (Double -> Double) -> m ()
withCounter (MkCounter IORef Double
ioref) Double -> Double
f =
    IO () -> m ()
forall (m :: * -> *). MonadMonitor m => IO () -> m ()
doIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef Double -> (Double -> Double) -> IO ()
forall t. IORef t -> (t -> t) -> IO ()
Atomics.atomicModifyIORefCAS_ IORef Double
ioref Double -> Double
f

-- | Increments the value of a counter metric by 1.
incCounter :: MonadMonitor m => Counter -> m ()
incCounter :: forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter Counter
c = Counter -> (Double -> Double) -> m ()
forall (m :: * -> *).
MonadMonitor m =>
Counter -> (Double -> Double) -> m ()
withCounter Counter
c (Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
1)

-- | Add the given value to the counter, if it is zero or more.
addCounter :: MonadMonitor m => Counter -> Double -> m Bool
addCounter :: forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m Bool
addCounter Counter
c Double
x
  | Double
x Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 = Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  | Bool
otherwise = do
      Counter -> (Double -> Double) -> m ()
forall (m :: * -> *).
MonadMonitor m =>
Counter -> (Double -> Double) -> m ()
withCounter Counter
c Double -> Double
add
      Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
  where add :: Double -> Double
add Double
i = Double
i Double -> Double -> Double
forall a b. a -> b -> b
`seq` Double
x Double -> Double -> Double
forall a b. a -> b -> b
`seq` Double
i Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
x

-- | Add the given value to the counter. Panic if it is less than zero.
unsafeAddCounter :: MonadMonitor m => Counter -> Double -> m ()
unsafeAddCounter :: forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m ()
unsafeAddCounter Counter
c Double
x = do
  Bool
added <- Counter -> Double -> m Bool
forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m Bool
addCounter Counter
c Double
x
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
added (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> m ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> m ()) -> [Char] -> m ()
forall a b. (a -> b) -> a -> b
$ [Char]
"Tried to add negative value to counter: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Double -> [Char]
forall a. Show a => a -> [Char]
show Double
x

-- | Add the duration of an IO action (in seconds) to a counter.
--
-- If the IO action throws, no duration is added.
addDurationToCounter :: (MonadIO m, MonadMonitor m) => Counter -> m a -> m a
addDurationToCounter :: forall (m :: * -> *) a.
(MonadIO m, MonadMonitor m) =>
Counter -> m a -> m a
addDurationToCounter Counter
metric m a
io = do
    (a
result, Double
duration) <- m a -> m (a, Double)
forall (m :: * -> *) a. MonadIO m => m a -> m (a, Double)
timeAction m a
io
    Bool
_ <- Counter -> Double -> m Bool
forall (m :: * -> *). MonadMonitor m => Counter -> Double -> m Bool
addCounter Counter
metric Double
duration 
    a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
result

-- | Retrieves the current value of a counter metric.
getCounter :: MonadIO m => Counter -> m Double
getCounter :: forall (m :: * -> *). MonadIO m => Counter -> m Double
getCounter (MkCounter IORef Double
ioref) = IO Double -> m Double
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Double -> m Double) -> IO Double -> m Double
forall a b. (a -> b) -> a -> b
$ IORef Double -> IO Double
forall a. IORef a -> IO a
IORef.readIORef IORef Double
ioref

collectCounter :: Info -> IORef.IORef Double -> IO [SampleGroup]
collectCounter :: Info -> IORef Double -> IO [SampleGroup]
collectCounter Info
info IORef Double
c = do
    Double
value <- IORef Double -> IO Double
forall a. IORef a -> IO a
IORef.readIORef IORef Double
c
    let sample :: Sample
sample = Text -> LabelPairs -> ByteString -> Sample
Sample (Info -> Text
metricName Info
info) [] ([Char] -> ByteString
BS.fromString ([Char] -> ByteString) -> [Char] -> ByteString
forall a b. (a -> b) -> a -> b
$ Double -> [Char]
forall a. Show a => a -> [Char]
show Double
value)
    [SampleGroup] -> IO [SampleGroup]
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return [Info -> SampleType -> [Sample] -> SampleGroup
SampleGroup Info
info SampleType
CounterType [Sample
sample]]

-- | Count the amount of times an action throws any synchronous exception.
--
-- >>> exceptions <- register $ counter (Info "exceptions_total" "Total amount of exceptions thrown")
-- >>> countExceptions exceptions $ return ()
-- >>> getCounter exceptions
-- 0.0
-- >>> countExceptions exceptions (error "Oh no!") `catch` (\SomeException{} -> return ())
-- >>> getCounter exceptions
-- 1.0
--
-- It's important to note that this will count *all* synchronous exceptions. If
-- you want more granular counting of exceptions, you will need to write custom
-- code using 'incCounter'.
countExceptions :: (MonadCatch m, MonadMonitor m) => Counter -> m a -> m a
countExceptions :: forall (m :: * -> *) a.
(MonadCatch m, MonadMonitor m) =>
Counter -> m a -> m a
countExceptions Counter
m m a
io = m a
io m a -> m () -> m a
forall (m :: * -> *) a b.
(HasCallStack, MonadCatch m) =>
m a -> m b -> m a
`onException` Counter -> m ()
forall (m :: * -> *). MonadMonitor m => Counter -> m ()
incCounter Counter
m

-- $setup
-- >>> :module +Prometheus
-- >>> :set -XOverloadedStrings