{-# LANGUAGE DefaultSignatures   #-}
{-# LANGUAGE FlexibleContexts    #-}
{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}

module Control.Monad.Class.MonadTimer
  ( MonadDelay (..)
  , MonadTimer (..)
  ) where

import qualified Control.Concurrent as IO
import           Control.Concurrent.Class.MonadSTM
import qualified Control.Concurrent.STM.TVar as STM

import           Control.Monad.Reader (ReaderT (..))
import           Control.Monad.Trans (lift)

import qualified System.Timeout as IO

class Monad m => MonadDelay m where
  threadDelay :: Int -> m ()

class (MonadDelay m, MonadSTM m) => MonadTimer m where

  registerDelay :: Int -> m (TVar m Bool)

  timeout :: Int -> m a -> m (Maybe a)

--
-- Instances for IO
--

-- | With 'threadDelay' one can use arbitrary large 'DiffTime's, which is an
-- advantage over 'IO.threadDelay'.
--
instance MonadDelay IO where
  threadDelay :: Int -> IO ()
threadDelay = Int -> IO ()
IO.threadDelay


instance MonadTimer IO where

  registerDelay :: Int -> IO (TVar IO Bool)
registerDelay = Int -> IO (TVar Bool)
STM.registerDelay
  timeout :: forall a. Int -> IO a -> IO (Maybe a)
timeout = forall a. Int -> IO a -> IO (Maybe a)
IO.timeout

--
-- Transformer's instances
--

instance MonadDelay m => MonadDelay (ReaderT r m) where
  threadDelay :: Int -> ReaderT r m ()
threadDelay = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadDelay m => Int -> m ()
threadDelay

instance MonadTimer m => MonadTimer (ReaderT r m) where
  registerDelay :: Int -> ReaderT r m (TVar (ReaderT r m) Bool)
registerDelay = forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *). MonadTimer m => Int -> m (TVar m Bool)
registerDelay
  timeout :: forall a. Int -> ReaderT r m a -> ReaderT r m (Maybe a)
timeout Int
d ReaderT r m a
f   = forall r (m :: * -> *) a. (r -> m a) -> ReaderT r m a
ReaderT forall a b. (a -> b) -> a -> b
$ \r
r -> forall (m :: * -> *) a. MonadTimer m => Int -> m a -> m (Maybe a)
timeout Int
d (forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT r m a
f r
r)