module Engine.Types.RefCounted where

import RIO

import Control.Monad.Trans.Resource (allocate_)
import GHC.IO.Exception (IOErrorType(UserError), IOException(IOError))
import UnliftIO.Resource (MonadResource)

-- | A 'RefCounted' will perform the specified action when the count reaches 0
data RefCounted = RefCounted
  { RefCounted -> IORef Int
rcCount  :: IORef Int
  , RefCounted -> IO ()
rcAction :: IO ()
  }

-- | Create a counter with a value of 1
newRefCounted :: MonadIO m => IO () -> m RefCounted
newRefCounted :: forall (m :: * -> *). MonadIO m => IO () -> m RefCounted
newRefCounted IO ()
rcAction = do
  IORef Int
rcCount <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => a -> m (IORef a)
newIORef Int
1
  pure RefCounted{IO ()
IORef Int
rcCount :: IORef Int
rcAction :: IO ()
$sel:rcAction:RefCounted :: IO ()
$sel:rcCount:RefCounted :: IORef Int
..}

-- | Decrement the value, the action will be run promptly and in
-- this thread if the counter reached 0.
releaseRefCounted :: MonadIO m => RefCounted -> m ()
releaseRefCounted :: forall (m :: * -> *). MonadIO m => RefCounted -> m ()
releaseRefCounted RefCounted{IO ()
IORef Int
rcAction :: IO ()
rcCount :: IORef Int
$sel:rcAction:RefCounted :: RefCounted -> IO ()
$sel:rcCount:RefCounted :: RefCounted -> IORef Int
..} =
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) b.
MonadUnliftIO m =>
((forall a. m a -> m a) -> m b) -> m b
mask \forall a. IO a -> IO a
_ ->
    forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Int
rcCount (\Int
c -> (Int
c forall a. Num a => a -> a -> a
- Int
1, Int
c forall a. Num a => a -> a -> a
- Int
1)) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      Int
n | Int
n forall a. Ord a => a -> a -> Bool
< Int
0 ->
        forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwM forall a b. (a -> b) -> a -> b
$ Maybe Handle
-> IOErrorType
-> String
-> String
-> Maybe CInt
-> Maybe String
-> IOException
IOError
          forall a. Maybe a
Nothing
          IOErrorType
UserError
          String
""
          String
"Ref counted value decremented below 0"
          forall a. Maybe a
Nothing
          forall a. Maybe a
Nothing

      Int
0 ->
        IO ()
rcAction

      Int
_stillReferenced ->
        forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

-- | Increment the counter by 1
takeRefCounted :: MonadIO m => RefCounted -> m ()
takeRefCounted :: forall (m :: * -> *). MonadIO m => RefCounted -> m ()
takeRefCounted RefCounted{IO ()
IORef Int
rcAction :: IO ()
rcCount :: IORef Int
$sel:rcAction:RefCounted :: RefCounted -> IO ()
$sel:rcCount:RefCounted :: RefCounted -> IORef Int
..} =
  forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a b.
MonadIO m =>
IORef a -> (a -> (a, b)) -> m b
atomicModifyIORef' IORef Int
rcCount \Int
c -> (Int
c forall a. Num a => a -> a -> a
+ Int
1, ())

-- | Hold a reference for the duration of the 'MonadResource' action
resourceTRefCount :: MonadResource f => RefCounted -> f ()
resourceTRefCount :: forall (f :: * -> *). MonadResource f => RefCounted -> f ()
resourceTRefCount RefCounted
r =
  forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a.
MonadResource m =>
IO a -> IO () -> m ReleaseKey
allocate_ (forall (m :: * -> *). MonadIO m => RefCounted -> m ()
takeRefCounted RefCounted
r) (forall (m :: * -> *). MonadIO m => RefCounted -> m ()
releaseRefCounted RefCounted
r)