{-# LANGUAGE FlexibleContexts, GeneralizedNewtypeDeriving, RecordWildCards, OverloadedStrings #-}
module Database.PostgreSQL.Transact where
import Control.Monad.Trans.Reader
import Database.PostgreSQL.Simple as Simple
import Database.PostgreSQL.Simple.Transaction
import Control.Monad.IO.Class
import Control.Monad.Trans.Class
import Control.Monad.Trans.Control
import Control.Monad.Catch
import Data.Int
import Control.Monad
import qualified Data.ByteString as BS

newtype DBT m a = DBT { unDBT :: ReaderT Connection m a }
  deriving (MonadTrans, MonadThrow)

type DB = DBT IO

instance Functor m => Functor (DBT m) where
  fmap f = DBT . fmap f . unDBT

instance Applicative m => Applicative (DBT m) where
  pure = DBT . pure
  f <*> v = DBT $ unDBT f <*> unDBT v

instance MonadIO m => MonadIO (DBT m) where
  liftIO = lift . liftIO

instance Monad m => Monad (DBT m) where
  return = lift . return
  DBT m >>= k = DBT $ m >>= unDBT . k

isClass25 :: SqlError -> Bool
isClass25 SqlError{..} = BS.take 2 sqlState == "25"

instance (MonadIO m, MonadMask m) => MonadCatch (DBT m) where
  catch (DBT act) handler = DBT $ mask $ \restore -> do
    conn <- ask
    sp   <- liftIO $ newSavepoint conn
    let setup = catch (restore act) $ \e -> do
                  liftIO $ rollbackToSavepoint conn sp
                  unDBT $ handler e

    setup `finally` liftIO (tryJust (guard . isClass25) (releaseSavepoint conn sp))

getConnection :: Monad m => DBT m Connection
getConnection = DBT ask

runDBT :: MonadBaseControl IO m => DBT m a -> IsolationLevel -> Connection -> m a
runDBT action level conn
  = control
  $ \run -> withTransactionLevel level conn
  $ run
  $ runReaderT (unDBT action) conn

runDBTSerializable :: MonadBaseControl IO m => DBT m a -> Connection -> m a
runDBTSerializable action conn
  = control
  $ \run -> withTransactionSerializable conn
  $ run
  $ runReaderT (unDBT action) conn

query :: (ToRow a, FromRow b, MonadIO m) => Query -> a -> DBT m [b]
query q x = getConnection >>= \conn -> liftIO $ Simple.query conn q x

query_ :: (FromRow b, MonadIO m) => Query -> DBT m [b]
query_ q = getConnection >>= \conn -> liftIO $ Simple.query_ conn q

execute :: (ToRow q, MonadIO m) => Query -> q -> DBT m Int64
execute q x = getConnection >>= \conn -> liftIO $ Simple.execute conn q x

execute_ :: MonadIO m => Query -> DBT m Int64
execute_ q = getConnection >>= \conn -> liftIO $ Simple.execute_ conn q

executeMany :: (ToRow q, MonadIO m) => Query -> [q] -> DBT m Int64
executeMany q xs = getConnection >>= \conn -> liftIO $ Simple.executeMany conn q xs

returning :: (ToRow q, FromRow r, MonadIO m) => Query -> [q] -> DBT m [r]
returning q xs = getConnection >>= \conn -> liftIO $ Simple.returning conn q xs