{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
module Database.Persist.Sql.Raw where
import Database.Persist
import Database.Persist.Sql.Types
import Database.Persist.Sql.Class
import qualified Data.Map as Map
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ReaderT, ask, MonadReader)
import Data.Acquire (allocateAcquire, Acquire, mkAcquire, with)
import Data.IORef (writeIORef, readIORef, newIORef)
import Control.Exception (throwIO)
import Control.Monad (when, liftM)
import Data.Text (Text, pack)
import Control.Monad.Logger (logDebugNS, runLoggingT)
import Data.Int (Int64)
import qualified Data.Text as T
import Data.Conduit
import Control.Monad.Trans.Resource (MonadResource,release)
rawQuery :: (MonadResource m, MonadReader env m, HasPersistBackend env, BaseBackend env ~ SqlBackend)
=> Text
-> [PersistValue]
-> ConduitM () [PersistValue] m ()
rawQuery sql vals = do
srcRes <- liftPersist $ rawQueryRes sql vals
(releaseKey, src) <- allocateAcquire srcRes
src
release releaseKey
rawQueryRes
:: (MonadIO m1, MonadIO m2, IsSqlBackend env)
=> Text
-> [PersistValue]
-> ReaderT env m1 (Acquire (ConduitM () [PersistValue] m2 ()))
rawQueryRes sql vals = do
conn <- persistBackend `liftM` ask
let make = do
runLoggingT (logDebugNS (pack "SQL") $ T.append sql $ pack $ "; " ++ show vals)
(connLogFunc conn)
getStmtConn conn sql
return $ do
stmt <- mkAcquire make stmtReset
stmtQuery stmt vals
rawExecute :: (MonadIO m, BackendCompatible SqlBackend backend)
=> Text
-> [PersistValue]
-> ReaderT backend m ()
rawExecute x y = liftM (const ()) $ rawExecuteCount x y
rawExecuteCount :: (MonadIO m, BackendCompatible SqlBackend backend)
=> Text
-> [PersistValue]
-> ReaderT backend m Int64
rawExecuteCount sql vals = do
conn <- projectBackend `liftM` ask
runLoggingT (logDebugNS (pack "SQL") $ T.append sql $ pack $ "; " ++ show vals)
(connLogFunc conn)
stmt <- getStmt sql
res <- liftIO $ stmtExecute stmt vals
liftIO $ stmtReset stmt
return res
getStmt
:: (MonadIO m, BackendCompatible SqlBackend backend)
=> Text -> ReaderT backend m Statement
getStmt sql = do
conn <- projectBackend `liftM` ask
liftIO $ getStmtConn conn sql
getStmtConn :: SqlBackend -> Text -> IO Statement
getStmtConn conn sql = do
smap <- liftIO $ readIORef $ connStmtMap conn
case Map.lookup sql smap of
Just stmt -> return stmt
Nothing -> do
stmt' <- liftIO $ connPrepare conn sql
iactive <- liftIO $ newIORef True
let stmt = Statement
{ stmtFinalize = do
active <- readIORef iactive
if active
then do
stmtFinalize stmt'
writeIORef iactive False
else return ()
, stmtReset = do
active <- readIORef iactive
when active $ stmtReset stmt'
, stmtExecute = \x -> do
active <- readIORef iactive
if active
then stmtExecute stmt' x
else throwIO $ StatementAlreadyFinalized sql
, stmtQuery = \x -> do
active <- liftIO $ readIORef iactive
if active
then stmtQuery stmt' x
else liftIO $ throwIO $ StatementAlreadyFinalized sql
}
liftIO $ writeIORef (connStmtMap conn) $ Map.insert sql stmt smap
return stmt
rawSql :: (RawSql a, MonadIO m)
=> Text
-> [PersistValue]
-> ReaderT SqlBackend m [a]
rawSql stmt = run
where
getType :: (x -> m [a]) -> a
getType = error "rawSql.getType"
x = getType run
process = rawSqlProcessRow
withStmt' colSubsts params sink = do
srcRes <- rawQueryRes sql params
liftIO $ with srcRes (\src -> runConduit $ src .| sink)
where
sql = T.concat $ makeSubsts colSubsts $ T.splitOn placeholder stmt
placeholder = "??"
makeSubsts (s:ss) (t:ts) = t : s : makeSubsts ss ts
makeSubsts [] [] = []
makeSubsts [] ts = [T.intercalate placeholder ts]
makeSubsts ss [] = error (concat err)
where
err = [ "rawsql: there are still ", show (length ss)
, "'??' placeholder substitutions to be made "
, "but all '??' placeholders have already been "
, "consumed. Please read 'rawSql's documentation "
, "on how '??' placeholders work."
]
run params = do
conn <- ask
let (colCount, colSubsts) = rawSqlCols (connEscapeName conn) x
withStmt' colSubsts params $ firstRow colCount
firstRow colCount = do
mrow <- await
case mrow of
Nothing -> return []
Just row
| colCount == length row -> getter mrow
| otherwise -> fail $ concat
[ "rawSql: wrong number of columns, got "
, show (length row), " but expected ", show colCount
, " (", rawSqlColCountReason x, ")." ]
getter = go id
where
go acc Nothing = return (acc [])
go acc (Just row) =
case process row of
Left err -> fail (T.unpack err)
Right r -> await >>= go (acc . (r:))