module Database.Persist.GenericSql
( SqlPersist (..)
, Connection
, ConnectionPool
, Statement
, runSqlConn
, runSqlPool
, Key
, Checkmark(..)
, rawSql
, Entity(..)
, Single(..)
, RawSql
, Migration
, parseMigration
, parseMigration'
, printMigration
, getMigration
, runMigration
, runMigrationSilent
, runMigrationUnsafe
, migrate
, commit
, rollback
) where
import qualified Prelude as P
import Prelude hiding ((++), unlines, concat, show)
import Control.Applicative ((<$>), (<*>))
import Control.Arrow ((&&&))
import Database.Persist.Store
import Control.Monad.IO.Class
import Control.Monad.Trans.Reader
import Data.Conduit.Pool
import Database.Persist.GenericSql.Internal
import Database.Persist.GenericSql.Migration
import qualified Database.Persist.GenericSql.Raw as R
import Database.Persist.GenericSql.Raw (SqlPersist (..))
import Control.Monad.Trans.Control (MonadBaseControl, control)
import qualified Control.Exception as E
import Control.Exception (throw)
import Data.Text (Text, pack, unpack, concat)
import qualified Data.Text as T
import Web.PathPieces (PathPiece (..))
import qualified Data.Text.Read
import Data.Maybe (fromMaybe)
import Data.Monoid (Monoid, mappend)
import Database.Persist.EntityDef
import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL
import Control.Monad.Logger (MonadLogger)
import Control.Monad.Base (liftBase)
type ConnectionPool = Pool Connection
instance PathPiece (KeyBackend R.SqlBackend entity) where
toPathPiece (Key (PersistInt64 i)) = toPathPiece i
toPathPiece k = throw $ PersistInvalidField $ "Invalid Key: " ++ show k
fromPathPiece t =
case Data.Text.Read.signed Data.Text.Read.decimal t of
Right (i, "") -> Just $ Key $ PersistInt64 i
_ -> Nothing
execute' :: (MonadIO m, MonadLogger m) => Text -> [PersistValue] -> SqlPersist m ()
execute' = R.execute
runSqlPool :: MonadBaseControl IO m => SqlPersist m a -> Pool Connection -> m a
runSqlPool r pconn = withResource pconn $ runSqlConn r
runSqlConn :: MonadBaseControl IO m => SqlPersist m a -> Connection -> m a
runSqlConn (SqlPersist r) conn = do
let getter = R.getStmt' conn
liftBase $ begin conn getter
x <- onException
(runReaderT r conn)
(liftBase $ rollbackC conn getter)
liftBase $ commitC conn getter
return x
instance (C.MonadResource m, MonadLogger m) => PersistStore (SqlPersist m) where
type PersistMonadBackend (SqlPersist m) = R.SqlBackend
insert val = do
conn <- SqlPersist ask
let esql = insertSql conn (entityDB t) (map fieldDB $ entityFields t) (entityID t)
i <-
case esql of
ISRSingle sql -> R.withStmt sql vals C.$$ do
x <- CL.head
case x of
Just [PersistInt64 i] -> return i
Nothing -> error $ "SQL insert did not return a result giving the generated ID"
Just vals' -> error $ "Invalid result from a SQL insert, got: " P.++ P.show vals'
ISRInsertGet sql1 sql2 -> do
execute' sql1 vals
R.withStmt sql2 [] C.$$ do
Just [PersistInt64 i] <- CL.head
return i
return $ Key $ PersistInt64 i
where
t = entityDef val
vals = map toPersistValue $ toPersistFields val
replace k val = do
conn <- SqlPersist ask
let t = entityDef val
let sql = concat
[ "UPDATE "
, escapeName conn (entityDB t)
, " SET "
, T.intercalate "," (map (go conn . fieldDB) $ entityFields t)
, " WHERE "
, escapeName conn $ entityID t
, "=?"
]
vals = map toPersistValue (toPersistFields val) `mappend` [unKey k]
execute' sql vals
where
go conn x = escapeName conn x ++ "=?"
insertKey = insrepHelper "INSERT"
repsert key value = do
delete key
insertKey key value
get k = do
conn <- SqlPersist ask
let t = entityDef $ dummyFromKey k
let cols = T.intercalate ","
$ map (escapeName conn . fieldDB) $ entityFields t
let sql = concat
[ "SELECT "
, cols
, " FROM "
, escapeName conn $ entityDB t
, " WHERE "
, escapeName conn $ entityID t
, "=?"
]
vals' = [unKey k]
R.withStmt sql vals' C.$$ do
res <- CL.head
case res of
Nothing -> return Nothing
Just vals ->
case fromPersistValues vals of
Left e -> error $ unpack $ "get " ++ show (unKey k) ++ ": " ++ e
Right v -> return $ Just v
delete k = do
conn <- SqlPersist ask
execute' (sql conn) [unKey k]
where
t = entityDef $ dummyFromKey k
sql conn = concat
[ "DELETE FROM "
, escapeName conn $ entityDB t
, " WHERE "
, escapeName conn $ entityID t
, "=?"
]
insrepHelper :: (MonadIO m, PersistEntity val, MonadLogger m)
=> Text
-> Key val
-> val
-> SqlPersist m ()
insrepHelper command (Key k) val = do
conn <- SqlPersist ask
execute' (sql conn) vals
where
t = entityDef val
sql conn = concat
[ command
, " INTO "
, escapeName conn (entityDB t)
, "("
, T.intercalate ","
$ map (escapeName conn)
$ entityID t : map fieldDB (entityFields t)
, ") VALUES("
, T.intercalate "," ("?" : map (const "?") (entityFields t))
, ")"
]
vals = k : map toPersistValue (toPersistFields val)
instance (C.MonadResource m, MonadLogger m) => PersistUnique (SqlPersist m) where
deleteBy uniq = do
conn <- SqlPersist ask
let sql' = sql conn
vals = persistUniqueToValues uniq
execute' sql' vals
where
t = entityDef $ dummyFromUnique uniq
go = map snd . persistUniqueToFieldNames
go' conn x = escapeName conn x ++ "=?"
sql conn = concat
[ "DELETE FROM "
, escapeName conn $ entityDB t
, " WHERE "
, T.intercalate " AND " $ map (go' conn) $ go uniq
]
getBy uniq = do
conn <- SqlPersist ask
let cols = T.intercalate "," $ (escapeName conn $ entityID t)
: map (escapeName conn . fieldDB) (entityFields t)
let sql = concat
[ "SELECT "
, cols
, " FROM "
, escapeName conn $ entityDB t
, " WHERE "
, sqlClause conn
]
vals' = persistUniqueToValues uniq
R.withStmt sql vals' C.$$ do
row <- CL.head
case row of
Nothing -> return Nothing
Just (PersistInt64 k:vals) ->
case fromPersistValues vals of
Left s -> error $ unpack s
Right x -> return $ Just (Entity (Key $ PersistInt64 k) x)
Just _ -> error "Database.Persist.GenericSql: Bad list in getBy"
where
sqlClause conn =
T.intercalate " AND " $ map (go conn) $ toFieldNames' uniq
go conn x = escapeName conn x ++ "=?"
t = entityDef $ dummyFromUnique uniq
toFieldNames' = map snd . persistUniqueToFieldNames
dummyFromKey :: KeyBackend R.SqlBackend v -> v
dummyFromKey _ = error "dummyFromKey"
dummyFromUnique :: Unique v -> v
dummyFromUnique _ = error "dummyFromUnique"
#if MIN_VERSION_monad_control(0, 3, 0)
onException :: MonadBaseControl IO m => m α -> m β -> m α
onException m what = control $ \runInIO ->
E.onException (runInIO m)
(runInIO what)
#endif
infixr 5 ++
(++) :: Text -> Text -> Text
(++) = mappend
show :: Show a => a -> Text
show = pack . P.show
data Checkmark = Active
| Inactive
deriving (Eq, Ord, Read, Show, Enum, Bounded)
instance PersistField Checkmark where
toPersistValue Active = PersistBool True
toPersistValue Inactive = PersistNull
fromPersistValue PersistNull = Right Inactive
fromPersistValue (PersistBool True) = Right Active
fromPersistValue (PersistBool False) =
Left "PersistField Checkmark: found unexpected FALSE value"
fromPersistValue other =
Left $ "PersistField Checkmark: unknown value " ++ show other
sqlType _ = SqlBool
isNullable _ = True
instance PathPiece Checkmark where
toPathPiece = show
fromPathPiece txt =
case reads (T.unpack txt) of
[(a, "")] -> Just a
_ -> Nothing
newtype Single a = Single {unSingle :: a}
deriving (Eq, Ord, Show, Read)
rawSql :: (RawSql a, C.MonadResource m, MonadLogger m)
=> Text
-> [PersistValue]
-> SqlPersist m [a]
rawSql stmt = run
where
getType :: (x -> SqlPersist m [a]) -> a
getType = undefined
x = getType run
process = rawSqlProcessRow
withStmt' colSubsts params = do
R.withStmt sql params
where
sql = T.concat $ makeSubsts colSubsts $ T.splitOn placeholder stmt
placeholder = "??"
makeSubsts (s:ss) (t:ts) = t : s : makeSubsts ss ts
makeSubsts [] [] = []
makeSubsts [] ts = [T.intercalate placeholder ts]
makeSubsts ss [] = error (P.concat err)
where
err = [ "rawsql: there are still ", P.show (length ss)
, "'??' placeholder substitutions to be made "
, "but all '??' placeholders have already been "
, "consumed. Please read 'rawSql's documentation "
, "on how '??' placeholders work."
]
run params = do
conn <- SqlPersist ask
let (colCount, colSubsts) = rawSqlCols (escapeName conn) x
withStmt' colSubsts params C.$$ firstRow colCount
firstRow colCount = do
mrow <- CL.head
case mrow of
Nothing -> return []
Just row
| colCount == length row -> getter mrow
| otherwise -> fail $ P.concat
[ "rawSql: wrong number of columns, got "
, P.show (length row), " but expected ", P.show colCount
, " (", rawSqlColCountReason x, ")." ]
getter = go id
where
go acc Nothing = return (acc [])
go acc (Just row) =
case process row of
Left err -> fail (T.unpack err)
Right r -> CL.head >>= go (acc . (r:))
class RawSql a where
rawSqlCols :: (DBName -> Text) -> a -> (Int, [Text])
rawSqlColCountReason :: a -> String
rawSqlProcessRow :: [PersistValue] -> Either Text a
instance PersistField a => RawSql (Single a) where
rawSqlCols _ _ = (1, [])
rawSqlColCountReason _ = "one column for a 'Single' data type"
rawSqlProcessRow [pv] = Single <$> fromPersistValue pv
rawSqlProcessRow _ = Left "RawSql (Single a): wrong number of columns."
instance PersistEntity a => RawSql (Entity a) where
rawSqlCols escape = ((+1) . length . entityFields &&& process) . entityDef . entityVal
where
process ed = (:[]) $
T.intercalate ", " $
map ((name ed ++) . escape) $
(entityID ed:) $
map fieldDB $
entityFields ed
name ed = escape (entityDB ed) ++ "."
rawSqlColCountReason a =
case fst (rawSqlCols undefined a) of
1 -> "one column for an 'Entity' data type without fields"
n -> P.show n P.++ " columns for an 'Entity' data type"
rawSqlProcessRow (idCol:ent) = Entity <$> fromPersistValue idCol
<*> fromPersistValues ent
rawSqlProcessRow _ = Left "RawSql (Entity a): wrong number of columns."
instance RawSql a => RawSql (Maybe a) where
rawSqlCols e = rawSqlCols e . extractMaybe
rawSqlColCountReason = rawSqlColCountReason . extractMaybe
rawSqlProcessRow cols
| all isNull cols = return Nothing
| otherwise =
case rawSqlProcessRow cols of
Right v -> Right (Just v)
Left msg -> Left $ "RawSql (Maybe a): not all columns were Null " ++
"but the inner parser has failed. Its message " ++
"was \"" ++ msg ++ "\". Did you apply Maybe " ++
"to a tuple, perhaps? The main use case for " ++
"Maybe is to allow OUTER JOINs to be written, " ++
"in which case 'Maybe (Entity v)' is used."
where isNull PersistNull = True
isNull _ = False
extractMaybe :: Maybe a -> a
extractMaybe = fromMaybe (error "Database.Persist.GenericSql.extractMaybe")
instance (RawSql a, RawSql b) => RawSql (a, b) where
rawSqlCols e x = rawSqlCols e (fst x) # rawSqlCols e (snd x)
where (cnta, lsta) # (cntb, lstb) = (cnta + cntb, lsta P.++ lstb)
rawSqlColCountReason x = rawSqlColCountReason (fst x) P.++ ", " P.++
rawSqlColCountReason (snd x)
rawSqlProcessRow =
let x = getType processRow
getType :: (z -> Either y x) -> x
getType = undefined
colCountFst = fst $ rawSqlCols undefined (fst x)
processRow row =
let (rowFst, rowSnd) = splitAt colCountFst row
in (,) <$> rawSqlProcessRow rowFst
<*> rawSqlProcessRow rowSnd
in colCountFst `seq` processRow
instance (RawSql a, RawSql b, RawSql c) => RawSql (a, b, c) where
rawSqlCols e = rawSqlCols e . from3
rawSqlColCountReason = rawSqlColCountReason . from3
rawSqlProcessRow = fmap to3 . rawSqlProcessRow
from3 :: (a,b,c) -> ((a,b),c)
from3 (a,b,c) = ((a,b),c)
to3 :: ((a,b),c) -> (a,b,c)
to3 ((a,b),c) = (a,b,c)
instance (RawSql a, RawSql b, RawSql c, RawSql d) => RawSql (a, b, c, d) where
rawSqlCols e = rawSqlCols e . from4
rawSqlColCountReason = rawSqlColCountReason . from4
rawSqlProcessRow = fmap to4 . rawSqlProcessRow
from4 :: (a,b,c,d) -> ((a,b),(c,d))
from4 (a,b,c,d) = ((a,b),(c,d))
to4 :: ((a,b),(c,d)) -> (a,b,c,d)
to4 ((a,b),(c,d)) = (a,b,c,d)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e)
=> RawSql (a, b, c, d, e) where
rawSqlCols e = rawSqlCols e . from5
rawSqlColCountReason = rawSqlColCountReason . from5
rawSqlProcessRow = fmap to5 . rawSqlProcessRow
from5 :: (a,b,c,d,e) -> ((a,b),(c,d),e)
from5 (a,b,c,d,e) = ((a,b),(c,d),e)
to5 :: ((a,b),(c,d),e) -> (a,b,c,d,e)
to5 ((a,b),(c,d),e) = (a,b,c,d,e)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f)
=> RawSql (a, b, c, d, e, f) where
rawSqlCols e = rawSqlCols e . from6
rawSqlColCountReason = rawSqlColCountReason . from6
rawSqlProcessRow = fmap to6 . rawSqlProcessRow
from6 :: (a,b,c,d,e,f) -> ((a,b),(c,d),(e,f))
from6 (a,b,c,d,e,f) = ((a,b),(c,d),(e,f))
to6 :: ((a,b),(c,d),(e,f)) -> (a,b,c,d,e,f)
to6 ((a,b),(c,d),(e,f)) = (a,b,c,d,e,f)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f,
RawSql g)
=> RawSql (a, b, c, d, e, f, g) where
rawSqlCols e = rawSqlCols e . from7
rawSqlColCountReason = rawSqlColCountReason . from7
rawSqlProcessRow = fmap to7 . rawSqlProcessRow
from7 :: (a,b,c,d,e,f,g) -> ((a,b),(c,d),(e,f),g)
from7 (a,b,c,d,e,f,g) = ((a,b),(c,d),(e,f),g)
to7 :: ((a,b),(c,d),(e,f),g) -> (a,b,c,d,e,f,g)
to7 ((a,b),(c,d),(e,f),g) = (a,b,c,d,e,f,g)
instance (RawSql a, RawSql b, RawSql c,
RawSql d, RawSql e, RawSql f,
RawSql g, RawSql h)
=> RawSql (a, b, c, d, e, f, g, h) where
rawSqlCols e = rawSqlCols e . from8
rawSqlColCountReason = rawSqlColCountReason . from8
rawSqlProcessRow = fmap to8 . rawSqlProcessRow
from8 :: (a,b,c,d,e,f,g,h) -> ((a,b),(c,d),(e,f),(g,h))
from8 (a,b,c,d,e,f,g,h) = ((a,b),(c,d),(e,f),(g,h))
to8 :: ((a,b),(c,d),(e,f),(g,h)) -> (a,b,c,d,e,f,g,h)
to8 ((a,b),(c,d),(e,f),(g,h)) = (a,b,c,d,e,f,g,h)