{- |
Copyright : Flipstone Technology Partners 2023
License   : MIT
Stability : Stable

@since 1.0.0.0
-}
module Orville.PostgreSQL.Internal.Bracket
  ( bracketWithResult
  , BracketResult (BracketSuccess, BracketError)
  ) where

import Control.Exception (SomeException, catch, mask, throwIO)
import Control.Monad.IO.Class (MonadIO (liftIO))

import Orville.PostgreSQL.Monad.MonadOrville (MonadOrvilleControl (liftCatch, liftMask))

data BracketResult
  = BracketSuccess
  | BracketError

{- |
  INTERNAL: A version of 'Control.Exception.bracket' that allows us to distinguish between
  exception and non-exception release cases. This is available in certain
  packages as a typeclass function under the name "generalBracket", but is
  implemented here directly in terms of IO's 'mask' and 'catch' to guarantee
  our exception handling semantics without forcing the Orville user's choice of
  library for lifting and unlift IO actions (e.g. UnliftIO).

@since 1.0.0.0
-}
bracketWithResult ::
  (MonadIO m, MonadOrvilleControl m) =>
  m a ->
  (a -> BracketResult -> m c) ->
  (a -> m b) ->
  m b
bracketWithResult :: forall (m :: * -> *) a c b.
(MonadIO m, MonadOrvilleControl m) =>
m a -> (a -> BracketResult -> m c) -> (a -> m b) -> m b
bracketWithResult m a
acquire a -> BracketResult -> m c
release a -> m b
action = do
  (forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall {a}. m a -> m a) -> m b) -> m b
forall c.
(forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall {a}. m a -> m a) -> m c) -> m c
forall (m :: * -> *) c.
MonadOrvilleControl m =>
(forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b)
-> ((forall a. m a -> m a) -> m c) -> m c
liftMask ((forall a. IO a -> IO a) -> IO b) -> IO b
forall b. ((forall a. IO a -> IO a) -> IO b) -> IO b
mask (((forall {a}. m a -> m a) -> m b) -> m b)
-> ((forall {a}. m a -> m a) -> m b) -> m b
forall a b. (a -> b) -> a -> b
$ \forall {a}. m a -> m a
restore -> do
    a
resource <- m a
acquire

    b
result <-
      (forall a. IO a -> (SomeException -> IO a) -> IO a)
-> m b -> (SomeException -> m b) -> m b
forall e b.
Exception e =>
(forall a. IO a -> (e -> IO a) -> IO a) -> m b -> (e -> m b) -> m b
forall (m :: * -> *) e b.
(MonadOrvilleControl m, Exception e) =>
(forall a. IO a -> (e -> IO a) -> IO a) -> m b -> (e -> m b) -> m b
liftCatch
        IO a -> (SomeException -> IO a) -> IO a
forall a. IO a -> (SomeException -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch
        (m b -> m b
forall {a}. m a -> m a
restore (a -> m b
action a
resource))
        (m c -> SomeException -> m b
forall (m :: * -> *) a b. MonadIO m => m a -> SomeException -> m b
handleAndRethrow (a -> BracketResult -> m c
release a
resource BracketResult
BracketError))

    c
_ <- a -> BracketResult -> m c
release a
resource BracketResult
BracketSuccess

    b -> m b
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
result

{- |
  INTERNAL: Catch any exception, run the given handler, and rethrow the
  exception. This is mostly useful to force the exception being caught to be of
  the type 'SomeException'.

@since 1.0.0.0
-}
handleAndRethrow ::
  MonadIO m =>
  m a ->
  SomeException ->
  m b
handleAndRethrow :: forall (m :: * -> *) a b. MonadIO m => m a -> SomeException -> m b
handleAndRethrow m a
handle SomeException
ex = do
  a
_ <- m a
handle
  IO b -> m b
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO b -> m b) -> (SomeException -> IO b) -> SomeException -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeException -> IO b
forall e a. Exception e => e -> IO a
throwIO (SomeException -> m b) -> SomeException -> m b
forall a b. (a -> b) -> a -> b
$ SomeException
ex