{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE CPP                 #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RankNTypes          #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Snap.Internal.Http.Server.TimeoutManager
  ( TimeoutManager
  , TimeoutThread
  , initialize
  , stop
  , register
  , tickle
  , set
  , modify
  , cancel
  ) where

------------------------------------------------------------------------------
import           Control.Exception                (evaluate, finally)
import qualified Control.Exception                as E
import           Control.Monad                    (Monad (return, (>>=)), mapM_, void, when)
import qualified Data.ByteString.Char8            as S
import           Data.IORef                       (IORef, newIORef, readIORef, writeIORef)
import           Prelude                          (Bool, Double, IO, Int, Show (..), const, fromIntegral, max, null, otherwise, round, ($), ($!), (+), (++), (-), (.), (<=), (==))
------------------------------------------------------------------------------
import           Control.Concurrent               (MVar, newEmptyMVar, putMVar, readMVar, takeMVar, tryPutMVar)
------------------------------------------------------------------------------
import           Snap.Internal.Http.Server.Clock  (ClockTime)
import qualified Snap.Internal.Http.Server.Clock  as Clock
import           Snap.Internal.Http.Server.Common (atomicModifyIORef', eatException)
import qualified Snap.Internal.Http.Server.Thread as T


------------------------------------------------------------------------------
type State = ClockTime

canceled :: State
canceled :: State
canceled = State
0

isCanceled :: State -> Bool
isCanceled :: State -> Bool
isCanceled = (State -> State -> Bool
forall a. Eq a => a -> a -> Bool
== State
0)


------------------------------------------------------------------------------
data TimeoutThread = TimeoutThread {
      TimeoutThread -> SnapThread
_thread     :: !T.SnapThread
    , TimeoutThread -> IORef State
_state      :: !(IORef State)
    , TimeoutThread -> IO State
_hGetTime   :: !(IO ClockTime)
    }

instance Show TimeoutThread where
    show :: TimeoutThread -> String
show = SnapThread -> String
forall a. Show a => a -> String
show (SnapThread -> String)
-> (TimeoutThread -> SnapThread) -> TimeoutThread -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeoutThread -> SnapThread
_thread


------------------------------------------------------------------------------
-- | Given a 'State' value and the current time, apply the given modification
-- function to the amount of time remaining.
--
smap :: ClockTime -> (ClockTime -> ClockTime) -> State -> State
smap :: State -> (State -> State) -> State -> State
smap State
now State -> State
f State
deadline | State -> Bool
isCanceled State
deadline = State
deadline
                    | Bool
otherwise = State
t'
  where
    remaining :: State
remaining    = State -> State -> State
forall a. Ord a => a -> a -> a
max State
0 (State
deadline State -> State -> State
forall a. Num a => a -> a -> a
- State
now)
    newremaining :: State
newremaining = State -> State
f State
remaining
    t' :: State
t'           = State
now State -> State -> State
forall a. Num a => a -> a -> a
+ State
newremaining


------------------------------------------------------------------------------
data TimeoutManager = TimeoutManager {
      TimeoutManager -> State
_defaultTimeout :: !ClockTime
    , TimeoutManager -> State
_pollInterval   :: !ClockTime
    , TimeoutManager -> IO State
_getTime        :: !(IO ClockTime)
    , TimeoutManager -> IORef [TimeoutThread]
_threads        :: !(IORef [TimeoutThread])
    , TimeoutManager -> MVar ()
_morePlease     :: !(MVar ())
    , TimeoutManager -> MVar SnapThread
_managerThread  :: !(MVar T.SnapThread)
    }


------------------------------------------------------------------------------
-- | Create a new TimeoutManager.
initialize :: Double            -- ^ default timeout
           -> Double            -- ^ poll interval
           -> IO ClockTime      -- ^ function to get current time
           -> IO TimeoutManager
initialize :: Double -> Double -> IO State -> IO TimeoutManager
initialize Double
defaultTimeout Double
interval IO State
getTime = IO TimeoutManager -> IO TimeoutManager
forall a. IO a -> IO a
E.uninterruptibleMask_ (IO TimeoutManager -> IO TimeoutManager)
-> IO TimeoutManager -> IO TimeoutManager
forall a b. (a -> b) -> a -> b
$ do
    IORef [TimeoutThread]
conns <- [TimeoutThread] -> IO (IORef [TimeoutThread])
forall a. a -> IO (IORef a)
newIORef []
    MVar ()
mp    <- IO (MVar ())
forall a. IO (MVar a)
newEmptyMVar
    MVar SnapThread
mthr  <- IO (MVar SnapThread)
forall a. IO (MVar a)
newEmptyMVar

    let tm :: TimeoutManager
tm = State
-> State
-> IO State
-> IORef [TimeoutThread]
-> MVar ()
-> MVar SnapThread
-> TimeoutManager
TimeoutManager (Double -> State
Clock.fromSecs Double
defaultTimeout)
                            (Double -> State
Clock.fromSecs Double
interval)
                            IO State
getTime
                            IORef [TimeoutThread]
conns
                            MVar ()
mp
                            MVar SnapThread
mthr

    SnapThread
thr <- ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
T.fork ByteString
"snap-server: timeout manager" (((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread)
-> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
forall a b. (a -> b) -> a -> b
$ TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
managerThread TimeoutManager
tm
    MVar SnapThread -> SnapThread -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar SnapThread
mthr SnapThread
thr
    TimeoutManager -> IO TimeoutManager
forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutManager
tm


------------------------------------------------------------------------------
-- | Stop a TimeoutManager.
stop :: TimeoutManager -> IO ()
stop :: TimeoutManager -> IO ()
stop TimeoutManager
tm = MVar SnapThread -> IO SnapThread
forall a. MVar a -> IO a
readMVar (TimeoutManager -> MVar SnapThread
_managerThread TimeoutManager
tm) IO SnapThread -> (SnapThread -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SnapThread -> IO ()
T.cancelAndWait


------------------------------------------------------------------------------
wakeup :: TimeoutManager -> IO ()
wakeup :: TimeoutManager -> IO ()
wakeup TimeoutManager
tm = IO Bool -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Bool -> IO ()) -> IO Bool -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar () -> () -> IO Bool
forall a. MVar a -> a -> IO Bool
tryPutMVar (TimeoutManager -> MVar ()
_morePlease TimeoutManager
tm) (() -> IO Bool) -> () -> IO Bool
forall a b. (a -> b) -> a -> b
$! ()


------------------------------------------------------------------------------
-- | Register a new thread with the TimeoutManager.
register :: TimeoutManager                        -- ^ manager to register
                                                  --   with
         -> S.ByteString                          -- ^ thread label
         -> ((forall a . IO a -> IO a) -> IO ())  -- ^ thread action to run
         -> IO TimeoutThread
register :: TimeoutManager
-> ByteString
-> ((forall a. IO a -> IO a) -> IO ())
-> IO TimeoutThread
register TimeoutManager
tm ByteString
label (forall a. IO a -> IO a) -> IO ()
action = do
    State
now <- IO State
getTime
    let !state :: State
state = State
now State -> State -> State
forall a. Num a => a -> a -> a
+ State
defaultTimeout
    IORef State
stateRef <- State -> IO (IORef State)
forall a. a -> IO (IORef a)
newIORef State
state
    TimeoutThread
th <- IO TimeoutThread -> IO TimeoutThread
forall a. IO a -> IO a
E.uninterruptibleMask_ (IO TimeoutThread -> IO TimeoutThread)
-> IO TimeoutThread -> IO TimeoutThread
forall a b. (a -> b) -> a -> b
$ do
        SnapThread
t <- ByteString -> ((forall a. IO a -> IO a) -> IO ()) -> IO SnapThread
T.fork ByteString
label (forall a. IO a -> IO a) -> IO ()
action
        let h :: TimeoutThread
h = SnapThread -> IORef State -> IO State -> TimeoutThread
TimeoutThread SnapThread
t IORef State
stateRef IO State
getTime
        IORef [TimeoutThread]
-> ([TimeoutThread] -> ([TimeoutThread], ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [TimeoutThread]
threads (\[TimeoutThread]
x -> (TimeoutThread
hTimeoutThread -> [TimeoutThread] -> [TimeoutThread]
forall a. a -> [a] -> [a]
:[TimeoutThread]
x, ())) IO () -> (() -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= () -> IO ()
forall a. a -> IO a
evaluate
        TimeoutThread -> IO TimeoutThread
forall (m :: * -> *) a. Monad m => a -> m a
return (TimeoutThread -> IO TimeoutThread)
-> TimeoutThread -> IO TimeoutThread
forall a b. (a -> b) -> a -> b
$! TimeoutThread
h
    TimeoutManager -> IO ()
wakeup TimeoutManager
tm
    TimeoutThread -> IO TimeoutThread
forall (m :: * -> *) a. Monad m => a -> m a
return TimeoutThread
th

  where
    getTime :: IO State
getTime        = TimeoutManager -> IO State
_getTime TimeoutManager
tm
    threads :: IORef [TimeoutThread]
threads        = TimeoutManager -> IORef [TimeoutThread]
_threads TimeoutManager
tm
    defaultTimeout :: State
defaultTimeout = TimeoutManager -> State
_defaultTimeout TimeoutManager
tm


------------------------------------------------------------------------------
-- | Tickle the timeout on a connection to be at least N seconds into the
-- future. If the existing timeout is set for M seconds from now, where M > N,
-- then the timeout is unaffected.
tickle :: TimeoutThread -> Int -> IO ()
tickle :: TimeoutThread -> Int -> IO ()
tickle TimeoutThread
th = TimeoutThread -> (Int -> Int) -> IO ()
modify TimeoutThread
th ((Int -> Int) -> IO ()) -> (Int -> Int -> Int) -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Ord a => a -> a -> a
max
{-# INLINE tickle #-}


------------------------------------------------------------------------------
-- | Set the timeout on a connection to be N seconds into the future.
set :: TimeoutThread -> Int -> IO ()
set :: TimeoutThread -> Int -> IO ()
set TimeoutThread
th = TimeoutThread -> (Int -> Int) -> IO ()
modify TimeoutThread
th ((Int -> Int) -> IO ()) -> (Int -> Int -> Int) -> Int -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a b. a -> b -> a
const
{-# INLINE set #-}


------------------------------------------------------------------------------
-- | Modify the timeout with the given function.
modify :: TimeoutThread -> (Int -> Int) -> IO ()
modify :: TimeoutThread -> (Int -> Int) -> IO ()
modify TimeoutThread
th Int -> Int
f = do
    State
now   <- IO State
getTime
    State
state <- IORef State -> IO State
forall a. IORef a -> IO a
readIORef IORef State
stateRef
    let !state' :: State
state' = State -> (State -> State) -> State -> State
smap State
now State -> State
f' State
state
    IORef State -> State -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef State
stateRef State
state'

  where
    f' :: State -> State
f' !State
x    = Double -> State
Clock.fromSecs (Double -> State) -> Double -> State
forall a b. (a -> b) -> a -> b
$! Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Int -> Int
f (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ Double -> Int
forall a b. (RealFrac a, Integral b) => a -> b
round (Double -> Int) -> Double -> Int
forall a b. (a -> b) -> a -> b
$ State -> Double
Clock.toSecs State
x
    getTime :: IO State
getTime  = TimeoutThread -> IO State
_hGetTime TimeoutThread
th
    stateRef :: IORef State
stateRef = TimeoutThread -> IORef State
_state TimeoutThread
th
{-# INLINE modify #-}


------------------------------------------------------------------------------
-- | Cancel a timeout.
cancel :: TimeoutThread -> IO ()
cancel :: TimeoutThread -> IO ()
cancel TimeoutThread
h = IO () -> IO ()
forall a. IO a -> IO a
E.uninterruptibleMask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
    IORef State -> State -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (TimeoutThread -> IORef State
_state TimeoutThread
h) State
canceled
    SnapThread -> IO ()
T.cancel (SnapThread -> IO ()) -> SnapThread -> IO ()
forall a b. (a -> b) -> a -> b
$ TimeoutThread -> SnapThread
_thread TimeoutThread
h
{-# INLINE cancel #-}


------------------------------------------------------------------------------
managerThread :: TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
managerThread :: TimeoutManager -> (forall a. IO a -> IO a) -> IO ()
managerThread TimeoutManager
tm forall a. IO a -> IO a
restore = IO () -> IO ()
forall a. IO a -> IO a
restore IO ()
forall b. IO b
loop IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` IO ()
cleanup
  where
    cleanup :: IO ()
cleanup = IO () -> IO ()
forall a. IO a -> IO a
E.uninterruptibleMask_ (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
              IO () -> IO ()
forall a. IO a -> IO ()
eatException (IORef [TimeoutThread] -> IO [TimeoutThread]
forall a. IORef a -> IO a
readIORef IORef [TimeoutThread]
threads IO [TimeoutThread] -> ([TimeoutThread] -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [TimeoutThread] -> IO ()
forall (t :: * -> *). Foldable t => t TimeoutThread -> IO ()
destroyAll)

    --------------------------------------------------------------------------
    getTime :: IO State
getTime      = TimeoutManager -> IO State
_getTime TimeoutManager
tm
    morePlease :: MVar ()
morePlease   = TimeoutManager -> MVar ()
_morePlease TimeoutManager
tm
    pollInterval :: State
pollInterval = TimeoutManager -> State
_pollInterval TimeoutManager
tm
    threads :: IORef [TimeoutThread]
threads      = TimeoutManager -> IORef [TimeoutThread]
_threads TimeoutManager
tm

    --------------------------------------------------------------------------
    loop :: IO b
loop = do
        State
now <- IO State
getTime
        ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
E.uninterruptibleMask (((forall a. IO a -> IO a) -> IO ()) -> IO ())
-> ((forall a. IO a -> IO a) -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \forall a. IO a -> IO a
restore' -> do
            [TimeoutThread]
handles <- IORef [TimeoutThread]
-> ([TimeoutThread] -> ([TimeoutThread], [TimeoutThread]))
-> IO [TimeoutThread]
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [TimeoutThread]
threads (\[TimeoutThread]
x -> ([], [TimeoutThread]
x))
            if [TimeoutThread] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [TimeoutThread]
handles
              then do IO () -> IO ()
forall a. IO a -> IO a
restore' (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ MVar () -> IO ()
forall a. MVar a -> IO a
takeMVar MVar ()
morePlease
              else do
                [TimeoutThread]
handles' <- State -> [TimeoutThread] -> IO [TimeoutThread]
processHandles State
now [TimeoutThread]
handles
                IORef [TimeoutThread]
-> ([TimeoutThread] -> ([TimeoutThread], ())) -> IO ()
forall a b. IORef a -> (a -> (a, b)) -> IO b
atomicModifyIORef' IORef [TimeoutThread]
threads (\[TimeoutThread]
x -> ([TimeoutThread]
handles' [TimeoutThread] -> [TimeoutThread] -> [TimeoutThread]
forall a. [a] -> [a] -> [a]
++ [TimeoutThread]
x, ()))
                    IO () -> (() -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= () -> IO ()
forall a. a -> IO a
evaluate
        State -> IO ()
Clock.sleepFor State
pollInterval
        IO b
loop

    --------------------------------------------------------------------------
    processHandles :: State -> [TimeoutThread] -> IO [TimeoutThread]
processHandles State
now [TimeoutThread]
handles = [TimeoutThread] -> [TimeoutThread] -> IO [TimeoutThread]
go [TimeoutThread]
handles []
      where
        go :: [TimeoutThread] -> [TimeoutThread] -> IO [TimeoutThread]
go [] ![TimeoutThread]
kept = [TimeoutThread] -> IO [TimeoutThread]
forall (m :: * -> *) a. Monad m => a -> m a
return ([TimeoutThread] -> IO [TimeoutThread])
-> [TimeoutThread] -> IO [TimeoutThread]
forall a b. (a -> b) -> a -> b
$! [TimeoutThread]
kept

        go (TimeoutThread
x:[TimeoutThread]
xs) ![TimeoutThread]
kept = do
            !State
state <- IORef State -> IO State
forall a. IORef a -> IO a
readIORef (IORef State -> IO State) -> IORef State -> IO State
forall a b. (a -> b) -> a -> b
$ TimeoutThread -> IORef State
_state TimeoutThread
x
            ![TimeoutThread]
kept' <-
                if State -> Bool
isCanceled State
state
                  then do Bool
b <- SnapThread -> IO Bool
T.isFinished (TimeoutThread -> SnapThread
_thread TimeoutThread
x)
                          [TimeoutThread] -> IO [TimeoutThread]
forall (m :: * -> *) a. Monad m => a -> m a
return ([TimeoutThread] -> IO [TimeoutThread])
-> [TimeoutThread] -> IO [TimeoutThread]
forall a b. (a -> b) -> a -> b
$! if Bool
b
                                      then [TimeoutThread]
kept
                                      else (TimeoutThread
xTimeoutThread -> [TimeoutThread] -> [TimeoutThread]
forall a. a -> [a] -> [a]
:[TimeoutThread]
kept)
                  else do Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (State
state State -> State -> Bool
forall a. Ord a => a -> a -> Bool
<= State
now) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                            SnapThread -> IO ()
T.cancel (TimeoutThread -> SnapThread
_thread TimeoutThread
x)
                            IORef State -> State -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef (TimeoutThread -> IORef State
_state TimeoutThread
x) State
canceled
                          [TimeoutThread] -> IO [TimeoutThread]
forall (m :: * -> *) a. Monad m => a -> m a
return (TimeoutThread
xTimeoutThread -> [TimeoutThread] -> [TimeoutThread]
forall a. a -> [a] -> [a]
:[TimeoutThread]
kept)
            [TimeoutThread] -> [TimeoutThread] -> IO [TimeoutThread]
go [TimeoutThread]
xs [TimeoutThread]
kept'

    --------------------------------------------------------------------------
    destroyAll :: t TimeoutThread -> IO ()
destroyAll t TimeoutThread
xs = do
        (TimeoutThread -> IO ()) -> t TimeoutThread -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SnapThread -> IO ()
T.cancel (SnapThread -> IO ())
-> (TimeoutThread -> SnapThread) -> TimeoutThread -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeoutThread -> SnapThread
_thread) t TimeoutThread
xs
        (TimeoutThread -> IO ()) -> t TimeoutThread -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (SnapThread -> IO ()
T.wait (SnapThread -> IO ())
-> (TimeoutThread -> SnapThread) -> TimeoutThread -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TimeoutThread -> SnapThread
_thread) t TimeoutThread
xs