{-# LANGUAGE CPP #-}

-- |
--  This module has functions to send commands LISTEN and NOTIFY to the database server.
--  It also has a function to wait for and handle notifications on a database connection.
--
--  For more information check the [PostgreSQL documentation](https://www.postgresql.org/docs/current/libpq-notify.html).
module Hasql.Notifications
  ( notifyPool,
    notify,
    listen,
    unlisten,
    waitForNotifications,
    PgIdentifier,
    toPgIdentifier,
    fromPgIdentifier,
    FatalError (..),
  )
where

import Control.Concurrent (threadDelay, threadWaitRead)
import Control.Exception (Exception, throw)
import Control.Monad (forever, unless, void, when)
import Data.ByteString.Char8 (ByteString)
import Data.Functor.Contravariant (contramap)
import Data.Text (Text)
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Database.PostgreSQL.LibPQ as PQ
import Hasql.Connection (Connection, withLibPQConnection)
import qualified Hasql.Decoders as HD
import qualified Hasql.Encoders as HE
import Hasql.Pool (Pool, UsageError, use)
import Hasql.Session (run, sql, statement)
import qualified Hasql.Session as S
import qualified Hasql.Statement as HST

-- | A wrapped text that represents a properly escaped and quoted PostgreSQL identifier
newtype PgIdentifier = PgIdentifier Text deriving (Int -> PgIdentifier -> ShowS
[PgIdentifier] -> ShowS
PgIdentifier -> String
(Int -> PgIdentifier -> ShowS)
-> (PgIdentifier -> String)
-> ([PgIdentifier] -> ShowS)
-> Show PgIdentifier
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> PgIdentifier -> ShowS
showsPrec :: Int -> PgIdentifier -> ShowS
$cshow :: PgIdentifier -> String
show :: PgIdentifier -> String
$cshowList :: [PgIdentifier] -> ShowS
showList :: [PgIdentifier] -> ShowS
Show)

-- | Uncatchable exceptions thrown and never caught.
newtype FatalError = FatalError {FatalError -> String
fatalErrorMessage :: String}

instance Exception FatalError

instance Show FatalError where
  show :: FatalError -> String
show = FatalError -> String
fatalErrorMessage

-- | Given a PgIdentifier returns the wrapped text
fromPgIdentifier :: PgIdentifier -> Text
fromPgIdentifier :: PgIdentifier -> Text
fromPgIdentifier (PgIdentifier Text
bs) = Text
bs

-- | Given a text returns a properly quoted and escaped PgIdentifier
toPgIdentifier :: Text -> PgIdentifier
toPgIdentifier :: Text -> PgIdentifier
toPgIdentifier Text
x =
  Text -> PgIdentifier
PgIdentifier (Text -> PgIdentifier) -> Text -> PgIdentifier
forall a b. (a -> b) -> a -> b
$ Text
"\"" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
strictlyReplaceQuotes Text
x Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\""
  where
    strictlyReplaceQuotes :: Text -> Text
    strictlyReplaceQuotes :: Text -> Text
strictlyReplaceQuotes = HasCallStack => Text -> Text -> Text -> Text
Text -> Text -> Text -> Text
T.replace Text
"\"" (Text
"\"\"" :: Text)

-- | Given a Hasql Pool, a channel and a message sends a notify command to the database
notifyPool ::
  -- | Pool from which the connection will be used to issue a NOTIFY command.
  Pool ->
  -- | Channel where to send the notification
  Text ->
  -- | Payload to be sent with the notification
  Text ->
  IO (Either UsageError ())
notifyPool :: Pool -> Text -> Text -> IO (Either UsageError ())
notifyPool Pool
pool Text
channel Text
mesg =
  Pool -> Session () -> IO (Either UsageError ())
forall a. Pool -> Session a -> IO (Either UsageError a)
use Pool
pool ((Text, Text) -> Statement (Text, Text) () -> Session ()
forall params result.
params -> Statement params result -> Session result
statement (Text
channel, Text
mesg) Statement (Text, Text) ()
callStatement)
  where
    callStatement :: Statement (Text, Text) ()
callStatement = ByteString
-> Params (Text, Text)
-> Result ()
-> Bool
-> Statement (Text, Text) ()
forall params result.
ByteString
-> Params params
-> Result result
-> Bool
-> Statement params result
HST.Statement (ByteString
"SELECT pg_notify" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"($1, $2)") Params (Text, Text)
encoder Result ()
HD.noResult Bool
False
    encoder :: Params (Text, Text)
encoder = ((Text, Text) -> Text) -> Params Text -> Params (Text, Text)
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (Text, Text) -> Text
forall a b. (a, b) -> a
fst (NullableOrNot Value Text -> Params Text
forall a. NullableOrNot Value a -> Params a
HE.param (NullableOrNot Value Text -> Params Text)
-> NullableOrNot Value Text -> Params Text
forall a b. (a -> b) -> a -> b
$ Value Text -> NullableOrNot Value Text
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
HE.nonNullable Value Text
HE.text) Params (Text, Text) -> Params (Text, Text) -> Params (Text, Text)
forall a. Semigroup a => a -> a -> a
<> ((Text, Text) -> Text) -> Params Text -> Params (Text, Text)
forall a' a. (a' -> a) -> Params a -> Params a'
forall (f :: * -> *) a' a.
Contravariant f =>
(a' -> a) -> f a -> f a'
contramap (Text, Text) -> Text
forall a b. (a, b) -> b
snd (NullableOrNot Value Text -> Params Text
forall a. NullableOrNot Value a -> Params a
HE.param (NullableOrNot Value Text -> Params Text)
-> NullableOrNot Value Text -> Params Text
forall a b. (a -> b) -> a -> b
$ Value Text -> NullableOrNot Value Text
forall (encoder :: * -> *) a. encoder a -> NullableOrNot encoder a
HE.nonNullable Value Text
HE.text)

-- | Given a Hasql Connection, a channel and a message sends a notify command to the database
notify ::
  -- | Connection to be used to send the NOTIFY command
  Connection ->
  -- | Channel where to send the notification
  PgIdentifier ->
  -- | Payload to be sent with the notification
  Text ->
  IO (Either S.QueryError ())
notify :: Connection -> PgIdentifier -> Text -> IO (Either QueryError ())
notify Connection
con PgIdentifier
channel Text
mesg =
  Session () -> Connection -> IO (Either QueryError ())
forall a. Session a -> Connection -> IO (Either QueryError a)
run (ByteString -> Session ()
sql (ByteString -> Session ()) -> ByteString -> Session ()
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 (Text
"NOTIFY " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PgIdentifier -> Text
fromPgIdentifier PgIdentifier
channel Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
", '" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
mesg Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"'")) Connection
con

-- |
--  Given a Hasql Connection and a channel sends a listen command to the database.
--  Once the connection sends the LISTEN command the server register its interest in the channel.
--  Hence it's important to keep track of which connection was used to open the listen command.
--
--  Example of listening and waiting for a notification:
--
--  @
--  import System.Exit (die)
--
--  import Hasql.Connection
--  import Hasql.Notifications
--
--  main :: IO ()
--  main = do
--    dbOrError <- acquire "postgres://localhost/db_name"
--    case dbOrError of
--        Right db -> do
--            let channelToListen = toPgIdentifier "sample-channel"
--            listen db channelToListen
--            waitForNotifications (\channel _ -> print $ "Just got notification on channel " <> channel) db
--        _ -> die "Could not open database connection"
--  @
listen ::
  -- | Connection to be used to send the LISTEN command
  Connection ->
  -- | Channel this connection will be registered to listen to
  PgIdentifier ->
  IO ()
listen :: Connection -> PgIdentifier -> IO ()
listen Connection
con PgIdentifier
channel =
  IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withLibPQConnection Connection
con Connection -> IO ()
execListen
  where
    execListen :: Connection -> IO ()
execListen = ByteString -> Connection -> IO ()
executeOrPanic (ByteString -> Connection -> IO ())
-> ByteString -> Connection -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text
"LISTEN " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PgIdentifier -> Text
fromPgIdentifier PgIdentifier
channel

-- | Given a Hasql Connection and a channel sends a unlisten command to the database
unlisten ::
  -- | Connection currently registerd by a previous 'listen' call
  Connection ->
  -- | Channel this connection will be deregistered from
  PgIdentifier ->
  IO ()
unlisten :: Connection -> PgIdentifier -> IO ()
unlisten Connection
con PgIdentifier
channel =
  IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withLibPQConnection Connection
con Connection -> IO ()
execUnlisten
  where
    execUnlisten :: Connection -> IO ()
execUnlisten = ByteString -> Connection -> IO ()
executeOrPanic (ByteString -> Connection -> IO ())
-> ByteString -> Connection -> IO ()
forall a b. (a -> b) -> a -> b
$ Text -> ByteString
T.encodeUtf8 (Text -> ByteString) -> Text -> ByteString
forall a b. (a -> b) -> a -> b
$ Text
"UNLISTEN " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PgIdentifier -> Text
fromPgIdentifier PgIdentifier
channel

executeOrPanic :: ByteString -> PQ.Connection -> IO ()
executeOrPanic :: ByteString -> Connection -> IO ()
executeOrPanic ByteString
cmd Connection
pqCon = do
  Maybe Result
mResult <- Connection -> ByteString -> IO (Maybe Result)
PQ.exec Connection
pqCon ByteString
cmd
  case Maybe Result
mResult of
    Maybe Result
Nothing -> do
      Maybe ByteString
mError <- Connection -> IO (Maybe ByteString)
PQ.errorMessage Connection
pqCon
      String -> IO ()
forall a. String -> a
panic (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String
"Error executing" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
cmd) (Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
T.decodeUtf8Lenient) Maybe ByteString
mError
    Just Result
result -> do
      ExecStatus
status <- Result -> IO ExecStatus
PQ.resultStatus Result
result
      Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ExecStatus
status ExecStatus -> ExecStatus -> Bool
forall a. Eq a => a -> a -> Bool
== ExecStatus
PQ.FatalError) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
        Maybe ByteString
mError <- Result -> IO (Maybe ByteString)
PQ.resultErrorMessage Result
result
        String -> IO ()
forall a. String -> a
panic (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (String
"Error executing" String -> ShowS
forall a. Semigroup a => a -> a -> a
<> ByteString -> String
forall a. Show a => a -> String
show ByteString
cmd) (Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
T.decodeUtf8Lenient) Maybe ByteString
mError

-- |
--  Given a function that handles notifications and a Hasql connection it will listen
--  on the database connection and call the handler everytime a message arrives.
--
--  The message handler passed as first argument needs two parameters channel and payload.
--  See an example of handling notification on a separate thread:
--
--  @
--  import Control.Concurrent.Async (async)
--  import Control.Monad (void)
--  import System.Exit (die)
--
--  import Hasql.Connection
--  import Hasql.Notifications
--
--  notificationHandler :: ByteString -> ByteString -> IO()
--  notificationHandler channel payload =
--    void $ async do
--      print $ "Handle payload " <> payload <> " in its own thread"
--
--  main :: IO ()
--  main = do
--    dbOrError <- acquire "postgres://localhost/db_name"
--    case dbOrError of
--        Right db -> do
--            let channelToListen = toPgIdentifier "sample-channel"
--            listen db channelToListen
--            waitForNotifications notificationHandler db
--        _ -> die "Could not open database connection"
--  @
waitForNotifications ::
  -- | Callback function to handle incoming notifications
  (ByteString -> ByteString -> IO ()) ->
  -- | Connection where we will listen to
  Connection ->
  IO ()
waitForNotifications :: (ByteString -> ByteString -> IO ()) -> Connection -> IO ()
waitForNotifications ByteString -> ByteString -> IO ()
sendNotification Connection
con =
  Connection -> (Connection -> IO ()) -> IO ()
forall a. Connection -> (Connection -> IO a) -> IO a
withLibPQConnection Connection
con ((Connection -> IO ()) -> IO ()) -> (Connection -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ IO Any -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO Any -> IO ()) -> (Connection -> IO Any) -> Connection -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO () -> IO Any
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO Any) -> (Connection -> IO ()) -> Connection -> IO Any
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Connection -> IO ()
pqFetch
  where
    pqFetch :: Connection -> IO ()
pqFetch Connection
pqCon = do
      Maybe Notify
mNotification <- Connection -> IO (Maybe Notify)
PQ.notifies Connection
pqCon
      case Maybe Notify
mNotification of
        Maybe Notify
Nothing -> do
          Maybe Fd
mfd <- Connection -> IO (Maybe Fd)
PQ.socket Connection
pqCon
          case Maybe Fd
mfd of
            Maybe Fd
Nothing -> IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> IO ()
threadDelay Int
1000000
            Just Fd
fd -> do
              IO () -> IO ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Fd -> IO ()
threadWaitRead Fd
fd

              Bool
result <- Connection -> IO Bool
PQ.consumeInput Connection
pqCon
              Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
result (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
                Maybe ByteString
mError <- Connection -> IO (Maybe ByteString)
PQ.errorMessage Connection
pqCon
                String -> IO ()
forall a. String -> a
panic (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> (ByteString -> String) -> Maybe ByteString -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"Error checking for PostgreSQL notifications" (Text -> String
T.unpack (Text -> String) -> (ByteString -> Text) -> ByteString -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Text
T.decodeUtf8Lenient) Maybe ByteString
mError
        Just Notify
notification ->
          ByteString -> ByteString -> IO ()
sendNotification (Notify -> ByteString
PQ.notifyRelname Notify
notification) (Notify -> ByteString
PQ.notifyExtra Notify
notification)

panic :: String -> a
panic :: forall a. String -> a
panic String
a = FatalError -> a
forall a e. Exception e => e -> a
throw (String -> FatalError
FatalError String
a)