module Data.STM.RollingQueue (
RollingQueue,
new,
newIO,
write,
read,
tryRead,
isEmpty,
length,
setLimit,
getLimit,
checkInvariants,
CheckException(..),
dump,
) where
import Prelude hiding (length, read)
import qualified Prelude
import Control.Concurrent.STM hiding (check)
import Control.Exception (Exception)
import Control.Monad (join)
import Data.Typeable (Typeable)
data RollingQueue a = RQ (TVar (ReadEnd a)) (TVar (WriteEnd a))
deriving Typeable
instance Eq (RollingQueue a) where
(==) (RQ r1 _) (RQ r2 _) = r1 == r2
data ReadEnd a =
ReadEnd
{ readPtr :: !(TCell a)
, readCounter :: !Int
, readDiscarded :: !Int
}
data WriteEnd a =
WriteEnd
{ writePtr :: !(TCell a)
, writeCounter :: !Int
, sizeLimit :: !Int
}
type TCell a = TVar (TList a)
data TList a = TNil | TCons a (TCell a)
new :: Int -> STM (RollingQueue a)
new limit = do
hole <- newTVar TNil
rv <- newTVar $ ReadEnd hole 0 0
wv <- newTVar $ WriteEnd hole 0 (max 0 limit)
return (RQ rv wv)
newIO :: Int -> IO (RollingQueue a)
newIO limit = do
hole <- newTVarIO TNil
rv <- newTVarIO $ ReadEnd hole 0 0
wv <- newTVarIO $ WriteEnd hole 0 (max 0 limit)
return (RQ rv wv)
write :: RollingQueue a -> a -> STM ()
write rq@(RQ _ wv) x = do
w <- readTVar wv
new_hole <- newTVar TNil
writeTVar (writePtr w) (TCons x new_hole)
updateWriteEnd rq $ WriteEnd new_hole (writeCounter w + 1) (sizeLimit w)
read :: RollingQueue a -> STM (a, Int)
read rq = tryRead rq >>= maybe retry return
tryRead :: RollingQueue a -> STM (Maybe (a, Int))
tryRead (RQ rv _) = do
r <- readTVar rv
xs <- readTVar (readPtr r)
case xs of
TNil -> return Nothing
TCons x cell' -> do
writeTVar rv $ ReadEnd cell' (readCounter r + 1) 0
return $ Just (x, readDiscarded r)
isEmpty :: RollingQueue a -> STM Bool
isEmpty (RQ rv _) = do
r <- readTVar rv
xs <- readTVar (readPtr r)
case xs of
TNil -> return True
TCons _ _ -> return False
length :: RollingQueue a -> STM Int
length (RQ rv wv) = do
r <- readTVar rv
w <- readTVar wv
return (writeCounter w readCounter r)
setLimit :: RollingQueue a -> Int -> STM ()
setLimit rq@(RQ _ wv) new_limit = do
w <- readTVar wv
updateWriteEnd rq w{sizeLimit = max 0 new_limit}
getLimit :: RollingQueue a -> STM Int
getLimit (RQ _ wv) = do
w <- readTVar wv
return (sizeLimit w)
updateWriteEnd :: RollingQueue a -> WriteEnd a -> STM ()
updateWriteEnd (RQ rv wv) w
| writeCounter w <= sizeLimit w
= writeTVar wv w
| otherwise = do
r <- readTVar rv
(r', w') <- syncEnds r w
writeTVar rv r'
writeTVar wv w'
syncEnds :: ReadEnd a -> WriteEnd a -> STM (ReadEnd a, WriteEnd a)
syncEnds r w = do
let count = writeCounter w readCounter r
limit = sizeLimit w
if count > limit
then do
let drop_count = count limit
rp' <- dropItems drop_count (readPtr r)
return ( ReadEnd rp' 0 (readDiscarded r + drop_count)
, WriteEnd (writePtr w) limit limit
)
else
return ( ReadEnd (readPtr r) 0 (readDiscarded r)
, WriteEnd (writePtr w) count limit
)
dropItems :: Int -> TCell a -> STM (TCell a)
dropItems n cell
| n <= 0 = return cell
| otherwise = do
xs <- readTVar cell
case xs of
TNil -> return cell
TCons _ cell' -> dropItems (n1) cell'
data CheckException = CheckException String
deriving Typeable
instance Show CheckException where
show (CheckException msg) = "Data.STM.RollingQueue checkInvariants: " ++ msg
instance Exception CheckException
checkInvariants :: RollingQueue a -> STM ()
checkInvariants (RQ rv wv) = do
r <- readTVar rv
w <- readTVar wv
check (readCounter r >= 0) "readCounter >= 0"
check (readDiscarded r >= 0) "readDiscarded >= 0"
check (writeCounter w >= 0) "writeCounter >= 0"
check (sizeLimit w >= 0) "sizeLimit >= 0"
check (writeCounter w <= sizeLimit w) "writeCounter <= sizeLimit"
hole <- readTVar (writePtr w)
case hole of
TNil -> return ()
TCons _ _ -> throwSTM $ CheckException "writePtr does not point to a TNil"
check (writeCounter w >= readCounter r) "writeCounter >= readCounter"
len <- traverseLength (readPtr r)
check (writeCounter w readCounter r == len) "writeCounter - readCounter == length"
where
check b expr | b = return ()
| otherwise = throwSTM $ CheckException $ expr ++ " does not hold"
traverseLength :: TCell a -> STM Int
traverseLength = loop 0
where
loop !n cell = do
xs <- readTVar cell
case xs of
TNil -> return n
TCons _ cell' -> loop (n+1) cell'
getItems :: RollingQueue a -> STM [a]
getItems (RQ rv _) = do
r <- readTVar rv
loop id (readPtr r)
where
loop dl cell = do
xs <- readTVar cell
case xs of
TNil -> return $ dl []
TCons x cell' -> loop (dl . (x :)) cell'
getAttributes :: RollingQueue a -> STM [(String, String)]
getAttributes (RQ rv wv) = do
r <- readTVar rv
w <- readTVar wv
return [ ("readCounter", show $ readCounter r)
, ("readDiscarded", show $ readDiscarded r)
, ("writeCounter", show $ writeCounter w)
, ("sizeLimit", show $ sizeLimit w)
]
dump :: Show a => RollingQueue a -> IO ()
dump rq = join $ atomically $ do
checkInvariants rq
xs <- getItems rq
attrs <- getAttributes rq
return $ do
print xs
let c1width = maximum $ map (Prelude.length . fst) attrs
mapM_ putStrLn
[k ++ replicate (c1width Prelude.length k) ' ' ++ " = " ++ v | (k, v) <- attrs]