{-# LANGUAGE RecordWildCards, ScopedTypeVariables #-}
module Database.PostgreSQL.Simple.Transaction
(
withTransaction
, withTransactionLevel
, withTransactionMode
, withTransactionModeRetry
, withTransactionSerializable
, TransactionMode(..)
, IsolationLevel(..)
, ReadWriteMode(..)
, defaultTransactionMode
, defaultIsolationLevel
, defaultReadWriteMode
, begin
, beginLevel
, beginMode
, commit
, rollback
, withSavepoint
, Savepoint
, newSavepoint
, releaseSavepoint
, rollbackToSavepoint
, rollbackToAndReleaseSavepoint
, isSerializationError
, isNoActiveTransactionError
, isFailedTransactionError
) where
import qualified Control.Exception as E
import qualified Data.ByteString as B
import Database.PostgreSQL.Simple.Internal
import Database.PostgreSQL.Simple.Types
import Database.PostgreSQL.Simple.Errors
import Database.PostgreSQL.Simple.Compat (mask, (<>))
data IsolationLevel
= DefaultIsolationLevel
| ReadCommitted
| RepeatableRead
| Serializable
deriving (Show, Eq, Ord, Enum, Bounded)
data ReadWriteMode
= DefaultReadWriteMode
| ReadWrite
| ReadOnly
deriving (Show, Eq, Ord, Enum, Bounded)
data TransactionMode = TransactionMode {
isolationLevel :: !IsolationLevel,
readWriteMode :: !ReadWriteMode
} deriving (Show, Eq)
defaultTransactionMode :: TransactionMode
defaultTransactionMode = TransactionMode
defaultIsolationLevel
defaultReadWriteMode
defaultIsolationLevel :: IsolationLevel
defaultIsolationLevel = DefaultIsolationLevel
defaultReadWriteMode :: ReadWriteMode
defaultReadWriteMode = DefaultReadWriteMode
withTransaction :: Connection -> IO a -> IO a
withTransaction = withTransactionMode defaultTransactionMode
withTransactionSerializable :: Connection -> IO a -> IO a
withTransactionSerializable =
withTransactionModeRetry
TransactionMode
{ isolationLevel = Serializable
, readWriteMode = ReadWrite
}
isSerializationError
withTransactionLevel :: IsolationLevel -> Connection -> IO a -> IO a
withTransactionLevel lvl
= withTransactionMode defaultTransactionMode { isolationLevel = lvl }
withTransactionMode :: TransactionMode -> Connection -> IO a -> IO a
withTransactionMode mode conn act =
mask $ \restore -> do
beginMode mode conn
r <- restore act `E.onException` rollback_ conn
commit conn
return r
withTransactionModeRetry :: TransactionMode -> (SqlError -> Bool) -> Connection -> IO a -> IO a
withTransactionModeRetry mode shouldRetry conn act =
mask $ \restore ->
retryLoop $ E.try $ do
a <- restore act
commit conn
return a
where
retryLoop :: IO (Either E.SomeException a) -> IO a
retryLoop act' = do
beginMode mode conn
r <- act'
case r of
Left e -> do
rollback_ conn
case fmap shouldRetry (E.fromException e) of
Just True -> retryLoop act'
_ -> E.throwIO e
Right a ->
return a
rollback :: Connection -> IO ()
rollback conn = execute_ conn "ROLLBACK" >> return ()
rollback_ :: Connection -> IO ()
rollback_ conn = rollback conn `E.catch` \(_ :: IOError) -> return ()
commit :: Connection -> IO ()
commit conn = execute_ conn "COMMIT" >> return ()
begin :: Connection -> IO ()
begin = beginMode defaultTransactionMode
beginLevel :: IsolationLevel -> Connection -> IO ()
beginLevel lvl = beginMode defaultTransactionMode { isolationLevel = lvl }
beginMode :: TransactionMode -> Connection -> IO ()
beginMode mode conn = do
_ <- execute_ conn $! Query (B.concat ["BEGIN", isolevel, readmode])
return ()
where
isolevel = case isolationLevel mode of
DefaultIsolationLevel -> ""
ReadCommitted -> " ISOLATION LEVEL READ COMMITTED"
RepeatableRead -> " ISOLATION LEVEL REPEATABLE READ"
Serializable -> " ISOLATION LEVEL SERIALIZABLE"
readmode = case readWriteMode mode of
DefaultReadWriteMode -> ""
ReadWrite -> " READ WRITE"
ReadOnly -> " READ ONLY"
withSavepoint :: Connection -> IO a -> IO a
withSavepoint conn body =
mask $ \restore -> do
sp <- newSavepoint conn
r <- restore body `E.onException` rollbackToAndReleaseSavepoint conn sp
releaseSavepoint conn sp `E.catch` \err ->
if isFailedTransactionError err
then rollbackToAndReleaseSavepoint conn sp
else E.throwIO err
return r
newSavepoint :: Connection -> IO Savepoint
newSavepoint conn = do
name <- newTempName conn
_ <- execute_ conn ("SAVEPOINT " <> name)
return (Savepoint name)
releaseSavepoint :: Connection -> Savepoint -> IO ()
releaseSavepoint conn (Savepoint name) =
execute_ conn ("RELEASE SAVEPOINT " <> name) >> return ()
rollbackToSavepoint :: Connection -> Savepoint -> IO ()
rollbackToSavepoint conn (Savepoint name) =
execute_ conn ("ROLLBACK TO SAVEPOINT " <> name) >> return ()
rollbackToAndReleaseSavepoint :: Connection -> Savepoint -> IO ()
rollbackToAndReleaseSavepoint conn (Savepoint name) =
execute_ conn sql >> return ()
where
sql = "ROLLBACK TO SAVEPOINT " <> name <> "; RELEASE SAVEPOINT " <> name