{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiWayIf #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
module Test.DejaFu.Conc.Internal where
import Control.Exception (Exception,
MaskingState(..),
toException)
import qualified Control.Monad.Catch as E
import qualified Control.Monad.Conc.Class as C
import Data.Foldable (foldrM)
import Data.Functor (void)
import Data.List (nub, partition, sortOn)
import qualified Data.Map.Strict as M
import Data.Maybe (isJust, isNothing)
import Data.Monoid ((<>))
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import GHC.Stack (HasCallStack)
import Test.DejaFu.Conc.Internal.Common
import Test.DejaFu.Conc.Internal.Memory
import Test.DejaFu.Conc.Internal.STM
import Test.DejaFu.Conc.Internal.Threading
import Test.DejaFu.Internal
import Test.DejaFu.Schedule
import Test.DejaFu.Types
type SeqTrace
= Seq (Decision, [(ThreadId, Lookahead)], ThreadAction)
data CResult n g a = CResult
{ finalContext :: Context n g
, finalRef :: C.IORef n (Maybe (Either Condition a))
, finalRestore :: Threads n -> n ()
, finalTrace :: SeqTrace
, finalDecision :: Maybe (ThreadId, ThreadAction)
}
runConcurrency :: (C.MonadConc n, HasCallStack)
=> [Invariant n ()]
-> Bool
-> Scheduler g
-> MemType
-> g
-> IdSource
-> Int
-> ModelConc n a
-> n (CResult n g a)
runConcurrency invariants forSnapshot sched memtype g idsrc caps ma = do
let ctx = Context { cSchedState = g
, cIdSource = idsrc
, cThreads = M.empty
, cWriteBuf = emptyBuffer
, cCaps = caps
, cInvariants = InvariantContext { icActive = invariants, icBlocked = [] }
, cNewInvariants = []
, cCState = initialCState
}
(c, ref) <- runRefCont AStop (Just . Right) (runModelConc ma)
let threads0 = launch' Unmasked initialThread (const c) (cThreads ctx)
threads <- (if C.rtsSupportsBoundThreads then makeBound initialThread else pure) threads0
res <- runThreads forSnapshot sched memtype ref ctx { cThreads = threads }
killAllThreads (finalContext res)
pure res
runConcurrencyWithSnapshot :: (C.MonadConc n, HasCallStack)
=> Scheduler g
-> MemType
-> Context n g
-> (Threads n -> n ())
-> ModelConc n a
-> n (CResult n g a)
runConcurrencyWithSnapshot sched memtype ctx restore ma = do
(c, ref) <- runRefCont AStop (Just . Right) (runModelConc ma)
let threads0 = M.delete initialThread (cThreads ctx)
let threads1 = launch' Unmasked initialThread (const c) threads0
let boundThreads = M.filter (isJust . _bound) threads1
threads2 <- (if C.rtsSupportsBoundThreads then makeBound initialThread else pure) threads1
threads3 <- foldrM makeBound threads2 (M.keys boundThreads)
restore threads3
res <- runThreads False sched memtype ref ctx { cThreads = threads3 }
killAllThreads (finalContext res)
pure res
killAllThreads :: (C.MonadConc n, HasCallStack) => Context n g -> n ()
killAllThreads ctx =
let finalThreads = cThreads ctx
in mapM_ (`kill` finalThreads) (M.keys finalThreads)
data Context n g = Context
{ cSchedState :: g
, cIdSource :: IdSource
, cThreads :: Threads n
, cWriteBuf :: WriteBuffer n
, cCaps :: Int
, cInvariants :: InvariantContext n
, cNewInvariants :: [Invariant n ()]
, cCState :: ConcurrencyState
}
runThreads :: (C.MonadConc n, HasCallStack)
=> Bool
-> Scheduler g
-> MemType
-> C.IORef n (Maybe (Either Condition a))
-> Context n g
-> n (CResult n g a)
runThreads forSnapshot sched memtype ref = schedule (const $ pure ()) Seq.empty Nothing where
die reason finalR finalT finalD finalC = do
C.writeIORef ref (Just $ Left reason)
stop finalR finalT finalD finalC
stop finalR finalT finalD finalC = pure CResult
{ finalContext = finalC
, finalRef = ref
, finalRestore = finalR
, finalTrace = finalT
, finalDecision = finalD
}
schedule restore sofar prior ctx
| isTerminated = stop restore sofar prior ctx
| isDeadlocked = die Deadlock restore sofar prior ctx
| otherwise =
let ctx' = ctx { cSchedState = g' }
in case choice of
Just chosen -> case M.lookup chosen threadsc of
Just thread
| isBlocked thread -> E.throwM ScheduledBlockedThread
| otherwise ->
let decision
| Just chosen == (fst <$> prior) = Continue
| (fst <$> prior) `notElem` map (Just . fst) runnable' = Start chosen
| otherwise = SwitchTo chosen
alternatives = filter (\(t, _) -> t /= chosen) runnable'
in step decision alternatives chosen thread restore sofar prior ctx'
Nothing -> E.throwM ScheduledMissingThread
Nothing -> die Abort restore sofar prior ctx'
where
(choice, g') = scheduleThread sched prior (efromList runnable') (cCState ctx) (cSchedState ctx)
runnable' = [(t, lookahead (_continuation a)) | (t, a) <- sortOn fst $ M.assocs runnable]
runnable = M.filter (not . isBlocked) threadsc
threadsc = addCommitThreads (cWriteBuf ctx) threads
threads = cThreads ctx
isBlocked = isJust . _blocking
isTerminated = initialThread `notElem` M.keys threads
isDeadlocked = M.null (M.filter (not . isBlocked) threads)
step decision alternatives chosen thread restore sofar prior ctx = do
(res, actOrTrc, actionSnap) <- stepThread
forSnapshot
(isNothing prior)
sched
memtype
chosen
(_continuation thread)
ctx
let sofar' = sofar <> getTrc actOrTrc
let prior' = getPrior actOrTrc
let restore' threads' =
if forSnapshot
then restore threads' >> actionSnap threads'
else restore threads'
let ctx' = fixContext memtype chosen actOrTrc res ctx
case res of
Succeeded _ -> checkInvariants (cInvariants ctx') >>= \case
Right ic ->
schedule restore' sofar' prior' ctx' { cInvariants = ic }
Left exc ->
die (InvariantFailure exc) restore' sofar' prior' ctx'
Failed failure ->
die failure restore' sofar' prior' ctx'
where
getTrc a = Seq.singleton (decision, alternatives, a)
getPrior a = Just (chosen, a)
fixContext :: MemType -> ThreadId -> ThreadAction -> What n g -> Context n g -> Context n g
fixContext memtype tid act what ctx0 = fixContextCommon $ case what of
Succeeded ctx@Context{..} -> ctx
{ cThreads =
if (interruptible <$> M.lookup tid cThreads) /= Just False
then unblockWaitingOn tid cThreads
else cThreads
}
_ -> ctx0
where
fixContextCommon ctx@Context{..} = ctx
{ cThreads = delCommitThreads cThreads
, cInvariants = unblockInvariants act cInvariants
, cCState = updateCState memtype cCState tid act
}
unblockWaitingOn :: ThreadId -> Threads n -> Threads n
unblockWaitingOn tid = fmap $ \thread -> case _blocking thread of
Just (OnMask t) | t == tid -> thread { _blocking = Nothing }
_ -> thread
data What n g
= Succeeded (Context n g)
| Failed Condition
stepThread :: (C.MonadConc n, HasCallStack)
=> Bool
-> Bool
-> Scheduler g
-> MemType
-> ThreadId
-> Action n
-> Context n g
-> n (What n g, ThreadAction, Threads n -> n ())
stepThread _ _ _ _ tid (AFork n a b) = \ctx@Context{..} -> pure $
let (idSource', newtid) = nextTId n cIdSource
threads' = launch tid newtid a cThreads
in ( Succeeded ctx { cThreads = goto (b newtid) tid threads', cIdSource = idSource' }
, Fork newtid
, const (pure ())
)
stepThread _ _ _ _ tid (AForkOS n a b) = \ctx@Context{..} -> do
let (idSource', newtid) = nextTId n cIdSource
let threads' = launch tid newtid a cThreads
threads'' <- makeBound newtid threads'
pure ( Succeeded ctx { cThreads = goto (b newtid) tid threads'', cIdSource = idSource' }
, ForkOS newtid
, const (pure ())
)
stepThread _ _ _ _ tid (AIsBound c) = \ctx@Context{..} -> do
let isBound = isJust . _bound $ elookup tid cThreads
pure ( Succeeded ctx { cThreads = goto (c isBound) tid cThreads }
, IsCurrentThreadBound isBound
, const (pure ())
)
stepThread _ _ _ _ tid (AMyTId c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto (c tid) tid cThreads }
, MyThreadId
, const (pure ())
)
stepThread _ _ _ _ tid (AGetNumCapabilities c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto (c cCaps) tid cThreads }
, GetNumCapabilities cCaps
, const (pure ())
)
stepThread _ _ _ _ tid (ASetNumCapabilities i c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads, cCaps = i }
, SetNumCapabilities i
, const (pure ())
)
stepThread _ _ _ _ tid (AYield c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Yield
, const (pure ())
)
stepThread _ _ _ _ tid (ADelay n c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, ThreadDelay n
, const (pure ())
)
stepThread _ _ _ _ tid (ANewMVar n c) = \ctx@Context{..} -> do
let (idSource', newmvid) = nextMVId n cIdSource
ref <- C.newIORef Nothing
let mvar = ModelMVar newmvid ref
pure ( Succeeded ctx { cThreads = goto (c mvar) tid cThreads, cIdSource = idSource' }
, NewMVar newmvid
, const (C.writeIORef ref Nothing)
)
stepThread _ _ _ _ tid (APutMVar mvar@ModelMVar{..} a c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- putIntoMVar mvar a c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, if success then PutMVar mvarId woken else BlockedPutMVar mvarId
, const effect
)
stepThread _ _ _ _ tid (ATryPutMVar mvar@ModelMVar{..} a c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- tryPutIntoMVar mvar a c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, TryPutMVar mvarId success woken
, const effect
)
stepThread _ _ _ _ tid (AReadMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', _, _) <- readFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, if success then ReadMVar mvarId else BlockedReadMVar mvarId
, const (pure ())
)
stepThread _ _ _ _ tid (ATryReadMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', _, _) <- tryReadFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, TryReadMVar mvarId success
, const (pure ())
)
stepThread _ _ _ _ tid (ATakeMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- takeFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, if success then TakeMVar mvarId woken else BlockedTakeMVar mvarId
, const effect
)
stepThread _ _ _ _ tid (ATryTakeMVar mvar@ModelMVar{..} c) = synchronised $ \ctx@Context{..} -> do
(success, threads', woken, effect) <- tryTakeFromMVar mvar c tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, TryTakeMVar mvarId success woken
, const effect
)
stepThread _ _ _ _ tid (ANewIORef n a c) = \ctx@Context{..} -> do
let (idSource', newiorid) = nextIORId n cIdSource
let val = (M.empty, 0, a)
ioref <- C.newIORef val
let ref = ModelIORef newiorid ioref
pure ( Succeeded ctx { cThreads = goto (c ref) tid cThreads, cIdSource = idSource' }
, NewIORef newiorid
, const (C.writeIORef ioref val)
)
stepThread _ _ _ _ tid (AReadIORef ref@ModelIORef{..} c) = \ctx@Context{..} -> do
val <- readIORef ref tid
pure ( Succeeded ctx { cThreads = goto (c val) tid cThreads }
, ReadIORef iorefId
, const (pure ())
)
stepThread _ _ _ _ tid (AReadIORefCas ref@ModelIORef{..} c) = \ctx@Context{..} -> do
tick <- readForTicket ref tid
pure ( Succeeded ctx { cThreads = goto (c tick) tid cThreads }
, ReadIORefCas iorefId
, const (pure ())
)
stepThread _ _ _ _ tid (AModIORef ref@ModelIORef{..} f c) = synchronised $ \ctx@Context{..} -> do
(new, val) <- f <$> readIORef ref tid
effect <- writeImmediate ref new
pure ( Succeeded ctx { cThreads = goto (c val) tid cThreads }
, ModIORef iorefId
, const effect
)
stepThread _ _ _ _ tid (AModIORefCas ref@ModelIORef{..} f c) = synchronised $ \ctx@Context{..} -> do
tick@(ModelTicket _ _ old) <- readForTicket ref tid
let (new, val) = f old
(_, _, effect) <- casIORef ref tid tick new
pure ( Succeeded ctx { cThreads = goto (c val) tid cThreads }
, ModIORefCas iorefId
, const effect
)
stepThread _ _ _ memtype tid (AWriteIORef ref@ModelIORef{..} a c) = \ctx@Context{..} -> case memtype of
SequentialConsistency -> do
effect <- writeImmediate ref a
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, WriteIORef iorefId
, const effect
)
TotalStoreOrder -> do
wb' <- bufferWrite cWriteBuf (tid, Nothing) ref a
pure ( Succeeded ctx { cThreads = goto c tid cThreads, cWriteBuf = wb' }
, WriteIORef iorefId
, const (pure ())
)
PartialStoreOrder -> do
wb' <- bufferWrite cWriteBuf (tid, Just iorefId) ref a
pure ( Succeeded ctx { cThreads = goto c tid cThreads, cWriteBuf = wb' }
, WriteIORef iorefId
, const (pure ())
)
stepThread _ _ _ _ tid (ACasIORef ref@ModelIORef{..} tick a c) = synchronised $ \ctx@Context{..} -> do
(suc, tick', effect) <- casIORef ref tid tick a
pure ( Succeeded ctx { cThreads = goto (c (suc, tick')) tid cThreads }
, CasIORef iorefId suc
, const effect
)
stepThread _ _ _ memtype _ (ACommit t c) = \ctx@Context{..} -> do
wb' <- case memtype of
SequentialConsistency ->
fatal "stepThread.ACommit" "Attempting to commit under SequentialConsistency"
TotalStoreOrder ->
commitWrite cWriteBuf (t, Nothing)
PartialStoreOrder ->
commitWrite cWriteBuf (t, Just c)
pure ( Succeeded ctx { cWriteBuf = wb' }
, CommitIORef t c
, const (pure ())
)
stepThread _ _ _ _ tid (AAtom stm c) = synchronised $ \ctx@Context{..} -> do
let transaction = runTransaction stm cIdSource
let effect = const (void transaction)
(res, idSource', trace) <- transaction
case res of
Success _ written val -> do
let (threads', woken) = wake (OnTVar written) cThreads
pure ( Succeeded ctx { cThreads = goto (c val) tid threads', cIdSource = idSource' }
, STM trace woken
, effect
)
Retry touched -> do
let threads' = block (OnTVar touched) tid cThreads
pure ( Succeeded ctx { cThreads = threads', cIdSource = idSource'}
, BlockedSTM trace
, effect
)
Exception e -> do
let act = STM trace []
res' <- stepThrow (const act) tid e ctx
pure $ case res' of
(Succeeded ctx', _, effect') -> (Succeeded ctx' { cIdSource = idSource' }, act, effect')
(Failed err, _, effect') -> (Failed err, act, effect')
stepThread _ _ _ _ tid (ALift na) = \ctx@Context{..} -> do
let effect threads = runLiftedAct tid threads na
a <- effect cThreads
pure (Succeeded ctx { cThreads = goto a tid cThreads }
, LiftIO
, void <$> effect
)
stepThread _ _ _ _ tid (AThrow e) = stepThrow Throw tid e
stepThread _ _ _ _ tid (AThrowTo t e c) = synchronised $ \ctx@Context{..} ->
let threads' = goto c tid cThreads
blocked = block (OnMask t) tid cThreads
in case M.lookup t cThreads of
Just thread
| interruptible thread || t == tid -> stepThrow (ThrowTo t) t e ctx { cThreads = threads' }
| otherwise -> pure
( Succeeded ctx { cThreads = blocked }
, BlockedThrowTo t
, const (pure ())
)
Nothing -> pure
(Succeeded ctx { cThreads = threads' }
, ThrowTo t False
, const (pure ())
)
stepThread _ _ _ _ tid (ACatching h ma c) = \ctx@Context{..} -> pure $
let a = runModelConc ma (APopCatching . c)
e exc = runModelConc (h exc) c
in ( Succeeded ctx { cThreads = goto a tid (catching e tid cThreads) }
, Catching
, const (pure ())
)
stepThread _ _ _ _ tid (APopCatching a) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto a tid (uncatching tid cThreads) }
, PopCatching
, const (pure ())
)
stepThread _ _ _ _ tid (AMasking m ma c) = \ctx@Context{..} -> pure $
let resetMask typ ms = ModelConc $ \k -> AResetMask typ True ms $ k ()
umask mb = resetMask True m' >> mb >>= \b -> resetMask False m >> pure b
m' = _masking $ elookup tid cThreads
a = runModelConc (ma umask) (AResetMask False False m' . c)
in ( Succeeded ctx { cThreads = goto a tid (mask m tid cThreads) }
, SetMasking False m
, const (pure ())
)
stepThread _ _ _ _ tid (AResetMask b1 b2 m c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid (mask m tid cThreads) }
, (if b1 then SetMasking else ResetMasking) b2 m
, const (pure ())
)
stepThread _ _ _ _ tid (AReturn c) = \ctx@Context{..} ->
pure ( Succeeded ctx { cThreads = goto c tid cThreads }
, Return
, const (pure ())
)
stepThread _ _ _ _ tid (AStop na) = \ctx@Context{..} -> do
na
threads' <- kill tid cThreads
pure ( Succeeded ctx { cThreads = threads' }
, Stop
, const (pure ())
)
stepThread _ _ _ _ tid (ANewInvariant inv c) = \ctx@Context{..} ->
pure ( Succeeded ctx
{ cThreads = goto c tid cThreads
, cNewInvariants = inv : cNewInvariants
}
, RegisterInvariant
, const (pure ())
)
stepThrow :: (C.MonadConc n, Exception e)
=> (Bool -> ThreadAction)
-> ThreadId
-> e
-> Context n g
-> n (What n g, ThreadAction, Threads n -> n ())
stepThrow act tid e ctx@Context{..} = case propagate some tid cThreads of
Just ts' -> pure
( Succeeded ctx { cThreads = ts' }
, act False
, const (pure ())
)
Nothing
| tid == initialThread -> pure
( Failed (UncaughtException some)
, act True
, const (pure ())
)
| otherwise -> do
ts' <- kill tid cThreads
pure ( Succeeded ctx { cThreads = ts' }
, act True
, const (pure ())
)
where
some = toException e
synchronised :: C.MonadConc n
=> (Context n g -> n x)
-> Context n g
-> n x
synchronised ma ctx@Context{..} = do
writeBarrier cWriteBuf
ma ctx { cWriteBuf = emptyBuffer }
data InvariantContext n = InvariantContext
{ icActive :: [Invariant n ()]
, icBlocked :: [(Invariant n (), ([IORefId], [MVarId], [TVarId]))]
}
unblockInvariants :: ThreadAction -> InvariantContext n -> InvariantContext n
unblockInvariants act ic = InvariantContext active blocked where
active = map fst unblocked ++ icActive ic
(unblocked, blocked) = (`partition` icBlocked ic) $
\(_, (ioridsB, mvidsB, tvidsB)) ->
maybe False (`elem` ioridsB) (iorefOf (simplifyAction act)) ||
maybe False (`elem` mvidsB) (mvarOf (simplifyAction act)) ||
any (`elem` tvidsB) (tvarsOf act)
checkInvariants :: C.MonadConc n
=> InvariantContext n
-> n (Either E.SomeException (InvariantContext n))
checkInvariants ic = go (icActive ic) >>= \case
Right blocked -> pure (Right (InvariantContext [] (blocked ++ icBlocked ic)))
Left exc -> pure (Left exc)
where
go (inv:is) = checkInvariant inv >>= \case
Right o -> fmap ((inv,o):) <$> go is
Left exc -> pure (Left exc)
go [] = pure (Right [])
checkInvariant :: C.MonadConc n
=> Invariant n a
-> n (Either E.SomeException ([IORefId], [MVarId], [TVarId]))
checkInvariant inv = doInvariant inv >>= \case
(Right _, iorefs, mvars, tvars) -> pure (Right (iorefs, mvars, tvars))
(Left exc, _, _, _) -> pure (Left exc)
doInvariant :: C.MonadConc n
=> Invariant n a
-> n (Either E.SomeException a, [IORefId], [MVarId], [TVarId])
doInvariant inv = do
(c, ref) <- runRefCont IStop (Just . Right) (runInvariant inv)
(iorefs, mvars, tvars) <- go ref c [] [] []
val <- C.readIORef ref
pure (efromJust val, nub iorefs, nub mvars, nub tvars)
where
go ref act iorefs mvars tvars = do
(res, iorefs', mvars', tvars') <- stepInvariant act
let newIORefs = iorefs' ++ iorefs
let newMVars = mvars' ++ mvars
let newTVars = tvars' ++ tvars
case res of
Right (Just act') ->
go ref act' newIORefs newMVars newTVars
Right Nothing ->
pure (newIORefs, newMVars, newTVars)
Left exc -> do
C.writeIORef ref (Just (Left exc))
pure (newIORefs, newMVars, newTVars)
stepInvariant :: C.MonadConc n
=> IAction n
-> n (Either E.SomeException (Maybe (IAction n)), [IORefId], [MVarId], [TVarId])
stepInvariant (IInspectIORef ioref@ModelIORef{..} k) = do
a <- readIORefGlobal ioref
pure (Right (Just (k a)), [iorefId], [], [])
stepInvariant (IInspectMVar ModelMVar{..} k) = do
a <- C.readIORef mvarRef
pure (Right (Just (k a)), [], [mvarId], [])
stepInvariant (IInspectTVar ModelTVar{..} k) = do
a <- C.readIORef tvarRef
pure (Right (Just (k a)), [], [], [tvarId])
stepInvariant (ICatch h nx k) = doInvariant nx >>= \case
(Right a, iorefs, mvars, tvars) ->
pure (Right (Just (k a)), iorefs, mvars, tvars)
(Left exc, iorefs, mvars, tvars) -> case E.fromException exc of
Just exc' -> doInvariant (h exc') >>= \case
(Right a, iorefs', mvars', tvars') ->
pure (Right (Just (k a)), iorefs' ++ iorefs, mvars' ++ mvars, tvars' ++ tvars)
(Left exc'', iorefs', mvars', tvars') ->
pure (Left exc'', iorefs' ++ iorefs, mvars' ++ mvars, tvars' ++ tvars)
Nothing -> pure (Left exc, iorefs, mvars, tvars)
stepInvariant (IThrow exc) =
pure (Left (toException exc), [], [], [])
stepInvariant (IStop finalise) = do
finalise
pure (Right Nothing, [], [], [])