{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE RecordWildCards #-}

module System.TimeManager (
    -- ** Types
    Manager,
    TimeoutAction,
    Handle,

    -- ** Manager
    initialize,
    stopManager,
    killManager,
    withManager,
    withManager',

    -- ** Registering a timeout action
    withHandle,
    withHandleKillThread,

    -- ** Control timeout
    tickle,
    pause,
    resume,

    -- ** Low level
    register,
    registerKillThread,
    cancel,

    -- ** Exceptions
    TimeoutThread (..),
) where

import Control.Concurrent (mkWeakThreadId, myThreadId)
import qualified Control.Exception as E
import Control.Reaper
import Data.IORef (IORef)
import qualified Data.IORef as I
import Data.Typeable (Typeable)
import System.Mem.Weak (deRefWeak)

----------------------------------------------------------------

-- | A timeout manager
type Manager = Reaper [Handle] Handle

-- | An action to be performed on timeout.
type TimeoutAction = IO ()

-- | A handle used by a timeout manager.
data Handle = Handle
    { Handle -> Manager
handleManager :: Manager
    , Handle -> IORef TimeoutAction
handleActionRef :: IORef TimeoutAction
    , Handle -> IORef State
handleStateRef :: IORef State
    }

data State
    = Active -- Manager turns it to Inactive.
    | Inactive -- Manager removes it with timeout action.
    | Paused -- Manager does not change it.

----------------------------------------------------------------

-- | Creating timeout manager which works every N micro seconds
--   where N is the first argument.
initialize :: Int -> IO Manager
initialize :: Int -> IO Manager
initialize Int
timeout =
    ReaperSettings [Handle] Handle -> IO Manager
forall workload item.
ReaperSettings workload item -> IO (Reaper workload item)
mkReaper
        ReaperSettings [Handle] Handle
forall item. ReaperSettings [item] item
defaultReaperSettings
            { -- Data.Set cannot be used since 'partition' cannot be used
              -- with 'readIORef`. So, let's just use a list.
              reaperAction = mkListAction prune
            , reaperDelay = timeout
            , reaperThreadName = "WAI timeout manager (Reaper)"
            }
  where
    prune :: Handle -> IO (Maybe Handle)
prune m :: Handle
m@Handle{Manager
IORef TimeoutAction
IORef State
handleManager :: Handle -> Manager
handleActionRef :: Handle -> IORef TimeoutAction
handleStateRef :: Handle -> IORef State
handleManager :: Manager
handleActionRef :: IORef TimeoutAction
handleStateRef :: IORef State
..} = do
        State
state <- IORef State -> (State -> (State, State)) -> IO State
forall a b. IORef a -> (a -> (a, b)) -> IO b
I.atomicModifyIORef' IORef State
handleStateRef (\State
x -> (State -> State
inactivate State
x, State
x))
        case State
state of
            State
Inactive -> do
                TimeoutAction
onTimeout <- IORef TimeoutAction -> IO TimeoutAction
forall a. IORef a -> IO a
I.readIORef IORef TimeoutAction
handleActionRef
                TimeoutAction
onTimeout TimeoutAction -> (SomeException -> TimeoutAction) -> TimeoutAction
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> TimeoutAction
ignoreSync
                Maybe Handle -> IO (Maybe Handle)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Handle
forall a. Maybe a
Nothing
            State
_ -> Maybe Handle -> IO (Maybe Handle)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Handle -> IO (Maybe Handle))
-> Maybe Handle -> IO (Maybe Handle)
forall a b. (a -> b) -> a -> b
$ Handle -> Maybe Handle
forall a. a -> Maybe a
Just Handle
m

    inactivate :: State -> State
inactivate State
Active = State
Inactive
    inactivate State
x = State
x

----------------------------------------------------------------

-- | Stopping timeout manager with onTimeout fired.
stopManager :: Manager -> IO ()
stopManager :: Manager -> TimeoutAction
stopManager Manager
mgr = TimeoutAction -> TimeoutAction
forall a. IO a -> IO a
E.mask_ (Manager -> IO [Handle]
forall workload item. Reaper workload item -> IO workload
reaperStop Manager
mgr IO [Handle] -> ([Handle] -> TimeoutAction) -> TimeoutAction
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Handle -> TimeoutAction) -> [Handle] -> TimeoutAction
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Handle -> TimeoutAction
fire)
  where
    fire :: Handle -> TimeoutAction
fire Handle{Manager
IORef TimeoutAction
IORef State
handleManager :: Handle -> Manager
handleActionRef :: Handle -> IORef TimeoutAction
handleStateRef :: Handle -> IORef State
handleManager :: Manager
handleActionRef :: IORef TimeoutAction
handleStateRef :: IORef State
..} = do
        TimeoutAction
onTimeout <- IORef TimeoutAction -> IO TimeoutAction
forall a. IORef a -> IO a
I.readIORef IORef TimeoutAction
handleActionRef
        TimeoutAction
onTimeout TimeoutAction -> (SomeException -> TimeoutAction) -> TimeoutAction
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`E.catch` SomeException -> TimeoutAction
ignoreSync

-- | Killing timeout manager immediately without firing onTimeout.
killManager :: Manager -> IO ()
killManager :: Manager -> TimeoutAction
killManager = Manager -> TimeoutAction
forall workload item. Reaper workload item -> TimeoutAction
reaperKill

----------------------------------------------------------------

-- | Registering a timeout action and unregister its handle
--   when the body action is finished.
--   'Nothing' is returned on timeout.
withHandle :: Manager -> TimeoutAction -> (Handle -> IO a) -> IO (Maybe a)
withHandle :: forall a.
Manager -> TimeoutAction -> (Handle -> IO a) -> IO (Maybe a)
withHandle Manager
mgr TimeoutAction
onTimeout Handle -> IO a
action =
    (TimeoutThread -> IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle TimeoutThread -> IO (Maybe a)
forall {m :: * -> *} {a}. Monad m => TimeoutThread -> m (Maybe a)
ignore (IO (Maybe a) -> IO (Maybe a)) -> IO (Maybe a) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ IO Handle
-> (Handle -> TimeoutAction)
-> (Handle -> IO (Maybe a))
-> IO (Maybe a)
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Manager -> TimeoutAction -> IO Handle
register Manager
mgr TimeoutAction
onTimeout) Handle -> TimeoutAction
cancel ((Handle -> IO (Maybe a)) -> IO (Maybe a))
-> (Handle -> IO (Maybe a)) -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ \Handle
th ->
        a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> IO a -> IO (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Handle -> IO a
action Handle
th
  where
    ignore :: TimeoutThread -> m (Maybe a)
ignore TimeoutThread
TimeoutThread = Maybe a -> m (Maybe a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing

-- | Registering a timeout action of killing this thread and
--   unregister its handle when the body action is killed or finished.
withHandleKillThread :: Manager -> TimeoutAction -> (Handle -> IO ()) -> IO ()
withHandleKillThread :: Manager
-> TimeoutAction -> (Handle -> TimeoutAction) -> TimeoutAction
withHandleKillThread Manager
mgr TimeoutAction
onTimeout Handle -> TimeoutAction
action =
    (TimeoutThread -> TimeoutAction) -> TimeoutAction -> TimeoutAction
forall e a. Exception e => (e -> IO a) -> IO a -> IO a
E.handle TimeoutThread -> TimeoutAction
forall {m :: * -> *}. Monad m => TimeoutThread -> m ()
ignore (TimeoutAction -> TimeoutAction) -> TimeoutAction -> TimeoutAction
forall a b. (a -> b) -> a -> b
$ IO Handle
-> (Handle -> TimeoutAction)
-> (Handle -> TimeoutAction)
-> TimeoutAction
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket (Manager -> TimeoutAction -> IO Handle
registerKillThread Manager
mgr TimeoutAction
onTimeout) Handle -> TimeoutAction
cancel Handle -> TimeoutAction
action
  where
    ignore :: TimeoutThread -> m ()
ignore TimeoutThread
TimeoutThread = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

----------------------------------------------------------------

-- | Registering a timeout action.
register :: Manager -> TimeoutAction -> IO Handle
register :: Manager -> TimeoutAction -> IO Handle
register Manager
mgr !TimeoutAction
onTimeout = do
    IORef TimeoutAction
actionRef <- TimeoutAction -> IO (IORef TimeoutAction)
forall a. a -> IO (IORef a)
I.newIORef TimeoutAction
onTimeout
    IORef State
stateRef <- State -> IO (IORef State)
forall a. a -> IO (IORef a)
I.newIORef State
Active
    let h :: Handle
h =
            Handle
                { handleManager :: Manager
handleManager = Manager
mgr
                , handleActionRef :: IORef TimeoutAction
handleActionRef = IORef TimeoutAction
actionRef
                , handleStateRef :: IORef State
handleStateRef = IORef State
stateRef
                }
    Manager -> Handle -> TimeoutAction
forall workload item. Reaper workload item -> item -> TimeoutAction
reaperAdd Manager
mgr Handle
h
    Handle -> IO Handle
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h

-- | Removing the 'Handle' from the 'Manager' immediately.
cancel :: Handle -> IO ()
cancel :: Handle -> TimeoutAction
cancel Handle{Manager
IORef TimeoutAction
IORef State
handleManager :: Handle -> Manager
handleActionRef :: Handle -> IORef TimeoutAction
handleStateRef :: Handle -> IORef State
handleManager :: Manager
handleActionRef :: IORef TimeoutAction
handleStateRef :: IORef State
..} = do
    [Handle]
_ <- Manager -> ([Handle] -> [Handle]) -> IO [Handle]
forall workload item.
Reaper workload item -> (workload -> workload) -> IO workload
reaperModify Manager
handleManager [Handle] -> [Handle]
filt
    () -> TimeoutAction
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where
    -- It's very important that this function forces the whole workload so we
    -- don't retain old handles, otherwise disasterous leaks occur.
    filt :: [Handle] -> [Handle]
filt [] = []
    filt (h :: Handle
h@(Handle Manager
_ IORef TimeoutAction
_ IORef State
ref) : [Handle]
hs)
        | IORef State
handleStateRef IORef State -> IORef State -> Bool
forall a. Eq a => a -> a -> Bool
== IORef State
ref = [Handle]
hs
        | Bool
otherwise =
            let !hs' :: [Handle]
hs' = [Handle] -> [Handle]
filt [Handle]
hs
             in Handle
h Handle -> [Handle] -> [Handle]
forall a. a -> [a] -> [a]
: [Handle]
hs'

----------------------------------------------------------------

-- | The asynchronous exception thrown if a thread is registered via
-- 'registerKillThread'.
data TimeoutThread = TimeoutThread
    deriving (Typeable)

instance E.Exception TimeoutThread where
    toException :: TimeoutThread -> SomeException
toException = TimeoutThread -> SomeException
forall e. Exception e => e -> SomeException
E.asyncExceptionToException
    fromException :: SomeException -> Maybe TimeoutThread
fromException = SomeException -> Maybe TimeoutThread
forall e. Exception e => SomeException -> Maybe e
E.asyncExceptionFromException
instance Show TimeoutThread where
    show :: TimeoutThread -> String
show TimeoutThread
TimeoutThread = String
"Thread killed by timeout manager"

-- | Registering a timeout action of killing this thread.
--   'TimeoutThread' is thrown to the thread which called this
--   function on timeout. Catch 'TimeoutThread' if you don't
--   want to leak the asynchronous exception to GHC RTS.
registerKillThread :: Manager -> TimeoutAction -> IO Handle
registerKillThread :: Manager -> TimeoutAction -> IO Handle
registerKillThread Manager
m TimeoutAction
onTimeout = do
    ThreadId
tid <- IO ThreadId
myThreadId
    Weak ThreadId
wtid <- ThreadId -> IO (Weak ThreadId)
mkWeakThreadId ThreadId
tid
    -- First run the timeout action in case the child thread is masked.
    Manager -> TimeoutAction -> IO Handle
register Manager
m (TimeoutAction -> IO Handle) -> TimeoutAction -> IO Handle
forall a b. (a -> b) -> a -> b
$
        TimeoutAction
onTimeout TimeoutAction -> TimeoutAction -> TimeoutAction
forall a b. IO a -> IO b -> IO a
`E.finally` do
            Maybe ThreadId
mtid <- Weak ThreadId -> IO (Maybe ThreadId)
forall v. Weak v -> IO (Maybe v)
deRefWeak Weak ThreadId
wtid
            case Maybe ThreadId
mtid of
                Maybe ThreadId
Nothing -> () -> TimeoutAction
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                Just ThreadId
tid' -> ThreadId -> TimeoutThread -> TimeoutAction
forall e. Exception e => ThreadId -> e -> TimeoutAction
E.throwTo ThreadId
tid' TimeoutThread
TimeoutThread

----------------------------------------------------------------

-- | Setting the state to active.
--   'Manager' turns active to inactive repeatedly.
tickle :: Handle -> IO ()
tickle :: Handle -> TimeoutAction
tickle Handle{Manager
IORef TimeoutAction
IORef State
handleManager :: Handle -> Manager
handleActionRef :: Handle -> IORef TimeoutAction
handleStateRef :: Handle -> IORef State
handleManager :: Manager
handleActionRef :: IORef TimeoutAction
handleStateRef :: IORef State
..} = IORef State -> State -> TimeoutAction
forall a. IORef a -> a -> TimeoutAction
I.writeIORef IORef State
handleStateRef State
Active

-- | Setting the state to paused.
--   'Manager' does not change the value.
pause :: Handle -> IO ()
pause :: Handle -> TimeoutAction
pause Handle{Manager
IORef TimeoutAction
IORef State
handleManager :: Handle -> Manager
handleActionRef :: Handle -> IORef TimeoutAction
handleStateRef :: Handle -> IORef State
handleManager :: Manager
handleActionRef :: IORef TimeoutAction
handleStateRef :: IORef State
..} = IORef State -> State -> TimeoutAction
forall a. IORef a -> a -> TimeoutAction
I.writeIORef IORef State
handleStateRef State
Paused

-- | Setting the paused state to active.
--   This is an alias to 'tickle'.
resume :: Handle -> IO ()
resume :: Handle -> TimeoutAction
resume = Handle -> TimeoutAction
tickle

----------------------------------------------------------------

-- | Call the inner function with a timeout manager.
--   'stopManager' is used after that.
withManager
    :: Int
    -- ^ timeout in microseconds
    -> (Manager -> IO a)
    -> IO a
withManager :: forall a. Int -> (Manager -> IO a) -> IO a
withManager Int
timeout Manager -> IO a
f =
    IO Manager
-> (Manager -> TimeoutAction) -> (Manager -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket
        (Int -> IO Manager
initialize Int
timeout)
        Manager -> TimeoutAction
stopManager
        Manager -> IO a
f

-- | Call the inner function with a timeout manager.
--   'killManager' is used after that.
withManager'
    :: Int
    -- ^ timeout in microseconds
    -> (Manager -> IO a)
    -> IO a
withManager' :: forall a. Int -> (Manager -> IO a) -> IO a
withManager' Int
timeout Manager -> IO a
f =
    IO Manager
-> (Manager -> TimeoutAction) -> (Manager -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
E.bracket
        (Int -> IO Manager
initialize Int
timeout)
        Manager -> TimeoutAction
killManager
        Manager -> IO a
f

----------------------------------------------------------------

isAsyncException :: E.Exception e => e -> Bool
isAsyncException :: forall e. Exception e => e -> Bool
isAsyncException e
e =
    case SomeException -> Maybe SomeAsyncException
forall e. Exception e => SomeException -> Maybe e
E.fromException (e -> SomeException
forall e. Exception e => e -> SomeException
E.toException e
e) of
        Just (E.SomeAsyncException e
_) -> Bool
True
        Maybe SomeAsyncException
Nothing -> Bool
False

ignoreSync :: E.SomeException -> IO ()
ignoreSync :: SomeException -> TimeoutAction
ignoreSync SomeException
se
    | SomeException -> Bool
forall e. Exception e => e -> Bool
isAsyncException SomeException
se = SomeException -> TimeoutAction
forall e a. Exception e => e -> IO a
E.throwIO SomeException
se
    | Bool
otherwise = () -> TimeoutAction
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()