{-# LANGUAGE CPP             #-}
#ifdef STM_STATS
{-# LANGUAGE RecordWildCards #-}
#endif
module Control.Concurrent.STM.Stats
    ( atomicallyNamed
    , atomically
    , getSTMStats
    , dumpSTMStats
    , module Control.Concurrent.STM
    ) where
import           Control.Concurrent.STM hiding (atomically)
import qualified Control.Concurrent.STM as STM
import           Data.Map               (Map)
#ifdef STM_STATS
import           Control.Exception      (BlockedIndefinitelyOnSTM, Exception,
                                         catch, throwIO)
import           Control.Monad
import           Data.IORef
import qualified Data.Map.Strict        as M
import           Data.Time              (getCurrentTime)
import           Data.Typeable          (Typeable)
import           GHC.Conc               (unsafeIOToSTM)
import           System.IO
import           System.IO.Unsafe
import           Text.Printf
#endif
atomicallyNamed :: String -> STM a -> IO a
atomically :: STM a -> IO a
dumpSTMStats :: IO ()
getSTMStats :: IO (Map String (Int,Int))
#ifndef STM_STATS
getSTMStats :: IO (Map String (Int, Int))
getSTMStats = Map String (Int, Int) -> IO (Map String (Int, Int))
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Map String (Int, Int)
forall a. Monoid a => a
mempty
atomicallyNamed :: forall a. String -> STM a -> IO a
atomicallyNamed String
_ = STM a -> IO a
forall a. STM a -> IO a
atomically
dumpSTMStats :: IO ()
dumpSTMStats = () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
atomically :: forall a. STM a -> IO a
atomically = STM a -> IO a
forall a. STM a -> IO a
STM.atomically
#else
atomicallyNamed = trackNamedSTM
atomically = trackSTM
globalRetryCountMap :: IORef (Map String (Int,Int))
globalRetryCountMap = unsafePerformIO (newIORef M.empty)
{-# NOINLINE globalRetryCountMap #-}
data TrackSTMConf = TrackSTMConf
    { tryThreshold      :: Maybe Int
        
        
    , globalThreshold   :: Maybe Int
        
        
        
    , extendException   :: Bool
        
        
        
    , warnFunction      :: String -> IO ()
        
    , warnInSTMFunction :: String -> IO ()
        
        
        
        
    }
defaultTrackSTMConf :: TrackSTMConf
defaultTrackSTMConf = TrackSTMConf
    { tryThreshold = Just 10
    , globalThreshold = Just 3000
    , extendException = True
    , warnFunction = hPutStrLn stderr
    , warnInSTMFunction = \_ -> return ()
    }
trackSTM :: STM a -> IO a
trackSTM = trackSTMConf defaultTrackSTMConf { extendException = False } "_anonymous_"
trackNamedSTM :: String -> STM a -> IO a
trackNamedSTM = trackSTMConf defaultTrackSTMConf
trackSTMConf :: TrackSTMConf -> String -> STM a -> IO a
trackSTMConf (TrackSTMConf {..}) name txm = do
    counter <- newIORef 0
    let wrappedTx =
            do  unsafeIOToSTM $ do
                    i <- atomicModifyIORef' counter incCounter
                    when (warnPred i) $
                        warnInSTMFunction $ msgPrefix ++ " reached try count of " ++ show i
                txm
    res <- if extendException
          then STM.atomically wrappedTx
              `catch` (\(_::BlockedIndefinitelyOnSTM) ->
                       throwIO (BlockedIndefinitelyOnNamedSTM name))
          else STM.atomically wrappedTx
    i <- readIORef counter
    doMB tryThreshold $ \threshold ->
       when (i > threshold) $
            warnFunction $ msgPrefix ++ " finished after " ++ show (i-1) ++ " retries"
    incGlobalRetryCount (i - 1)
    return res
  where
    doMB Nothing _  = return ()
    doMB (Just x) m = m x
    incCounter i = let j = i + 1 in (j, j)
    warnPred j = case tryThreshold of
        Nothing -> False
        Just n  -> j >= 2*n && (j >= 4 * n || j `mod` (2 * n) == 0)
    msgPrefix = "STM transaction " ++ name
    incGlobalRetryCount i = do
        (k,k') <- atomicModifyIORef' globalRetryCountMap $ \m ->
                let (oldVal, m') = M.insertLookupWithKey
                                    (\_ (a1,b1) (a2,b2) -> ((,) $! a1+a2) $! b1+b2)
                                    name
                                    (1,i)
                                    m
                in (m', let j = maybe 0 snd oldVal in (j,j+i))
        doMB globalThreshold $ \globalRetryThreshold ->
            when (k `div` globalRetryThreshold /= k' `div` globalRetryThreshold) $
                warnFunction $ msgPrefix ++ " reached global retry count of " ++ show k'
newtype BlockedIndefinitelyOnNamedSTM = BlockedIndefinitelyOnNamedSTM String
    deriving (Typeable)
instance Show BlockedIndefinitelyOnNamedSTM where
    showsPrec _ (BlockedIndefinitelyOnNamedSTM name) =
        showString $ "thread blocked indefinitely in STM transaction" ++ name
instance Exception BlockedIndefinitelyOnNamedSTM
getSTMStats = readIORef globalRetryCountMap
dumpSTMStats = do
    stats <- getSTMStats
    time <- show <$> getCurrentTime
    hPutStrLn stderr $ "STM transaction statistics (" ++ time ++ "):"
    sequence_ $
        hPrintf stderr "%-22s %10s %10s %10s\n" "Transaction" "Commits" "Retries" "Ratio" :
        [ hPrintf stderr "%-22s %10d %10d %10.2f\n" name commits retries ratio
        | (name,(commits,retries)) <- M.toList stats
        , commits > 0 
        , let ratio = fromIntegral retries / fromIntegral commits :: Double
        ]
#endif