module Control.Monad.IOSimPOR.Timeout
  ( Timeout
  , timeout
  , unsafeTimeout
  ) where

-- This module provides a timeout function like System.Timeout, BUT
-- garbage collection time is not included (provided GHC stats are
-- enabled, +RTS -T -RTS). Thus this can be used more reliably to
-- limit computation time.

import           Control.Concurrent
import           Control.Exception (Exception (..), asyncExceptionFromException,
                     asyncExceptionToException, bracket, handleJust,
                     uninterruptibleMask_)
import           Control.Monad
import           Data.Unique (Unique, newUnique)
import           GHC.Stats
import           System.IO.Unsafe


-- An internal type that is thrown as a dynamic exception to
-- interrupt the running IO computation when the timeout has
-- expired.

-- | An exception thrown to a thread by 'timeout' to interrupt a timed-out
-- computation.

newtype Timeout = Timeout Unique deriving Timeout -> Timeout -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Timeout -> Timeout -> Bool
$c/= :: Timeout -> Timeout -> Bool
== :: Timeout -> Timeout -> Bool
$c== :: Timeout -> Timeout -> Bool
Eq

-- | @since 4.0
instance Show Timeout where
    show :: Timeout -> String
show Timeout
_ = String
"<<timeout>>"

instance Exception Timeout where
  toException :: Timeout -> SomeException
toException = forall e. Exception e => e -> SomeException
asyncExceptionToException
  fromException :: SomeException -> Maybe Timeout
fromException = forall e. Exception e => SomeException -> Maybe e
asyncExceptionFromException

timeout :: Int -> IO a -> IO (Maybe a)
timeout :: forall a. Int -> IO a -> IO (Maybe a)
timeout Int
n IO a
f
    | Int
n forall a. Ord a => a -> a -> Bool
<  Int
0    = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Maybe a
Just IO a
f
    | Int
n forall a. Eq a => a -> a -> Bool
== Int
0    = forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing
    | Bool
otherwise = do
        ThreadId
pid <- IO ThreadId
myThreadId
        Timeout
ex  <- forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Unique -> Timeout
Timeout IO Unique
newUnique
        forall e b a.
Exception e =>
(e -> Maybe b) -> (b -> IO a) -> IO a -> IO a
handleJust (\Timeout
e -> if Timeout
e forall a. Eq a => a -> a -> Bool
== Timeout
ex then forall a. a -> Maybe a
Just () else forall a. Maybe a
Nothing)
                   (\()
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a. Maybe a
Nothing)
                   (forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
unmask ->
                                 forall a. IO a -> IO a
unmask forall a b. (a -> b) -> a -> b
$ Int -> IO ()
waitFor Int
n forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall e. Exception e => ThreadId -> e -> IO ()
throwTo ThreadId
pid Timeout
ex)
                            (forall a. IO a -> IO a
uninterruptibleMask_ forall b c a. (b -> c) -> (a -> b) -> a -> c
. ThreadId -> IO ()
killThread)
                            (\ThreadId
_ -> forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall a. a -> Maybe a
Just IO a
f))

waitFor :: Int -> IO ()
waitFor :: Int -> IO ()
waitFor Int
n = do
  Int
t0 <- IO Int
getGCTime
  Int -> IO ()
threadDelay Int
n
  Int
t1 <- IO Int
getGCTime
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
t1 forall a. Ord a => a -> a -> Bool
> Int
t0) forall a b. (a -> b) -> a -> b
$
    -- allow some extra time because of GC
    Int -> IO ()
waitFor (Int
t1forall a. Num a => a -> a -> a
-Int
t0)

getGCTime :: IO Int
getGCTime :: IO Int
getGCTime = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a. Integral a => a -> a -> a
`div` RtsTime
1000) forall b c a. (b -> c) -> (a -> b) -> a -> c
. RTSStats -> RtsTime
gc_elapsed_ns forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO RTSStats
getRTSStats

-- | unsafeTimeout n a forces the evaluation of a, with a time limit of n microseconds.
unsafeTimeout :: Int -> a -> Maybe a
unsafeTimeout :: forall a. Int -> a -> Maybe a
unsafeTimeout Int
n a
a = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ forall a. Int -> IO a -> IO (Maybe a)
timeout Int
n forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! a
a