module Database.PostgreSQL.PQTypes.Transaction
  ( Savepoint (..)
  , withSavepoint
  , withTransaction
  , begin
  , commit
  , rollback
  , withTransaction'
  , begin'
  , commit'
  , rollback'
  ) where

import Control.Monad
import Control.Monad.Catch
import Data.Function
import Data.String
import Data.Typeable
import GHC.Stack

import Data.Monoid.Utils
import Database.PostgreSQL.PQTypes.Class
import Database.PostgreSQL.PQTypes.Internal.Exception
import Database.PostgreSQL.PQTypes.SQL.Raw
import Database.PostgreSQL.PQTypes.Transaction.Settings
import Database.PostgreSQL.PQTypes.Utils

-- | Wrapper that represents savepoint name.
newtype Savepoint = Savepoint (RawSQL ())

instance IsString Savepoint where
  fromString :: String -> Savepoint
fromString = RawSQL () -> Savepoint
Savepoint (RawSQL () -> Savepoint)
-> (String -> RawSQL ()) -> String -> Savepoint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> RawSQL ()
forall a. IsString a => String -> a
fromString

-- | Create a savepoint and roll back to it if given monadic action throws.
-- This may only be used if a transaction is already active. Note that it
-- provides something like \"nested transaction\".
--
-- See <http://www.postgresql.org/docs/current/static/sql-savepoint.html>
withSavepoint :: (HasCallStack, MonadDB m, MonadMask m) => Savepoint -> m a -> m a
withSavepoint :: forall (m :: * -> *) a.
(HasCallStack, MonadDB m, MonadMask m) =>
Savepoint -> m a -> m a
withSavepoint (Savepoint RawSQL ()
savepoint) m a
m =
  (a, ()) -> a
forall a b. (a, b) -> a
fst
    ((a, ()) -> a) -> m (a, ()) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m () -> (() -> ExitCase a -> m ()) -> (() -> m a) -> m (a, ())
forall a b c.
HasCallStack =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
forall (m :: * -> *) a b c.
(MonadMask m, HasCallStack) =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
      (RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ (RawSQL () -> m ()) -> RawSQL () -> m ()
forall a b. (a -> b) -> a -> b
$ RawSQL ()
"SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint)
      ( \() -> \case
          ExitCaseSuccess a
_ -> RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint
          ExitCase a
_ -> m ()
rollbackAndReleaseSavepoint
      )
      (\() -> m a
m)
  where
    sqlReleaseSavepoint :: RawSQL ()
sqlReleaseSavepoint = RawSQL ()
"RELEASE SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
    rollbackAndReleaseSavepoint :: m ()
rollbackAndReleaseSavepoint = do
      RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ (RawSQL () -> m ()) -> RawSQL () -> m ()
forall a b. (a -> b) -> a -> b
$ RawSQL ()
"ROLLBACK TO SAVEPOINT" RawSQL () -> RawSQL () -> RawSQL ()
forall m. (IsString m, Monoid m) => m -> m -> m
<+> RawSQL ()
savepoint
      RawSQL () -> m ()
forall sql (m :: * -> *).
(HasCallStack, IsSQL sql, MonadDB m) =>
sql -> m ()
runQuery_ RawSQL ()
sqlReleaseSavepoint

----------------------------------------

-- | Same as 'withTransaction'' except that it uses current
-- transaction settings instead of custom ones.  It is worth
-- noting that changing transaction settings inside supplied
-- monadic action won't have any effect  on the final 'commit'
-- / 'rollback' as settings that were in effect during the call
-- to 'withTransaction' will be used.
withTransaction :: (HasCallStack, MonadDB m, MonadMask m) => m a -> m a
withTransaction :: forall (m :: * -> *) a.
(HasCallStack, MonadDB m, MonadMask m) =>
m a -> m a
withTransaction m a
m = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (TransactionSettings -> m a -> m a)
-> m a -> TransactionSettings -> m a
forall a b c. (a -> b -> c) -> b -> a -> c
flip TransactionSettings -> m a -> m a
forall (m :: * -> *) a.
(HasCallStack, MonadDB m, MonadMask m) =>
TransactionSettings -> m a -> m a
withTransaction' m a
m

-- | Begin transaction using current transaction settings.
begin :: (HasCallStack, MonadDB m) => m ()
begin :: forall (m :: * -> *). (HasCallStack, MonadDB m) => m ()
begin = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
begin'

-- | Commit active transaction using current transaction settings.
commit :: (HasCallStack, MonadDB m) => m ()
commit :: forall (m :: * -> *). (HasCallStack, MonadDB m) => m ()
commit = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
commit'

-- | Rollback active transaction using current transaction settings.
rollback :: (HasCallStack, MonadDB m) => m ()
rollback :: forall (m :: * -> *). (HasCallStack, MonadDB m) => m ()
rollback = m TransactionSettings
forall (m :: * -> *). MonadDB m => m TransactionSettings
getTransactionSettings m TransactionSettings -> (TransactionSettings -> m ()) -> m ()
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
rollback'

----------------------------------------

-- | Execute monadic action within a transaction using given transaction
-- settings. Note that it won't work as expected if a transaction is already
-- active (in such case 'withSavepoint' should be used instead).
withTransaction'
  :: (HasCallStack, MonadDB m, MonadMask m)
  => TransactionSettings
  -> m a
  -> m a
withTransaction' :: forall (m :: * -> *) a.
(HasCallStack, MonadDB m, MonadMask m) =>
TransactionSettings -> m a -> m a
withTransaction' TransactionSettings
ts m a
m = (((Integer -> m a) -> Integer -> m a) -> Integer -> m a
forall a. (a -> a) -> a
`fix` Integer
1) (((Integer -> m a) -> Integer -> m a) -> m a)
-> ((Integer -> m a) -> Integer -> m a) -> m a
forall a b. (a -> b) -> a -> b
$ \Integer -> m a
loop Integer
n -> do
  -- Optimization for squashing possible space leaks.
  -- It looks like GHC doesn't like 'catch' and passes
  -- on introducing strictness in some cases.
  let maybeRestart :: m a -> m a
maybeRestart = case TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts of
        Just RestartPredicate
_ -> (SomeException -> Maybe ()) -> (() -> m a) -> m a -> m a
forall (m :: * -> *) e b a.
(HasCallStack, MonadCatch m, Exception e) =>
(e -> Maybe b) -> (b -> m a) -> m a -> m a
handleJust (Integer -> SomeException -> Maybe ()
expred Integer
n) (\()
_ -> Integer -> m a
loop (Integer -> m a) -> Integer -> m a
forall a b. (a -> b) -> a -> b
$ Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)
        Maybe RestartPredicate
Nothing -> m a -> m a
forall a. a -> a
id
  m a -> m a
maybeRestart (m a -> m a) -> m a -> m a
forall a b. (a -> b) -> a -> b
$
    (a, ()) -> a
forall a b. (a, b) -> a
fst
      ((a, ()) -> a) -> m (a, ()) -> m a
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m () -> (() -> ExitCase a -> m ()) -> (() -> m a) -> m (a, ())
forall a b c.
HasCallStack =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
forall (m :: * -> *) a b c.
(MonadMask m, HasCallStack) =>
m a -> (a -> ExitCase b -> m c) -> (a -> m b) -> m (b, c)
generalBracket
        (TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
begin' TransactionSettings
ts)
        ( \() -> \case
            ExitCaseSuccess a
_ -> TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
commit' TransactionSettings
ts
            ExitCase a
_ -> TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
rollback' TransactionSettings
ts
        )
        (\() -> m a
m)
  where
    expred :: Integer -> SomeException -> Maybe ()
    expred :: Integer -> SomeException -> Maybe ()
expred !Integer
n SomeException
e = do
      -- check if the predicate exists
      RestartPredicate e -> Integer -> Bool
f <- TransactionSettings -> Maybe RestartPredicate
tsRestartPredicate TransactionSettings
ts
      -- cast exception to the type expected by the predicate
      e
err <-
        [Maybe e] -> Maybe e
forall (t :: * -> *) (m :: * -> *) a.
(Foldable t, MonadPlus m) =>
t (m a) -> m a
msum
          [ -- either cast the exception itself...
            SomeException -> Maybe e
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e
          , -- ...or extract it from DBException
            SomeException -> Maybe DBException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e Maybe DBException -> (DBException -> Maybe e) -> Maybe e
forall a b. Maybe a -> (a -> Maybe b) -> Maybe b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \DBException {e
sql
CallStack
BackendPid
dbeQueryContext :: sql
dbeBackendPid :: BackendPid
dbeError :: e
dbeCallStack :: CallStack
dbeQueryContext :: ()
dbeBackendPid :: DBException -> BackendPid
dbeError :: ()
dbeCallStack :: DBException -> CallStack
..} -> e -> Maybe e
forall a b. (Typeable a, Typeable b) => a -> Maybe b
cast e
dbeError
          ]
      -- check if the predicate allows for the restart
      Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ e -> Integer -> Bool
f e
err Integer
n

-- | Begin transaction using given transaction settings.
begin' :: (HasCallStack, MonadDB m) => TransactionSettings -> m ()
begin' :: forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
begin' TransactionSettings
ts = SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ (SQL -> m ()) -> ([SQL] -> SQL) -> [SQL] -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SQL -> [SQL] -> SQL
forall m. Monoid m => m -> [m] -> m
mintercalate SQL
" " ([SQL] -> m ()) -> [SQL] -> m ()
forall a b. (a -> b) -> a -> b
$ [SQL
"BEGIN", SQL
isolationLevel, SQL
permissions]
  where
    isolationLevel :: SQL
isolationLevel = case TransactionSettings -> IsolationLevel
tsIsolationLevel TransactionSettings
ts of
      IsolationLevel
DefaultLevel -> SQL
""
      IsolationLevel
ReadCommitted -> SQL
"ISOLATION LEVEL READ COMMITTED"
      IsolationLevel
RepeatableRead -> SQL
"ISOLATION LEVEL REPEATABLE READ"
      IsolationLevel
Serializable -> SQL
"ISOLATION LEVEL SERIALIZABLE"
    permissions :: SQL
permissions = case TransactionSettings -> Permissions
tsPermissions TransactionSettings
ts of
      Permissions
DefaultPermissions -> SQL
""
      Permissions
ReadOnly -> SQL
"READ ONLY"
      Permissions
ReadWrite -> SQL
"READ WRITE"

-- | Commit active transaction using given transaction settings.
commit' :: (HasCallStack, MonadDB m) => TransactionSettings -> m ()
commit' :: forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
commit' TransactionSettings
ts = do
  SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ SQL
"COMMIT"
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TransactionSettings -> Bool
tsAutoTransaction TransactionSettings
ts) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
begin' TransactionSettings
ts

-- | Rollback active transaction using given transaction settings.
rollback' :: (HasCallStack, MonadDB m) => TransactionSettings -> m ()
rollback' :: forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
rollback' TransactionSettings
ts = do
  SQL -> m ()
forall (m :: * -> *). (HasCallStack, MonadDB m) => SQL -> m ()
runSQL_ SQL
"ROLLBACK"
  Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (TransactionSettings -> Bool
tsAutoTransaction TransactionSettings
ts) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
    TransactionSettings -> m ()
forall (m :: * -> *).
(HasCallStack, MonadDB m) =>
TransactionSettings -> m ()
begin' TransactionSettings
ts