{-# LANGUAGE FlexibleInstances, TypeSynonymInstances, LambdaCase #-}
module Control.Time (
  -- * Delays
    delay, delayTill
  -- * Timeouts
  , timeout, timeoutAt
  -- * Callbacks
  , CallbackKey
  , callbackAfter, callbackAt
  , updateCallbackToAfter, updateCallbackTo
  , cancelCallback
  -- * Time period conversion
  , AsMicro(..)
  ) where

-- Some writings on the problems with leap seconds:
-- http://www.madore.org/~david/computers/unix-leap-seconds.html
-- http://www.leapsecond.com/java/gpsclock.htm

import           Control.Concurrent
import qualified Control.Concurrent.Thread.Delay as D
import qualified Control.Concurrent.Timeout as T
import qualified Control.Exception as E
import qualified Control.Monad.Catch as MC
import           Control.Monad
import           Control.Monad.Trans
import           Data.Fixed
import           Data.Int
import           Data.Time
import           Data.Typeable
import           Data.Unique
import           Data.Word
import qualified GHC.Event as Ev
import           Numeric.Natural
import           Numeric.Units.Dimensional ((/~))
import qualified Numeric.Units.Dimensional as D
import qualified Numeric.Units.Dimensional.SIUnits as D

-- | Delay a until at least the specified amount of time has passed.
delay :: (MonadIO m, AsMicro period) => period -> m ()
delay = liftIO . D.delay . toMicro

day :: NominalDiffTime
day = 86400

-- | Delay until a specific 'UTCTime' has occured (at least once).
--   This is slighly confusing, as we can't guarantee we don't return only after the second
--   occurence of the 'UTCTime' under certain leap second handling regimes. Consider for example
--   when 'delayTill' is called during a leap second occurence, where the system clock jumps back
--   and repeats the second. As there is no indication the time has already passed once, we must
--    wait until the second occurence.
delayTill :: (MonadIO m) => UTCTime -> m ()
delayTill t = liftIO $ do
  n <- getCurrentTime
  case n >= t of
    True -> return ()
    False -> do
      -- A maximum wait of 86400 seconds at a time, so we can't be off by more then one second in the case of
      -- backwards leap seconds.
      delay (realToFrac . min day . diffUTCTime t $ n::Pico)
      -- in the case of waits across leap seconds, we may not have waited long enough.
      delayTill t

-- 'timeout'/'timeoutAt' code based on 'timeout' in base, which is
-- (c) The University of Glasgow 2007.

newtype Timeout = Timeout Unique deriving (Eq, Typeable)

instance Show Timeout where
    show _ = "<<timeout>>"

-- Timeout is a child of SomeAsyncException
instance E.Exception Timeout where
  toException = E.asyncExceptionToException
  fromException = E.asyncExceptionFromException

-- | Run an monad action for at least a specific number of seconds, but not much more.
timeout :: (MonadIO m, MC.MonadMask m, AsMicro period) => period -> m a -> m (Maybe a)
timeout p a | 0 >= toMicro p = return Nothing
timeout p a = do
  pid <- liftIO myThreadId
  ex  <- liftIO $ fmap Timeout newUnique
  MC.handleJust (\e -> if e == ex then Just () else Nothing)
                 (\_ -> return Nothing)
                 (MC.bracket (liftIO $ forkIOWithUnmask $ \unmask ->
                               unmask $ delay p >> E.throwTo pid ex)
                  (MC.uninterruptibleMask_ . liftIO . killThread)
                  (\_ -> fmap Just a))

-- | Run a monadic action until it produces a result or a specific time occures.
--   Leap second handling as per delayTill.
timeoutAt :: (MonadIO m, MC.MonadMask m) => UTCTime -> m a -> m (Maybe a)
timeoutAt t a = do
  now <- liftIO $ getCurrentTime
  case now >= t of
    True -> return Nothing
    False -> do
      pid <- liftIO myThreadId
      ex  <- liftIO $ fmap Timeout newUnique
      MC.handleJust (\e -> if e == ex then Just () else Nothing)
            (\_ -> return Nothing)
            (MC.bracket (liftIO $ forkIOWithUnmask $ \unmask ->
                             unmask $ delayTill t >> E.throwTo pid ex)
             (MC.uninterruptibleMask_ . liftIO . killThread)
             (\_ -> fmap Just a))

-- Make it clear what units the Integer holds.
newtype MicroSeconds = MS Integer

-- | Our actual callback data.
--   CallbackCancel so that we can know if we're canceled even if we raced with execution.
data CallbackHandle =
    CallbackCanceled
  | CallbackAt Ev.TimeoutKey UTCTime (IO ())
  | CallbackAfter Ev.TimeoutKey MicroSeconds (IO ())

-- | We hold the CallbackHandle in an MVar so we don't race on update/cancelation
--   and executing the callback.
--   Updating or canceling holds the lock, which is also required to export the action.
type CallbackKey = MVar CallbackHandle

-- | Actually register our next wait, or actually execute the callback.
--   (or, realize we were cancled)
doCallback :: CallbackKey -> IO ()
doCallback ck = do
  mngr <- Ev.getSystemTimerManager
  delayedAction <- modifyMVarMasked ck $ \case
    CallbackCanceled -> return (CallbackCanceled, return ())
    CallbackAt _ t act -> do
      n <- getCurrentTime
      case t `diffUTCTime` n of
        w | w <= 0 -> return (CallbackCanceled, act)
        w -> do
          newKey <- Ev.registerTimeout mngr
                   (fromInteger . toMicro $ (realToFrac . min day $ w::Pico))
                   (doCallback ck)
          return (CallbackAt newKey t act, return ())
    CallbackAfter _ (MS w) act | w <= 0 -> return (CallbackCanceled, act)
    CallbackAfter _ (MS w) act -> do
      let stepAmount = min (toInteger (maxBound::Int)) w
      newKey <- Ev.registerTimeout mngr
               (fromInteger stepAmount)
               (doCallback ck)
      return (CallbackAfter newKey (MS $ w - stepAmount) act, return ())
  delayedAction

-- | Run a callback after a period of time has passed.
callbackAfter :: (MonadIO m, AsMicro period) => period -> IO () -> m CallbackKey
callbackAfter p act = liftIO $ do
  h <- newMVar (CallbackAfter undefined (MS . toMicro $ p) act)
  -- doCallback never looks at the Ev.TimeoutKey, but does set it, so its not undefined when we
  -- return, which is the first time someone could call an update or cancel.
  doCallback h
  return h

-- | Run a callback at a specific time.
callbackAt :: MonadIO m => UTCTime -> IO () -> m CallbackKey
callbackAt t act = liftIO $ do
  h <- newMVar (CallbackAt undefined t act)
  -- doCallback never looks at the Ev.TimeoutKey, but does set it, so its not undefined when we
  -- return, which is the first time someone could call an update or cancel.
  doCallback h
  return h

-- | Change an existing, unexecuted or canceled callbak to run after a specific period of time
--   from the call of 'updateCallbackToAfter'.
updateCallbackToAfter :: (MonadIO m, AsMicro period) => CallbackKey -> period -> m ()
updateCallbackToAfter ck p = liftIO $ do
    delayed <- modifyMVarMasked ck $ \case
      CallbackCanceled -> return (CallbackCanceled, return ())
      CallbackAt tk _ act -> reRegister tk act
      CallbackAfter tk _ act -> reRegister tk act
    delayed
  where
    p' :: Integer
    p' = toMicro p
    reRegister :: Ev.TimeoutKey -> IO () -> IO (CallbackHandle, IO ())
    reRegister tk act = do
      mngr <- Ev.getSystemTimerManager
      Ev.unregisterTimeout mngr tk
      let stepAmount = fromInteger . min (toInteger (maxBound::Int)) $ p'
      case stepAmount <= 0 of
        True -> return (CallbackCanceled, void . forkIO $ act)
        False -> do
          newKey <- Ev.registerTimeout mngr
                   (fromInteger stepAmount)
                   (doCallback ck)
          return (CallbackAfter newKey (MS $ p' - stepAmount) act, return ())

-- | Change an existing, unexecuted or canceled callbak to run after at a specific time.
updateCallbackTo :: MonadIO m => CallbackKey -> UTCTime -> m ()
updateCallbackTo ck t = liftIO $ do
    delayed <- modifyMVarMasked ck $ \case
      CallbackCanceled -> return (CallbackCanceled, return ())
      CallbackAt tk _ act -> reRegister tk act
      CallbackAfter tk _ act -> reRegister tk act
    delayed
  where
    reRegister :: Ev.TimeoutKey -> IO () -> IO (CallbackHandle, IO ())
    reRegister tk act = do
      mngr <- Ev.getSystemTimerManager
      Ev.unregisterTimeout mngr tk
      n <- getCurrentTime
      let w = t `diffUTCTime` n
      let stepAmount = fromInteger . toMicro $ (realToFrac . max 0 . min day $ w::Pico)
      case stepAmount <= 0 of
        True -> return (CallbackCanceled, void . forkIO $ act)
        False -> do
          newKey <- Ev.registerTimeout mngr stepAmount (doCallback ck)
          return (CallbackAt newKey t act, return ())

-- | Terminate an unexecuted callback.
cancelCallback :: MonadIO m => CallbackKey -> m ()
cancelCallback ck =
    liftIO . modifyMVarMasked_ ck $ \case
      CallbackCanceled -> return CallbackCanceled
      CallbackAt tk _ _ -> cancelCB tk
      CallbackAfter tk _ _ -> cancelCB tk
  where
    cancelCB :: Ev.TimeoutKey -> IO CallbackHandle
    cancelCB tk = do
      mngr <- Ev.getSystemTimerManager
      Ev.unregisterTimeout mngr tk
      return CallbackCanceled

microPrecision :: Num n => n
microPrecision = (10^(6::Int))

-- | Calculate the number of microseconds of delay a value represents.
--   Instances must round up for correctness.
class AsMicro d where
  toMicro :: d -> Integer

-- UTCTime and NominalDiffTime can't have an instance of AsMicro since we can't calculate the
-- number of microseconds until such a time actually occurs.
-- To be specific: over a long period there could be a decision to institute a leap second,
-- in either direction, for UTC.  This could cause us to return early.

instance AsMicro DiffTime where
  toMicro = ceiling . (*) microPrecision

instance (Fractional n, AsMicro n) => AsMicro (D.Time n) where
  toMicro t = toMicro (t /~ D.second)

-- | As seconds.
instance AsMicro Integer where
  toMicro = (*) microPrecision

-- | As seconds.
instance AsMicro Natural where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Int where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Int8 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Int16 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Int32 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Int64 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Word where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Word8 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Word16 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Word32 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Word64 where
  toMicro = toMicro . toInteger

-- | As seconds.
instance AsMicro Float where
  toMicro = ceiling . (*) microPrecision

-- | As seconds.
instance AsMicro Double where
  toMicro = ceiling . (*) microPrecision

-- | As seconds.
instance HasResolution d => AsMicro (Fixed d) where
  toMicro (d@(MkFixed v)) = (v *  microPrecision) `div` (resolution d)