module Control.Concurrent.FairRWLock
( RWLock, RWLockException(..), RWLockExceptionKind(..),FRW(..),LockKind(..),TMap,TSet
, new
, withRead, withWrite
, acquireRead, acquireWrite
, releaseRead, releaseWrite
, peekLock, checkLock
) where
import Control.Applicative(liftA2)
import Control.Concurrent
import Control.Exception(Exception,bracket_,onException,evaluate,uninterruptibleMask_,mask_,throw)
import Control.Monad((>=>),join,forM_)
import Data.Sequence((<|),(|>),(><),Seq,ViewL(..),ViewR(..))
import qualified Data.Sequence as Seq(empty,viewl,viewr,breakl,spanl)
import qualified Data.Foldable as F(toList)
import Data.Map(Map)
import qualified Data.Map as Map
import Data.Set(Set)
import qualified Data.Set as Set
import Data.Typeable(Typeable)
type TMap = Map ThreadId Int
type TSet = Set ThreadId
data LockKind = ReaderKind { unRK :: TSet }
| WriterKind { unWK :: ThreadId }
deriving (Eq,Ord,Show)
type LockQ = Seq (LockKind,MVar ())
data LockUser =
FreeLock
| Readers { readerCounts :: TMap
, queueR :: Maybe ( (ThreadId,MVar ())
, LockQ )
}
| Writer { writerID :: ThreadId
, writerCount
, readerCount :: !Int
, queue :: LockQ }
deriving (Eq,Typeable)
newtype RWLock = RWL (MVar LockUser)
data RWLockException = RWLockException ThreadId RWLockExceptionKind String
deriving (Show,Typeable)
data RWLockExceptionKind = RWLock'acquireWrite | RWLock'releaseWrite
| RWLock'acquireRead | RWLock'releaseRead
deriving (Show,Typeable)
instance Exception RWLockException
data FRW = F | R TMap | W (ThreadId,(Int,Int)) deriving (Show)
new :: IO RWLock
new = fmap RWL (newMVar FreeLock)
withRead :: RWLock -> IO a -> IO a
withRead = liftA2 bracket_ acquireRead (releaseRead >=> either throw return)
withWrite :: RWLock -> IO a -> IO a
withWrite = liftA2 bracket_ acquireWrite (releaseWrite >=> either throw return)
peekLock :: RWLock -> IO (FRW,[LockKind])
peekLock (RWL rwlVar) = withMVar rwlVar $ \ rwd -> return $
case rwd of
FreeLock -> (F,[])
Readers { readerCounts=rcs, queueR=qr } -> (R rcs,maybe [] (\((t,_),q) -> WriterKind t : map fst (F.toList q)) qr)
Writer { writerID=it, writerCount=wc, readerCount=rc, queue=q } -> (W (it,(rc,wc)), map fst (F.toList q))
checkLock :: RWLock -> IO (Int,Int)
checkLock (RWL rwlVar) = do
me <- myThreadId
withMVar rwlVar $ \ rwd -> return $
case rwd of
FreeLock -> (0,0)
Readers { readerCounts=rcs } ->
case Map.lookup me rcs of
Nothing -> (0,0)
Just rc -> (rc,0)
Writer { writerID=it, writerCount=wc, readerCount=rc } ->
if it==me then (rc,wc) else (0,0)
releaseRead :: RWLock -> IO (Either RWLockException ())
releaseRead (RWL rwlVar) = mask_ $ do
me <- myThreadId
releaseRead' False me rwlVar
releaseRead' :: Bool -> ThreadId -> MVar LockUser -> IO (Either RWLockException ())
releaseRead' abandon me rwlVar = uninterruptibleMask_ . modifyMVar rwlVar $ \ rwd -> do
let impossible :: Show x => String -> x -> IO a
impossible s x = throw
(RWLockException me (if abandon then RWLock'acquireRead else RWLock'releaseRead) (imp s x))
err :: Show x => String -> x -> IO (LockUser,Either RWLockException ())
err s x = return . ((,) rwd) . Left $
(RWLockException me (if abandon then RWLock'acquireRead else RWLock'releaseRead) (s++" : "++show x))
ret :: LockUser -> IO (LockUser,Either RWLockException ())
ret x = return (x,Right ())
dropReader :: LockQ -> IO LockQ
dropReader q = do
let inR (ReaderKind rcs,_) = Set.member me rcs
inR _ = False
(pre,myselfPost) = Seq.breakl inR q
case Seq.viewl myselfPost of
EmptyL ->
impossible "failure to abandon acquireRead, RWLock locked by other thread(s) and this thread is not in queue" me
(myself,mblock) :< post -> do
let rcs' = Set.delete me (unRK myself)
evaluate $ if Set.null rcs' then pre >< post else pre >< ((ReaderKind rcs',mblock) <| post)
case rwd of
FreeLock | abandon ->
impossible "acquireRead interrupted with unlocked RWLock" me
| otherwise ->
err "cannot releaseRead lock from unlocked RWLock" me
w@(Writer { writerID=it, readerCount=rc, queue=q }) | it==me -> do
case rc of
0 | abandon ->
impossible "acquireRead interrupted with write lock but not read lock" (me,it)
| otherwise ->
err "releaseRead when holding write lock but not read lock" (me,it)
_ -> do
rc' <- evaluate $ pred rc
ret (w { readerCount=rc' })
| abandon -> do
q' <- dropReader q
ret (w { queue=q' })
| otherwise ->
err "releaseRead called when not read locked " me
r@(Readers { readerCounts=rcs,queueR=qR }) ->
case Map.lookup me rcs of
Just 1 -> do
let rcs' = Map.delete me rcs
if Map.null rcs'
then case qR of
Nothing ->
ret FreeLock
Just ((wid,mblock),q) -> do
putMVar mblock ()
ret (Writer { writerID=wid, writerCount=1, readerCount=0, queue=q })
else ret (r { readerCounts=rcs' })
Just rc -> do
rc' <- evaluate $ pred rc
rcs' <- evaluate $ Map.insert me rc' rcs
ret (r { readerCounts=rcs' })
Nothing | abandon ->
case qR of
Nothing ->
impossible "acquireRead interrupted not holding lock and with no queue" (me,rcs)
Just (w,q) -> do
q' <- dropReader q
ret (r { queueR = Just (w,q') })
| otherwise ->
err "releaseRead called with read lock held by others" (me,rcs)
releaseWrite :: RWLock -> IO (Either RWLockException ())
releaseWrite (RWL rwlVar) = mask_ $ do
me <- myThreadId
releaseWrite' False me rwlVar
releaseWrite' :: Bool -> ThreadId -> MVar LockUser -> IO (Either RWLockException ())
releaseWrite' abandon me rwlVar = uninterruptibleMask_ . modifyMVar rwlVar $ \ rwd -> do
let impossible :: Show x => String -> x -> IO a
impossible s x = throw
(RWLockException me (if abandon then RWLock'acquireWrite else RWLock'releaseWrite) (imp s x))
err :: Show x => String -> x -> IO (LockUser,Either RWLockException ())
err s x = return . ((,) rwd) . Left $
(RWLockException me (if abandon then RWLock'acquireWrite else RWLock'releaseWrite) (s++" : "++show x))
ret :: LockUser -> IO (LockUser,Either RWLockException ())
ret x = return (x,Right ())
dropWriter :: LockQ -> IO LockQ
dropWriter q = do
let inW (WriterKind it,_) = me==it
inW _ = False
(pre,myselfPost) = Seq.breakl inW q
case Seq.viewl myselfPost of
EmptyL ->
impossible "failure to abandon acquireWrite, RWLock locked by other and not in queue" me
_ :< post ->
evaluate $ pre><post
case rwd of
FreeLock | abandon ->
impossible "acquireWrite interrupted with unlocked RWLock" me
| otherwise ->
err "cannot releaseWrite lock from unlocked RWLock" me
w@(Writer { writerID=it, writerCount=wc, readerCount=rc, queue=q }) | it==me -> do
case (wc,rc) of
(1,0) -> ret =<< promote q
_ | abandon -> impossible "acquireWrite interrupted with write lock and bad RWLock state" (me,it,wc,rc)
(1,_) -> ret =<< promoteReader rc q
(_,_) -> ret (w { writerCount=(pred wc) })
| abandon -> do
q' <- dropWriter q
ret (w { queue=q' })
| otherwise -> do
err "cannot releaseWrite when not not holding the write lock" (me,it)
Readers { readerCounts=rcs} | not abandon ->
err "cannot releaseWrite when RWLock is read locked" (me,rcs)
Readers { readerCounts=rcs, queueR=Nothing } ->
impossible "failure to abandon acquireWrite, RWLock read locked and no queue" (me,rcs)
r@(Readers { readerCounts=rcs, queueR=Just (w@(it,_),q) }) | it==me -> do
(rcs'new,qr) <- splitReaders q
ret (r { readerCounts=Map.union rcs rcs'new, queueR=qr })
| otherwise -> do
q' <- dropWriter q
ret (r { queueR=Just (w,q') })
where
promoteReader :: Int -> LockQ -> IO LockUser
promoteReader rc q = do
(rcs'new, qr) <- splitReaders q
let rcs = Map.insert me rc rcs'new
return (Readers { readerCounts=rcs, queueR=qr })
promote :: LockQ -> IO LockUser
promote qIn = do
case Seq.viewl qIn of
EmptyL -> return FreeLock
(WriterKind it,mblock) :< qOut -> do
putMVar mblock ()
return (Writer { writerID=it, writerCount=1, readerCount=0, queue=qOut })
_ -> do
(rcs,qr) <- splitReaders qIn
return (Readers { readerCounts=rcs, queueR=qr })
splitReaders :: LockQ -> IO (TMap,Maybe ((ThreadId,MVar ()),LockQ))
splitReaders qIn = do
let (more'Readers,qTail) = Seq.spanl isReader qIn
(rks,mblocks) = unzip (F.toList more'Readers)
rcs = Map.fromDistinctAscList . map (\k -> (k,1)) . F.toList . Set.unions . map unRK $ rks
qr = case Seq.viewl qTail of
EmptyL -> Nothing
(wk,mblock) :< qOut -> Just ((unWK wk,mblock),qOut)
forM_ mblocks (\mblock -> putMVar mblock ())
return (rcs,qr)
where
isReader (ReaderKind {},_) = True
isReader _ = False
acquireRead :: RWLock -> IO ()
acquireRead (RWL rwlVar) = mask_ . join . modifyMVar rwlVar $ \ rwd -> do
me <- myThreadId
let safeBlock mblock = (readMVar mblock) `onException` (releaseRead' True me rwlVar)
case rwd of
FreeLock ->
return ( Readers { readerCounts=Map.singleton me 1, queueR=Nothing }
, return () )
w@(Writer { writerID=it, readerCount=rc, queue=q }) | it == me -> do
rc' <- evaluate $ succ rc
return ( w { readerCount=rc' }
, return () )
| otherwise -> do
(q',mblock) <- enterQueueR q me
return ( w { queue = q' }
, safeBlock mblock )
r@(Readers { readerCounts=rcs }) | Just rc <- Map.lookup me rcs -> do
rc' <- evaluate $ succ rc
rcs' <- evaluate $ Map.insert me rc' rcs
return ( r { readerCounts=rcs' }
, return () )
r@(Readers { readerCounts=rcs, queueR=Nothing }) -> do
rcs' <- evaluate $ Map.insert me 1 rcs
return ( r { readerCounts=rcs' }
, return () )
r@(Readers { queueR=Just (w,q) }) -> do
(q',mblock) <- enterQueueR q me
return ( r { queueR=Just (w,q') }
, safeBlock mblock )
where
enterQueueR :: LockQ -> ThreadId -> IO (LockQ,MVar ())
enterQueueR qIn me = do
case Seq.viewr qIn of
pre :> (ReaderKind rcs,mblock) -> do
rcs' <- addMe rcs
return (pre |> (ReaderKind rcs', mblock),mblock)
_ -> do
mblock <- newEmptyMVar
return (qIn |> (ReaderKind (Set.singleton me),mblock), mblock)
where
addMe :: TSet -> IO TSet
addMe rcs | Set.member me rcs = error (imp "enterQueueR.addMe when already in set" me)
| otherwise = return (Set.insert me rcs)
acquireReadPriority :: RWLock -> IO ()
acquireReadPriority (RWL rwlVar) = uninterruptibleMask_ . join . modifyMVar rwlVar $ \ rwd -> do
me <- myThreadId
let safeBlock mblock = (readMVar mblock) `onException` (releaseRead' True me rwlVar)
case rwd of
FreeLock ->
return ( Readers { readerCounts=Map.singleton me 1, queueR=Nothing }
, return () )
w@(Writer { writerID=it, readerCount=rc, queue=q }) | it == me -> do
rc' <- evaluate $ succ rc
return ( w { readerCount=rc' }
, return () )
| otherwise -> do
(q',mblock) <- enterQueueL me q
return ( w { queue = q' }
, safeBlock mblock )
r@(Readers { readerCounts=rcs }) -> do
case Map.lookup me rcs of
Just rc -> do
rc' <- evaluate $ succ rc
rcs' <- evaluate $ Map.insert me rc' rcs
return ( r { readerCounts=rcs' }
, return () )
Nothing -> do
rcs' <- evaluate $ Map.insert me 1 rcs
return ( r { readerCounts=rcs' }
, return () )
where
enterQueueL :: ThreadId -> LockQ -> IO (LockQ,MVar ())
enterQueueL me qIn = do
case Seq.viewl qIn of
(ReaderKind rcs,mblock) :< post -> do
rcs' <- addMe rcs
return ((ReaderKind rcs', mblock) <| post,mblock)
_ -> do
mblock <- newEmptyMVar
return ((ReaderKind (Set.singleton me),mblock) <| qIn , mblock)
where
addMe :: TSet -> IO TSet
addMe rcs | Set.member me rcs = error (imp "enterQueueL.addMe when already in set" me)
| otherwise = return (Set.insert me rcs)
acquireWrite :: RWLock -> IO ()
acquireWrite rwl@(RWL rwlVar) = mask_ . join . modifyMVar rwlVar $ \ rwd -> do
me <- myThreadId
let safeBlock mblock = (takeMVar mblock) `onException` (releaseWrite' True me rwlVar)
case rwd of
FreeLock ->
return ( Writer { writerID=me, writerCount=1, readerCount=0, queue=Seq.empty }
, return () )
w@(Writer { writerID=it, writerCount=wc, queue=q }) | it==me ->
return ( w { writerCount=(succ wc) }
, return () )
| otherwise -> do
mblock <- newEmptyMVar
q' <- evaluate $ q |> (WriterKind me,mblock)
return ( w { queue=q' }
, safeBlock mblock )
Readers { readerCounts=rcs } | Just rc <- Map.lookup me rcs -> do
return ( rwd
, withoutReads rc (acquireWrite rwl) )
r@(Readers { queueR=Nothing }) -> do
mblock <- newEmptyMVar
let qr = Just ((me,mblock),Seq.empty)
return ( r { queueR=qr }
, safeBlock mblock )
r@(Readers { queueR=Just (w,q) }) -> do
mblock <- newEmptyMVar
q' <- evaluate $ q |> (WriterKind me,mblock)
return ( r { queueR=Just (w,q') }
, safeBlock mblock )
where
withoutReads :: Int -> IO a -> IO a
withoutReads n x = foldr (.) id (replicate n withoutRead) $ x
withoutRead :: IO a -> IO a
withoutRead = bracket_ (releaseRead rwl >>= either throw return) (acquireReadPriority rwl)
imp :: Show x => String -> x -> String
imp s x = "FairRWLock impossible error: "++s++" : "++show x