{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash, UnboxedTuples, PatternGuards, ScopedTypeVariables, RankNTypes #-}
module Control.Distributed.Process.Internal.CQueue
( CQueue
, BlockSpec(..)
, MatchOn(..)
, newCQueue
, enqueue
, enqueueSTM
, dequeue
, mkWeakCQueue
, queueSize
) where
import Prelude hiding (length, reverse)
import Control.Concurrent.STM
( atomically
, STM
, TChan
, TVar
, modifyTVar'
, tryReadTChan
, newTChan
, newTVarIO
, writeTChan
, readTChan
, readTVarIO
, orElse
, retry
)
import Control.Applicative ((<$>), (<*>))
import Control.Exception (mask_, onException)
import System.Timeout (timeout)
import Control.Distributed.Process.Internal.StrictMVar
( StrictMVar(StrictMVar)
, newMVar
, takeMVar
, putMVar
)
import Control.Distributed.Process.Internal.StrictList
( StrictList(..)
, append
)
import Data.Maybe (fromJust)
import Data.Traversable (traverse)
import GHC.MVar (MVar(MVar))
import GHC.IO (IO(IO), unIO)
import GHC.Exts (mkWeak#)
import GHC.Weak (Weak(Weak))
data CQueue a = CQueue (StrictMVar (StrictList a))
(TChan a)
(TVar Int)
newCQueue :: IO (CQueue a)
newCQueue = CQueue <$> newMVar Nil <*> atomically newTChan <*> newTVarIO 0
enqueue :: CQueue a -> a -> IO ()
enqueue c !a = atomically (enqueueSTM c a)
enqueueSTM :: CQueue a -> a -> STM ()
enqueueSTM (CQueue _arrived incoming size) !a = do
writeTChan incoming a
modifyTVar' size succ
data BlockSpec =
NonBlocking
| Blocking
| Timeout Int
data MatchOn m a
= MatchMsg (m -> Maybe a)
| MatchChan (STM a)
deriving (Functor)
type MatchChunks m a = [Either [m -> Maybe a] [STM a]]
chunkMatches :: [MatchOn m a] -> MatchChunks m a
chunkMatches [] = []
chunkMatches (MatchMsg m : ms) = Left (m : chk) : chunkMatches rest
where (chk, rest) = spanMatchMsg ms
chunkMatches (MatchChan r : ms) = Right (r : chk) : chunkMatches rest
where (chk, rest) = spanMatchChan ms
spanMatchMsg :: [MatchOn m a] -> ([m -> Maybe a], [MatchOn m a])
spanMatchMsg [] = ([],[])
spanMatchMsg (m : ms)
| MatchMsg msg <- m = (msg:msgs, rest)
| otherwise = ([], m:ms)
where !(msgs,rest) = spanMatchMsg ms
spanMatchChan :: [MatchOn m a] -> ([STM a], [MatchOn m a])
spanMatchChan [] = ([],[])
spanMatchChan (m : ms)
| MatchChan stm <- m = (stm:stms, rest)
| otherwise = ([], m:ms)
where !(stms,rest) = spanMatchChan ms
dequeue :: forall m a.
CQueue m
-> BlockSpec
-> [MatchOn m a]
-> IO (Maybe a)
dequeue (CQueue arrived incoming size) blockSpec matchons = mask_ $ decrementJust $
case blockSpec of
Timeout n -> timeout n $ fmap fromJust run
_other ->
case chunks of
[Right ports] ->
case blockSpec of
NonBlocking -> atomically $ waitChans ports (return Nothing)
_ -> atomically $ waitChans ports retry
_other -> run
where
decrementJust :: IO (Maybe (Either a a)) -> IO (Maybe a)
decrementJust f =
traverse (either return (\x -> decrement >> return x)) =<< f
decrement = atomically $ modifyTVar' size pred
chunks = chunkMatches matchons
run = do
arr <- takeMVar arrived
let grabNew xs = do
r <- atomically $ tryReadTChan incoming
case r of
Nothing -> return xs
Just x -> grabNew (Snoc xs x)
arr' <- grabNew arr
goCheck chunks arr'
waitChans :: [STM a] -> STM (Maybe (Either a a)) -> STM (Maybe (Either a a))
waitChans ports on_block =
foldr orElse on_block (map (fmap (Just . Left)) ports)
goCheck :: MatchChunks m a
-> StrictList m
-> IO (Maybe (Either a a))
goCheck [] old = goWait old
goCheck (Right ports : rest) old = do
r <- atomically $ waitChans ports (return Nothing)
case r of
Just _ -> returnOld old r
Nothing -> goCheck rest old
goCheck (Left matches : rest) old = do
case checkArrived matches old of
(old', Just r) -> returnOld old' (Just (Right r))
(old', Nothing) -> goCheck rest old'
mkSTM :: MatchChunks m a -> STM (Either m a)
mkSTM [] = retry
mkSTM (Left _ : rest)
= fmap Left (readTChan incoming) `orElse` mkSTM rest
mkSTM (Right ports : rest)
= foldr orElse (mkSTM rest) (map (fmap Right) ports)
waitIncoming :: IO (Maybe (Either m a))
waitIncoming = case blockSpec of
NonBlocking -> atomically $ fmap Just stm `orElse` return Nothing
_ -> atomically $ fmap Just stm
where
stm = mkSTM chunks
goWait :: StrictList m -> IO (Maybe (Either a a))
goWait old = do
r <- waitIncoming `onException` putMVar arrived old
case r of
Nothing -> returnOld old Nothing
Just e -> case e of
Left m -> goCheck1 chunks m old
Right a -> returnOld old (Just (Left a))
goCheck1 :: MatchChunks m a
-> m
-> StrictList m
-> IO (Maybe (Either a a))
goCheck1 [] m old = goWait (Snoc old m)
goCheck1 (Right ports : rest) m old = do
r <- atomically $ waitChans ports (return Nothing)
case r of
Nothing -> goCheck1 rest m old
Just _ -> returnOld (Snoc old m) r
goCheck1 (Left matches : rest) m old = do
case checkMatches matches m of
Nothing -> goCheck1 rest m old
Just p -> returnOld old (Just (Right p))
returnOld :: StrictList m -> Maybe (Either a a) -> IO (Maybe (Either a a))
returnOld old r = do putMVar arrived old; return r
checkArrived :: [m -> Maybe a] -> StrictList m -> (StrictList m, Maybe a)
checkArrived matches list = go list Nil
where
go Nil Nil = (Nil, Nothing)
go Nil r = go r Nil
go (Append xs ys) tl = go xs (append ys tl)
go (Snoc xs x) tl = go xs (Cons x tl)
go (Cons x xs) tl
| Just y <- checkMatches matches x = (append xs tl, Just y)
| otherwise = let !(rest,r) = go xs tl in (Cons x rest, r)
checkMatches :: [m -> Maybe a] -> m -> Maybe a
checkMatches [] _ = Nothing
checkMatches (m:ms) a = case m a of Nothing -> checkMatches ms a
Just b -> Just b
mkWeakCQueue :: CQueue a -> IO () -> IO (Weak (CQueue a))
mkWeakCQueue m@(CQueue (StrictMVar (MVar m#)) _ _) f = IO $ \s ->
#if MIN_VERSION_base(4,9,0)
case mkWeak# m# m (unIO f) s of (# s1, w #) -> (# s1, Weak w #)
#else
case mkWeak# m# m f s of (# s1, w #) -> (# s1, Weak w #)
#endif
queueSize :: CQueue a -> IO Int
queueSize (CQueue _ _ size) = readTVarIO size