{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}

-- | Miscellaneous IO utilities
module Ki.Internal.IO
  ( -- * Unexceptional IO
    UnexceptionalIO (..),
    IOResult (..),
    unexceptionalTry,
    unexceptionalTryEither,

    -- * Exception utils
    isAsyncException,
    interruptiblyMasked,
    uninterruptiblyMasked,
    tryEitherSTM,

    -- * Fork utils
    forkIO,
    forkOn,
  )
where

import Control.Exception
import Control.Monad (join)
import Data.Coerce (coerce)
import GHC.Base (maskAsyncExceptions#, maskUninterruptible#)
import GHC.Conc (STM, ThreadId (ThreadId), catchSTM)
import GHC.Exts (Int (I#), fork#, forkOn#)
import GHC.IO (IO (IO))
import Prelude

-- A little promise that this IO action cannot throw an exception.
--
-- Yeah it's verbose, and maybe not that necessary, but the code that bothers to use it really does require
-- un-exceptiony IO actions for correctness, so here we are.
newtype UnexceptionalIO a = UnexceptionalIO
  {forall a. UnexceptionalIO a -> IO a
runUnexceptionalIO :: IO a}
  deriving newtype (Functor UnexceptionalIO
forall a. a -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
$c<* :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO a
*> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
$c*> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
liftA2 :: forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
$cliftA2 :: forall a b c.
(a -> b -> c)
-> UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO c
<*> :: forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
$c<*> :: forall a b.
UnexceptionalIO (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
pure :: forall a. a -> UnexceptionalIO a
$cpure :: forall a. a -> UnexceptionalIO a
Applicative, forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
$c<$ :: forall a b. a -> UnexceptionalIO b -> UnexceptionalIO a
fmap :: forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
$cfmap :: forall a b. (a -> b) -> UnexceptionalIO a -> UnexceptionalIO b
Functor, Applicative UnexceptionalIO
forall a. a -> UnexceptionalIO a
forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: forall a. a -> UnexceptionalIO a
$creturn :: forall a. a -> UnexceptionalIO a
>> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
$c>> :: forall a b.
UnexceptionalIO a -> UnexceptionalIO b -> UnexceptionalIO b
>>= :: forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
$c>>= :: forall a b.
UnexceptionalIO a -> (a -> UnexceptionalIO b) -> UnexceptionalIO b
Monad)

data IOResult a
  = Failure !SomeException -- sync or async exception
  | Success a

unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
unexceptionalTry :: forall a. IO a -> UnexceptionalIO (IOResult a)
unexceptionalTry IO a
action =
  forall a. IO a -> UnexceptionalIO a
UnexceptionalIO do
    (forall a. a -> IOResult a
Success forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
action) forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` \SomeException
exception ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall a. SomeException -> IOResult a
Failure SomeException
exception)

-- Like try, but with continuations. Also, catches all exceptions, because that's the only flavor we need.
unexceptionalTryEither ::
  forall a b.
  (SomeException -> UnexceptionalIO b) ->
  (a -> UnexceptionalIO b) ->
  IO a ->
  UnexceptionalIO b
unexceptionalTryEither :: forall a b.
(SomeException -> UnexceptionalIO b)
-> (a -> UnexceptionalIO b) -> IO a -> UnexceptionalIO b
unexceptionalTryEither SomeException -> UnexceptionalIO b
onFailure a -> UnexceptionalIO b
onSuccess IO a
action =
  forall a. IO a -> UnexceptionalIO a
UnexceptionalIO do
    forall (m :: * -> *) a. Monad m => m (m a) -> m a
join do
      forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
        (coerce :: forall a b. Coercible a b => a -> b
coerce @_ @(a -> IO b) a -> UnexceptionalIO b
onSuccess forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO a
action)
        (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. coerce :: forall a b. Coercible a b => a -> b
coerce @_ @(SomeException -> IO b) SomeException -> UnexceptionalIO b
onFailure)

isAsyncException :: SomeException -> Bool
isAsyncException :: SomeException -> Bool
isAsyncException SomeException
exception =
  case forall e. Exception e => SomeException -> Maybe e
fromException @SomeAsyncException SomeException
exception of
    Maybe SomeAsyncException
Nothing -> Bool
False
    Just SomeAsyncException
_ -> Bool
True

-- | Call an action with asynchronous exceptions interruptibly masked.
interruptiblyMasked :: IO a -> IO a
interruptiblyMasked :: forall a. IO a -> IO a
interruptiblyMasked (IO State# RealWorld -> (# State# RealWorld, a #)
io) =
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO (forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
maskAsyncExceptions# State# RealWorld -> (# State# RealWorld, a #)
io)

-- | Call an action with asynchronous exceptions uninterruptibly masked.
uninterruptiblyMasked :: IO a -> IO a
uninterruptiblyMasked :: forall a. IO a -> IO a
uninterruptiblyMasked (IO State# RealWorld -> (# State# RealWorld, a #)
io) =
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO (forall a.
(State# RealWorld -> (# State# RealWorld, a #))
-> State# RealWorld -> (# State# RealWorld, a #)
maskUninterruptible# State# RealWorld -> (# State# RealWorld, a #)
io)

-- Like try, but with continuations
tryEitherSTM :: Exception e => (e -> STM b) -> (a -> STM b) -> STM a -> STM b
tryEitherSTM :: forall e b a.
Exception e =>
(e -> STM b) -> (a -> STM b) -> STM a -> STM b
tryEitherSTM e -> STM b
onFailure a -> STM b
onSuccess STM a
action =
  forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (forall e a. Exception e => STM a -> (e -> STM a) -> STM a
catchSTM (a -> STM b
onSuccess forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> STM a
action) (forall (f :: * -> *) a. Applicative f => a -> f a
pure forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> STM b
onFailure))

-- Control.Concurrent.forkIO without the exception handler
forkIO :: IO () -> IO ThreadId
forkIO :: IO () -> IO ThreadId
forkIO (IO State# RealWorld -> (# State# RealWorld, () #)
action) =
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0 ->
    case forall a.
a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
fork# State# RealWorld -> (# State# RealWorld, () #)
action State# RealWorld
s0 of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)

-- Control.Concurrent.forkOn without the exception handler
forkOn :: Int -> IO () -> IO ThreadId
forkOn :: Int -> IO () -> IO ThreadId
forkOn (I# Int#
cap) (IO State# RealWorld -> (# State# RealWorld, () #)
action) =
  forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO \State# RealWorld
s0 ->
    case forall a.
Int# -> a -> State# RealWorld -> (# State# RealWorld, ThreadId# #)
forkOn# Int#
cap State# RealWorld -> (# State# RealWorld, () #)
action State# RealWorld
s0 of
      (# State# RealWorld
s1, ThreadId#
tid #) -> (# State# RealWorld
s1, ThreadId# -> ThreadId
ThreadId ThreadId#
tid #)