{-# LANGUAGE OverloadedStrings #-}
module Network.WebSockets.Server
( ServerApp
, runServer
, ServerOptions (..)
, defaultServerOptions
, runServerWithOptions
, runServerWith
, makeListenSocket
, makePendingConnection
, makePendingConnectionFromStream
, PongTimeout
) where
import Control.Concurrent (threadDelay)
import qualified Control.Concurrent.Async as Async
import Control.Exception (Exception, allowInterrupt,
bracket, bracketOnError,
finally, mask_, throwIO)
import Control.Monad (forever, void, when)
import qualified Data.IORef as IORef
import Data.Maybe (isJust)
import Network.Socket (Socket)
import qualified Network.Socket as S
import qualified System.Clock as Clock
import Network.WebSockets.Connection
import Network.WebSockets.Http
import qualified Network.WebSockets.Stream as Stream
import Network.WebSockets.Types
type ServerApp = PendingConnection -> IO ()
runServer :: String
-> Int
-> ServerApp
-> IO ()
runServer host port app = runServerWith host port defaultConnectionOptions app
runServerWith :: String -> Int -> ConnectionOptions -> ServerApp -> IO ()
runServerWith host port opts = runServerWithOptions defaultServerOptions
{ serverHost = host
, serverPort = port
, serverConnectionOptions = opts
}
{-# DEPRECATED runServerWith "Use 'runServerWithOptions' instead" #-}
data ServerOptions = ServerOptions
{ serverHost :: String
, serverPort :: Int
, serverConnectionOptions :: ConnectionOptions
, serverRequirePong :: Maybe Int
}
defaultServerOptions :: ServerOptions
defaultServerOptions = ServerOptions
{ serverHost = "127.0.0.1"
, serverPort = 8080
, serverConnectionOptions = defaultConnectionOptions
, serverRequirePong = Nothing
}
runServerWithOptions :: ServerOptions -> ServerApp -> IO a
runServerWithOptions opts app = S.withSocketsDo $
bracket
(makeListenSocket host port)
S.close $ \sock -> mask_ $ forever $ do
allowInterrupt
(conn, _) <- S.accept sock
killRef <- IORef.newIORef =<< (+ killDelay) <$> getSecs
let tickle = IORef.writeIORef killRef =<< (+ killDelay) <$> getSecs
let connOpts'
| not useKiller = connOpts
| otherwise = connOpts
{ connectionOnPong = tickle >> connectionOnPong connOpts
}
appAsync <- Async.asyncWithUnmask $ \unmask ->
(unmask $ do
runApp conn connOpts' app) `finally`
(S.close conn)
when useKiller $ void $ Async.async (killer killRef appAsync)
where
host = serverHost opts
port = serverPort opts
connOpts = serverConnectionOptions opts
getSecs = Clock.sec <$> Clock.getTime Clock.Monotonic
useKiller = isJust $ serverRequirePong opts
killDelay = maybe 0 fromIntegral (serverRequirePong opts)
killer killRef appAsync = do
killAt <- IORef.readIORef killRef
now <- getSecs
appState <- Async.poll appAsync
case appState of
Just _ -> return ()
Nothing | now < killAt -> do
threadDelay (fromIntegral killDelay * 1000 * 1000)
killer killRef appAsync
_ -> Async.cancelWith appAsync PongTimeout
makeListenSocket :: String -> Int -> IO Socket
makeListenSocket host port = do
addr:_ <- S.getAddrInfo (Just hints) (Just host) (Just (show port))
bracketOnError
(S.socket (S.addrFamily addr) S.Stream S.defaultProtocol)
S.close
(\sock -> do
_ <- S.setSocketOption sock S.ReuseAddr 1
_ <- S.setSocketOption sock S.NoDelay 1
S.bind sock (S.addrAddress addr)
S.listen sock 5
return sock
)
where
hints = S.defaultHints { S.addrSocketType = S.Stream }
runApp :: Socket
-> ConnectionOptions
-> ServerApp
-> IO ()
runApp socket opts app =
bracket
(makePendingConnection socket opts)
(Stream.close . pendingStream)
app
makePendingConnection
:: Socket -> ConnectionOptions -> IO PendingConnection
makePendingConnection socket opts = do
stream <- Stream.makeSocketStream socket
makePendingConnectionFromStream stream opts
makePendingConnectionFromStream
:: Stream.Stream -> ConnectionOptions -> IO PendingConnection
makePendingConnectionFromStream stream opts = do
mbRequest <- Stream.parse stream (decodeRequestHead False)
case mbRequest of
Nothing -> throwIO ConnectionClosed
Just request -> return PendingConnection
{ pendingOptions = opts
, pendingRequest = request
, pendingOnAccept = \_ -> return ()
, pendingStream = stream
}
data PongTimeout = PongTimeout deriving Show
instance Exception PongTimeout