{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}

-- |
-- Module: Internal.Rc
-- Description: Reference counted boxes.
--
-- This module provides a reference-counted cell type 'Rc', which contains a
-- value and a finalizer. When the reference count reaches zero, the value is
-- dropped and the finalizer is run.
module Internal.Rc
  ( Rc,
    new,
    get,
    incr,
    decr,
    release,
  )
where

import Control.Concurrent.STM

-- | A reference-counted container for a value of type @a@.
newtype Rc a
  = Rc (TVar (Maybe (RcState a)))
  deriving (Rc a -> Rc a -> Bool
forall a. Rc a -> Rc a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Rc a -> Rc a -> Bool
$c/= :: forall a. Rc a -> Rc a -> Bool
== :: Rc a -> Rc a -> Bool
$c== :: forall a. Rc a -> Rc a -> Bool
Eq)

data RcState a = RcState
  { forall a. RcState a -> Int
refCount :: !Int,
    forall a. RcState a -> a
value :: a,
    forall a. RcState a -> STM ()
finalizer :: STM ()
  }

-- | @'new' val finalizer@ creates a new 'Rc' containing the value @val@, with
-- an initial reference count of 1. When the reference count drops to zero, the
-- finalizer will be run.
new :: a -> STM () -> STM (Rc a)
new :: forall a. a -> STM () -> STM (Rc a)
new a
value STM ()
finalizer =
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. TVar (Maybe (RcState a)) -> Rc a
Rc forall a b. (a -> b) -> a -> b
$
    forall a. a -> STM (TVar a)
newTVar forall a b. (a -> b) -> a -> b
$
      forall a. a -> Maybe a
Just
        RcState
          { refCount :: Int
refCount = Int
1,
            a
value :: a
value :: a
value,
            STM ()
finalizer :: STM ()
finalizer :: STM ()
finalizer
          }

-- | Increment the reference count.
incr :: Rc a -> STM ()
incr :: forall a. Rc a -> STM ()
incr (Rc TVar (Maybe (RcState a))
tv) = forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar (Maybe (RcState a))
tv forall a b. (a -> b) -> a -> b
$
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a b. (a -> b) -> a -> b
$
    \s :: RcState a
s@RcState {Int
refCount :: Int
refCount :: forall a. RcState a -> Int
refCount} -> RcState a
s {refCount :: Int
refCount = Int
refCount forall a. Num a => a -> a -> a
+ Int
1}

-- | Decrement the reference count. If this brings the count to zero, run the
-- finalizer and release the value.
decr :: Rc a -> STM ()
decr :: forall a. Rc a -> STM ()
decr (Rc TVar (Maybe (RcState a))
tv) =
  forall a. TVar a -> STM a
readTVar TVar (Maybe (RcState a))
tv forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (RcState a)
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just RcState {refCount :: forall a. RcState a -> Int
refCount = Int
1, STM ()
finalizer :: STM ()
finalizer :: forall a. RcState a -> STM ()
finalizer} -> do
      forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (RcState a))
tv forall a. Maybe a
Nothing
      STM ()
finalizer
    Just s :: RcState a
s@RcState {Int
refCount :: Int
refCount :: forall a. RcState a -> Int
refCount} ->
      forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (RcState a))
tv forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just RcState a
s {refCount :: Int
refCount = Int
refCount forall a. Num a => a -> a -> a
- Int
1}

-- | Release the value immediately, and run the finalizer, regardless of the
-- current reference count.
release :: Rc a -> STM ()
release :: forall a. Rc a -> STM ()
release (Rc TVar (Maybe (RcState a))
tv) =
  forall a. TVar a -> STM a
readTVar TVar (Maybe (RcState a))
tv forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
    Maybe (RcState a)
Nothing ->
      forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
    Just RcState {STM ()
finalizer :: STM ()
finalizer :: forall a. RcState a -> STM ()
finalizer} -> do
      STM ()
finalizer
      forall a. TVar a -> a -> STM ()
writeTVar TVar (Maybe (RcState a))
tv forall a. Maybe a
Nothing

-- | Fetch the value, or 'Nothing' if it has been released.
get :: Rc a -> STM (Maybe a)
get :: forall a. Rc a -> STM (Maybe a)
get (Rc TVar (Maybe (RcState a))
tv) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. RcState a -> a
value forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. TVar a -> STM a
readTVar TVar (Maybe (RcState a))
tv