{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE Rank2Types #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
module Database.Persist.Sql.Orphan.PersistStore
( withRawQuery
, BackendKey(..)
, toSqlKey
, fromSqlKey
, getFieldName
, getTableName
, tableDBName
, fieldDBName
) where
import Database.Persist
import Database.Persist.Sql.Types
import Database.Persist.Sql.Raw
import Database.Persist.Sql.Util (
dbIdColumns, keyAndEntityColumnNames, parseEntityValues, entityColumnNames
, updatePersistValue, mkUpdateText, commaSeparated)
import Data.Conduit (ConduitM, (=$=), (.|), runConduit)
import qualified Data.Conduit.List as CL
import qualified Data.Text as T
import Data.Text (Text, unpack)
import Data.Monoid (mappend, (<>))
import Control.Monad.IO.Class
import Data.ByteString.Char8 (readInteger)
import Data.Maybe (isJust)
import Data.List (find)
import Data.Void (Void)
import Control.Monad.Trans.Reader (ReaderT, ask, withReaderT)
import Data.Acquire (with)
import Data.Int (Int64)
import Web.PathPieces (PathPiece)
import Web.HttpApiData (ToHttpApiData, FromHttpApiData)
import Database.Persist.Sql.Class (PersistFieldSql)
import qualified Data.Aeson as A
import Control.Exception (throwIO)
import Database.Persist.Class ()
import qualified Data.Map as Map
import qualified Data.Foldable as Foldable
withRawQuery :: MonadIO m
=> Text
-> [PersistValue]
-> ConduitM [PersistValue] Void IO a
-> ReaderT SqlBackend m a
withRawQuery sql vals sink = do
srcRes <- rawQueryRes sql vals
liftIO $ with srcRes (\src -> runConduit $ src .| sink)
toSqlKey :: ToBackendKey SqlBackend record => Int64 -> Key record
toSqlKey = fromBackendKey . SqlBackendKey
fromSqlKey :: ToBackendKey SqlBackend record => Key record -> Int64
fromSqlKey = unSqlBackendKey . toBackendKey
whereStmtForKey :: PersistEntity record => SqlBackend -> Key record -> Text
whereStmtForKey conn k =
T.intercalate " AND "
$ map (<> "=? ")
$ dbIdColumns conn entDef
where
entDef = entityDef $ dummyFromKey k
whereStmtForKeys :: PersistEntity record => SqlBackend -> [Key record] -> Text
whereStmtForKeys conn ks = T.intercalate " OR " $ whereStmtForKey conn `fmap` ks
getTableName :: forall record m backend.
( PersistEntity record
, PersistEntityBackend record ~ SqlBackend
, IsSqlBackend backend
, Monad m
) => record -> ReaderT backend m Text
getTableName rec = withReaderT persistBackend $ do
conn <- ask
return $ connEscapeName conn $ tableDBName rec
tableDBName :: (PersistEntity record) => record -> DBName
tableDBName rec = entityDB $ entityDef (Just rec)
getFieldName :: forall record typ m backend.
( PersistEntity record
, PersistEntityBackend record ~ SqlBackend
, IsSqlBackend backend
, Monad m
)
=> EntityField record typ -> ReaderT backend m Text
getFieldName rec = withReaderT persistBackend $ do
conn <- ask
return $ connEscapeName conn $ fieldDBName rec
fieldDBName :: forall record typ. (PersistEntity record) => EntityField record typ -> DBName
fieldDBName = fieldDB . persistFieldDef
instance PersistCore SqlBackend where
newtype BackendKey SqlBackend = SqlBackendKey { unSqlBackendKey :: Int64 }
deriving (Show, Read, Eq, Ord, Num, Integral, PersistField, PersistFieldSql, PathPiece, ToHttpApiData, FromHttpApiData, Real, Enum, Bounded, A.ToJSON, A.FromJSON)
instance PersistCore SqlReadBackend where
newtype BackendKey SqlReadBackend = SqlReadBackendKey { unSqlReadBackendKey :: Int64 }
deriving (Show, Read, Eq, Ord, Num, Integral, PersistField, PersistFieldSql, PathPiece, ToHttpApiData, FromHttpApiData, Real, Enum, Bounded, A.ToJSON, A.FromJSON)
instance PersistCore SqlWriteBackend where
newtype BackendKey SqlWriteBackend = SqlWriteBackendKey { unSqlWriteBackendKey :: Int64 }
deriving (Show, Read, Eq, Ord, Num, Integral, PersistField, PersistFieldSql, PathPiece, ToHttpApiData, FromHttpApiData, Real, Enum, Bounded, A.ToJSON, A.FromJSON)
instance BackendCompatible SqlBackend SqlBackend where
projectBackend = id
instance BackendCompatible SqlBackend SqlReadBackend where
projectBackend = unSqlReadBackend
instance BackendCompatible SqlBackend SqlWriteBackend where
projectBackend = unSqlWriteBackend
instance PersistStoreWrite SqlBackend where
update _ [] = return ()
update k upds = do
conn <- ask
let wher = whereStmtForKey conn k
let sql = T.concat
[ "UPDATE "
, connEscapeName conn $ tableDBName $ recordTypeFromKey k
, " SET "
, T.intercalate "," $ map (mkUpdateText conn) upds
, " WHERE "
, wher
]
rawExecute sql $
map updatePersistValue upds `mappend` keyToValues k
insert val = do
conn <- ask
let esql = connInsertSql conn t vals
key <-
case esql of
ISRSingle sql -> withRawQuery sql vals $ do
x <- CL.head
case x of
Just [PersistInt64 i] -> case keyFromValues [PersistInt64 i] of
Left err -> error $ "SQL insert: keyFromValues: PersistInt64 " `mappend` show i `mappend` " " `mappend` unpack err
Right k -> return k
Nothing -> error $ "SQL insert did not return a result giving the generated ID"
Just vals' -> case keyFromValues vals' of
Left e -> error $ "Invalid result from a SQL insert, got: " ++ show vals' ++ ". Error was: " ++ unpack e
Right k -> return k
ISRInsertGet sql1 sql2 -> do
rawExecute sql1 vals
withRawQuery sql2 [] $ do
mm <- CL.head
let m = maybe
(Left $ "No results from ISRInsertGet: " `mappend` tshow (sql1, sql2))
Right mm
let convert x =
case x of
[PersistByteString i] -> case readInteger i of
Just (ret,"") -> [PersistInt64 $ fromIntegral ret]
_ -> x
_ -> x
onLeft Left{} x = x
onLeft x _ = x
case m >>= (\x -> keyFromValues x `onLeft` keyFromValues (convert x)) of
Right k -> return k
Left err -> throw $ "ISRInsertGet: keyFromValues failed: " `mappend` err
ISRManyKeys sql fs -> do
rawExecute sql vals
case entityPrimary t of
Nothing -> error $ "ISRManyKeys is used when Primary is defined " ++ show sql
Just pdef ->
let pks = map fieldHaskell $ compositeFields pdef
keyvals = map snd $ filter (\(a, _) -> let ret=isJust (find (== a) pks) in ret) $ zip (map fieldHaskell $ entityFields t) fs
in case keyFromValues keyvals of
Right k -> return k
Left e -> error $ "ISRManyKeys: unexpected keyvals result: " `mappend` unpack e
return key
where
tshow :: Show a => a -> Text
tshow = T.pack . show
throw = liftIO . throwIO . userError . T.unpack
t = entityDef $ Just val
vals = map toPersistValue $ toPersistFields val
insertMany [] = return []
insertMany vals = do
conn <- ask
case connInsertManySql conn of
Nothing -> mapM insert vals
Just insertManyFn -> do
case insertManyFn ent valss of
ISRSingle sql -> rawSql sql (concat valss)
_ -> error "ISRSingle is expected from the connInsertManySql function"
where
ent = entityDef vals
valss = map (map toPersistValue . toPersistFields) vals
insertMany_ vals0 = runChunked (length $ entityFields t) insertMany_' vals0
where
t = entityDef vals0
insertMany_' vals = do
conn <- ask
let valss = map (map toPersistValue . toPersistFields) vals
let sql = T.concat
[ "INSERT INTO "
, connEscapeName conn (entityDB t)
, "("
, T.intercalate "," $ map (connEscapeName conn . fieldDB) $ entityFields t
, ") VALUES ("
, T.intercalate "),(" $ replicate (length valss) $ T.intercalate "," $ map (const "?") (entityFields t)
, ")"
]
rawExecute sql (concat valss)
replace k val = do
conn <- ask
let t = entityDef $ Just val
let wher = whereStmtForKey conn k
let sql = T.concat
[ "UPDATE "
, connEscapeName conn (entityDB t)
, " SET "
, T.intercalate "," (map (go conn . fieldDB) $ entityFields t)
, " WHERE "
, wher
]
vals = map toPersistValue (toPersistFields val) `mappend` keyToValues k
rawExecute sql vals
where
go conn x = connEscapeName conn x `T.append` "=?"
insertKey k v = insrepHelper "INSERT" [Entity k v]
insertEntityMany es' = do
conn <- ask
let entDef = entityDef $ map entityVal es'
let columnNames = keyAndEntityColumnNames entDef conn
runChunked (length columnNames) go es'
where
go es = insrepHelper "INSERT" es
repsert key value = do
mExisting <- get key
case mExisting of
Nothing -> insertKey key value
Just _ -> replace key value
repsertMany krs = do
let es = (uncurry Entity) `fmap` krs
let ks = entityKey `fmap` es
let mEs = Map.fromList $ zip ks es
mRsExisting <- getMany ks
let mEsNew = Map.difference mEs mRsExisting
let esNew = snd `fmap` Map.toList mEsNew
insertEntityMany esNew
delete k = do
conn <- ask
rawExecute (sql conn) (keyToValues k)
where
wher conn = whereStmtForKey conn k
sql conn = T.concat
[ "DELETE FROM "
, connEscapeName conn $ tableDBName $ recordTypeFromKey k
, " WHERE "
, wher conn
]
instance PersistStoreWrite SqlWriteBackend where
insert v = withReaderT persistBackend $ insert v
insertMany vs = withReaderT persistBackend $ insertMany vs
insertMany_ vs = withReaderT persistBackend $ insertMany_ vs
insertEntityMany vs = withReaderT persistBackend $ insertEntityMany vs
insertKey k v = withReaderT persistBackend $ insertKey k v
repsert k v = withReaderT persistBackend $ repsert k v
replace k v = withReaderT persistBackend $ replace k v
delete k = withReaderT persistBackend $ delete k
update k upds = withReaderT persistBackend $ update k upds
repsertMany krs = withReaderT persistBackend $ repsertMany krs
instance PersistStoreRead SqlBackend where
get k = do
mEs <- getMany [k]
return $ Map.lookup k mEs
getMany [] = return Map.empty
getMany ks@(k:_)= do
conn <- ask
let t = entityDef . dummyFromKey $ k
let cols = commaSeparated . entityColumnNames t
let wher = whereStmtForKeys conn ks
let sql = T.concat
[ "SELECT "
, cols conn
, " FROM "
, connEscapeName conn $ entityDB t
, " WHERE "
, wher
]
let parse vals
= case parseEntityValues t vals of
Left s -> liftIO $ throwIO $ PersistMarshalError s
Right row -> return row
withRawQuery sql (Foldable.foldMap keyToValues ks) $ do
es <- CL.mapM parse =$= CL.consume
return $ Map.fromList $ fmap (\e -> (entityKey e, entityVal e)) es
instance PersistStoreRead SqlReadBackend where
get k = withReaderT persistBackend $ get k
getMany ks = withReaderT persistBackend $ getMany ks
instance PersistStoreRead SqlWriteBackend where
get k = withReaderT persistBackend $ get k
getMany ks = withReaderT persistBackend $ getMany ks
dummyFromKey :: Key record -> Maybe record
dummyFromKey = Just . recordTypeFromKey
recordTypeFromKey :: Key record -> record
recordTypeFromKey _ = error "dummyFromKey"
insrepHelper :: (MonadIO m, PersistEntity val)
=> Text
-> [Entity val]
-> ReaderT SqlBackend m ()
insrepHelper _ [] = return ()
insrepHelper command es = do
conn <- ask
let columnNames = keyAndEntityColumnNames entDef conn
rawExecute (sql conn columnNames) vals
where
entDef = entityDef $ map entityVal es
sql conn columnNames = T.concat
[ command
, " INTO "
, connEscapeName conn (entityDB entDef)
, "("
, T.intercalate "," columnNames
, ") VALUES ("
, T.intercalate "),(" $ replicate (length es) $ T.intercalate "," $ map (const "?") columnNames
, ")"
]
vals = Foldable.foldMap entityValues es
runChunked
:: (Monad m)
=> Int
-> ([a] -> ReaderT SqlBackend m ())
-> [a]
-> ReaderT SqlBackend m ()
runChunked _ _ [] = return ()
runChunked width m xs = do
conn <- ask
case connMaxParams conn of
Nothing -> m xs
Just maxParams -> let chunkSize = maxParams `div` width in
mapM_ m (chunksOf chunkSize xs)
chunksOf :: Int -> [a] -> [[a]]
chunksOf _ [] = []
chunksOf size xs = let (chunk, rest) = splitAt size xs in chunk : chunksOf size rest