-- Copyright 2013 Kevin Backhouse. -- | 'Control.Monad.MultiPass.ThreadContext.CounterTC' defines a -- thread context which is used to generate a series of unique -- consecutive numbers. It has two passes. The first pass, -- 'CounterTC1', creates a log of the number of new values that need -- to be generated in each thread. The second pass, 'CounterTC2', uses -- the log to compute the correct starting value for each thread, so -- that the threads appear to be incrementing a single global counter, -- even though they are operating concurrently. module Control.Monad.MultiPass.ThreadContext.CounterTC ( -- * First Pass CounterTC1 , counterVal1, incrCounterTC1, addkCounterTC1 , newCounterTC1 -- * Second Pass , CounterTC2 , counterVal2, incrCounterTC2, addkCounterTC2 , newCounterTC2, resetCounterTC2 ) where import Control.Monad.State.Strict import Control.Monad.ST2 import Control.Monad.MultiPass data CounterLogSequential i r = CounterLogSequential !i !(ST2RArray r Int (CounterLogParallel i r)) newtype CounterLogParallel i r = CounterLogParallel (ST2RArray r Int (CounterLogSequential i r)) -- | 'CounterTC1' is used during the first pass. It builds up a log of -- the parallel tasks that were spawned, which is used during the -- second pass to generate a series of unique consecutive numbers. data CounterTC1 i r = CounterTC1 { -- Counter log for the current node. (Accumulates in reverse.) counterLog1 :: ![CounterLogParallel i r] -- | Get the current value of the counter. , counterVal1 :: !i } instance Num i => ThreadContext r w (CounterTC1 i r) where splitThreadContext _ _ _ = return $ CounterTC1 [] 0 mergeThreadContext m getSubNode node = do xs <- newST2Array_ (0,m-1) c <- flip execStateT 0 $ sequence_ [ do subnode0 <- lift $ getSubNode i c <- get let subnode1 = subnode0 { counterVal1 = c } put (c + counterVal1 subnode0) subnode2 <- lift $ mkCounterLogSequential subnode1 lift $ writeST2Array xs i subnode2 | i <- [0 .. m-1] ] let xs' = CounterLogParallel (mkST2RArray xs) return $ CounterTC1 { counterLog1 = xs' : counterLog1 node , counterVal1 = c + counterVal1 node } instance Num i => NextThreadContext r w () gc (CounterTC1 i r) where nextThreadContext _ _ () _ = return newCounterTC1 instance Num i => NextThreadContext r w (CounterTC1 i r) gc (CounterTC1 i r) where nextThreadContext _ _ _ _ = return newCounterTC1 -- | Create a new counter. newCounterTC1 :: Num i => CounterTC1 i r newCounterTC1 = CounterTC1 [] 0 -- | Increment the counter. incrCounterTC1 :: Num i => CounterTC1 i r -> CounterTC1 i r incrCounterTC1 = addkCounterTC1 1 -- | Add @k@ to the counter. addkCounterTC1 :: Num i => i -> CounterTC1 i r -> CounterTC1 i r addkCounterTC1 k (CounterTC1 h c) = CounterTC1 h (c+k) -- The log has been accumulated as a list in reverse order. This -- function reverses the list and converts it to a read-only array. mkCounterLogSequential :: CounterTC1 i r -> ST2 r w (CounterLogSequential i r) mkCounterLogSequential (CounterTC1 xs c) = let n = length xs in do xs' <- newST2Array_ (0,n-1) sequence_ [ writeST2Array xs' (n-i) x | (x,i) <- zip xs [1 .. n] ] return (CounterLogSequential c (mkST2RArray xs')) -- | 'CounterTC2' is used during the second pass. It uses the log -- which was computed by 'CounterTC1' to generate a series of unique -- consecutive numbers. data CounterTC2 i r = CounterTC2 { counterLog2 :: !(ST2RArray r Int (CounterLogParallel i r)) -- Current index in the counter log. , counterIdx2 :: !Int -- | Get the current value of the counter. , counterVal2 :: !i } -- | Increment the counter. incrCounterTC2 :: Num i => CounterTC2 i r -> CounterTC2 i r incrCounterTC2 = addkCounterTC2 1 -- | Add @k@ to the counter. addkCounterTC2 :: Num i => i -> CounterTC2 i r -> CounterTC2 i r addkCounterTC2 k node = node { counterVal2 = k + counterVal2 node } instance Num i => ThreadContext r w (CounterTC2 i r) where splitThreadContext _ i node = do -- Read the current index of the log. CounterLogParallel ps <- readST2RArray (counterLog2 node) (counterIdx2 node) -- Get the log for thread i. CounterLogSequential k pss <- readST2RArray ps i return $ CounterTC2 { counterLog2 = pss , counterIdx2 = 0 , counterVal2 = k + counterVal2 node } mergeThreadContext m getSubNode node = do -- Get the new counter value from the last sub-node. lastSubNode <- getSubNode (m-1) return $ node { counterIdx2 = 1 + counterIdx2 node , counterVal2 = counterVal2 lastSubNode } instance Num i => NextThreadContext r w (CounterTC1 i r) gc (CounterTC2 i r) where nextThreadContext _ _ node _ = newCounterTC2 node instance Num i => NextThreadContext r w (CounterTC2 i r) gc (CounterTC1 i r) where nextThreadContext _ _ _ _ = return newCounterTC1 instance Num i => NextThreadContext r w (CounterTC2 i r) gc (CounterTC2 i r) where nextThreadContext _ _ node _ = return (resetCounterTC2 node) -- | Convert a 'CounterTC1' to a 'CounterTC2'. newCounterTC2 :: Num i => CounterTC1 i r -> ST2 r w (CounterTC2 i r) newCounterTC2 node = do CounterLogSequential _ pss <- mkCounterLogSequential node return $ CounterTC2 { counterLog2 = pss , counterIdx2 = 0 , counterVal2 = 0 } -- | Reset the counter to zero and rewind to the beginning of the log. resetCounterTC2 :: Num i => CounterTC2 i r -> CounterTC2 i r resetCounterTC2 node = node { counterIdx2 = 0, counterVal2 = 0 }