{- This file is part of time-out.
 -
 - Written in 2016 by fr33domlover <fr33domlover@riseup.net>.
 -
 - ♡ Copying is an act of love. Please copy, reuse and share.
 -
 - The author(s) have dedicated all copyright and related and neighboring
 - rights to this software to the public domain worldwide. This software is
 - distributed without any warranty.
 -
 - You should have received a copy of the CC0 Public Domain Dedication along
 - with this software. If not, see
 - <http://creativecommons.org/publicdomain/zero/1.0/>.
 -}

{-# LANGUAGE MultiParamTypeClasses #-}

module Control.Timeout
    ( timeout
    , delay
    )
where

import Control.Concurrent
import Control.Monad (when)
import Control.Monad.Catch
import Control.Monad.IO.Class
import Control.Monad.Timeout.Class
import Data.List (genericReplicate)
import Data.Maybe (isJust)
import Data.Time.Units

data Timeout' = Timeout' deriving Show

instance Exception Timeout'

instance MonadTimeout IO IO where
    timeoutThrow t act = do
        result <- timeoutCatch t act
        case result of
            Nothing -> throwM Timeout'
            Just a  -> return a

    timeoutCatch = timeout

-- | If the action succeeds, return 'Just' the result. If a timeout exception
-- is thrown during the action, catch it and return 'Nothing'. Other exceptions
-- aren't caught.
catchTimeout :: (MonadIO m, MonadCatch m) => m a -> m (Maybe a)
catchTimeout action = catch (Just <$> action) $ \ Timeout' -> return Nothing

-- | Run a monadic action with a time limit. If it finishes before that time
-- passes and returns value @x@, then @Just x@ is returned. If the timeout
-- passes, the action is aborted and @Nothing@ is returned. If the action
-- throws an exception, it is aborted and the exception is rethrown.
--
-- >>> timeout (3 :: Second) $ delay (1 :: Second) >> return "hello"
-- Just "hello"
--
-- >>> timeout (3 :: Second) $ delay (5 :: Second) >> return "hello"
-- Nothing
--
-- >>> timeout (1 :: Second) $ error "hello"
-- *** Exception: hello
timeout :: (TimeUnit t, MonadIO m, MonadCatch m) => t -> m a -> m (Maybe a)
timeout time action = do
    tidMain <- liftIO myThreadId
    tidTemp <- liftIO $ forkIO $ delay time >> throwTo tidMain Timeout'
    result <- catchTimeout action `onException` liftIO (killThread tidTemp)
    when (isJust result) $ liftIO $ killThread tidTemp
    return result

delayInt :: MonadIO m => Int -> m ()
delayInt usec = liftIO $ threadDelay usec

delayInteger :: MonadIO m => Integer -> m ()
delayInteger usec =
    when (usec > 0) $ do
        let maxInt = maxBound :: Int
            (times, rest) = usec `divMod` toInteger maxInt
        sequence_ $ genericReplicate times $ delayInt maxInt
        delayInt $ fromInteger rest

-- | Suspend the current thread for the given amount of time.
--
-- Example:
--
-- > delay (5 :: Second)
delay :: (TimeUnit t, MonadIO m) => t -> m ()
delay = delayInteger . toMicroseconds