{-# LANGUAGE CPP #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE UnboxedTuples #-}

module Ki.Concurrency
  ( IO,
    MVar,
    STM,
    TBQueue,
    TMVar,
    TQueue,
    TVar,
    ThreadId,
    atomically,
    catch,
    check,
    forkIO,
    modifyTVar',
    myThreadId,
    newEmptyTMVarIO,
    newMVar,
    newTBQueueIO,
    newTQueueIO,
    newTVar,
    newTVarIO,
    onException,
    putTMVar,
    putTMVarIO,
    readTBQueue,
    readTMVar,
    readTQueue,
    readTVar,
    readTVarIO,
    registerDelay,
    retry,
    threadDelay,
    throwIO,
    throwSTM,
    throwTo,
    try,
    uninterruptibleMask,
    uniqueInt,
    unsafeUnmask,
    withMVar,
    writeTBQueue,
    writeTQueue,
    writeTVar,
  )
where

#ifdef TEST

import Control.Concurrent.Classy hiding (MVar, STM, TBQueue, TMVar, TQueue, TVar, ThreadId, registerDelay)
import qualified Control.Concurrent.Classy
import Control.Exception (Exception, SomeException)
import Numeric.Natural (Natural)
import qualified Test.DejaFu
import qualified Test.DejaFu.Conc.Internal.Common
import qualified Test.DejaFu.Conc.Internal.STM
import Test.DejaFu.Types (ThreadId)
import qualified Prelude
import Prelude hiding (IO)

type IO =
  Test.DejaFu.ConcIO

type MVar =
  Test.DejaFu.Conc.Internal.Common.ModelMVar Prelude.IO

type STM =
  Test.DejaFu.Conc.Internal.STM.ModelSTM Prelude.IO

type TBQueue =
  Control.Concurrent.Classy.TBQueue STM

type TQueue =
  Control.Concurrent.Classy.TQueue STM

type TMVar =
  Control.Concurrent.Classy.TMVar STM

type TVar =
  Test.DejaFu.Conc.Internal.STM.ModelTVar Prelude.IO

forkIO :: IO () -> IO ThreadId
forkIO =
  fork

newTBQueueIO :: Natural -> IO (TBQueue a)
newTBQueueIO =
  atomically . newTBQueue

newTQueueIO :: IO (TQueue a)
newTQueueIO =
  atomically newTQueue

newEmptyTMVarIO :: IO (TMVar a)
newEmptyTMVarIO =
  atomically newEmptyTMVar

newTVarIO :: a -> IO (TVar a)
newTVarIO =
  atomically . newTVar

onException :: IO a -> IO b -> IO a
onException action cleanup =
  catch @_ @SomeException action \exception -> do
    _ <- cleanup
    throwIO exception

putTMVarIO :: TMVar a -> a -> IO ()
putTMVarIO var x =
  atomically (putTMVar var x)

readTVarIO :: TVar a -> IO a
readTVarIO =
  atomically . readTVar

registerDelay :: Int -> IO (STM (), IO ())
registerDelay micros = do
  var <- Control.Concurrent.Classy.registerDelay micros
  pure (readTVar var >>= check, pure ())

throwIO :: Exception e => e -> IO a
throwIO =
  throw

try :: Exception e => IO a -> IO (Either e a)
try action =
  catch (Right <$> action) (pure . Left)

uniqueInt :: IO Int
uniqueInt =
  pure 0

#else

import Control.Concurrent hiding (forkIO)
import Control.Concurrent.STM hiding (registerDelay)
import Control.Exception
import Control.Monad (unless)
import Data.IORef (IORef, atomicModifyIORef', newIORef)
import GHC.Conc (ThreadId (ThreadId))
#if defined(mingw32_HOST_OS)
import GHC.Conc.Windows
#else
import GHC.Event
#endif
import GHC.Exts (fork#)
import GHC.IO (IO (IO), unsafePerformIO, unsafeUnmask)
import Prelude

forkIO :: IO () -> IO ThreadId
forkIO :: IO () -> IO ThreadId
forkIO IO ()
action =
  (State# RealWorld -> (# State# RealWorld, ThreadId #))
-> IO ThreadId
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s ->
    case IO () -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forall a.
a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
fork# IO ()
action State# RealWorld
s of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)

putTMVarIO :: TMVar a -> a -> IO ()
putTMVarIO :: TMVar a -> a -> IO ()
putTMVarIO TMVar a
var a
x =
  STM () -> IO ()
forall a. STM a -> IO a
atomically (TMVar a -> a -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar a
var a
x)

#if defined(mingw32_HOST_OS)
registerDelay :: Int -> IO (STM (), IO ())
registerDelay micros = do
  var <- GHC.Conc.Windows.registerDelay micros
  pure (readTVar var >>= \b -> unless b retry, pure ()) -- no unregister on Windows =P
#else
registerDelay :: Int -> IO (STM (), IO ())
registerDelay :: Int -> IO (STM (), IO ())
registerDelay Int
micros = do
  TVar Bool
var <- Bool -> IO (TVar Bool)
forall a. a -> IO (TVar a)
newTVarIO Bool
False
  TimerManager
manager <- IO TimerManager
getSystemTimerManager
  TimeoutKey
key <- TimerManager -> Int -> IO () -> IO TimeoutKey
registerTimeout TimerManager
manager Int
micros (STM () -> IO ()
forall a. STM a -> IO a
atomically (TVar Bool -> Bool -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar Bool
var Bool
True))
  (STM (), IO ()) -> IO (STM (), IO ())
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TVar Bool -> STM Bool
forall a. TVar a -> STM a
readTVar TVar Bool
var STM Bool -> (Bool -> STM ()) -> STM ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Bool
b -> Bool -> STM () -> STM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
b STM ()
forall a. STM a
retry, TimerManager -> TimeoutKey -> IO ()
unregisterTimeout TimerManager
manager TimeoutKey
key)
#endif

uniqueInt :: IO Int
uniqueInt :: IO Int
uniqueInt =
  IORef Int -> (Int -> (Int, Int)) -> IO Int
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef Int
counter \Int
n -> let n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 in (Int
n', Int
n')

counter :: IORef Int
counter :: IORef Int
counter =
  IO (IORef Int) -> IORef Int
forall a. IO a -> a
unsafePerformIO (Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef Int
0)
{-# NOINLINE counter #-}

#endif