-- | This module provides a container-/cgroup-aware substitute for GHC's RTS
-- @-N@ flag. See 'initRTSThreads'.
module Control.Concurrent.CGroup (
  initRTSThreads,
  initRTSThreadsWith,
  RoundQuota (..),
) where

import Control.Exception (Exception (..), SomeAsyncException (SomeAsyncException), SomeException, catch, throwIO)
import qualified Data.Ratio as Ratio
import GHC.Conc (getNumProcessors, setNumCapabilities)
import System.CGroup.Types (CPUQuota (..))
import qualified System.CGroup.V1.CPU as V1
import qualified System.CGroup.V2.CPU as V2

-- | CPU quotas can be fractions, but the number of RTS capabilities is an integer. This type
-- determines how to round the CPU quota to get to the number of capabilities.
--
-- The names correspond to the similarly named methods of 'RealFrac'.
data RoundQuota
  = CeilingQuota
  | FloorQuota
  | RoundQuota
  deriving (Int -> RoundQuota -> ShowS
[RoundQuota] -> ShowS
RoundQuota -> String
(Int -> RoundQuota -> ShowS)
-> (RoundQuota -> String)
-> ([RoundQuota] -> ShowS)
-> Show RoundQuota
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [RoundQuota] -> ShowS
$cshowList :: [RoundQuota] -> ShowS
show :: RoundQuota -> String
$cshow :: RoundQuota -> String
showsPrec :: Int -> RoundQuota -> ShowS
$cshowsPrec :: Int -> RoundQuota -> ShowS
Show, ReadPrec [RoundQuota]
ReadPrec RoundQuota
Int -> ReadS RoundQuota
ReadS [RoundQuota]
(Int -> ReadS RoundQuota)
-> ReadS [RoundQuota]
-> ReadPrec RoundQuota
-> ReadPrec [RoundQuota]
-> Read RoundQuota
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [RoundQuota]
$creadListPrec :: ReadPrec [RoundQuota]
readPrec :: ReadPrec RoundQuota
$creadPrec :: ReadPrec RoundQuota
readList :: ReadS [RoundQuota]
$creadList :: ReadS [RoundQuota]
readsPrec :: Int -> ReadS RoundQuota
$creadsPrec :: Int -> ReadS RoundQuota
Read, RoundQuota -> RoundQuota -> Bool
(RoundQuota -> RoundQuota -> Bool)
-> (RoundQuota -> RoundQuota -> Bool) -> Eq RoundQuota
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: RoundQuota -> RoundQuota -> Bool
$c/= :: RoundQuota -> RoundQuota -> Bool
== :: RoundQuota -> RoundQuota -> Bool
$c== :: RoundQuota -> RoundQuota -> Bool
Eq, Eq RoundQuota
Eq RoundQuota
-> (RoundQuota -> RoundQuota -> Ordering)
-> (RoundQuota -> RoundQuota -> Bool)
-> (RoundQuota -> RoundQuota -> Bool)
-> (RoundQuota -> RoundQuota -> Bool)
-> (RoundQuota -> RoundQuota -> Bool)
-> (RoundQuota -> RoundQuota -> RoundQuota)
-> (RoundQuota -> RoundQuota -> RoundQuota)
-> Ord RoundQuota
RoundQuota -> RoundQuota -> Bool
RoundQuota -> RoundQuota -> Ordering
RoundQuota -> RoundQuota -> RoundQuota
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: RoundQuota -> RoundQuota -> RoundQuota
$cmin :: RoundQuota -> RoundQuota -> RoundQuota
max :: RoundQuota -> RoundQuota -> RoundQuota
$cmax :: RoundQuota -> RoundQuota -> RoundQuota
>= :: RoundQuota -> RoundQuota -> Bool
$c>= :: RoundQuota -> RoundQuota -> Bool
> :: RoundQuota -> RoundQuota -> Bool
$c> :: RoundQuota -> RoundQuota -> Bool
<= :: RoundQuota -> RoundQuota -> Bool
$c<= :: RoundQuota -> RoundQuota -> Bool
< :: RoundQuota -> RoundQuota -> Bool
$c< :: RoundQuota -> RoundQuota -> Bool
compare :: RoundQuota -> RoundQuota -> Ordering
$ccompare :: RoundQuota -> RoundQuota -> Ordering
Ord)

-- | Round a quota.
roundQuota :: RoundQuota -> Ratio.Ratio Int -> Int
roundQuota :: RoundQuota -> Ratio Int -> Int
roundQuota RoundQuota
roundMode =
  case RoundQuota
roundMode of
    RoundQuota
CeilingQuota -> Ratio Int -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling
    RoundQuota
FloorQuota -> Ratio Int -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor
    RoundQuota
RoundQuota -> Ratio Int -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round

-- | A container-/cgroup-aware substitute for GHC's RTS @-N@ flag.
--
-- On most platforms, this sets the number of runtime threads to match the
-- number of physical processors (see 'GHC.Conc.getNumProcessors'), which is the
-- default behavior of the GHC @-N@ flag.
--
-- When running within a cgroup on linux (most often within a container), this
-- observes the current process' cgroup cpu quota to constrain the number of
-- runtime threads.
--
-- By default, the number of capabilities is determined by rounding the CPU quota down.
--
-- See 'CPUQuota'
initRTSThreads :: IO ()
initRTSThreads :: IO ()
initRTSThreads = RoundQuota -> IO ()
initRTSThreadsWith RoundQuota
FloorQuota

-- | Same as 'initRTSThreads' but lets you specify the quota round mode.
initRTSThreadsWith :: RoundQuota -> IO ()
initRTSThreadsWith :: RoundQuota -> IO ()
initRTSThreadsWith RoundQuota
roundMode = do
  CPUQuota
quota <-
    IO CPUQuota
V1.getProcessCPUQuota
      IO CPUQuota -> IO CPUQuota -> IO CPUQuota
forall a. IO a -> IO a -> IO a
`fallback` IO CPUQuota
V2.getProcessEffectiveCPUQuota
      IO CPUQuota -> IO CPUQuota -> IO CPUQuota
forall a. IO a -> IO a -> IO a
`fallback` CPUQuota -> IO CPUQuota
forall (f :: * -> *) a. Applicative f => a -> f a
pure CPUQuota
NoQuota
  RoundQuota -> CPUQuota -> IO ()
initRTSThreadsFromQuota RoundQuota
roundMode CPUQuota
quota

-- | Use a CPU quota to set the number of runtime threads.
initRTSThreadsFromQuota :: RoundQuota -> CPUQuota -> IO ()
initRTSThreadsFromQuota :: RoundQuota -> CPUQuota -> IO ()
initRTSThreadsFromQuota RoundQuota
_ CPUQuota
NoQuota = IO ()
defaultInitRTSThreads
initRTSThreadsFromQuota RoundQuota
roundMode (CPUQuota Ratio Int
ratio) = do
  Int
procs <- IO Int
getNumProcessors
  Int -> IO ()
setNumCapabilities (Int -> IO ()) -> Int -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Int -> Int
clamp Int
1 Int
procs (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ RoundQuota -> Ratio Int -> Int
roundQuota RoundQuota
roundMode Ratio Int
ratio

-- | Set number of runtime threads to the number of available processors. This
-- matches the behavior of GHC's RTS @-N@ flag.
defaultInitRTSThreads :: IO ()
defaultInitRTSThreads :: IO ()
defaultInitRTSThreads = Int -> IO ()
setNumCapabilities (Int -> IO ()) -> IO Int -> IO ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IO Int
getNumProcessors

-- | Clamp a value within a range
clamp :: Int -> Int -> Int -> Int
clamp :: Int -> Int -> Int -> Int
clamp Int
lower Int
upper = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
lower (Int -> Int) -> (Int -> Int) -> Int -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
upper

-- | Catch non-async exceptions
safeCatch :: Exception e => IO a -> (e -> IO a) -> IO a
safeCatch :: forall e a. Exception e => IO a -> (e -> IO a) -> IO a
safeCatch IO a
act e -> IO a
hdl = IO a
act IO a -> (e -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` (\e
e -> if e -> Bool
forall e. Exception e => e -> Bool
isSyncException e
e then e -> IO a
hdl e
e else e -> IO a
forall e a. Exception e => e -> IO a
throwIO e
e)

isSyncException :: Exception e => e -> Bool
isSyncException :: forall e. Exception e => e -> Bool
isSyncException e
e =
  case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
fromException (e -> SomeException
forall e. Exception e => e -> SomeException
toException e
e) of
    Just (SomeAsyncException e
_) -> Bool
False
    Maybe SomeAsyncException
Nothing -> Bool
True

-- | Return the result of the first successful action
fallback :: IO a -> IO a -> IO a
fallback :: forall a. IO a -> IO a -> IO a
fallback IO a
a IO a
b = IO a
a IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`safeCatch` (\(SomeException
_ :: SomeException) -> IO a
b)