-- | This module provides a container-/cgroup-aware substitute for GHC's RTS
-- @-N@ flag. See 'initRTSThreads'.
module Control.Concurrent.CGroup (
  initRTSThreads,
) where

import Control.Exception (Exception (..), SomeAsyncException (SomeAsyncException), SomeException, catch, throwIO)
import qualified Data.Ratio as Ratio
import GHC.Conc (getNumProcessors, setNumCapabilities)
import System.CGroup.Types (CPUQuota (..))
import qualified System.CGroup.V1.CPU as V1
import qualified System.CGroup.V2.CPU as V2

-- | A container-/cgroup-aware substitute for GHC's RTS @-N@ flag.
--
-- On most platforms, this sets the number of runtime threads to match the
-- number of physical processors (see 'GHC.Conc.getNumProcessors'), which is the
-- default behavior of the GHC @-N@ flag.
--
-- When running within a cgroup on linux (most often within a container), this
-- observes the current process' cgroup cpu quota to constrain the number of
-- runtime threads.
--
-- See 'CPUQuota'
initRTSThreads :: IO ()
initRTSThreads :: IO ()
initRTSThreads = do
  CPUQuota
quota <-
    IO CPUQuota
V1.getProcessCPUQuota
      IO CPUQuota -> IO CPUQuota -> IO CPUQuota
forall a. IO a -> IO a -> IO a
`fallback` IO CPUQuota
V2.getProcessEffectiveCPUQuota
      IO CPUQuota -> IO CPUQuota -> IO CPUQuota
forall a. IO a -> IO a -> IO a
`fallback` CPUQuota -> IO CPUQuota
forall (f :: * -> *) a. Applicative f => a -> f a
pure CPUQuota
NoQuota
  CPUQuota -> IO ()
initRTSThreadsFromQuota CPUQuota
quota

-- | Use a CPU quota to set the number of runtime threads.
initRTSThreadsFromQuota :: CPUQuota -> IO ()
initRTSThreadsFromQuota :: CPUQuota -> IO ()
initRTSThreadsFromQuota CPUQuota
NoQuota = IO ()
defaultInitRTSThreads
initRTSThreadsFromQuota (CPUQuota Ratio Int
ratio) = do
  Int
procs <- IO Int
getNumProcessors
  let capabilities :: Int
capabilities = Int -> Int -> Int -> Int
clamp Int
1 Int
procs (Ratio Int -> Int
forall a. Ratio a -> a
Ratio.numerator Ratio Int
ratio Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Ratio Int -> Int
forall a. Ratio a -> a
Ratio.denominator Ratio Int
ratio)
  Int -> IO ()
setNumCapabilities Int
capabilities

-- | Set number of runtime threads to the number of available processors. This
-- matches the behavior of GHC's RTS @-N@ flag.
defaultInitRTSThreads :: IO ()
defaultInitRTSThreads :: IO ()
defaultInitRTSThreads = Int -> IO ()
setNumCapabilities (Int -> IO ()) -> IO Int -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Int
getNumProcessors

-- | Clamp a value within a range
clamp :: Int -> Int -> Int -> Int
clamp :: Int -> Int -> Int -> Int
clamp Int
lower Int
upper = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
lower (Int -> Int) -> (Int -> Int) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
upper

-- | Catch non-async exceptions
safeCatch :: Exception e => IO a -> (e -> IO a) -> IO a
safeCatch :: IO a -> (e -> IO a) -> IO a
safeCatch IO a
act e -> IO a
hdl = IO a
act IO a -> (e -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\e
e -> if e -> Bool
forall e. Exception e => e -> Bool
isSyncException e
e then e -> IO a
hdl e
e else e -> IO a
forall e a. Exception e => e -> IO a
throwIO e
e)

isSyncException :: Exception e => e -> Bool
isSyncException :: e -> Bool
isSyncException e
e =
  case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException (e -> SomeException
forall e. Exception e => e -> SomeException
toException e
e) of
    Just (SomeAsyncException e
_) -> Bool
False
    Maybe SomeAsyncException
Nothing -> Bool
True

-- | Return the result of the first successful action
fallback :: IO a -> IO a -> IO a
fallback :: IO a -> IO a -> IO a
fallback IO a
a IO a
b = IO a
a IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`safeCatch` (\(SomeException
_ :: SomeException) -> IO a
b)