{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Database.GP.GenericPersistence
  ( retrieveById,
    retrieveAll,
    retrieveAllWhere,
    entitiesFromRows,
    persist,
    insert,
    insertMany,
    update,
    updateMany,
    delete,
    setupTableFor,
    idValue,
    Conn(..),
    connect,
    Database(..),
    Entity (..),
    GToRow,
    GFromRow,
    columnNameFor,
    maybeFieldTypeFor,
    toString,
    EntityId,
    entityId,
    TypeInfo (..),
    typeInfo,
  )
where

import           Data.Convertible         (ConvertResult, Convertible)
import           Data.Convertible.Base    (Convertible (safeConvert))
import           Data.List                (elemIndex)
import           Database.GP.Conn
import           Database.GP.Entity
import           Database.GP.SqlGenerator
import           Database.GP.TypeInfo
import           Database.HDBC
import Control.Monad (when)

{- | 
 This module defines RDBMS Persistence operations for Record Data Types that are instances of 'Data'.
 I call instances of such a data type Entities.

 The Persistence operations are using Haskell generics to provide compile time reflection capabilities.
 HDBC is used to access the RDBMS.
-}

-- | A function that retrieves an entity from a database.
-- The function takes entity id as parameter.
-- If an entity with the given id exists in the database, it is returned as a Just value.
-- If no such entity exists, Nothing is returned.
-- An error is thrown if there are more than one entity with the given id.
retrieveById :: forall a id. (Entity a, Convertible id SqlValue) => Conn -> id -> IO (Maybe a)
retrieveById :: forall a id.
(Entity a, Convertible id SqlValue) =>
Conn -> id -> IO (Maybe a)
retrieveById Conn
conn id
idx = do
  [[SqlValue]]
resultRowsSqlValues <- Conn -> String -> [SqlValue] -> IO [[SqlValue]]
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
stmt [SqlValue
eid]
  case [[SqlValue]]
resultRowsSqlValues of
    [] -> Maybe a -> IO (Maybe a)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe a
forall a. Maybe a
Nothing
    [[SqlValue]
singleRow] -> a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> IO a -> IO (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Conn -> [SqlValue] -> IO a
forall a. Entity a => Conn -> [SqlValue] -> IO a
fromRow Conn
conn [SqlValue]
singleRow
    [[SqlValue]]
_ -> String -> IO (Maybe a)
forall a. HasCallStack => String -> a
error (String -> IO (Maybe a)) -> String -> IO (Maybe a)
forall a b. (a -> b) -> a -> b
$ String
"More than one" String -> String -> String
forall a. [a] -> [a] -> [a]
++ TypeInfo a -> String
forall {k} (a :: k). TypeInfo a -> String
constructorName TypeInfo a
ti String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" found for id " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SqlValue -> String
forall a. Show a => a -> String
show SqlValue
eid
  where
    ti :: TypeInfo a
ti = forall a.
(HasConstructor (Rep a), HasSelectors (Rep a), Generic a) =>
TypeInfo a
typeInfo @a
    stmt :: String
stmt = forall a. Entity a => String
selectStmtFor @a
    eid :: SqlValue
eid = id -> SqlValue
forall a. Convertible a SqlValue => a -> SqlValue
toSql id
idx

-- | This function retrieves all entities of type `a` from a database.
--  The function takes an HDBC connection as parameter.
--  The type `a` is determined by the context of the function call.
retrieveAll :: forall a. (Entity a) => Conn -> IO [a]
retrieveAll :: forall a. Entity a => Conn -> IO [a]
retrieveAll Conn
conn = do
  [[SqlValue]]
resultRows <- Conn -> String -> [SqlValue] -> IO [[SqlValue]]
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
stmt []
  Conn -> [[SqlValue]] -> IO [a]
forall a. Entity a => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows Conn
conn [[SqlValue]]
resultRows
  where
    stmt :: String
stmt = forall a. Entity a => String
selectAllStmtFor @a

-- | This function retrieves all entities of type `a` where a given field has a given value.
--  The function takes an HDBC connection, the name of the field and the value as parameters.
--  The type `a` is determined by the context of the function call.
--  The function returns a (possibly empty) list of all matching entities.
retrieveAllWhere :: forall a. (Entity a) => Conn -> String -> SqlValue -> IO [a]
retrieveAllWhere :: forall a. Entity a => Conn -> String -> SqlValue -> IO [a]
retrieveAllWhere Conn
conn String
field SqlValue
val = do
  [[SqlValue]]
resultRows <- Conn -> String -> [SqlValue] -> IO [[SqlValue]]
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
stmt [SqlValue
val]
  Conn -> [[SqlValue]] -> IO [a]
forall a. Entity a => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows Conn
conn [[SqlValue]]
resultRows
  where
    stmt :: String
stmt = forall a. Entity a => String -> String
selectAllWhereStmtFor @a String
field

-- | This function converts a list of database rows, represented as a `[[SqlValue]]` to a list of entities.
--   The function takes an HDBC connection and a list of database rows as parameters.
--   The type `a` is determined by the context of the function call.
--   The function returns a (possibly empty) list of all matching entities.
--   The function is used internally by `retrieveAll` and `retrieveAllWhere`.
--   But it can also be used to convert the result of a custom SQL query to a list of entities.
entitiesFromRows :: forall a. (Entity a) => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows :: forall a. Entity a => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows = ([SqlValue] -> IO a) -> [[SqlValue]] -> IO [a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (([SqlValue] -> IO a) -> [[SqlValue]] -> IO [a])
-> (Conn -> [SqlValue] -> IO a) -> Conn -> [[SqlValue]] -> IO [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Conn -> [SqlValue] -> IO a
forall a. Entity a => Conn -> [SqlValue] -> IO a
fromRow

-- | A function that persists an entity to a database.
-- The function takes an HDBC connection and an entity as parameters.
-- The entity is either inserted or updated, depending on whether it already exists in the database.
-- The required SQL statements are generated dynamically using Haskell generics and reflection
persist :: forall a. (Entity a) => Conn -> a -> IO ()
persist :: forall a. Entity a => Conn -> a -> IO ()
persist Conn
conn a
entity = do
  SqlValue
eid <- Conn -> a -> IO SqlValue
forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
entity
  [[SqlValue]]
resultRows <- Conn -> String -> [SqlValue] -> IO [[SqlValue]]
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
preparedSelectStmt [SqlValue
eid]
  case [[SqlValue]]
resultRows of
    []           -> Conn -> a -> IO ()
forall a. Entity a => Conn -> a -> IO ()
insert Conn
conn a
entity
    [[SqlValue]
_singleRow] -> Conn -> a -> IO ()
forall a. Entity a => Conn -> a -> IO ()
update Conn
conn a
entity
    [[SqlValue]]
_            -> String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"More than one entity found for id " String -> String -> String
forall a. [a] -> [a] -> [a]
++ SqlValue -> String
forall a. Show a => a -> String
show SqlValue
eid
  where
    preparedSelectStmt :: String
preparedSelectStmt = forall a. Entity a => String
selectStmtFor @a

-- | A function that explicitely inserts an entity into a database.
insert :: forall a. (Entity a) => Conn -> a -> IO ()
insert :: forall a. Entity a => Conn -> a -> IO ()
insert Conn
conn a
entity = do
  [SqlValue]
row <- Conn -> a -> IO [SqlValue]
forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn a
entity
  Integer
_rowcount <- Conn -> String -> [SqlValue] -> IO Integer
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run Conn
conn (forall a. Entity a => String
insertStmtFor @a) [SqlValue]
row
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | A function that inserts a list of entities into a database.
--   The function takes an HDBC connection and a list of entities as parameters.
--   The insert-statement is compiled only once and then executed for each entity.
insertMany :: forall a. (Entity a) => Conn -> [a] -> IO ()
insertMany :: forall a. Entity a => Conn -> [a] -> IO ()
insertMany Conn
conn [a]
entities = do
  [[SqlValue]]
rows <- (a -> IO [SqlValue]) -> [a] -> IO [[SqlValue]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Conn -> a -> IO [SqlValue]
forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn) [a]
entities
  Statement
stmt <- Conn -> String -> IO Statement
forall conn. IConnection conn => conn -> String -> IO Statement
prepare Conn
conn (forall a. Entity a => String
insertStmtFor @a)
  Statement -> [[SqlValue]] -> IO ()
executeMany Statement
stmt [[SqlValue]]
rows
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit Conn
conn
  

-- | A function that explicitely updates an entity in a database.
update :: forall a. (Entity a) => Conn -> a -> IO ()
update :: forall a. Entity a => Conn -> a -> IO ()
update Conn
conn a
entity = do
  SqlValue
eid <- Conn -> a -> IO SqlValue
forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
entity
  [SqlValue]
row <- Conn -> a -> IO [SqlValue]
forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn a
entity
  Integer
_rowcount <- Conn -> String -> [SqlValue] -> IO Integer
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run Conn
conn (forall a. Entity a => String
updateStmtFor @a) ([SqlValue]
row [SqlValue] -> [SqlValue] -> [SqlValue]
forall a. [a] -> [a] -> [a]
++ [SqlValue
eid])
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | A function that updates a list of entities in a database.
--   The function takes an HDBC connection and a list of entities as parameters.
--   The update-statement is compiled only once and then executed for each entity.
updateMany :: forall a. (Entity a) => Conn -> [a] -> IO ()
updateMany :: forall a. Entity a => Conn -> [a] -> IO ()
updateMany Conn
conn [a]
entities = do
  [SqlValue]
eids <- (a -> IO SqlValue) -> [a] -> IO [SqlValue]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Conn -> a -> IO SqlValue
forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn) [a]
entities
  [[SqlValue]]
rows <- (a -> IO [SqlValue]) -> [a] -> IO [[SqlValue]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Conn -> a -> IO [SqlValue]
forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn) [a]
entities
  Statement
stmt <- Conn -> String -> IO Statement
forall conn. IConnection conn => conn -> String -> IO Statement
prepare Conn
conn (forall a. Entity a => String
updateStmtFor @a)
  -- the update statement has one more parameter than the row: the id value for the where clause
  Statement -> [[SqlValue]] -> IO ()
executeMany Statement
stmt (([SqlValue] -> SqlValue -> [SqlValue])
-> [[SqlValue]] -> [SqlValue] -> [[SqlValue]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\[SqlValue]
l SqlValue
x -> [SqlValue]
l [SqlValue] -> [SqlValue] -> [SqlValue]
forall a. [a] -> [a] -> [a]
++ [SqlValue
x]) [[SqlValue]]
rows [SqlValue]
eids)
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | A function that deletes an entity from a database.
--   The function takes an HDBC connection and an entity as parameters.
delete :: forall a. (Entity a) => Conn -> a -> IO ()
delete :: forall a. Entity a => Conn -> a -> IO ()
delete Conn
conn a
entity = do
  SqlValue
eid <- Conn -> a -> IO SqlValue
forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
entity
  Integer
_rowCount <- Conn -> String -> [SqlValue] -> IO Integer
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run Conn
conn (forall a. Entity a => String
deleteStmtFor @a) [SqlValue
eid]
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | set up a table for a given entity type. The table is dropped (if existing) and recreated.
--   The function takes an HDBC connection as parameter.
setupTableFor :: forall a. (Entity a) => Conn -> IO ()
setupTableFor :: forall a. Entity a => Conn -> IO ()
setupTableFor Conn
conn = do
  Conn -> String -> IO ()
forall conn. IConnection conn => conn -> String -> IO ()
runRaw Conn
conn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ forall a. Entity a => String
dropTableStmtFor @a
  Conn -> String -> IO ()
forall conn. IConnection conn => conn -> String -> IO ()
runRaw Conn
conn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ forall a. Entity a => Database -> String
createTableStmtFor @a (Conn -> Database
db Conn
conn)
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ Conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | Computes the EntityId of an entity.
--   The EntityId of an entity is a (typeRep, idValue) tuple.
--   The function takes an HDBC connection and an entity as parameters.
entityId :: forall a. (Entity a) => Conn -> a -> IO EntityId
entityId :: forall a. Entity a => Conn -> a -> IO EntityId
entityId Conn
conn a
x = do
  SqlValue
eid <- Conn -> a -> IO SqlValue
forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
x
  EntityId -> IO EntityId
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (String
tyName, SqlValue
eid)
  where
    tyName :: String
tyName = TypeInfo a -> String
forall {k} (a :: k). TypeInfo a -> String
constructorName (forall a.
(HasConstructor (Rep a), HasSelectors (Rep a), Generic a) =>
TypeInfo a
typeInfo @a)

-- | A function that returns the primary key value of an entity as a SqlValue.
--   The function takes an HDBC connection and an entity as parameters.
idValue :: forall a. (Entity a) => Conn -> a -> IO SqlValue
idValue :: forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
x = do
  [SqlValue]
sqlValues <- Conn -> a -> IO [SqlValue]
forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn a
x
  SqlValue -> IO SqlValue
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ([SqlValue]
sqlValues [SqlValue] -> Int -> SqlValue
forall a. HasCallStack => [a] -> Int -> a
!! Int
idFieldIndex)
  where
    idFieldIndex :: Int
idFieldIndex = forall a. Entity a => String -> Int
fieldIndex @a (forall a. Entity a => String
idField @a)

-- | returns the index of a field of an entity.
--   The index is the position of the field in the list of fields of the entity.
--   If no such field exists, an error is thrown.
--   The function takes an field name as parameters, 
--   the type of the entity is determined by the context.
fieldIndex :: forall a. (Entity a) => String -> Int
fieldIndex :: forall a. Entity a => String -> Int
fieldIndex String
fieldName =
  String -> Maybe Int -> Int
forall a. String -> Maybe a -> a
expectJust
    (String
"Field " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
fieldName String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" is not present in type " String -> String -> String
forall a. [a] -> [a] -> [a]
++ TypeInfo a -> String
forall {k} (a :: k). TypeInfo a -> String
constructorName TypeInfo a
ti)
    (String -> [String] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
elemIndex String
fieldName [String]
fieldList)
  where
    ti :: TypeInfo a
ti = forall a.
(HasConstructor (Rep a), HasSelectors (Rep a), Generic a) =>
TypeInfo a
typeInfo @a
    fieldList :: [String]
fieldList = TypeInfo a -> [String]
forall {k} (a :: k). TypeInfo a -> [String]
fieldNames TypeInfo a
ti

expectJust :: String -> Maybe a -> a
expectJust :: forall a. String -> Maybe a -> a
expectJust String
_ (Just a
x)  = a
x
expectJust String
err Maybe a
Nothing = String -> a
forall a. HasCallStack => String -> a
error (String
"expectJust " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
err)

-- | These instances are needed to make the Convertible type class work with Enum types out of the box.
--   This is needed because the Convertible type class is used to convert SqlValues to Haskell types.
instance {-# OVERLAPS #-} forall a. (Enum a) => Convertible SqlValue a where
  safeConvert :: SqlValue -> ConvertResult a
  safeConvert :: SqlValue -> ConvertResult a
safeConvert = a -> ConvertResult a
forall a b. b -> Either a b
Right (a -> ConvertResult a)
-> (SqlValue -> a) -> SqlValue -> ConvertResult a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a
forall a. Enum a => Int -> a
toEnum (Int -> a) -> (SqlValue -> Int) -> SqlValue -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SqlValue -> Int
forall a. Convertible SqlValue a => SqlValue -> a
fromSql

instance {-# OVERLAPS #-} forall a. (Enum a) => Convertible a SqlValue where
  safeConvert :: a -> ConvertResult SqlValue
  safeConvert :: a -> ConvertResult SqlValue
safeConvert = SqlValue -> ConvertResult SqlValue
forall a b. b -> Either a b
Right (SqlValue -> ConvertResult SqlValue)
-> (a -> SqlValue) -> a -> ConvertResult SqlValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SqlValue
forall a. Convertible a SqlValue => a -> SqlValue
toSql (Int -> SqlValue) -> (a -> Int) -> a -> SqlValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
forall a. Enum a => a -> Int
fromEnum