{-|
Module      : PostgresWebsockets.Middleware
Description : PostgresWebsockets WAI middleware, add functionality to any WAI application.

Allow websockets connections that will communicate with the database through LISTEN/NOTIFY channels.
-}
{-# LANGUAGE DeriveGeneric #-}

module PostgresWebsockets.Middleware
  ( postgresWsMiddleware
  ) where

import Protolude hiding (toS)
import Protolude.Conv
import Data.Time.Clock (UTCTime)
import Data.Time.Clock.POSIX (utcTimeToPOSIXSeconds, posixSecondsToUTCTime)
import Control.Concurrent.AlarmClock (newAlarmClock, setAlarm)
import qualified Hasql.Notifications as H
import qualified Hasql.Pool as H
import qualified Network.Wai as Wai
import qualified Network.Wai.Handler.WebSockets as WS
import qualified Network.WebSockets as WS

import qualified Data.Aeson as A
import qualified Data.Aeson.KeyMap as A
import qualified Data.Aeson.Key as Key

import qualified Data.Text as T
import qualified Data.ByteString.Lazy as BL

import PostgresWebsockets.Broadcast (onMessage)
import PostgresWebsockets.Claims ( ConnectionInfo, validateClaims )
import PostgresWebsockets.Context ( Context(..) )
import PostgresWebsockets.Config (AppConfig(..))
import qualified PostgresWebsockets.Broadcast as B


data Event =
    WebsocketMessage
  | ConnectionOpen
  deriving (Int -> Event -> ShowS
[Event] -> ShowS
Event -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Event] -> ShowS
$cshowList :: [Event] -> ShowS
show :: Event -> String
$cshow :: Event -> String
showsPrec :: Int -> Event -> ShowS
$cshowsPrec :: Int -> Event -> ShowS
Show, Event -> Event -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Event -> Event -> Bool
$c/= :: Event -> Event -> Bool
== :: Event -> Event -> Bool
$c== :: Event -> Event -> Bool
Eq, forall x. Rep Event x -> Event
forall x. Event -> Rep Event x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Event x -> Event
$cfrom :: forall x. Event -> Rep Event x
Generic)

data Message = Message
  { Message -> Object
claims  :: A.Object
  , Message -> Event
event   :: Event
  , Message -> Text
payload :: Text
  , Message -> Text
channel :: Text
  } deriving (Int -> Message -> ShowS
[Message] -> ShowS
Message -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Message] -> ShowS
$cshowList :: [Message] -> ShowS
show :: Message -> String
$cshow :: Message -> String
showsPrec :: Int -> Message -> ShowS
$cshowsPrec :: Int -> Message -> ShowS
Show, Message -> Message -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Message -> Message -> Bool
$c/= :: Message -> Message -> Bool
== :: Message -> Message -> Bool
$c== :: Message -> Message -> Bool
Eq, forall x. Rep Message x -> Message
forall x. Message -> Rep Message x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cto :: forall x. Rep Message x -> Message
$cfrom :: forall x. Message -> Rep Message x
Generic)

instance A.ToJSON Event
instance A.ToJSON Message

-- | Given a secret, a function to fetch the system time, a Hasql Pool and a Multiplexer this will give you a WAI middleware.
postgresWsMiddleware :: Context -> Wai.Middleware
postgresWsMiddleware :: Context -> Middleware
postgresWsMiddleware =
  ConnectionOptions -> ServerApp -> Middleware
WS.websocketsOr ConnectionOptions
WS.defaultConnectionOptions forall b c a. (b -> c) -> (a -> b) -> a -> c
. Context -> ServerApp
wsApp

-- private functions
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode :: Word16
jwtExpirationStatusCode = Word16
3001

-- when the websocket is closed a ConnectionClosed Exception is triggered
-- this kills all children and frees resources for us
wsApp :: Context -> WS.ServerApp
wsApp :: Context -> ServerApp
wsApp Context{IO UTCTime
Pool
Multiplexer
AppConfig
ctxGetTime :: Context -> IO UTCTime
ctxMulti :: Context -> Multiplexer
ctxPool :: Context -> Pool
ctxConfig :: Context -> AppConfig
ctxGetTime :: IO UTCTime
ctxMulti :: Multiplexer
ctxPool :: Pool
ctxConfig :: AppConfig
..} PendingConnection
pendingConn =
  IO UTCTime
ctxGetTime forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Maybe Text
-> ByteString
-> LByteString
-> UTCTime
-> IO (Either Text ConnectionInfo)
validateClaims Maybe Text
requestChannel (AppConfig -> ByteString
configJwtSecret AppConfig
ctxConfig) (forall a b. StringConv a b => a -> b
toS Text
jwtToken) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Text -> IO ()
rejectRequest ConnectionInfo -> IO ()
forkSessions
  where
    hasRead :: Text -> Bool
hasRead Text
m = Text
m forall a. Eq a => a -> a -> Bool
== (Text
"r" :: Text) Bool -> Bool -> Bool
|| Text
m forall a. Eq a => a -> a -> Bool
== (Text
"rw" :: Text)
    hasWrite :: Text -> Bool
hasWrite Text
m = Text
m forall a. Eq a => a -> a -> Bool
== (Text
"w" :: Text) Bool -> Bool -> Bool
|| Text
m forall a. Eq a => a -> a -> Bool
== (Text
"rw" :: Text)

    rejectRequest :: Text -> IO ()
    rejectRequest :: Text -> IO ()
rejectRequest Text
msg = do
      forall a (m :: * -> *). (Print a, MonadIO m) => a -> m ()
putErrLn forall a b. (a -> b) -> a -> b
$ Text
"Rejecting Request: " forall a. Semigroup a => a -> a -> a
<> Text
msg
      PendingConnection -> ByteString -> IO ()
WS.rejectRequest PendingConnection
pendingConn (forall a b. StringConv a b => a -> b
toS Text
msg)

    -- the URI has one of the two formats - /:jwt or /:channel/:jwt
    pathElements :: [Text]
pathElements = (Char -> Bool) -> Text -> [Text]
T.split (forall a. Eq a => a -> a -> Bool
== Char
'/') forall a b. (a -> b) -> a -> b
$ Int -> Text -> Text
T.drop Int
1 forall a b. (a -> b) -> a -> b
$ (forall a b. StringConv a b => a -> b
toSL forall b c a. (b -> c) -> (a -> b) -> a -> c
. RequestHead -> ByteString
WS.requestPath) forall a b. (a -> b) -> a -> b
$ PendingConnection -> RequestHead
WS.pendingRequest PendingConnection
pendingConn
    jwtToken :: Text
jwtToken =
      case forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
pathElements forall a. Ord a => a -> a -> Ordering
`compare` Int
1 of
        Ordering
GT -> forall a. a -> [a] -> a
headDef Text
"" forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
tailSafe [Text]
pathElements
        Ordering
_ -> forall a. a -> [a] -> a
headDef Text
"" [Text]
pathElements
    requestChannel :: Maybe Text
requestChannel =
      case forall (t :: * -> *) a. Foldable t => t a -> Int
length [Text]
pathElements forall a. Ord a => a -> a -> Ordering
`compare` Int
1 of
        Ordering
GT -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. a -> [a] -> a
headDef Text
"" [Text]
pathElements
        Ordering
_ -> forall a. Maybe a
Nothing
    forkSessions :: ConnectionInfo -> IO ()
    forkSessions :: ConnectionInfo -> IO ()
forkSessions ([Text]
chs, Text
mode, Object
validClaims) = do
          -- We should accept only after verifying JWT
          Connection
conn <- PendingConnection -> IO Connection
WS.acceptRequest PendingConnection
pendingConn
          -- Fork a pinging thread to ensure browser connections stay alive
          forall a. Connection -> Int -> IO () -> IO a -> IO a
WS.withPingThread Connection
conn Int
30 (forall (f :: * -> *) a. Applicative f => a -> f a
pure ()) forall a b. (a -> b) -> a -> b
$ do
            case forall v. Key -> KeyMap v -> Maybe v
A.lookup Key
"exp" Object
validClaims of
              Just (A.Number Scientific
expClaim) -> do
                AlarmClock UTCTime
connectionExpirer <- forall t.
TimeScale t =>
(AlarmClock t -> IO ()) -> IO (AlarmClock t)
newAlarmClock forall a b. (a -> b) -> a -> b
$ forall a b. a -> b -> a
const (forall a. WebSocketsData a => Connection -> Word16 -> a -> IO ()
WS.sendCloseCode Connection
conn Word16
jwtExpirationStatusCode (ByteString
"JWT expired" :: ByteString))
                forall t. TimeScale t => AlarmClock t -> t -> IO ()
setAlarm AlarmClock UTCTime
connectionExpirer (POSIXTime -> UTCTime
posixSecondsToUTCTime forall a b. (a -> b) -> a -> b
$ forall a b. (Real a, Fractional b) => a -> b
realToFrac Scientific
expClaim)
              Just Value
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              Maybe Value
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

            let sendNotification :: Text -> Text -> IO ()
sendNotification Text
msg Text
channel = Message -> IO ()
sendMessageWithTimestamp forall a b. (a -> b) -> a -> b
$ Text -> Text -> Message
websocketMessageForChannel Text
msg Text
channel
                sendMessageToDatabase :: Message -> IO ()
sendMessageToDatabase = Pool -> Text -> Message -> IO ()
sendToDatabase Pool
ctxPool (AppConfig -> Text
configListenChannel AppConfig
ctxConfig)
                sendMessageWithTimestamp :: Message -> IO ()
sendMessageWithTimestamp = IO UTCTime -> Message -> IO Message
timestampMessage IO UTCTime
ctxGetTime forall (m :: * -> *) a b c.
Monad m =>
(a -> m b) -> (b -> m c) -> a -> m c
>=> Message -> IO ()
sendMessageToDatabase
                websocketMessageForChannel :: Text -> Text -> Message
websocketMessageForChannel = Object -> Event -> Text -> Text -> Message
Message Object
validClaims Event
WebsocketMessage
                connectionOpenMessage :: Text -> Text -> Message
connectionOpenMessage = Object -> Event -> Text -> Text -> Message
Message Object
validClaims Event
ConnectionOpen

            case AppConfig -> Maybe Text
configMetaChannel AppConfig
ctxConfig of
              Maybe Text
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
              Just Text
ch -> Message -> IO ()
sendMessageWithTimestamp forall a b. (a -> b) -> a -> b
$ Text -> Text -> Message
connectionOpenMessage (forall a b. StringConv a b => a -> b
toS forall a b. (a -> b) -> a -> b
$ Text -> [Text] -> Text
T.intercalate Text
"," [Text]
chs) Text
ch

            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Text -> Bool
hasRead Text
mode) forall a b. (a -> b) -> a -> b
$
              forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Text]
chs forall a b. (a -> b) -> a -> b
$ forall a b c. (a -> b -> c) -> b -> a -> c
flip (Multiplexer -> Text -> (Message -> IO ()) -> IO ()
onMessage Multiplexer
ctxMulti) forall a b. (a -> b) -> a -> b
$ forall a. WebSocketsData a => Connection -> a -> IO ()
WS.sendTextData Connection
conn forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> Text
B.payload

            forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Text -> Bool
hasWrite Text
mode) forall a b. (a -> b) -> a -> b
$
              Connection -> (Text -> Text -> IO ()) -> [Text] -> IO ()
notifySession Connection
conn Text -> Text -> IO ()
sendNotification [Text]
chs

            MVar Any
waitForever <- forall a. IO (MVar a)
newEmptyMVar
            forall (f :: * -> *) a. Functor f => f a -> f ()
void forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
takeMVar MVar Any
waitForever

-- Having both channel and claims as parameters seem redundant
-- But it allows the function to ignore the claims structure and the source
-- of the channel, so all claims decoding can be coded in the caller
notifySession :: WS.Connection -> (Text -> Text -> IO ()) -> [Text] -> IO ()
notifySession :: Connection -> (Text -> Text -> IO ()) -> [Text] -> IO ()
notifySession Connection
wsCon Text -> Text -> IO ()
sendToChannel [Text]
chs =
  forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync (forall (f :: * -> *) a b. Applicative f => f a -> f b
forever IO ()
relayData) forall a. Async a -> IO a
wait
  where
    relayData :: IO ()
relayData = do
      Text
msg <- forall a. WebSocketsData a => Connection -> IO a
WS.receiveData Connection
wsCon
      forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Text]
chs (Text -> Text -> IO ()
sendToChannel Text
msg forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. StringConv a b => a -> b
toS)

sendToDatabase :: H.Pool -> Text -> Message -> IO ()
sendToDatabase :: Pool -> Text -> Message -> IO ()
sendToDatabase Pool
pool Text
dbChannel =
  ByteString -> IO ()
notify forall b c a. (b -> c) -> (a -> b) -> a -> c
. Message -> ByteString
jsonMsg
  where
    notify :: ByteString -> IO ()
notify = forall (f :: * -> *) a. Functor f => f a -> f ()
void forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pool -> Text -> Text -> IO (Either UsageError ())
H.notifyPool Pool
pool Text
dbChannel forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. StringConv a b => a -> b
toS
    jsonMsg :: Message -> ByteString
jsonMsg = LByteString -> ByteString
BL.toStrict forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. ToJSON a => a -> LByteString
A.encode

timestampMessage :: IO UTCTime -> Message -> IO Message
timestampMessage :: IO UTCTime -> Message -> IO Message
timestampMessage IO UTCTime
getTime msg :: Message
msg@Message{Text
Object
Event
channel :: Text
payload :: Text
event :: Event
claims :: Object
channel :: Message -> Text
payload :: Message -> Text
event :: Message -> Event
claims :: Message -> Object
..} = do
  POSIXTime
time <- UTCTime -> POSIXTime
utcTimeToPOSIXSeconds forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getTime
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ Message
msg{ claims :: Object
claims = forall v. Key -> v -> KeyMap v -> KeyMap v
A.insert (Text -> Key
Key.fromText Text
"message_delivered_at") (Scientific -> Value
A.Number forall a b. (a -> b) -> a -> b
$ forall a b. (Real a, Fractional b) => a -> b
realToFrac POSIXTime
time) Object
claims}