{- |

Provides 'AtomicModify', a typeclass for mutable references that have an atomic
modify operations. This generalizes atomic modify operations in 'IO' and 'STM'
contexts for 'IORef', 'MVar', 'TVar', and 'TMVar'.

* 'IORef' and 'MVar' can be modified in 'IO'.
* 'TVar' and 'TMVar' can be modified in 'IO' or 'STM'.

-}

{-# LANGUAGE BangPatterns, MultiParamTypeClasses #-}

module Control.Concurrent.AtomicModify
  ( AtomicModify (..)
  , atomicModifyStrict_
  , atomicModifyLazy_
  ) where

import Control.Concurrent.MVar (MVar, modifyMVar)
import Control.Concurrent.STM (STM, TMVar, TVar, atomically, putTMVar, readTVar,
                               takeTMVar, writeTVar)
import Control.Monad ((>>=))
import Data.Function (($), (&))
import Data.Functor (($>))
import Data.IORef (IORef, atomicModifyIORef, atomicModifyIORef')
import Prelude (IO, pure, ($!))


--------------------------------------------------------------------------------
--  Class
--------------------------------------------------------------------------------

{- |

A typeclass for mutable references that have an atomic modify operation.

Type variables:

* @ref@ - The reference (e.g. 'IORef', 'TVar', 'MVar', 'TMVar')
* @m@ - The monad in which the modification takes place (e.g. 'IO', 'STM')

As the name "atomic" implies, these functions are useful for using mutable
references in a safe way to prevent race conditions in a multithreaded
program.

-}
class AtomicModify ref m where

  {- |

  Atomically modify the contents of a @ref@ (type @a@) and produce a value (type
  @b@). This is strict; it forces the value stored in the @ref@ as well as the
  value returned.

  -}
  atomicModifyStrict :: ref a -> (a -> (a, b)) -> m b

  {- |

  Atomically modify the contents of a @ref@ (type @a@) and produce a value (type
  @b@). This is lazy, which means if the program calls 'atomicModifyLazy' many
  times, but seldomly uses the value, thunks will pile up in memory resulting in
  a space leak.

  -}
  atomicModifyLazy :: ref a -> (a -> (a, b)) -> m b


--------------------------------------------------------------------------------
--  Functions
--------------------------------------------------------------------------------

{- |

Atomically modify the contents of a @ref@. This is strict; it forces the value
stored in the @ref@ as well as the value returned.

-}
atomicModifyStrict_ :: AtomicModify v m => v a -> (a -> a) -> m ()
atomicModifyStrict_ :: v a -> (a -> a) -> m ()
atomicModifyStrict_ v a
ref a -> a
f = v a -> (a -> (a, ())) -> m ()
forall (ref :: * -> *) (m :: * -> *) a b.
AtomicModify ref m =>
ref a -> (a -> (a, b)) -> m b
atomicModifyStrict v a
ref (\a
a -> (a -> a
f a
a, ()))

{- |

Atomically modify the contents of a @ref@ (type @a@) and produce a value (type
@b@). This is lazy, which means if the program calls 'atomicModifyLazy_' many
times, but seldomly uses the value, thunks will pile up in memory resulting in a
space leak.

-}
atomicModifyLazy_ :: AtomicModify v m => v a -> (a -> a) -> m ()
atomicModifyLazy_ :: v a -> (a -> a) -> m ()
atomicModifyLazy_ v a
ref a -> a
f = v a -> (a -> (a, ())) -> m ()
forall (ref :: * -> *) (m :: * -> *) a b.
AtomicModify ref m =>
ref a -> (a -> (a, b)) -> m b
atomicModifyLazy v a
ref (\a
a -> (a -> a
f a
a, ()))


--------------------------------------------------------------------------------
--  Instances
--------------------------------------------------------------------------------

instance AtomicModify IORef IO
  where
    atomicModifyLazy :: IORef a -> (a -> (a, b)) -> IO b
atomicModifyLazy   IORef a
ref a -> (a, b)
f = IORef a -> (a -> (a, b)) -> IO b
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef  IORef a
ref a -> (a, b)
f
    atomicModifyStrict :: IORef a -> (a -> (a, b)) -> IO b
atomicModifyStrict IORef a
ref a -> (a, b)
f = IORef a -> (a -> (a, b)) -> IO b
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef a
ref a -> (a, b)
f

instance AtomicModify MVar IO
  where
    atomicModifyLazy :: MVar a -> (a -> (a, b)) -> IO b
atomicModifyLazy   MVar a
ref a -> (a, b)
f = MVar a -> (a -> IO (a, b)) -> IO b
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar a
ref (\a
x -> (a, b) -> IO (a, b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((a, b) -> IO (a, b)) -> (a, b) -> IO (a, b)
forall a b. (a -> b) -> a -> b
$  a -> (a, b)
f a
x)
    atomicModifyStrict :: MVar a -> (a -> (a, b)) -> IO b
atomicModifyStrict MVar a
ref a -> (a, b)
f = MVar a -> (a -> IO (a, b)) -> IO b
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
modifyMVar MVar a
ref (\a
x -> (a, b) -> IO (a, b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((a, b) -> IO (a, b)) -> (a, b) -> IO (a, b)
forall a b. (a -> b) -> a -> b
$! a -> (a, b)
f a
x)

instance AtomicModify TVar STM
  where
    atomicModifyLazy :: TVar a -> (a -> (a, b)) -> STM b
atomicModifyLazy TVar a
ref a -> (a, b)
f =
      TVar a -> STM a
forall a. TVar a -> STM a
readTVar TVar a
ref STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a -> a -> (a, b)
f a
a (a, b) -> ((a, b) -> STM b) -> STM b
forall a b. a -> (a -> b) -> b
& \( a
a',  b
b) -> TVar a -> a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar a
ref a
a' STM () -> b -> STM b
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> b
b
    atomicModifyStrict :: TVar a -> (a -> (a, b)) -> STM b
atomicModifyStrict TVar a
ref a -> (a, b)
f =
      TVar a -> STM a
forall a. TVar a -> STM a
readTVar TVar a
ref STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a -> a -> (a, b)
f a
a (a, b) -> ((a, b) -> STM b) -> STM b
forall a b. a -> (a -> b) -> b
& \(!a
a', !b
b) -> TVar a -> a -> STM ()
forall a. TVar a -> a -> STM ()
writeTVar TVar a
ref a
a' STM () -> b -> STM b
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> b
b

instance AtomicModify TMVar STM
  where
    atomicModifyLazy :: TMVar a -> (a -> (a, b)) -> STM b
atomicModifyLazy TMVar a
ref a -> (a, b)
f =
      TMVar a -> STM a
forall a. TMVar a -> STM a
takeTMVar TMVar a
ref STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a -> a -> (a, b)
f a
a (a, b) -> ((a, b) -> STM b) -> STM b
forall a b. a -> (a -> b) -> b
& \( a
a',  b
b) -> TMVar a -> a -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar a
ref a
a' STM () -> b -> STM b
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> b
b
    atomicModifyStrict :: TMVar a -> (a -> (a, b)) -> STM b
atomicModifyStrict TMVar a
ref a -> (a, b)
f =
      TMVar a -> STM a
forall a. TMVar a -> STM a
takeTMVar TMVar a
ref STM a -> (a -> STM b) -> STM b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
a -> a -> (a, b)
f a
a (a, b) -> ((a, b) -> STM b) -> STM b
forall a b. a -> (a -> b) -> b
& \(!a
a', !b
b) -> TMVar a -> a -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar a
ref a
a' STM () -> b -> STM b
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> b
b

instance AtomicModify TVar IO
  where
    atomicModifyLazy :: TVar a -> (a -> (a, b)) -> IO b
atomicModifyLazy   TVar a
ref a -> (a, b)
f = STM b -> IO b
forall a. STM a -> IO a
atomically (TVar a -> (a -> (a, b)) -> STM b
forall (ref :: * -> *) (m :: * -> *) a b.
AtomicModify ref m =>
ref a -> (a -> (a, b)) -> m b
atomicModifyLazy   TVar a
ref a -> (a, b)
f)
    atomicModifyStrict :: TVar a -> (a -> (a, b)) -> IO b
atomicModifyStrict TVar a
ref a -> (a, b)
f = STM b -> IO b
forall a. STM a -> IO a
atomically (TVar a -> (a -> (a, b)) -> STM b
forall (ref :: * -> *) (m :: * -> *) a b.
AtomicModify ref m =>
ref a -> (a -> (a, b)) -> m b
atomicModifyStrict TVar a
ref a -> (a, b)
f)

instance AtomicModify TMVar IO
  where
    atomicModifyLazy :: TMVar a -> (a -> (a, b)) -> IO b
atomicModifyLazy   TMVar a
ref a -> (a, b)
f = STM b -> IO b
forall a. STM a -> IO a
atomically (TMVar a -> (a -> (a, b)) -> STM b
forall (ref :: * -> *) (m :: * -> *) a b.
AtomicModify ref m =>
ref a -> (a -> (a, b)) -> m b
atomicModifyLazy   TMVar a
ref a -> (a, b)
f)
    atomicModifyStrict :: TMVar a -> (a -> (a, b)) -> IO b
atomicModifyStrict TMVar a
ref a -> (a, b)
f = STM b -> IO b
forall a. STM a -> IO a
atomically (TMVar a -> (a -> (a, b)) -> STM b
forall (ref :: * -> *) (m :: * -> *) a b.
AtomicModify ref m =>
ref a -> (a -> (a, b)) -> m b
atomicModifyStrict TMVar a
ref a -> (a, b)
f)