{-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, RecordWildCards, OverloadedStrings #-}
module Database.PostgreSQL.Transact where
import Control.Monad.Trans.Reader
import qualified Database.PostgreSQL.Simple as Simple
import Database.PostgreSQL.Simple (ToRow, FromRow, Connection, SqlError (..))
import Database.PostgreSQL.Simple.Types as Simple
import qualified Database.PostgreSQL.Simple.Transaction as Simple
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import Control.Monad.Catch
import Data.Int
import Control.Monad
import qualified Data.ByteString as BS
import qualified Control.Monad.Fail as Fail
newtype DBT m a = DBT { unDBT :: ReaderT Connection m a }
deriving (MonadTrans, MonadThrow)
type DB = DBT IO
instance Functor m => Functor (DBT m) where
fmap f = DBT . fmap f . unDBT
instance Applicative m => Applicative (DBT m) where
pure = DBT . pure
f <*> v = DBT $ unDBT f <*> unDBT v
instance MonadIO m => MonadIO (DBT m) where
liftIO = lift . liftIO
instance Monad m => Monad (DBT m) where
return = lift . return
DBT m >>= k = DBT $ m >>= unDBT . k
instance Fail.MonadFail m => Fail.MonadFail (DBT m) where
fail = lift . Fail.fail
isClass25 :: SqlError -> Bool
isClass25 SqlError{..} = BS.take 2 sqlState == "25"
instance (MonadIO m, MonadMask m) => MonadCatch (DBT m) where
catch (DBT act) handler = DBT $ mask $ \restore -> do
conn <- ask
sp <- liftIO $ Simple.newSavepoint conn
let setup = catch (restore act) $ \e -> do
liftIO $ Simple.rollbackToSavepoint conn sp
unDBT $ handler e
setup `finally` liftIO (tryJust (guard . isClass25) (Simple.releaseSavepoint conn sp))
instance (MonadIO m, MonadMask m) => MonadMask (DBT m) where
mask a = DBT $ mask $ \u -> unDBT (a $ q u)
where q :: (ReaderT Connection m a -> ReaderT Connection m a) -> DBT m a -> DBT m a
q u (DBT b) = DBT $ u b
uninterruptibleMask a =
DBT $ uninterruptibleMask $ \u -> unDBT (a $ q u)
where q :: (ReaderT Connection m a -> ReaderT Connection m a) -> DBT m a -> DBT m a
q u (DBT b) = DBT $ u b
generalBracket acquire release use = DBT $
generalBracket
(unDBT acquire)
(\resource exitCase -> unDBT (release resource exitCase))
(\resource -> unDBT (use resource))
getConnection :: Monad m => DBT m Connection
getConnection = DBT ask
runDBT :: MonadBaseControl IO m => DBT m a -> Simple.IsolationLevel -> Connection -> m a
runDBT action level conn
= control
$ \run -> Simple.withTransactionLevel level conn
$ run
$ runReaderT (unDBT action) conn
runDBTSerializable :: MonadBaseControl IO m => DBT m a -> Connection -> m a
runDBTSerializable action conn
= control
$ \run -> Simple.withTransactionSerializable conn
$ run
$ runReaderT (unDBT action) conn
query :: (ToRow a, FromRow b, MonadIO m) => Query -> a -> DBT m [b]
query q x = getConnection >>= \conn -> liftIO $ Simple.query conn q x
query_ :: (FromRow b, MonadIO m) => Query -> DBT m [b]
query_ q = getConnection >>= \conn -> liftIO $ Simple.query_ conn q
execute :: (ToRow q, MonadIO m) => Query -> q -> DBT m Int64
execute q x = getConnection >>= \conn -> liftIO $ Simple.execute conn q x
execute_ :: MonadIO m => Query -> DBT m Int64
execute_ q = getConnection >>= \conn -> liftIO $ Simple.execute_ conn q
executeMany :: (ToRow q, MonadIO m) => Query -> [q] -> DBT m Int64
executeMany q xs = getConnection >>= \conn -> liftIO $ Simple.executeMany conn q xs
returning :: (ToRow q, FromRow r, MonadIO m) => Query -> [q] -> DBT m [r]
returning q xs = getConnection >>= \conn -> liftIO $ Simple.returning conn q xs
formatQuery :: (ToRow q, MonadIO m) => Query -> q -> DBT m BS.ByteString
formatQuery q xs = getConnection >>= \conn -> liftIO $ Simple.formatQuery conn q xs
queryOne :: (MonadIO m, MonadThrow m, ToRow a, FromRow b) => Query -> a -> DBT m (Maybe b)
queryOne q x = do
rows <- query q x
case rows of
[] -> return Nothing
[a] -> return $ Just a
_ -> return Nothing
queryOne_ :: (MonadIO m, MonadThrow m, FromRow b) => Query -> DBT m (Maybe b)
queryOne_ q = do
rows <- query_ q
case rows of
[] -> return Nothing
[x] -> return $ Just x
_ -> return Nothing
savepoint :: DB Savepoint
savepoint = getConnection >>= liftIO . Simple.newSavepoint
rollbackToAndReleaseSavepoint :: Savepoint -> DB ()
rollbackToAndReleaseSavepoint sp = getConnection >>= liftIO . flip Simple.rollbackToAndReleaseSavepoint sp
rollback :: DB () -> DB ()
rollback actionToRollback = mask $ \restore -> do
sp <- savepoint
restore actionToRollback `finally` rollbackToAndReleaseSavepoint sp