{-# LANGUAGE CPP #-}
{-# LANGUAGE NumericUnderscores #-}
{-# LANGUAGE TypeFamilies #-}
#if defined(__GLASGOW_HASKELL__) && \
!defined(mingw32_HOST_OS) && \
!defined(__GHCJS__) && \
!defined(js_HOST_ARCH) && \
!defined(wasm32_HOST_ARCH)
#define GHC_TIMERS_API
#endif
module Control.Monad.Class.MonadTimer.NonStandard
( TimeoutState (..)
, newTimeout
, readTimeout
, cancelTimeout
, awaitTimeout
, NewTimeout
, ReadTimeout
, CancelTimeout
, AwaitTimeout
) where
import Control.Concurrent.STM qualified as STM
#ifndef GHC_TIMERS_API
import Control.Monad (when)
#endif
import Control.Monad.Class.MonadSTM
#ifdef GHC_TIMERS_API
import GHC.Event qualified as GHC (TimeoutKey, getSystemTimerManager,
registerTimeout, unregisterTimeout)
#else
import GHC.Conc.IO qualified as GHC (registerDelay)
#endif
data TimeoutState = TimeoutPending | TimeoutFired | TimeoutCancelled
deriving (TimeoutState -> TimeoutState -> Bool
(TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool) -> Eq TimeoutState
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TimeoutState -> TimeoutState -> Bool
== :: TimeoutState -> TimeoutState -> Bool
$c/= :: TimeoutState -> TimeoutState -> Bool
/= :: TimeoutState -> TimeoutState -> Bool
Eq, Eq TimeoutState
Eq TimeoutState =>
(TimeoutState -> TimeoutState -> Ordering)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> Bool)
-> (TimeoutState -> TimeoutState -> TimeoutState)
-> (TimeoutState -> TimeoutState -> TimeoutState)
-> Ord TimeoutState
TimeoutState -> TimeoutState -> Bool
TimeoutState -> TimeoutState -> Ordering
TimeoutState -> TimeoutState -> TimeoutState
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: TimeoutState -> TimeoutState -> Ordering
compare :: TimeoutState -> TimeoutState -> Ordering
$c< :: TimeoutState -> TimeoutState -> Bool
< :: TimeoutState -> TimeoutState -> Bool
$c<= :: TimeoutState -> TimeoutState -> Bool
<= :: TimeoutState -> TimeoutState -> Bool
$c> :: TimeoutState -> TimeoutState -> Bool
> :: TimeoutState -> TimeoutState -> Bool
$c>= :: TimeoutState -> TimeoutState -> Bool
>= :: TimeoutState -> TimeoutState -> Bool
$cmax :: TimeoutState -> TimeoutState -> TimeoutState
max :: TimeoutState -> TimeoutState -> TimeoutState
$cmin :: TimeoutState -> TimeoutState -> TimeoutState
min :: TimeoutState -> TimeoutState -> TimeoutState
Ord, Int -> TimeoutState -> ShowS
[TimeoutState] -> ShowS
TimeoutState -> String
(Int -> TimeoutState -> ShowS)
-> (TimeoutState -> String)
-> ([TimeoutState] -> ShowS)
-> Show TimeoutState
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TimeoutState -> ShowS
showsPrec :: Int -> TimeoutState -> ShowS
$cshow :: TimeoutState -> String
show :: TimeoutState -> String
$cshowList :: [TimeoutState] -> ShowS
showList :: [TimeoutState] -> ShowS
Show)
#ifdef GHC_TIMERS_API
data Timeout = TimeoutIO !(STM.TVar TimeoutState) !GHC.TimeoutKey
#else
data Timeout = TimeoutIO !(STM.TVar (STM.TVar Bool)) !(STM.TVar Bool)
#endif
newTimeout :: NewTimeout IO Timeout
type NewTimeout m timeout = Int -> m timeout
readTimeout :: ReadTimeout IO Timeout
type ReadTimeout m timeout = timeout -> STM m TimeoutState
cancelTimeout :: CancelTimeout IO Timeout
type CancelTimeout m timeout = timeout -> m ()
awaitTimeout :: AwaitTimeout IO Timeout
type AwaitTimeout m timeout = timeout -> STM m Bool
#ifdef GHC_TIMERS_API
readTimeout :: ReadTimeout IO Timeout
readTimeout (TimeoutIO TVar TimeoutState
var TimeoutKey
_key) = TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
newTimeout :: NewTimeout IO Timeout
newTimeout = \Int
d -> do
TVar TimeoutState
var <- TimeoutState -> IO (TVar TimeoutState)
forall a. a -> IO (TVar a)
STM.newTVarIO TimeoutState
TimeoutPending
TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
TimeoutKey
key <- TimerManager -> Int -> TimeoutCallback -> IO TimeoutKey
GHC.registerTimeout TimerManager
mgr Int
d (STM () -> TimeoutCallback
forall a. STM a -> IO a
STM.atomically (TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var))
Timeout -> IO Timeout
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (TVar TimeoutState -> TimeoutKey -> Timeout
TimeoutIO TVar TimeoutState
var TimeoutKey
key)
where
timeoutAction :: TVar TimeoutState -> STM ()
timeoutAction TVar TimeoutState
var = do
TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
case TimeoutState
x of
TimeoutState
TimeoutPending -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutFired
TimeoutState
TimeoutFired -> String -> STM ()
forall a. HasCallStack => String -> a
error String
"MonadTimer(IO): invariant violation"
TimeoutState
TimeoutCancelled -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
cancelTimeout :: CancelTimeout IO Timeout
cancelTimeout (TimeoutIO TVar TimeoutState
var TimeoutKey
key) = do
STM () -> TimeoutCallback
forall a. STM a -> IO a
STM.atomically (STM () -> TimeoutCallback) -> STM () -> TimeoutCallback
forall a b. (a -> b) -> a -> b
$ do
TimeoutState
x <- TVar TimeoutState -> STM TimeoutState
forall a. TVar a -> STM a
STM.readTVar TVar TimeoutState
var
case TimeoutState
x of
TimeoutState
TimeoutPending -> TVar TimeoutState -> TimeoutState -> STM ()
forall a. TVar a -> a -> STM ()
STM.writeTVar TVar TimeoutState
var TimeoutState
TimeoutCancelled
TimeoutState
TimeoutFired -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TimeoutState
TimeoutCancelled -> () -> STM ()
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
TimerManager
mgr <- IO TimerManager
GHC.getSystemTimerManager
TimerManager -> TimeoutKey -> TimeoutCallback
GHC.unregisterTimeout TimerManager
mgr TimeoutKey
key
#else
readTimeout (TimeoutIO timeoutvarvar cancelvar) = do
canceled <- STM.readTVar cancelvar
fired <- STM.readTVar =<< STM.readTVar timeoutvarvar
case (canceled, fired) of
(True, _) -> return TimeoutCancelled
(_, False) -> return TimeoutPending
(_, True) -> return TimeoutFired
newTimeout d = do
timeoutvar <- GHC.registerDelay d
timeoutvarvar <- STM.newTVarIO timeoutvar
cancelvar <- STM.newTVarIO False
return (TimeoutIO timeoutvarvar cancelvar)
cancelTimeout (TimeoutIO timeoutvarvar cancelvar) =
STM.atomically $ do
fired <- STM.readTVar =<< STM.readTVar timeoutvarvar
when (not fired) $ STM.writeTVar cancelvar True
#endif
awaitTimeout :: AwaitTimeout IO Timeout
awaitTimeout Timeout
t = do TimeoutState
s <- ReadTimeout IO Timeout
readTimeout Timeout
t
case TimeoutState
s of
TimeoutState
TimeoutPending -> STM Bool
STM IO Bool
forall a. STM IO a
forall (m :: * -> *) a. MonadSTM m => STM m a
retry
TimeoutState
TimeoutFired -> Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
TimeoutState
TimeoutCancelled -> Bool -> STM Bool
forall a. a -> STM a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False