module Data.Concurrent.Deque.Tests
(
test_fifo_filldrain, test_fifo_OneBottleneck, tests_fifo,
test_ws_triv1, test_ws_triv2, tests_wsqueue,
tests_all,
numElems, getNumAgents, producerRatio,
setTestThreads,
stdTestHarness
)
where
import Data.Concurrent.Deque.Class as C
import qualified Data.Concurrent.Deque.Reference as R
import Control.Monad
import Data.Array as A
import Data.IORef
import Data.Int
import qualified Data.Set as S
import Text.Printf
import GHC.Conc (throwTo, threadDelay, myThreadId)
import Control.Concurrent.MVar
import Control.Concurrent (yield, forkOS, forkIO, ThreadId)
import Control.Exception (catch, SomeException, fromException, bracket, AsyncException(ThreadKilled))
import System.Environment (withArgs, getArgs, getEnvironment)
import System.IO (hPutStrLn, stderr, hFlush, stdout)
import System.IO.Unsafe (unsafePerformIO)
import System.Random (randomIO, randomRIO)
import qualified Test.Framework as TF
import Test.Framework.Providers.HUnit (hUnitTestToTests)
import Test.HUnit as HU
import Debug.Trace (trace)
#if __GLASGOW_HASKELL__ >= 704
import GHC.Conc (getNumCapabilities, setNumCapabilities, getNumProcessors)
#else
import GHC.Conc (numCapabilities)
getNumCapabilities :: IO Int
getNumCapabilities = return numCapabilities
setNumCapabilities :: Int -> IO ()
setNumCapabilities = error "setNumCapabilities not supported in this older GHC! Set NUMTHREADS and +RTS -N to match."
getNumProcessors :: IO Int
getNumProcessors = return 1
#endif
theEnv :: [(String, String)]
theEnv = unsafePerformIO getEnvironment
numElems :: Int
numElems = case lookup "NUMELEMS" theEnv of
Nothing -> 100 * 1000
Just str -> warnUsing ("NUMELEMS = "++str) $
read str
forkThread :: IO () -> IO ThreadId
forkThread = case lookup "OSTHREADS" theEnv of
Nothing -> forkIO
Just x -> warnUsing ("OSTHREADS = "++x) $
case x of
"0" -> forkIO
"False" -> forkIO
"1" -> forkOS
"True" -> forkOS
oth -> error$"OSTHREAD environment variable set to unrecognized option: "++oth
getNumAgents :: IO Int
getNumAgents = case lookup "NUMAGENTS" theEnv of
Nothing -> getNumCapabilities
Just str -> warnUsing ("NUMAGENTS = "++str) $
return (read str)
producerRatio :: Double
producerRatio = case lookup "PRODUCERRATIO" theEnv of
Nothing -> 1.0
Just str -> warnUsing ("PRODUCERRATIO = "++str) $
read str
warnUsing :: String -> a -> a
warnUsing str a = trace (" [Warning]: Using environment variable "++str) a
setTestThreads :: Int -> HU.Test -> HU.Test
setTestThreads nm tst = loop False tst
where
loop flg x =
case x of
TestLabel lb t2 -> TestLabel (decor flg lb) (loop True t2)
TestList ls -> TestList (map (loop flg) ls)
TestCase io -> TestCase (bracketThreads nm io)
decor False lb = "N"++show nm++"_"++ lb
decor True lb = lb
bracketThreads :: Int -> IO a -> IO a
bracketThreads n act =
bracket (getNumCapabilities)
setNumCapabilities
(\_ -> do dbgPrint 1 ("\n [Setting # capabilities to "++show n++" before test] \n")
setNumCapabilities n
act)
stdTestHarness :: (IO Test) -> IO ()
stdTestHarness genTests = do
numAgents <- getNumAgents
putStrLn$ "Running with numElems "++show numElems++" and numAgents "++ show numAgents
putStrLn "Use NUMELEMS, NUMAGENTS, NUMTHREADS to control the size of this benchmark."
args <- getArgs
np <- getNumProcessors
putStrLn $"Running on a machine with "++show np++" hardware threads."
let all_threads = case lookup "NUMTHREADS" theEnv of
Just str -> [read str]
Nothing -> S.toList$ S.fromList$
[1, 2, np `quot` 2, np, 2*np ]
putStrLn $"Running tests for these thread settings: " ++show all_threads
all_tests <- genTests
withArgs (args ++ ["-j1","--jxml=test-results.xml"]) $ do
tests <- case all_threads of
[one] -> do cap <- getNumCapabilities
unless (cap == one) $ setNumCapabilities one
return all_tests
_ -> return$ TestList [ setTestThreads n all_tests | n <- all_threads ]
TF.defaultMain$ hUnitTestToTests tests
type Elt = Int64
test_fifo_filldrain :: DequeClass d => d Elt -> IO ()
test_fifo_filldrain q =
do
dbgPrintLn 1 "\nTest FIFO queue: sequential fill and then drain"
dbgPrintLn 1 "==============================================="
let n = fromIntegral numElems
dbgPrintLn 1$ "Done creating queue. Pushing "++show n++" elements:"
forM_ [1..n] $ \i -> do
pushL q i
when (i < 200) $ dbgPrint 1 $ printf " %d" i
dbgPrintLn 1 "\nDone filling queue with elements. Now popping..."
let loop 0 !sumR = return sumR
loop i !sumR = do
(x,_) <- spinPopBkoff q
when (i < 200) $ dbgPrint 1 $ printf " %d" x
loop (i1) (sumR + x)
s <- loop n 0
let expected = sum [1..n] :: Elt
dbgPrint 1 $ printf "\nSum of popped vals: %d should be %d\n" s expected
when (s /= expected) (assertFailure "Incorrect sum!")
return ()
test_fifo_OneBottleneck :: DequeClass d => Bool -> Int -> d Elt -> IO ()
test_fifo_OneBottleneck doBackoff total q =
do
assertBool "test_fifo_OneBottleneck requires thread safe left end" (leftThreadSafe q)
assertBool "test_fifo_OneBottleneck requires thread safe right end" (rightThreadSafe q)
dbgPrintLn 1$ "\nTest FIFO queue: producers & consumers thru 1 queue"
++(if doBackoff then " (with backoff)" else "(hard busy wait)")
dbgPrintLn 1 "======================================================"
bl <- nullQ q
dbgPrintLn 1$ "Check that queue is initially null: "++show bl
numAgents <- getNumAgents
let producers = max 1 (round$ producerRatio * (fromIntegral numAgents) / (producerRatio + 1))
consumers = max 1 (numAgents producers)
perthread = total `quot` producers
(perthread2,remain) = total `quotRem` consumers
numCap <- getNumCapabilities
when (not doBackoff && (numCap == 1 || numCap < producers + consumers)) $
error$ "The aggressively busy-waiting version of the test can only run with the right thread settings."
dbgPrint 1 $ printf "Forking %d producer threads, each producing %d elements.\n" producers perthread
dbgPrint 1 $ printf "Forking %d consumer threads, each consuming ~%d elements.\n" consumers perthread2
forM_ [0..producers1] $ \ ind ->
myfork "producer thread" $ do
let start = ind * perthread
dbgPrint 1 $ printf " * Producer thread %d pushing ints from %d to %d \n" ind start (start+perthread 1)
forI_ start (start+perthread) $ \ i -> do
pushL q i
ls <- forkJoin consumers $ \ ind ->
let mymax = if ind==0 then perthread2 + remain else perthread2
consume_loop summ maxiters i | i == mymax = return (summ, maxiters)
consume_loop !summ !maxiters i = do
(x,iters) <- if doBackoff then spinPopBkoff q
else spinPopHard q
when (i >= mymax 10) $ dbgPrint 1 $ printf " [c%d] popped #%d = %d \n" ind i x
consume_loop (summ+x) (max maxiters iters) (i+1)
in consume_loop 0 0 0
let finalSum = Prelude.sum (map fst ls)
dbgPrintLn 1$ "Consumers DONE. Maximum retries for each consumer thread: "++ show (map snd ls)
dbgPrintLn 1$ "Final sum: "++ show finalSum
assertEqual "Correct final sum" (expectedSum (fromIntegral$ producers * perthread)) finalSum
dbgPrintLn 1$ "Checking that queue is finally null..."
b <- nullQ q
if b then dbgPrintLn 1$ "Sum matched expected, test passed."
else assertFailure "Queue was not empty!!"
test_contention_free_parallel :: DequeClass d => Bool -> Int -> IO (d Elt) -> IO ()
test_contention_free_parallel doBackoff total newqueue =
do
dbgPrintLn 1$ "\nTest FIFO queue: producers & consumers thru N queues"
++(if doBackoff then " (with backoff)" else "(hard busy wait)")
dbgPrintLn 1 "======================================================"
mv <- newEmptyMVar
numAgents <- getNumAgents
let producers = max 1 (round$ producerRatio * (fromIntegral numAgents) / (producerRatio + 1))
consumers = producers
perthread = total `quot` producers
qs <- sequence (replicate consumers newqueue)
numCap <- getNumCapabilities
when (not doBackoff && (numCap == 1 || numCap < producers + consumers)) $
error$ "The aggressively busy-waiting version of the test can only run with the right thread settings."
dbgPrint 1 $ printf "Forking %d producer threads, each producing %d elements.\n" producers perthread
dbgPrint 1 $ printf "Forking %d consumer threads, each consuming %d elements.\n" consumers perthread
forM_ (zip [0..producers1] qs) $ \ (id, q) ->
myfork "producer thread" $
let start = id*perthread in
forI_ 0 perthread $ \ i -> do
pushL q i
forM_ (zip [0..consumers1] qs) $ \ (id, q) ->
myfork "consumer thread" $ do
let consume_loop sum maxiters i | i == perthread = return (sum, maxiters)
consume_loop !sum !maxiters i = do
(x,iters) <- if doBackoff then spinPopBkoff q
else spinPopHard q
when (i < 10) $ dbgPrint 1 $ printf " [%d] popped #%d = %d \n" id i x
unless (x == fromIntegral i) $
error $ "Message out of order! Expected "++show i++" recevied "++show x
consume_loop (sum+x) (max maxiters iters) (i+1)
pr <- consume_loop 0 0 0
putMVar mv pr
dbgPrint 1 $ printf "Reading sums from MVar...\n"
ls <- mapM (\_ -> takeMVar mv) [1..consumers]
let finalSum = Prelude.sum (map fst ls)
dbgPrintLn 1$ "Consumers DONE. Maximum retries for each consumer thread: "++ show (map snd ls)
dbgPrintLn 1$ "All messages received in order. Final sum: "++ show finalSum
assertEqual "Correct final sum" (fromIntegral producers * expectedSum (fromIntegral perthread)) finalSum
dbgPrintLn 1$ "Checking that queue is finally null..."
bs <- mapM nullQ qs
if all id bs
then dbgPrintLn 1$ "Sum matched expected, test passed."
else assertFailure "Queue was not empty!!"
test_random_array_comm :: DequeClass d => Int -> Int -> IO (d Elt) -> IO ()
test_random_array_comm size total newqueue = do
assertBool "positive size" (size > 0)
qs <- sequence (replicate size newqueue)
let arr = A.listArray (0,size1) qs
assertBool "test_random_array_comm requires thread safe left end" (leftThreadSafe (head qs))
assertBool "test_random_array_comm requires thread safe right end" (rightThreadSafe (head qs))
dbgPrintLn 1$ "\nTest FIFO queue: producers & consumers select random queues"
dbgPrintLn 1 "======================================================"
numAgents <- getNumAgents
let producers = max 1 (round$ producerRatio * (fromIntegral numAgents) / (producerRatio + 1))
consumers = max 1 (numAgents producers)
perthread = fromIntegral (total `quot` producers)
(perthread2,remain) = total `quotRem` consumers
dbgPrintLn 1 $ printf "Forking %d producer threads, each producing %d elements.\n" producers perthread
dbgPrintLn 1 $ printf "Forking %d consumer threads, each consuming ~%d elements.\n" consumers perthread2
forM_ [0..producers1] $ \ ind ->
myfork "producer thread" $
for_ 0 perthread $ \ i -> do
ix <- randomRIO (0,size 1) :: IO Int
pushL (arr ! ix) i
sums <- forkJoin consumers $ \ ind ->
let mymax = if ind==0 then perthread2 + remain else perthread2
consume_loop summ i | i == mymax = return summ
consume_loop !summ i = do
ix <- randomRIO (0,size1) :: IO Int
m <- spinPopN 100 (tryPopR (arr ! ix))
case m of
Just x -> do
when (i < 10) $ dbgPrint 1 $ printf " [%d] popped #%d = %d\n" ind i x
consume_loop (summ+x) (i+1)
Nothing ->
consume_loop summ i
in consume_loop 0 0
dbgPrintLn 1 "Reading sums from MVar..."
let finalSum = Prelude.sum sums
dbgPrintLn 1$ "Final sum: "++ show finalSum ++ ", per-consumer sums: "++show sums
dbgPrintLn 1$ "Checking that queue is finally null..."
assertEqual "Correct final sum" (fromIntegral producers * expectedSum perthread) finalSum
bs <- mapM nullQ qs
if all id bs
then dbgPrintLn 1$ "Sum matched expected, test passed."
else assertFailure "Queue was not empty!!"
expectedSum :: Integral a => a -> a
expectedSum n = (n * (n 1)) `quot` 2
tests_fifo :: DequeClass d => (forall elt. IO (d elt)) -> Test
tests_fifo newq = TestLabel "single-ended-queue-tests"$ TestList $
tests_basic newq ++
tests_fifo_exclusive newq
tests_fifo_exclusive :: DequeClass d => (forall elt. IO (d elt)) -> [Test]
tests_fifo_exclusive newq =
[ TestLabel "test_fifo_OneBottleneck_backoff" (TestCase$ assert $ newq >>= test_fifo_OneBottleneck True numElems)
] ++
[ TestLabel ("test_random_array_comm_"++show size)
(TestCase$ assert $ test_random_array_comm size numElems newq)
| size <- [10,100]
]
tests_basic :: DequeClass d => (forall elt. IO (d elt)) -> [Test]
tests_basic newq =
[ TestLabel "test_fifo_filldrain" (TestCase$ assert $ newq >>= test_fifo_filldrain)
, TestLabel "test_contention_free_parallel" (TestCase$ assert $ test_contention_free_parallel True numElems newq)
]
test_ws_triv1 :: PopL d => d [Char] -> IO ()
test_ws_triv1 q = do
pushL q "hi"
Just x <- tryPopL q
assertEqual "test_ws_triv1" x "hi"
test_ws_triv2 :: PopL d => d [Char] -> IO ()
test_ws_triv2 q = do
pushL q "one"
pushL q "two"
pushL q "three"
pushL q "four"
ls <- sequence [tryPopR q, tryPopR q,
tryPopL q, tryPopL q,
tryPopL q, tryPopR q ]
assertEqual "test_ws_triv2" ls
[Just "one",Just "two",Just "four",Just "three",Nothing,Nothing]
test_random_work_stealing :: (DequeClass d, PopL d) => Int -> IO (d Elt) -> IO ()
test_random_work_stealing total newqueue = do
dbgPrintLn 1$ "\nTest FIFO queue: producers & consumers select random queues"
dbgPrintLn 1 "======================================================"
numAgents <- getNumAgents
let producers, consumers :: Int
perthread, realtotal, perthread2 :: Elt
producers = max 1 (round$ producerRatio * (fromIntegral numAgents) / (producerRatio + 1))
consumers = max 1 (fromIntegral numAgents producers)
perthread = fromIntegral total `quot` fromIntegral producers
realtotal = perthread * fromIntegral producers
(perthread2,remain) = realtotal `quotRem` fromIntegral consumers
qs <- sequence (replicate producers newqueue)
assertBool "test_random_array_comm requires thread safe right end" (rightThreadSafe (head qs))
let arr = A.listArray (0,producers 1) qs
dbgPrint 1 $ printf "Forking %d producer threads, each producing %d elements.\n" producers perthread
dbgPrint 1 $ printf "Forking %d consumer threads, each consuming ~%d elements.\n" consumers perthread2
prod_results <- newEmptyMVar
forM_ (zip [0..producers1] qs) $ \ (ind,myQ) -> do
myfork "producer thread" $
let loop :: Elt -> Elt -> Elt -> IO ()
loop i !pops !acc
| i == perthread = putMVar prod_results (pops,acc)
| otherwise = do
b <- randomIO :: IO Bool
if b then do
x <- spinPopN 100 (tryPopL myQ)
case x of
Nothing -> loop i pops acc
Just n -> loop i (pops+1) (n+acc)
else do
pushL myQ i
loop (i+1) pops acc
in loop 0 0 0
consumer_sums <- forkJoin consumers $ \ ind ->
let mymax = if ind==0 then perthread2 + remain else perthread2
consume_loop !summ !successes i
| i == mymax = return (successes, summ)
| otherwise = do
ix <- randomRIO (0,producers 1) :: IO Int
m <- spinPopN 100 (tryPopR (arr ! ix))
case m of
Just x -> do
when (i < 10) $ dbgPrint 1 $ printf " [%d] popped try#%d = %d\n" ind i x
consume_loop (summ+x) (successes+1) (i+1)
Nothing ->
consume_loop summ successes (i+1)
in do dbgPrintLn 1 (" Beginning consumer thread loop for "++show mymax ++" attempts.")
consume_loop 0 0 0
dbgPrintLn 1 "Reading sums from MVar..."
prod_ls <- mapM (\_ -> takeMVar prod_results) [1..producers]
leftovers <- forM qs $ \ q ->
let loop !cnt !acc = do
x <- tryPopR q
case x of
Nothing -> return (cnt,acc)
Just n -> loop (cnt+1) (acc+n)
in loop 0 0
let finalSum = Prelude.sum (map snd consumer_sums ++
map snd prod_ls ++
map snd leftovers)
dbgPrintLn 0$ "Final sum: "++ show finalSum ++ ", producer/consumer/leftover sums: "++show (prod_ls, consumer_sums, leftovers)
dbgPrintLn 1$ "Total pop events: "++ show (Prelude.sum (map fst consumer_sums ++
map fst prod_ls ++
map fst leftovers))
++" should be "++ (show$ realtotal)
dbgPrintLn 1$ "Checking that queue is finally null..."
assertEqual "Correct final sum" (fromIntegral producers * expectedSum perthread) finalSum
bs <- mapM nullQ qs
if all id bs
then dbgPrintLn 1$ "Sum matched expected, test passed."
else assertFailure "Queue was not empty!!"
tests_wsqueue :: (PopL d) => (forall elt. IO (d elt)) -> Test
tests_wsqueue newq = TestLabel "work-stealing-deque-tests"$ TestList $
tests_wsqueue_exclusive newq ++
tests_basic newq
tests_wsqueue_exclusive :: (PopL d) => (forall elt. IO (d elt)) -> [Test]
tests_wsqueue_exclusive newq =
[ TestLabel "test_ws_triv1" (TestCase$ assert $ newq >>= test_ws_triv1)
, TestLabel "test_ws_triv2" (TestCase$ assert $ newq >>= test_ws_triv2)
, TestLabel "test_random_work_stealing" (TestCase$ assert $ test_random_work_stealing numElems newq)
]
tests_all :: (PopL d) => (forall elt. IO (d elt)) -> Test
tests_all newq = TestLabel "full-deque-tests"$ TestList $
tests_basic newq ++
tests_fifo_exclusive newq ++
tests_wsqueue_exclusive newq
myfork :: String -> IO () -> IO ThreadId
myfork msg = forkWithExceptions forkThread msg
forkWithExceptions :: (IO () -> IO ThreadId) -> String -> IO () -> IO ThreadId
forkWithExceptions forkit descr action = do
parent <- myThreadId
forkit $
Control.Exception.catch action
(\ e ->
case fromException e of
Just ThreadKilled -> return ()
_ -> do
hPutStrLn stderr $ "Exception inside child thread "++show descr++": "++show e
throwTo parent (e::SomeException)
)
forkJoin :: Int -> (Int -> IO b) -> IO [b]
forkJoin numthreads action =
do
answers <- sequence (replicate numthreads newEmptyMVar)
forM_ (zip [0..] answers) $ \ (ix,mv) ->
myfork "forkJoin worker" (action ix >>= putMVar mv)
ls <- mapM readMVar answers
return ls
spinPopBkoff :: DequeClass d => d t -> IO (t, Int)
spinPopBkoff q = loop 1
where
hardspinfor = 10
sleepevery = 1000
warnafter = 5000
errorafter = 1 * 1000 * 1000
loop n = do
when (n == warnafter)
(dbgPrintLn 0$ "Warning: Failed to pop "++ show warnafter ++
" times consecutively. That shouldn't happen in this benchmark.")
x <- tryPopR q
case x of
Nothing -> do
if n `mod` sleepevery == 0
then do dbgPrint 1 "!"
threadDelay n
else when (n > hardspinfor) $ do
dbgPrint 1 "."
hFlush stdout
yield
loop (n+1)
Just x -> return (x, n)
spinPopHard :: (DequeClass d) => d t -> IO (t, Int)
spinPopHard q = loop 1
where
loop n = do
x <- tryPopR q
case x of
Nothing -> do loop (n+1)
Just x -> return (x, n)
spinPopN :: Int -> IO (Maybe t) -> IO (Maybe t)
spinPopN 0 _ = return Nothing
spinPopN tries act = do
x <- act
case x of
Nothing -> spinPopN (tries1) act
res@(Just _) -> return res
for_ :: Monad m => Elt -> Elt -> (Elt -> m ()) -> m ()
for_ start end _fn | start > end = error "for_: start is greater than end"
for_ start end fn = loop start
where
loop !i | i == end = return ()
| otherwise = do fn i; loop (i+1)
forI_ :: Monad m => Int -> Int -> (Elt -> m ()) -> m ()
forI_ st en = for_ (fromIntegral st) (fromIntegral en)
dbg :: Int
dbg = case lookup "DEBUG" theEnv of
Nothing -> defaultDbg
Just "" -> defaultDbg
Just "0" -> defaultDbg
Just s ->
trace (" ! Responding to env Var: DEBUG="++s)$
case reads s of
((n,_):_) -> n
[] -> error$"Attempt to parse DEBUG env var as Int failed: "++show s
defaultDbg :: Int
defaultDbg = 0
dbgPrint :: Int -> String -> IO ()
dbgPrint lvl str = if dbg < lvl then return () else do
printf str
hFlush stdout
dbgPrintLn :: Int -> String -> IO ()
dbgPrintLn lvl str = dbgPrint lvl (str++"\n")