{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

module Ki.Internal.Prelude
  ( forkIO,
    forkOS,
    forkOn,
    interruptiblyMasked,
    uninterruptiblyMasked,
    module X,
  )
where

import Control.Applicative as X (optional, (<|>))
import Control.Concurrent hiding (forkIO, forkOS, forkOn)
import Control.Concurrent as X (ThreadId, myThreadId, threadDelay, throwTo)
import Control.Concurrent.MVar as X
import Control.Exception
import Control.Exception as X (Exception, SomeException, mask_, throwIO, try, uninterruptibleMask, uninterruptibleMask_)
import Control.Monad as X (join, when)
import Data.Coerce as X (coerce)
import Data.Data as X (Data)
import Data.Foldable as X (for_, traverse_)
import Data.Function as X (fix)
import Data.Functor as X (void, ($>), (<&>))
import Data.Int as X
import Data.IntMap.Strict as X (IntMap)
import Data.Map.Strict as X (Map)
import Data.Maybe as X (fromMaybe)
import Data.Sequence as X (Seq)
import Data.Set as X (Set)
import Data.Word as X (Word32)
import Foreign.C.Types (CInt (CInt))
import Foreign.StablePtr (StablePtr, freeStablePtr, newStablePtr)
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
import GHC.Conc (ThreadId (ThreadId))
import GHC.Exts (Int (I#), fork#, forkOn#)
import GHC.Generics as X (Generic)
import GHC.IO (IO (IO), unsafeUnmask)
import Numeric.Natural as X (Natural)
import Prelude as X

-- | Call an action with asynchronous exceptions interruptibly masked.
interruptiblyMasked :: IO a -> IO a
interruptiblyMasked :: IO a -> IO a
interruptiblyMasked (IO State# RealWorld -> (# State# RealWorld, a #)
io) =
  (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
maskAsyncExceptions# State# RealWorld -> (# State# RealWorld, a #)
io)

-- | Call an action with asynchronous exceptions uninterruptibly masked.
uninterruptiblyMasked :: IO a -> IO a
uninterruptiblyMasked :: IO a -> IO a
uninterruptiblyMasked (IO State# RealWorld -> (# State# RealWorld, a #)
io) =
  (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
maskUninterruptible# State# RealWorld -> (# State# RealWorld, a #)
io)

-- Control.Concurrent.forkIO without the dumb exception handler
forkIO :: IO () -> IO ThreadId
forkIO :: IO () -> IO ThreadId
forkIO IO ()
action =
  (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0 ->
    case IO () -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
fork# IO ()
action State# RealWorld
s0 of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)

-- Control.Concurrent.forkOn without the dumb exception handler
forkOn :: Int -> IO () -> IO ThreadId
forkOn :: Int -> IO () -> IO ThreadId
forkOn (I# Int#
cap) IO ()
action =
  (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0 ->
    case Int#
-> IO () -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
Int# -> a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forkOn# Int#
cap IO ()
action State# RealWorld
s0 of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)

-- Control.Concurrent.forkOS without the dumb exception handler
forkOS :: IO () -> IO ThreadId
forkOS :: IO () -> IO ThreadId
forkOS IO ()
action0 = do
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Bool -> Bool
not Bool
rtsSupportsBoundThreads) do
    String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"RTS doesn't support multiple OS threads (use ghc -threaded when linking)"

  MVar ThreadId
threadIdVar <- IO (MVar ThreadId)
forall a. IO (MVar a)
newEmptyMVar

  StablePtr (IO ())
actionStablePtr <- do
    IO ()
action <-
      -- createThread creates a MaskedInterruptible thread; this computation emulates forkIO's inheriting masking state
      IO MaskingState
getMaskingState IO MaskingState -> (MaskingState -> IO ()) -> IO (IO ())
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \case
        MaskingState
Unmasked -> IO () -> IO ()
forall a. IO a -> IO a
unsafeUnmask IO ()
action0
        MaskingState
MaskedInterruptible -> IO ()
action0
        MaskingState
MaskedUninterruptible -> IO () -> IO ()
forall a. IO a -> IO a
uninterruptiblyMasked IO ()
action0

    IO () -> IO (StablePtr (IO ()))
forall a. a -> IO (StablePtr a)
newStablePtr do
      ThreadId
threadId <- IO ThreadId
myThreadId
      MVar ThreadId -> ThreadId -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ThreadId
threadIdVar ThreadId
threadId
      IO ()
action

  StablePtr (IO ()) -> IO CInt
createThread StablePtr (IO ())
actionStablePtr IO CInt -> (CInt -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    CInt
0 -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    CInt
_ -> String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Cannot create OS thread."

  ThreadId
threadId <- MVar ThreadId -> IO ThreadId
forall a. MVar a -> IO a
takeMVar MVar ThreadId
threadIdVar
  StablePtr (IO ()) -> IO ()
forall a. StablePtr a -> IO ()
freeStablePtr StablePtr (IO ())
actionStablePtr
  ThreadId -> IO ThreadId
forall (m :: * -> *) a. Monad m => a -> m a
return ThreadId
threadId

------------------------------------------------------------------------------------------------------------------------
-- FFI calls

foreign import ccall
  createThread :: StablePtr (IO ()) -> IO CInt