{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
module Test.DejaFu.Internal where
import Control.DeepSeq (NFData(..))
import Control.Exception (MaskingState(..))
import qualified Control.Monad.Conc.Class as C
import Data.List.NonEmpty (NonEmpty(..))
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe (fromMaybe)
import Data.Set (Set)
import qualified Data.Set as S
import GHC.Generics (Generic)
import GHC.Stack (HasCallStack, withFrozenCallStack)
import System.Random (RandomGen)
import Test.DejaFu.Types
data Settings n a = Settings
{ _way :: Way
, _lengthBound :: Maybe LengthBound
, _memtype :: MemType
, _discard :: Maybe (Either Condition a -> Maybe Discard)
, _debugShow :: Maybe (a -> String)
, _debugPrint :: Maybe (String -> n ())
, _debugFatal :: Bool
, _earlyExit :: Maybe (Either Condition a -> Bool)
, _equality :: Maybe (a -> a -> Bool)
, _simplify :: Bool
, _safeIO :: Bool
, _showAborts :: Bool
}
data Way where
Systematic :: Bounds -> Way
Randomly :: RandomGen g => (g -> (Int, g)) -> g -> Int -> Way
instance Show Way where
show (Systematic bs) = "Systematic (" ++ show bs ++ ")"
show (Randomly _ _ n) = "Randomly <f> <gen> " ++ show n
data IdSource = IdSource
{ _iorids :: (Int, [String])
, _mvids :: (Int, [String])
, _tvids :: (Int, [String])
, _tids :: (Int, [String])
} deriving (Eq, Ord, Show, Generic, NFData)
nextIORId :: String -> IdSource -> (IdSource, IORefId)
nextIORId name idsource =
let (iorid, iorids') = nextId name (_iorids idsource)
in (idsource { _iorids = iorids' }, IORefId iorid)
nextMVId :: String -> IdSource -> (IdSource, MVarId)
nextMVId name idsource =
let (mvid, mvids') = nextId name (_mvids idsource)
in (idsource { _mvids = mvids' }, MVarId mvid)
nextTVId :: String -> IdSource -> (IdSource, TVarId)
nextTVId name idsource =
let (tvid, tvids') = nextId name (_tvids idsource)
in (idsource { _tvids = tvids' }, TVarId tvid)
nextTId :: String -> IdSource -> (IdSource, ThreadId)
nextTId name idsource =
let (tid, tids') = nextId name (_tids idsource)
in (idsource { _tids = tids' }, ThreadId tid)
nextId :: String -> (Int, [String]) -> (Id, (Int, [String]))
nextId name (num, used) = (Id newName (num+1), (num+1, newUsed)) where
newName
| null name = Nothing
| occurrences > 0 = Just (name ++ "-" ++ show occurrences)
| otherwise = Just name
newUsed
| null name = used
| otherwise = name : used
occurrences = length (filter (==name) used)
initialIdSource :: IdSource
initialIdSource = IdSource (0, []) (0, []) (0, []) (0, [])
isBlock :: ThreadAction -> Bool
isBlock (BlockedThrowTo _) = True
isBlock (BlockedTakeMVar _) = True
isBlock (BlockedReadMVar _) = True
isBlock (BlockedPutMVar _) = True
isBlock (BlockedSTM _) = True
isBlock _ = False
tvarsOf :: ThreadAction -> Set TVarId
tvarsOf act = tvarsRead act `S.union` tvarsWritten act
tvarsWritten :: ThreadAction -> Set TVarId
tvarsWritten act = S.fromList $ case act of
STM trc _ -> concatMap tvarsOf' trc
BlockedSTM trc -> concatMap tvarsOf' trc
_ -> []
where
tvarsOf' (TNew tv) = [tv]
tvarsOf' (TWrite tv) = [tv]
tvarsOf' (TOrElse ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' (TCatch ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' _ = []
tvarsRead :: ThreadAction -> Set TVarId
tvarsRead act = S.fromList $ case act of
STM trc _ -> concatMap tvarsOf' trc
BlockedSTM trc -> concatMap tvarsOf' trc
_ -> []
where
tvarsOf' (TRead tv) = [tv]
tvarsOf' (TOrElse ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' (TCatch ta tb) = concatMap tvarsOf' (ta ++ fromMaybe [] tb)
tvarsOf' _ = []
rewind :: ThreadAction -> Lookahead
rewind (Fork _) = WillFork
rewind (ForkOS _) = WillForkOS
rewind (IsCurrentThreadBound _) = WillIsCurrentThreadBound
rewind MyThreadId = WillMyThreadId
rewind (GetNumCapabilities _) = WillGetNumCapabilities
rewind (SetNumCapabilities i) = WillSetNumCapabilities i
rewind Yield = WillYield
rewind (ThreadDelay n) = WillThreadDelay n
rewind (NewMVar _) = WillNewMVar
rewind (PutMVar c _) = WillPutMVar c
rewind (BlockedPutMVar c) = WillPutMVar c
rewind (TryPutMVar c _ _) = WillTryPutMVar c
rewind (ReadMVar c) = WillReadMVar c
rewind (BlockedReadMVar c) = WillReadMVar c
rewind (TryReadMVar c _) = WillTryReadMVar c
rewind (TakeMVar c _) = WillTakeMVar c
rewind (BlockedTakeMVar c) = WillTakeMVar c
rewind (TryTakeMVar c _ _) = WillTryTakeMVar c
rewind (NewIORef _) = WillNewIORef
rewind (ReadIORef c) = WillReadIORef c
rewind (ReadIORefCas c) = WillReadIORefCas c
rewind (ModIORef c) = WillModIORef c
rewind (ModIORefCas c) = WillModIORefCas c
rewind (WriteIORef c) = WillWriteIORef c
rewind (CasIORef c _) = WillCasIORef c
rewind (CommitIORef t c) = WillCommitIORef t c
rewind (STM _ _) = WillSTM
rewind (BlockedSTM _) = WillSTM
rewind Catching = WillCatching
rewind PopCatching = WillPopCatching
rewind (Throw _) = WillThrow
rewind (ThrowTo t _) = WillThrowTo t
rewind (BlockedThrowTo t) = WillThrowTo t
rewind (SetMasking b m) = WillSetMasking b m
rewind (ResetMasking b m) = WillResetMasking b m
rewind LiftIO = WillLiftIO
rewind Return = WillReturn
rewind Stop = WillStop
rewind RegisterInvariant = WillRegisterInvariant
willRelease :: Lookahead -> Bool
willRelease WillFork = True
willRelease WillForkOS = True
willRelease WillYield = True
willRelease (WillThreadDelay _) = True
willRelease (WillPutMVar _) = True
willRelease (WillTryPutMVar _) = True
willRelease (WillReadMVar _) = True
willRelease (WillTakeMVar _) = True
willRelease (WillTryTakeMVar _) = True
willRelease WillSTM = True
willRelease WillThrow = True
willRelease (WillSetMasking _ _) = True
willRelease (WillResetMasking _ _) = True
willRelease WillStop = True
willRelease _ = False
data ActionType =
UnsynchronisedRead IORefId
| UnsynchronisedWrite IORefId
| UnsynchronisedOther
| PartiallySynchronisedCommit IORefId
| PartiallySynchronisedWrite IORefId
| PartiallySynchronisedModify IORefId
| SynchronisedModify IORefId
| SynchronisedRead MVarId
| SynchronisedWrite MVarId
| SynchronisedOther
deriving (Eq, Show, Generic, NFData)
isBarrier :: ActionType -> Bool
isBarrier (SynchronisedModify _) = True
isBarrier (SynchronisedRead _) = True
isBarrier (SynchronisedWrite _) = True
isBarrier SynchronisedOther = True
isBarrier _ = False
isCommit :: ActionType -> IORefId -> Bool
isCommit (PartiallySynchronisedCommit c) r = c == r
isCommit (PartiallySynchronisedWrite c) r = c == r
isCommit (PartiallySynchronisedModify c) r = c == r
isCommit _ _ = False
synchronises :: ActionType -> IORefId -> Bool
synchronises a r = isCommit a r || isBarrier a
iorefOf :: ActionType -> Maybe IORefId
iorefOf (UnsynchronisedRead r) = Just r
iorefOf (UnsynchronisedWrite r) = Just r
iorefOf (SynchronisedModify r) = Just r
iorefOf (PartiallySynchronisedCommit r) = Just r
iorefOf (PartiallySynchronisedWrite r) = Just r
iorefOf (PartiallySynchronisedModify r) = Just r
iorefOf _ = Nothing
mvarOf :: ActionType -> Maybe MVarId
mvarOf (SynchronisedRead c) = Just c
mvarOf (SynchronisedWrite c) = Just c
mvarOf _ = Nothing
tidsOf :: ThreadAction -> Set ThreadId
tidsOf (Fork tid) = S.singleton tid
tidsOf (ForkOS tid) = S.singleton tid
tidsOf (PutMVar _ tids) = S.fromList tids
tidsOf (TryPutMVar _ _ tids) = S.fromList tids
tidsOf (TakeMVar _ tids) = S.fromList tids
tidsOf (TryTakeMVar _ _ tids) = S.fromList tids
tidsOf (CommitIORef tid _) = S.singleton tid
tidsOf (STM _ tids) = S.fromList tids
tidsOf (ThrowTo tid _) = S.singleton tid
tidsOf (BlockedThrowTo tid) = S.singleton tid
tidsOf _ = S.empty
simplifyAction :: ThreadAction -> ActionType
simplifyAction = simplifyLookahead . rewind
simplifyLookahead :: Lookahead -> ActionType
simplifyLookahead (WillPutMVar c) = SynchronisedWrite c
simplifyLookahead (WillTryPutMVar c) = SynchronisedWrite c
simplifyLookahead (WillReadMVar c) = SynchronisedRead c
simplifyLookahead (WillTryReadMVar c) = SynchronisedRead c
simplifyLookahead (WillTakeMVar c) = SynchronisedRead c
simplifyLookahead (WillTryTakeMVar c) = SynchronisedRead c
simplifyLookahead (WillReadIORef r) = UnsynchronisedRead r
simplifyLookahead (WillReadIORefCas r) = UnsynchronisedRead r
simplifyLookahead (WillModIORef r) = SynchronisedModify r
simplifyLookahead (WillModIORefCas r) = PartiallySynchronisedModify r
simplifyLookahead (WillWriteIORef r) = UnsynchronisedWrite r
simplifyLookahead (WillCasIORef r) = PartiallySynchronisedWrite r
simplifyLookahead (WillCommitIORef _ r) = PartiallySynchronisedCommit r
simplifyLookahead WillSTM = SynchronisedOther
simplifyLookahead (WillThrowTo _) = SynchronisedOther
simplifyLookahead _ = UnsynchronisedOther
initialCState :: ConcurrencyState
initialCState = ConcurrencyState M.empty S.empty M.empty
updateCState :: MemType -> ConcurrencyState -> ThreadId -> ThreadAction -> ConcurrencyState
updateCState memtype cstate tid act = ConcurrencyState
{ concIOState = updateIOState memtype act $ concIOState cstate
, concMVState = updateMVState act $ concMVState cstate
, concMaskState = updateMaskState tid act $ concMaskState cstate
}
updateIOState :: MemType -> ThreadAction -> Map IORefId Int -> Map IORefId Int
updateIOState SequentialConsistency _ = const M.empty
updateIOState _ (CommitIORef _ r) = (`M.alter` r) $ \case
Just 1 -> Nothing
Just n -> Just (n-1)
Nothing -> Nothing
updateIOState _ (WriteIORef r) = M.insertWith (+) r 1
updateIOState _ ta
| isBarrier $ simplifyAction ta = const M.empty
| otherwise = id
updateMVState :: ThreadAction -> Set MVarId -> Set MVarId
updateMVState (PutMVar mvid _) = S.insert mvid
updateMVState (TryPutMVar mvid True _) = S.insert mvid
updateMVState (TakeMVar mvid _) = S.delete mvid
updateMVState (TryTakeMVar mvid True _) = S.delete mvid
updateMVState _ = id
updateMaskState :: ThreadId -> ThreadAction -> Map ThreadId MaskingState -> Map ThreadId MaskingState
updateMaskState tid (Fork tid2) = \masks -> case M.lookup tid masks of
Just ms -> M.insert tid2 ms masks
Nothing -> masks
updateMaskState tid (SetMasking _ ms) = M.insert tid ms
updateMaskState tid (ResetMasking _ ms) = M.insert tid ms
updateMaskState tid (Throw True) = M.delete tid
updateMaskState _ (ThrowTo tid True) = M.delete tid
updateMaskState tid Stop = M.delete tid
updateMaskState _ _ = id
etail :: HasCallStack => [a] -> [a]
etail (_:xs) = xs
etail _ = withFrozenCallStack $ fatal "tail: empty list"
eidx :: HasCallStack => [a] -> Int -> a
eidx xs i
| i < length xs = xs !! i
| otherwise = withFrozenCallStack $ fatal "(!!): index too large"
efromJust :: HasCallStack => Maybe a -> a
efromJust (Just x) = x
efromJust _ = withFrozenCallStack $ fatal "fromJust: Nothing"
efromList :: HasCallStack => [a] -> NonEmpty a
efromList (x:xs) = x:|xs
efromList _ = withFrozenCallStack $ fatal "fromList: empty list"
efromRight :: HasCallStack => Either a b -> b
efromRight (Right b) = b
efromRight _ = withFrozenCallStack $ fatal "fromRight: Left"
efromLeft :: HasCallStack => Either a b -> a
efromLeft (Left a) = a
efromLeft _ = withFrozenCallStack $ fatal "fromLeft: Right"
eadjust :: (Ord k, Show k, HasCallStack) => (v -> v) -> k -> M.Map k v -> M.Map k v
eadjust f k m = case M.lookup k m of
Just v -> M.insert k (f v) m
Nothing -> withFrozenCallStack $ fatal ("adjust: key '" ++ show k ++ "' not found")
einsert :: (Ord k, Show k, HasCallStack) => k -> v -> M.Map k v -> M.Map k v
einsert k v m
| M.member k m = withFrozenCallStack $ fatal ("insert: key '" ++ show k ++ "' already present")
| otherwise = M.insert k v m
elookup :: (Ord k, Show k, HasCallStack) => k -> M.Map k v -> v
elookup k =
fromMaybe (withFrozenCallStack $ fatal ("lookup: key '" ++ show k ++ "' not found")) .
M.lookup k
fatal :: HasCallStack => String -> a
fatal msg = withFrozenCallStack $ error ("(dejafu) " ++ msg)
runRefCont :: C.MonadConc n
=> (n () -> x)
-> (a -> Maybe b)
-> ((a -> x) -> x)
-> n (x, C.IORef n (Maybe b))
runRefCont act f k = do
ref <- C.newIORef Nothing
let c = k (act . C.writeIORef ref . f)
pure (c, ref)