module Network.WebSockets.Connection.PingPong
    ( withPingPong
    , PingPongOptions(..)
    , PongTimeout(..)
    , defaultPingPongOptions
    ) where 

import Control.Concurrent.Async as Async
import Control.Exception
import Control.Monad (void)
import Network.WebSockets.Connection (Connection, connectionHeartbeat, pingThread)
import Control.Concurrent.MVar (takeMVar)
import System.Timeout (timeout)


-- | Exception type used to kill connections if there
-- is a pong timeout.
data PongTimeout = PongTimeout deriving Int -> PongTimeout -> ShowS
[PongTimeout] -> ShowS
PongTimeout -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PongTimeout] -> ShowS
$cshowList :: [PongTimeout] -> ShowS
show :: PongTimeout -> String
$cshow :: PongTimeout -> String
showsPrec :: Int -> PongTimeout -> ShowS
$cshowsPrec :: Int -> PongTimeout -> ShowS
Show

instance Exception PongTimeout


-- | Options for ping-pong
-- 
-- Make sure that the ping interval is less than the pong timeout,
-- for example N/2.
data PingPongOptions = PingPongOptions {
    PingPongOptions -> Int
pingInterval :: Int, -- ^ Interval in seconds
    PingPongOptions -> Int
pongTimeout :: Int, -- ^ Timeout in seconds
    PingPongOptions -> IO ()
pingAction :: IO () -- ^ Action to perform after sending a ping
}

-- | Default options for ping-pong
-- 
--   Ping every 15 seconds, timeout after 30 seconds
defaultPingPongOptions :: PingPongOptions
defaultPingPongOptions :: PingPongOptions
defaultPingPongOptions = PingPongOptions {
    pingInterval :: Int
pingInterval = Int
15,
    pongTimeout :: Int
pongTimeout = Int
30,
    pingAction :: IO ()
pingAction = forall (m :: * -> *) a. Monad m => a -> m a
return ()
}

-- | Run an application with ping-pong enabled. Raises PongTimeout if a pong is not received.
-- 
-- Can used with Client and Server connections.
withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO ()
withPingPong :: PingPongOptions -> Connection -> (Connection -> IO ()) -> IO ()
withPingPong PingPongOptions
options Connection
connection Connection -> IO ()
app = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ 
    forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Connection -> IO ()
app Connection
connection) forall a b. (a -> b) -> a -> b
$ \Async ()
appAsync -> do
        forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (Connection -> Int -> IO () -> IO ()
pingThread Connection
connection (PingPongOptions -> Int
pingInterval PingPongOptions
options) (PingPongOptions -> IO ()
pingAction PingPongOptions
options)) forall a b. (a -> b) -> a -> b
$ \Async ()
pingAsync -> do
            forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (IO ()
heartbeat forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> forall e a. Exception e => e -> IO a
throwIO PongTimeout
PongTimeout) forall a b. (a -> b) -> a -> b
$ \Async ()
heartbeatAsync -> do
                forall a. [Async a] -> IO (Async a, a)
waitAnyCancel [Async ()
appAsync, Async ()
pingAsync, Async ()
heartbeatAsync]
    where
        heartbeat :: IO ()
heartbeat = forall a. IO (Maybe a) -> IO ()
whileJust forall a b. (a -> b) -> a -> b
$ forall a. Int -> IO a -> IO (Maybe a)
timeout (PingPongOptions -> Int
pongTimeout PingPongOptions
options forall a. Num a => a -> a -> a
* Int
1000 forall a. Num a => a -> a -> a
* Int
1000) 
           forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
takeMVar (Connection -> MVar ()
connectionHeartbeat Connection
connection)

        -- Loop until action returns Nothing
        whileJust :: IO (Maybe a) -> IO ()
        whileJust :: forall a. IO (Maybe a) -> IO ()
whileJust IO (Maybe a)
action = do
            Maybe a
result <- IO (Maybe a)
action
            case Maybe a
result of
                Maybe a
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
                Just a
_ -> forall a. IO (Maybe a) -> IO ()
whileJust IO (Maybe a)
action