{- |
Copyright : Flipstone Technology Partners 2023
License   : MIT
Stability : Stable

This module provides the functionality to work with SQL transactions - notably
to ensure some Haskell action occurs within a database transaction.

@since 1.0.0.0
-}
module Orville.PostgreSQL.Execution.Transaction
  ( withTransaction
  )
where

import Control.Monad.IO.Class (MonadIO, liftIO)

import qualified Orville.PostgreSQL.Execution.Execute as Execute
import qualified Orville.PostgreSQL.Execution.QueryType as QueryType
import qualified Orville.PostgreSQL.Expr as Expr
import qualified Orville.PostgreSQL.Internal.Bracket as Bracket
import qualified Orville.PostgreSQL.Internal.MonadOrville as MonadOrville
import qualified Orville.PostgreSQL.Internal.OrvilleState as OrvilleState
import qualified Orville.PostgreSQL.Monad as Monad
import qualified Orville.PostgreSQL.Raw.RawSql as RawSql

{- |
  Performs an action in an Orville monad within a database transaction. The transaction
  is begun before the action is called. If the action completes without raising an exception,
  the transaction will be committed. If the action raises an exception, the transaction will
  rollback.

  This function is safe to call from within another transaction. When called this way, the
  transaction will establish a new savepoint at the beginning of the nested transaction and
  either release the savepoint or rollback to it as appropriate.

  Note: Exceptions are handled using the implementations of 'Monad.catch' and
  'Monad.mask' provided by the 'Monad.MonadOrvilleControl' instance for @m@.

@since 1.0.0.0
-}
withTransaction :: Monad.MonadOrville m => m a -> m a
withTransaction :: forall (m :: * -> *) a. MonadOrville m => m a -> m a
withTransaction m a
action =
  (ConnectedState -> m a) -> m a
forall (m :: * -> *) a.
MonadOrville m =>
(ConnectedState -> m a) -> m a
MonadOrville.withConnectedState ((ConnectedState -> m a) -> m a) -> (ConnectedState -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \ConnectedState
connectedState -> do
    let
      conn :: Connection
conn = ConnectedState -> Connection
OrvilleState.connectedConnection ConnectedState
connectedState
      transaction :: TransactionState
transaction = Maybe TransactionState -> TransactionState
OrvilleState.newTransaction (ConnectedState -> Maybe TransactionState
OrvilleState.connectedTransaction ConnectedState
connectedState)

      innerConnectedState :: ConnectedState
innerConnectedState =
        ConnectedState
connectedState
          { connectedTransaction :: Maybe TransactionState
OrvilleState.connectedTransaction = TransactionState -> Maybe TransactionState
forall a. a -> Maybe a
Just TransactionState
transaction
          }

    OrvilleState
state <- m OrvilleState
forall (m :: * -> *). HasOrvilleState m => m OrvilleState
Monad.askOrvilleState

    let
      executeTransactionSql :: RawSql.RawSql -> IO ()
      executeTransactionSql :: RawSql -> IO ()
executeTransactionSql RawSql
sql =
        QueryType -> RawSql -> OrvilleState -> Connection -> IO ()
forall sql.
SqlExpression sql =>
QueryType -> sql -> OrvilleState -> Connection -> IO ()
Execute.executeVoidIO QueryType
QueryType.OtherQuery RawSql
sql OrvilleState
state Connection
conn

      callback :: TransactionEvent -> IO ()
callback =
        OrvilleState -> TransactionEvent -> IO ()
OrvilleState.orvilleTransactionCallback OrvilleState
state

      beginTransaction :: m ()
beginTransaction = do
        IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ do
          let
            openEvent :: TransactionEvent
openEvent = TransactionState -> TransactionEvent
OrvilleState.openTransactionEvent TransactionState
transaction
          RawSql -> IO ()
executeTransactionSql (OrvilleState -> TransactionEvent -> RawSql
transactionEventSql OrvilleState
state TransactionEvent
openEvent)
          TransactionEvent -> IO ()
callback TransactionEvent
openEvent

      doAction :: () -> m a
doAction () =
        (OrvilleState -> OrvilleState) -> m a -> m a
forall a. (OrvilleState -> OrvilleState) -> m a -> m a
forall (m :: * -> *) a.
HasOrvilleState m =>
(OrvilleState -> OrvilleState) -> m a -> m a
Monad.localOrvilleState
          (ConnectedState -> OrvilleState -> OrvilleState
OrvilleState.connectState ConnectedState
innerConnectedState)
          m a
action

      finishTransaction :: MonadIO m => () -> Bracket.BracketResult -> m ()
      finishTransaction :: forall (m :: * -> *). MonadIO m => () -> BracketResult -> m ()
finishTransaction () BracketResult
result =
        IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$
          case BracketResult
result of
            BracketResult
Bracket.BracketSuccess -> do
              let
                successEvent :: TransactionEvent
successEvent = TransactionState -> TransactionEvent
OrvilleState.transactionSuccessEvent TransactionState
transaction
              RawSql -> IO ()
executeTransactionSql (OrvilleState -> TransactionEvent -> RawSql
transactionEventSql OrvilleState
state TransactionEvent
successEvent)
              TransactionEvent -> IO ()
callback TransactionEvent
successEvent
            BracketResult
Bracket.BracketError -> do
              let
                rollbackEvent :: TransactionEvent
rollbackEvent = TransactionState -> TransactionEvent
OrvilleState.rollbackTransactionEvent TransactionState
transaction
              RawSql -> IO ()
executeTransactionSql (OrvilleState -> TransactionEvent -> RawSql
transactionEventSql OrvilleState
state TransactionEvent
rollbackEvent)
              TransactionEvent -> IO ()
callback TransactionEvent
rollbackEvent

    m () -> (() -> BracketResult -> m ()) -> (() -> m a) -> m a
forall (m :: * -> *) a c b.
(MonadIO m, MonadOrvilleControl m) =>
m a -> (a -> BracketResult -> m c) -> (a -> m b) -> m b
Bracket.bracketWithResult m ()
beginTransaction () -> BracketResult -> m ()
forall (m :: * -> *). MonadIO m => () -> BracketResult -> m ()
finishTransaction () -> m a
doAction

transactionEventSql ::
  OrvilleState.OrvilleState ->
  OrvilleState.TransactionEvent ->
  RawSql.RawSql
transactionEventSql :: OrvilleState -> TransactionEvent -> RawSql
transactionEventSql OrvilleState
state TransactionEvent
event =
  case TransactionEvent
event of
    TransactionEvent
OrvilleState.BeginTransaction ->
      BeginTransactionExpr -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql (BeginTransactionExpr -> RawSql) -> BeginTransactionExpr -> RawSql
forall a b. (a -> b) -> a -> b
$ OrvilleState -> BeginTransactionExpr
OrvilleState.orvilleBeginTransactionExpr OrvilleState
state
    OrvilleState.NewSavepoint Savepoint
savepoint ->
      SavepointExpr -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql (SavepointExpr -> RawSql) -> SavepointExpr -> RawSql
forall a b. (a -> b) -> a -> b
$ SavepointName -> SavepointExpr
Expr.savepoint (Savepoint -> SavepointName
savepointName Savepoint
savepoint)
    TransactionEvent
OrvilleState.RollbackTransaction ->
      RollbackExpr -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql (RollbackExpr -> RawSql) -> RollbackExpr -> RawSql
forall a b. (a -> b) -> a -> b
$ RollbackExpr
Expr.rollback
    OrvilleState.RollbackToSavepoint Savepoint
savepoint ->
      RollbackExpr -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql (RollbackExpr -> RawSql) -> RollbackExpr -> RawSql
forall a b. (a -> b) -> a -> b
$ SavepointName -> RollbackExpr
Expr.rollbackTo (Savepoint -> SavepointName
savepointName Savepoint
savepoint)
    TransactionEvent
OrvilleState.CommitTransaction ->
      CommitExpr -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql (CommitExpr -> RawSql) -> CommitExpr -> RawSql
forall a b. (a -> b) -> a -> b
$ CommitExpr
Expr.commit
    OrvilleState.ReleaseSavepoint Savepoint
savepoint ->
      ReleaseSavepointExpr -> RawSql
forall a. SqlExpression a => a -> RawSql
RawSql.toRawSql (ReleaseSavepointExpr -> RawSql) -> ReleaseSavepointExpr -> RawSql
forall a b. (a -> b) -> a -> b
$ SavepointName -> ReleaseSavepointExpr
Expr.releaseSavepoint (Savepoint -> SavepointName
savepointName Savepoint
savepoint)

{- |
  INTERNAL: Constructs a savepoint name based on the current nesting level of
  transactions, as tracked by the `OrvilleState.Savepoint` type. Strictly
  speaking this is not necessary for PostgreSQL because it supports shadowing
  savepoint names. The SQL standard doesn't allow for savepoint name shadowing,
  however. Re-using this same name in other databases would overwrite the
  savepoint rather than shadow it. This function constructs savepoint names
  that will work on any database that implements savepoints accordings to the
  SQL standard even though Orville only supports PostgreSQL currently.

@since 1.0.0.0
-}
savepointName :: OrvilleState.Savepoint -> Expr.SavepointName
savepointName :: Savepoint -> SavepointName
savepointName Savepoint
savepoint =
  let
    n :: Int
n = Savepoint -> Int
OrvilleState.savepointNestingLevel Savepoint
savepoint
  in
    String -> SavepointName
Expr.savepointName (String
"orville_savepoint_level_" String -> String -> String
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show Int
n)