module Control.Monad.Trans.Resource.Extra
   ( -- * Acquire
    mkAcquire1
   , mkAcquireType1
   , acquireReleaseSelf

    -- * MonadResource
   , acquireReleaseKey
   , registerType

    -- * MonadMask
   , withAcquire
   , withAcquireRelease

    -- * Restore
   , Restore (..)
   , getRestoreIO
   , withRestoreIO

    -- * IO
   , once
   , onceK
   ) where

import Control.Concurrent.MVar
import Control.Exception.Safe qualified as Ex
import Control.Monad
import Control.Monad.IO.Class
import Control.Monad.Trans.Resource.Internal qualified as R
import Data.Acquire.Internal qualified as A
import Data.IORef
import Data.IntMap.Strict qualified as IntMap
import Data.Kind
import System.IO.Unsafe

--------------------------------------------------------------------------------

-- | Like 'A.mkAcquire', but the release function will be run at most once.
mkAcquire1 :: IO a -> (a -> IO ()) -> A.Acquire a
mkAcquire1 :: forall a. IO a -> (a -> IO ()) -> Acquire a
mkAcquire1 IO a
m a -> IO ()
f = forall a. IO a -> (a -> IO ()) -> Acquire a
A.mkAcquire IO a
m (forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(a -> m ()) -> a -> m ()
onceK a -> IO ()
f)

-- | Like 'A.mkAcquireType', but the release function will be run at most once.
mkAcquireType1 :: IO a -> (a -> A.ReleaseType -> IO ()) -> A.Acquire a
mkAcquireType1 :: forall a. IO a -> (a -> ReleaseType -> IO ()) -> Acquire a
mkAcquireType1 IO a
m a -> ReleaseType -> IO ()
f = forall a. IO a -> (a -> ReleaseType -> IO ()) -> Acquire a
A.mkAcquireType IO a
m (forall a b c. ((a, b) -> c) -> a -> b -> c
curry (forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(a -> m ()) -> a -> m ()
onceK (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry a -> ReleaseType -> IO ()
f)))

-- | Build an 'A.Acquire' having access to its own release function.
acquireReleaseSelf :: A.Acquire ((A.ReleaseType -> IO ()) -> a) -> A.Acquire a
acquireReleaseSelf :: forall a. Acquire ((ReleaseType -> IO ()) -> a) -> Acquire a
acquireReleaseSelf (A.Acquire (forall b. IO b -> IO b)
-> IO (Allocated ((ReleaseType -> IO ()) -> a))
f) = forall a.
((forall b. IO b -> IO b) -> IO (Allocated a)) -> Acquire a
A.Acquire \forall b. IO b -> IO b
restore -> do
   A.Allocated (ReleaseType -> IO ()) -> a
g ReleaseType -> IO ()
rel0 <- (forall b. IO b -> IO b)
-> IO (Allocated ((ReleaseType -> IO ()) -> a))
f forall b. IO b -> IO b
restore
   let rel1 :: ReleaseType -> IO ()
rel1 = forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(a -> m ()) -> a -> m ()
onceK ReleaseType -> IO ()
rel0
   forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a. a -> (ReleaseType -> IO ()) -> Allocated a
A.Allocated ((ReleaseType -> IO ()) -> a
g ReleaseType -> IO ()
rel1) ReleaseType -> IO ()
rel1

--------------------------------------------------------------------------------

-- | Like 'withAcquireRelease', but doesn't take the extra release function.
withAcquire :: (Ex.MonadMask m, MonadIO m) => A.Acquire a -> (a -> m b) -> m b
withAcquire :: forall (m :: * -> *) a b.
(MonadMask m, MonadIO m) =>
Acquire a -> (a -> m b) -> m b
withAcquire (A.Acquire (forall b. IO b -> IO b) -> IO (Allocated a)
f) a -> m b
g = do
   Restore forall b. IO b -> IO b
restoreIO <- forall (m :: * -> *). MonadIO m => m (Restore IO)
getRestoreIO
   forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
Ex.mask \forall a. m a -> m a
restoreM -> do
      A.Allocated a
x ReleaseType -> IO ()
free <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ (forall b. IO b -> IO b) -> IO (Allocated a)
f forall b. IO b -> IO b
restoreIO
      b
b <- forall (m :: * -> *) e a b.
(MonadMask m, Exception e) =>
m a -> (e -> m b) -> m a
Ex.withException (forall a. m a -> m a
restoreM (a -> m b
g a
x)) \SomeException
e ->
         forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ReleaseType -> IO ()
free forall a b. (a -> b) -> a -> b
$ SomeException -> ReleaseType
A.ReleaseExceptionWith SomeException
e
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ReleaseType -> IO ()
free ReleaseType
A.ReleaseNormal
      forall (f :: * -> *) a. Applicative f => a -> f a
pure b
b

-- | @'withAcquireRelease' acq \\release a -> act@ acquires the @a@ and
-- automaticaly releases it when @mb@ returns or throws an exception.
-- If desired, @release@ can be used to release @a@ earlier.
withAcquireRelease
   :: (Ex.MonadMask m, MonadIO m)
   => A.Acquire a
   -> ((A.ReleaseType -> IO ()) -> a -> m b)
   -> m b
withAcquireRelease :: forall (m :: * -> *) a b.
(MonadMask m, MonadIO m) =>
Acquire a -> ((ReleaseType -> IO ()) -> a -> m b) -> m b
withAcquireRelease (A.Acquire (forall b. IO b -> IO b) -> IO (Allocated a)
f) (ReleaseType -> IO ()) -> a -> m b
g = do
   Restore forall b. IO b -> IO b
restoreIO <- forall (m :: * -> *). MonadIO m => m (Restore IO)
getRestoreIO
   forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
Ex.mask \forall a. m a -> m a
restoreM -> do
      A.Allocated a
x ReleaseType -> IO ()
free <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ (forall b. IO b -> IO b) -> IO (Allocated a)
f forall b. IO b -> IO b
restoreIO
      -- Wrapper so that we don't perform `free` again if `g` already did.
      let free1 :: ReleaseType -> IO ()
free1 = forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(a -> m ()) -> a -> m ()
onceK ReleaseType -> IO ()
free
      b
b <- forall (m :: * -> *) e a b.
(MonadMask m, Exception e) =>
m a -> (e -> m b) -> m a
Ex.withException (forall a. m a -> m a
restoreM ((ReleaseType -> IO ()) -> a -> m b
g ReleaseType -> IO ()
free1 a
x)) \SomeException
e ->
         forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ReleaseType -> IO ()
free1 forall a b. (a -> b) -> a -> b
$ SomeException -> ReleaseType
A.ReleaseExceptionWith SomeException
e
      forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ ReleaseType -> IO ()
free1 ReleaseType
A.ReleaseNormal
      forall (f :: * -> *) a. Applicative f => a -> f a
pure b
b

--------------------------------------------------------------------------------

-- | Like 'R.register', but gives access to the 'A.ReleaseType' too.
registerType
   :: (R.MonadResource m) => (A.ReleaseType -> IO ()) -> m R.ReleaseKey
registerType :: forall (m :: * -> *).
MonadResource m =>
(ReleaseType -> IO ()) -> m ReleaseKey
registerType = forall (m :: * -> *) a. MonadResource m => ResourceT IO a -> m a
R.liftResourceT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. (IORef ReleaseMap -> m a) -> ResourceT m a
R.ResourceT forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip IORef ReleaseMap -> (ReleaseType -> IO ()) -> IO ReleaseKey
R.registerType

-- | 'acquireReleaseKey' will 'R.unprotect' the 'R.ReleaseKey',
-- and use 'A.Acquire' to manage the release action instead.
acquireReleaseKey :: R.ReleaseKey -> A.Acquire ()
acquireReleaseKey :: ReleaseKey -> Acquire ()
acquireReleaseKey (R.ReleaseKey IORef ReleaseMap
istate Int
key) =
   forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. IO a -> (a -> ReleaseType -> IO ()) -> Acquire a
A.mkAcquireType IO (Maybe (ReleaseType -> IO ()))
acq Maybe (ReleaseType -> IO ()) -> ReleaseType -> IO ()
rel
  where
   acq :: IO (Maybe (A.ReleaseType -> IO ()))
   acq :: IO (Maybe (ReleaseType -> IO ()))
acq =
      -- The following code does pretty much the same as 'R.unprotect',
      -- which we can't use directly because its result doesn't allow us
      -- to specify the 'A.ReleaseType' during release.
      forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef IORef ReleaseMap
istate \case
         R.ReleaseMap Int
next RefCount
rf IntMap (ReleaseType -> IO ())
im
            | Just ReleaseType -> IO ()
g <- forall a. Int -> IntMap a -> Maybe a
IntMap.lookup Int
key IntMap (ReleaseType -> IO ())
im ->
               (Int -> RefCount -> IntMap (ReleaseType -> IO ()) -> ReleaseMap
R.ReleaseMap Int
next RefCount
rf (forall a. Int -> IntMap a -> IntMap a
IntMap.delete Int
key IntMap (ReleaseType -> IO ())
im), forall a. a -> Maybe a
Just ReleaseType -> IO ()
g)
         ReleaseMap
rm -> (ReleaseMap
rm, forall a. Maybe a
Nothing)
   rel :: Maybe (A.ReleaseType -> IO ()) -> A.ReleaseType -> IO ()
   rel :: Maybe (ReleaseType -> IO ()) -> ReleaseType -> IO ()
rel = forall b a. b -> (a -> b) -> Maybe a -> b
maybe forall a. Monoid a => a
mempty forall a. a -> a
id

--------------------------------------------------------------------------------

-- | Wrapper around a “restore” function like the one given
-- by @'mask' (\\restore -> ...)@, in a particular 'Monad' @m@.
type Restore :: (Type -> Type) -> Type
newtype Restore m = Restore (forall x. m x -> m x)

-- | Get the current 'Restore' action in 'IO', wrapped in 'Restore'.
getRestoreIO :: (MonadIO m) => m (Restore IO)
getRestoreIO :: forall (m :: * -> *). MonadIO m => m (Restore IO)
getRestoreIO =
   -- Ugly, but safe. Check the implementation in base.
   forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) b.
MonadMask m =>
((forall a. m a -> m a) -> m b) -> m b
Ex.mask \forall b. IO b -> IO b
f -> forall (f :: * -> *) a. Applicative f => a -> f a
pure (forall (m :: * -> *). (forall x. m x -> m x) -> Restore m
Restore forall b. IO b -> IO b
f)

-- | Get the current 'Restore' action in 'IO', without the 'Restore' wrapper.
withRestoreIO
   :: (Ex.MonadMask m, MonadIO m) => ((forall x. IO x -> IO x) -> m a) -> m a
withRestoreIO :: forall (m :: * -> *) a.
(MonadMask m, MonadIO m) =>
((forall b. IO b -> IO b) -> m a) -> m a
withRestoreIO (forall b. IO b -> IO b) -> m a
f = forall (m :: * -> *). MonadIO m => m (Restore IO)
getRestoreIO forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Restore forall b. IO b -> IO b
g) -> (forall b. IO b -> IO b) -> m a
f forall b. IO b -> IO b
g

--------------------------------------------------------------------------------

-- | @'once' ma@ wraps @ma@ so that @ma@ is executed at most once. Further
-- executions of the same @'once' ma@ are a no-op. It's safe to use the wrapper
-- concurrently; only one thread will get to execute the actual @ma@ at most.
once :: (MonadIO m, Ex.MonadMask m) => m () -> m ()
once :: forall (m :: * -> *). (MonadIO m, MonadMask m) => m () -> m ()
once m ()
ma = forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(a -> m ()) -> a -> m ()
onceK (forall a b. a -> b -> a
const m ()
ma) ()

-- | Kleisli version of 'once'.
onceK :: (MonadIO m, Ex.MonadMask m) => (a -> m ()) -> (a -> m ())
{-# NOINLINE onceK #-}
onceK :: forall (m :: * -> *) a.
(MonadIO m, MonadMask m) =>
(a -> m ()) -> a -> m ()
onceK a -> m ()
kma = forall a. IO a -> a
unsafePerformIO do
   MVar Bool
done <- forall a. a -> IO (MVar a)
newMVar Bool
False
   forall (f :: * -> *) a. Applicative f => a -> f a
pure \a
a ->
      forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
Ex.bracket
         (forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
takeMVar MVar Bool
done)
         (\Bool
_ -> forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> a -> IO ()
putMVar MVar Bool
done Bool
True)
         (\Bool
d -> forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
d (a -> m ()
kma a
a))