{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}

module Moto.PostgreSQL
  ( registryConf
  ) where

import qualified Control.Exception.Safe as Ex
import Control.Monad.IO.Class (MonadIO, liftIO)
import qualified Data.ByteString as B
import qualified Data.ByteString.Char8 as B8
import Data.Foldable (foldlM, for_)
import qualified Data.List as List
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import qualified Database.PostgreSQL.Simple as Pg
import qualified Database.PostgreSQL.Simple.FromRow as Pg
import qualified Database.PostgreSQL.Simple.FromField as Pg
import qualified Database.PostgreSQL.Simple.ToRow as Pg
import qualified Database.PostgreSQL.Simple.ToField as Pg
import qualified Database.PostgreSQL.Simple.Types as Pg
import qualified Di.Df1 as Di
import qualified Moto
import qualified Moto.Registry as Moto
import qualified System.Environment as IO

--------------------------------------------------------------------------------

-- | PostgreSQL connection URI.
--
-- See https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING
newtype ConnectionURI = ConnectionURI B.ByteString
  deriving (Eq, Show)

-- | Command-line configuration for a 'Moto.Registry' stored in an append-only
-- table in a PostgreSQL database.
--
-- The table name shall be @registry@, inside a schema named @moto@.
registryConf :: Moto.RegistryConf
registryConf = Moto.RegistryConf
  { Moto.registryConf_help =
      "URI to the database where registry is stored. \
      \See https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-CONNSTRING"
  , Moto.registryConf_parse = Right . ConnectionURI . T.encodeUtf8 . T.pack
  , Moto.registryConf_with = withRegistry
  }

-- | Obtain a 'Moto.Registry' stored in an append-only table in a PostgreSQL
-- database.
--
-- The table name shall be @registry@, inside a schema named @moto@.
withRegistry
  :: (MonadIO m, Ex.MonadMask m)
  => Di.Df1
  -> ConnectionURI
  -> (Moto.Registry -> m a)
  -> m a
withRegistry di0 (ConnectionURI cu) k = Ex.bracket
  (liftIO $ do
     checkPgEnvVars di0
     Di.debug_ di0 "Connecting to PostgreSQL database..."
     Pg.connectPostgreSQL cu)
  (\conn -> liftIO $ do
     Di.debug_ di0 "Closing connection to PostgreSQL database..."
     Pg.close conn)
  (\conn -> k =<< liftIO (do
     Di.debug_ di0 "Creating necessary tables, if any, \
                   \and acquiring exclusive registry lock..."
     -- Lock to prevent multiple executions of this program to touch the table.
     -- Then get our initial state, creating our table if necessary.
     logs :: [LogV1] <- Pg.query_ conn
        "SELECT pg_advisory_lock(589412153);\n\
        \SET client_min_messages = error;\n\
        \CREATE SCHEMA IF NOT EXISTS moto;\n\
        \CREATE TABLE IF NOT EXISTS moto.registry\n\
        \ ( ord bigserial NOT NULL\
        \ , act text NOT NULL\
        \ , t timestamptz NOT NULL\
        \ , mid text NULL\
        \ , dir text NULL );\
        \SET client_min_messages = notice;\n\
        \SELECT act, t, mid, dir FROM moto.registry ORDER BY ord ASC;"
     Di.debug_ di0 "Loading state from registry..."
     state0 :: Moto.State <- either Ex.throwM pure $ do
        foldlM Moto.updateState Moto.emptyState (map unLogV1 logs)
     Moto.newAppendOnlyRegistry state0 $ \log_ -> do
        1 <- Pg.execute conn
           "INSERT INTO moto.registry (act,t,mid,dir) VALUES (?,?,?,?);"
           (LogV1 log_)
        pure ()))

-- | Wrapper around 'Moto.Log' used for serialization purposes, so that we don't
-- expose a 'Pg.FromRow' and 'Pg.ToRow' instances for 'Moto.Log'.
newtype LogV1 = LogV1 { unLogV1 :: Moto.Log }

instance Pg.ToRow LogV1 where
  toRow (LogV1 (Moto.Log_Prepare t mId d)) =
    [ Pg.toField ("prepare" :: String)
    , Pg.toField t
    , Pg.toField (Moto.unMigId mId)
    , Pg.toField (Moto.direction "backwards" "forwards" d :: String) ]
  toRow (LogV1 (Moto.Log_Abort t)) =
    [ Pg.toField ("abort" :: String)
    , Pg.toField t
    , Pg.toField Pg.Null
    , Pg.toField Pg.Null ]
  toRow (LogV1 (Moto.Log_Commit t)) =
    [ Pg.toField ("commit" :: String)
    , Pg.toField t
    , Pg.toField Pg.Null
    , Pg.toField Pg.Null ]

instance Pg.FromRow LogV1 where
  fromRow = fmap LogV1 $ do
    Pg.field >>= \case
      "prepare" -> do
          t <- Pg.field
          mId <- fmap Moto.MigId Pg.field
          d <- Pg.field >>= \case
             "backwards" -> pure Moto.Backwards
             "forwards" -> pure Moto.Forwards
             s -> fail ("bad direction: " ++ show (s :: String))
          pure (Moto.Log_Prepare t mId d)
      "abort" -> do
          t <- Pg.field
          Pg.Null <- Pg.field
          Pg.Null <- Pg.field
          pure (Moto.Log_Abort t)
      "commit" -> do
          t <- Pg.field
          Pg.Null <- Pg.field
          Pg.Null <- Pg.field
          pure (Moto.Log_Commit t)
      s -> fail ("bad action: " ++ show (s :: String))

--------------------------------------------------------------------------------

checkPgEnvVars :: Di.Df1 -> IO ()
checkPgEnvVars di0 = do
  for_ pgEnvVars $ \(ne, ncs) -> do
     yev <- IO.lookupEnv ne
     for_ yev $ \ev -> do
        let ex = Err_PgEnvVar ne ncs
        Di.error di0 (Ex.displayException ex)
        Ex.throwM ex

data Err_PgEnvVar = Err_PgEnvVar String String
  deriving (Show)

instance Ex.Exception Err_PgEnvVar where
  displayException (Err_PgEnvVar ne ncs) =
    "Detected '" <> ne <> "' environment variable. We don't like this, we \
    \could be accidentally using another PostgreSQL database configuration. \
    \Please set '" <> ncs <> "' in the connection string instead."

-- | A map from PG* environment variables to their connection string
-- counterpart.
pgEnvVars :: [(String, String)]
pgEnvVars =
  [ ("PGHOST", "host")
  , ("PGHOSTADDR", "hostaddr")
  , ("PGPORT", "port")
  , ("PGDATABASE", "dbname")
  , ("PGUSER", "user")
  , ("PGPASSWORD", "password")
  , ("PGPASSFILE", "passfile")
  , ("PGSERVICE", "service")
  , ("PGOPTIONS", "options")
  , ("PGAPPNAME", "application_name")
  , ("PGSSLMODE", "sslmode")
  , ("PGREQUIRESSL", "requiressl")
  , ("PGSSLCOMPRESSION", "sslcompression")
  , ("PGSSLCERT", "sslcert")
  , ("PGSSLKEY", "sslkey")
  , ("PGSSLROOTCERT", "sslrootcert")
  , ("PGSSLCRL", "sslcrl")
  , ("PGREQUIREPEER", "requirepeer")
  , ("PGKRBSRVNAME", "krbsrvname")
  , ("PGGSSLIB", "gsslib")
  , ("PGCONNECT_TIMEOUT", "connect_timeout")
  , ("PGCLIENTENCODING", "client_encoding")
  , ("PGTARGETSESSIONATTRS", "target_session_attrs")
  ]