-- | Like Control.Concurrent.MVar.Strict but reduce to HNF, not NF
{-# LANGUAGE CPP, MagicHash, UnboxedTuples #-}
module Control.Distributed.Process.Internal.StrictMVar
  ( StrictMVar(StrictMVar)
  , newEmptyMVar
  , newMVar
  , takeMVar
  , putMVar
  , readMVar
  , withMVar
  , modifyMVar_
  , modifyMVar
  , modifyMVarMasked
  , mkWeakMVar
  ) where

import Control.Monad ((>=>))
import Control.Exception (evaluate, mask_, onException)
import qualified Control.Concurrent.MVar as MVar
  ( MVar
  , newEmptyMVar
  , newMVar
  , takeMVar
  , putMVar
  , readMVar
  , withMVar
  , modifyMVar_
  , modifyMVar
  )
import GHC.MVar (MVar(MVar))
import GHC.IO (IO(IO), unIO)
import GHC.Exts (mkWeak#)
import GHC.Weak (Weak(Weak))

newtype StrictMVar a = StrictMVar (MVar.MVar a)

newEmptyMVar :: IO (StrictMVar a)
newEmptyMVar :: forall a. IO (StrictMVar a)
newEmptyMVar = MVar a -> StrictMVar a
forall a. MVar a -> StrictMVar a
StrictMVar (MVar a -> StrictMVar a) -> IO (MVar a) -> IO (StrictMVar a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (MVar a)
forall a. IO (MVar a)
MVar.newEmptyMVar

newMVar :: a -> IO (StrictMVar a)
newMVar :: forall a. a -> IO (StrictMVar a)
newMVar a
x = a -> IO a
forall a. a -> IO a
evaluate a
x IO a -> IO (StrictMVar a) -> IO (StrictMVar a)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar a -> StrictMVar a
forall a. MVar a -> StrictMVar a
StrictMVar (MVar a -> StrictMVar a) -> IO (MVar a) -> IO (StrictMVar a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> a -> IO (MVar a)
forall a. a -> IO (MVar a)
MVar.newMVar a
x

takeMVar :: StrictMVar a -> IO a
takeMVar :: forall a. StrictMVar a -> IO a
takeMVar (StrictMVar MVar a
v) = MVar a -> IO a
forall a. MVar a -> IO a
MVar.takeMVar MVar a
v

putMVar :: StrictMVar a -> a -> IO ()
putMVar :: forall a. StrictMVar a -> a -> IO ()
putMVar (StrictMVar MVar a
v) a
x = a -> IO a
forall a. a -> IO a
evaluate a
x IO a -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
MVar.putMVar MVar a
v a
x

readMVar :: StrictMVar a -> IO a
readMVar :: forall a. StrictMVar a -> IO a
readMVar (StrictMVar MVar a
v) = MVar a -> IO a
forall a. MVar a -> IO a
MVar.readMVar MVar a
v

withMVar :: StrictMVar a -> (a -> IO b) -> IO b
withMVar :: forall a b. StrictMVar a -> (a -> IO b) -> IO b
withMVar (StrictMVar MVar a
v) = MVar a -> (a -> IO b) -> IO b
forall a b. MVar a -> (a -> IO b) -> IO b
MVar.withMVar MVar a
v

modifyMVar_ :: StrictMVar a -> (a -> IO a) -> IO ()
modifyMVar_ :: forall a. StrictMVar a -> (a -> IO a) -> IO ()
modifyMVar_ (StrictMVar MVar a
v) a -> IO a
f = MVar a -> (a -> IO a) -> IO ()
forall a. MVar a -> (a -> IO a) -> IO ()
MVar.modifyMVar_ MVar a
v (a -> IO a
f (a -> IO a) -> (a -> IO a) -> a -> IO a
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> a -> IO a
forall a. a -> IO a
evaluate)

modifyMVar :: StrictMVar a -> (a -> IO (a, b)) -> IO b
modifyMVar :: forall a b. StrictMVar a -> (a -> IO (a, b)) -> IO b
modifyMVar (StrictMVar MVar a
v) a -> IO (a, b)
f = MVar a -> (a -> IO (a, b)) -> IO b
forall a b. MVar a -> (a -> IO (a, b)) -> IO b
MVar.modifyMVar MVar a
v (a -> IO (a, b)
f (a -> IO (a, b)) -> ((a, b) -> IO (a, b)) -> a -> IO (a, b)
forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> (a, b) -> IO (a, b)
forall a b. (a, b) -> IO (a, b)
evaluateFst)
  where
    evaluateFst :: (a, b) -> IO (a, b)
    evaluateFst :: forall a b. (a, b) -> IO (a, b)
evaluateFst (a
x, b
y) = a -> IO a
forall a. a -> IO a
evaluate a
x IO a -> IO (a, b) -> IO (a, b)
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> (a, b) -> IO (a, b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (a
x, b
y)

modifyMVarMasked :: StrictMVar a -> (a -> IO (a, b)) -> IO b
modifyMVarMasked :: forall a b. StrictMVar a -> (a -> IO (a, b)) -> IO b
modifyMVarMasked (StrictMVar MVar a
v) a -> IO (a, b)
f =
  IO b -> IO b
forall a. IO a -> IO a
mask_ (IO b -> IO b) -> IO b -> IO b
forall a b. (a -> b) -> a -> b
$ do
    a
a      <- MVar a -> IO a
forall a. MVar a -> IO a
MVar.takeMVar MVar a
v
    (a
a',b
b) <- (a -> IO (a, b)
f a
a IO (a, b) -> ((a, b) -> IO (a, b)) -> IO (a, b)
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a, b) -> IO (a, b)
forall a. a -> IO a
evaluate) IO (a, b) -> IO () -> IO (a, b)
forall a b. IO a -> IO b -> IO a
`onException` MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
MVar.putMVar MVar a
v a
a
    MVar a -> a -> IO ()
forall a. MVar a -> a -> IO ()
MVar.putMVar MVar a
v a
a'
    b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
b

mkWeakMVar :: StrictMVar a -> IO () -> IO (Weak (StrictMVar a))
mkWeakMVar :: forall a. StrictMVar a -> IO () -> IO (Weak (StrictMVar a))
mkWeakMVar q :: StrictMVar a
q@(StrictMVar (MVar MVar# RealWorld a
m#)) IO ()
f = (State# RealWorld -> (# State# RealWorld, Weak (StrictMVar a) #))
-> IO (Weak (StrictMVar a))
forall a. (State# RealWorld -> (# State# RealWorld, a #)) -> IO a
IO ((State# RealWorld -> (# State# RealWorld, Weak (StrictMVar a) #))
 -> IO (Weak (StrictMVar a)))
-> (State# RealWorld
    -> (# State# RealWorld, Weak (StrictMVar a) #))
-> IO (Weak (StrictMVar a))
forall a b. (a -> b) -> a -> b
$ \State# RealWorld
s ->
#if MIN_VERSION_base(4,9,0)
  case MVar# RealWorld a
-> StrictMVar a
-> (State# RealWorld -> (# State# RealWorld, () #))
-> State# RealWorld
-> (# State# RealWorld, Weak# (StrictMVar a) #)
forall a b c.
a
-> b
-> (State# RealWorld -> (# State# RealWorld, c #))
-> State# RealWorld
-> (# State# RealWorld, Weak# b #)
mkWeak# MVar# RealWorld a
m# StrictMVar a
q (IO () -> State# RealWorld -> (# State# RealWorld, () #)
forall a. IO a -> State# RealWorld -> (# State# RealWorld, a #)
unIO IO ()
f) State# RealWorld
s of (# State# RealWorld
s', Weak# (StrictMVar a)
w #) -> (# State# RealWorld
s', Weak# (StrictMVar a) -> Weak (StrictMVar a)
forall v. Weak# v -> Weak v
Weak Weak# (StrictMVar a)
w #)
#else
  case mkWeak# m# q f s of (# s', w #) -> (# s', Weak w #)
#endif