-- | This module provides a container-/cgroup-aware substitute for GHC's RTS
-- @-N@ flag. See 'initRTSThreads'.
module Control.Concurrent.CGroup (
  initRTSThreads,
) where

import Control.Exception (Exception (..), SomeAsyncException (SomeAsyncException), SomeException, catch, throwIO)
import GHC.Conc (getNumProcessors, setNumCapabilities)
import System.CGroup.CPU (CPUQuota (..), getCPUQuota, resolveCPUController)

-- | A container-/cgroup-aware substitute for GHC's RTS @-N@ flag.
--
-- On most platforms, this sets the number of runtime threads to match the
-- number of physical processors (see 'GHC.Conc.getNumProcessors'), which is the
-- default behavior of the GHC @-N@ flag.
--
-- When running within a cgroup on linux (most often within a container), this
-- observes the current process' cgroup cpu quota to constrain the number of
-- runtime threads.
--
-- See 'CPUQuota'
initRTSThreads :: IO ()
initRTSThreads :: IO ()
initRTSThreads =
  IO ()
initRTSThreadsFromCGroup
    IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`safeCatch` (\(SomeException
_ :: SomeException) -> IO ()
defaultInitRTSThreads)

-- | Uses the current process' cgroup cpu quota to set the number of runtime
-- threads.
--
-- Throws an Exception when the current process is not running within a cgroup.
initRTSThreadsFromCGroup :: IO ()
initRTSThreadsFromCGroup :: IO ()
initRTSThreadsFromCGroup = do
  Controller CPU
cpuController <- IO (Controller CPU)
resolveCPUController
  CPUQuota
cgroupCpuQuota <- Controller CPU -> IO CPUQuota
getCPUQuota Controller CPU
cpuController
  case CPUQuota
cgroupCpuQuota of
    CPUQuota
NoQuota -> IO ()
defaultInitRTSThreads
    CPUQuota Int
quota Int
period -> do
      Int
procs <- IO Int
getNumProcessors
      let capabilities :: Int
capabilities = Int -> Int -> Int -> Int
clamp Int
1 Int
procs (Int
quota Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
period)
      Int -> IO ()
setNumCapabilities Int
capabilities

-- | Set number of runtime threads to the number of available processors. This
-- matches the behavior of GHC's RTS @-N@ flag.
defaultInitRTSThreads :: IO ()
defaultInitRTSThreads :: IO ()
defaultInitRTSThreads = Int -> IO ()
setNumCapabilities (Int -> IO ()) -> IO Int -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Int
getNumProcessors

-- | Clamp a value within a range
clamp :: Int -> Int -> Int -> Int
clamp :: Int -> Int -> Int -> Int
clamp Int
lower Int
upper = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
lower (Int -> Int) -> (Int -> Int) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
upper

-- | Catch non-async exceptions
safeCatch :: Exception e => IO a -> (e -> IO a) -> IO a
safeCatch :: IO a -> (e -> IO a) -> IO a
safeCatch IO a
act e -> IO a
hdl = IO a
act IO a -> (e -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\e
e -> if e -> Bool
forall e. Exception e => e -> Bool
isSyncException e
e then e -> IO a
hdl e
e else e -> IO a
forall e a. Exception e => e -> IO a
throwIO e
e)

isSyncException :: Exception e => e -> Bool
isSyncException :: e -> Bool
isSyncException e
e =
  case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException (e -> SomeException
forall e. Exception e => e -> SomeException
toException e
e) of
    Just (SomeAsyncException e
_) -> Bool
False
    Maybe SomeAsyncException
Nothing -> Bool
True