{-# LANGUAGE AllowAmbiguousTypes #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module Database.GP.GenericPersistence
  ( retrieveById,
    retrieveAll,
    retrieveAllWhere,
    entitiesFromRows,
    persist,
    insert,
    insertMany,
    update,
    updateMany,
    delete,
    setupTableFor,
    idValue,
    Conn(..),
    connect,
    Database(..),
    Entity (..),
    GToRow,
    GFromRow,
    columnNameFor,
    maybeFieldTypeFor,
    toString,
    EntityId,
    entityId,
    TypeInfo (..),
    typeInfo,
  )
where

import           Data.Convertible         (ConvertResult, Convertible)
import           Data.Convertible.Base    (Convertible (safeConvert))
import           Data.List                (elemIndex)
import           Database.GP.Conn
import           Database.GP.Entity
import           Database.GP.SqlGenerator
import           Database.GP.TypeInfo
import           Database.HDBC
import Control.Monad (when)

{- | 
 This module defines RDBMS Persistence operations for Record Data Types that are instances of 'Data'.
 I call instances of such a data type Entities.

 The Persistence operations are using Haskell generics to provide compile time reflection capabilities.
 HDBC is used to access the RDBMS.
-}

-- | A function that retrieves an entity from a database.
-- The function takes entity id as parameter.
-- If an entity with the given id exists in the database, it is returned as a Just value.
-- If no such entity exists, Nothing is returned.
-- An error is thrown if there are more than one entity with the given id.
retrieveById :: forall a id. (Entity a, Convertible id SqlValue) => Conn -> id -> IO (Maybe a)
retrieveById :: forall a id.
(Entity a, Convertible id SqlValue) =>
Conn -> id -> IO (Maybe a)
retrieveById Conn
conn id
idx = do
  [[SqlValue]]
resultRowsSqlValues <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
stmt [SqlValue
eid]
  case [[SqlValue]]
resultRowsSqlValues of
    [] -> forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Maybe a
Nothing
    [[SqlValue]
singleRow] -> forall a. a -> Maybe a
Just forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Entity a => Conn -> [SqlValue] -> IO a
fromRow Conn
conn [SqlValue]
singleRow
    [[SqlValue]]
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"More than one" forall a. [a] -> [a] -> [a]
++ forall {k} (a :: k). TypeInfo a -> String
constructorName TypeInfo a
ti forall a. [a] -> [a] -> [a]
++ String
" found for id " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show SqlValue
eid
  where
    ti :: TypeInfo a
ti = forall a.
(HasConstructor (Rep a), HasSelectors (Rep a), Generic a) =>
TypeInfo a
typeInfo @a
    stmt :: String
stmt = forall a. Entity a => String
selectStmtFor @a
    eid :: SqlValue
eid = forall a. Convertible a SqlValue => a -> SqlValue
toSql id
idx

-- | This function retrieves all entities of type `a` from a database.
--  The function takes an HDBC connection as parameter.
--  The type `a` is determined by the context of the function call.
retrieveAll :: forall a. (Entity a) => Conn -> IO [a]
retrieveAll :: forall a. Entity a => Conn -> IO [a]
retrieveAll Conn
conn = do
  [[SqlValue]]
resultRows <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
stmt []
  forall a. Entity a => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows Conn
conn [[SqlValue]]
resultRows
  where
    stmt :: String
stmt = forall a. Entity a => String
selectAllStmtFor @a

-- | This function retrieves all entities of type `a` where a given field has a given value.
--  The function takes an HDBC connection, the name of the field and the value as parameters.
--  The type `a` is determined by the context of the function call.
--  The function returns a (possibly empty) list of all matching entities.
retrieveAllWhere :: forall a. (Entity a) => Conn -> String -> SqlValue -> IO [a]
retrieveAllWhere :: forall a. Entity a => Conn -> String -> SqlValue -> IO [a]
retrieveAllWhere Conn
conn String
field SqlValue
val = do
  [[SqlValue]]
resultRows <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
stmt [SqlValue
val]
  forall a. Entity a => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows Conn
conn [[SqlValue]]
resultRows
  where
    stmt :: String
stmt = forall a. Entity a => String -> String
selectAllWhereStmtFor @a String
field

-- | This function converts a list of database rows, represented as a `[[SqlValue]]` to a list of entities.
--   The function takes an HDBC connection and a list of database rows as parameters.
--   The type `a` is determined by the context of the function call.
--   The function returns a (possibly empty) list of all matching entities.
--   The function is used internally by `retrieveAll` and `retrieveAllWhere`.
--   But it can also be used to convert the result of a custom SQL query to a list of entities.
entitiesFromRows :: forall a. (Entity a) => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows :: forall a. Entity a => Conn -> [[SqlValue]] -> IO [a]
entitiesFromRows = forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Entity a => Conn -> [SqlValue] -> IO a
fromRow

-- | A function that persists an entity to a database.
-- The function takes an HDBC connection and an entity as parameters.
-- The entity is either inserted or updated, depending on whether it already exists in the database.
-- The required SQL statements are generated dynamically using Haskell generics and reflection
persist :: forall a. (Entity a) => Conn -> a -> IO ()
persist :: forall a. Entity a => Conn -> a -> IO ()
persist Conn
conn a
entity = do
  SqlValue
eid <- forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
entity
  [[SqlValue]]
resultRows <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery Conn
conn String
preparedSelectStmt [SqlValue
eid]
  case [[SqlValue]]
resultRows of
    []           -> forall a. Entity a => Conn -> a -> IO ()
insert Conn
conn a
entity
    [[SqlValue]
_singleRow] -> forall a. Entity a => Conn -> a -> IO ()
update Conn
conn a
entity
    [[SqlValue]]
_            -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"More than one entity found for id " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show SqlValue
eid
  where
    preparedSelectStmt :: String
preparedSelectStmt = forall a. Entity a => String
selectStmtFor @a

-- | A function that explicitely inserts an entity into a database.
insert :: forall a. (Entity a) => Conn -> a -> IO ()
insert :: forall a. Entity a => Conn -> a -> IO ()
insert Conn
conn a
entity = do
  [SqlValue]
row <- forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn a
entity
  Integer
_rowcount <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run Conn
conn (forall a. Entity a => String
insertStmtFor @a) [SqlValue]
row
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) forall a b. (a -> b) -> a -> b
$ forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | A function that inserts a list of entities into a database.
--   The function takes an HDBC connection and a list of entities as parameters.
--   The insert-statement is compiled only once and then executed for each entity.
insertMany :: forall a. (Entity a) => Conn -> [a] -> IO ()
insertMany :: forall a. Entity a => Conn -> [a] -> IO ()
insertMany Conn
conn [a]
entities = do
  [[SqlValue]]
rows <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn) [a]
entities
  Statement
stmt <- forall conn. IConnection conn => conn -> String -> IO Statement
prepare Conn
conn (forall a. Entity a => String
insertStmtFor @a)
  Statement -> [[SqlValue]] -> IO ()
executeMany Statement
stmt [[SqlValue]]
rows
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) forall a b. (a -> b) -> a -> b
$ forall conn. IConnection conn => conn -> IO ()
commit Conn
conn
  

-- | A function that explicitely updates an entity in a database.
update :: forall a. (Entity a) => Conn -> a -> IO ()
update :: forall a. Entity a => Conn -> a -> IO ()
update Conn
conn a
entity = do
  SqlValue
eid <- forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
entity
  [SqlValue]
row <- forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn a
entity
  Integer
_rowcount <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run Conn
conn (forall a. Entity a => String
updateStmtFor @a) ([SqlValue]
row forall a. [a] -> [a] -> [a]
++ [SqlValue
eid])
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) forall a b. (a -> b) -> a -> b
$ forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | A function that updates a list of entities in a database.
--   The function takes an HDBC connection and a list of entities as parameters.
--   The update-statement is compiled only once and then executed for each entity.
updateMany :: forall a. (Entity a) => Conn -> [a] -> IO ()
updateMany :: forall a. Entity a => Conn -> [a] -> IO ()
updateMany Conn
conn [a]
entities = do
  [SqlValue]
eids <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn) [a]
entities
  [[SqlValue]]
rows <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn) [a]
entities
  Statement
stmt <- forall conn. IConnection conn => conn -> String -> IO Statement
prepare Conn
conn (forall a. Entity a => String
updateStmtFor @a)
  -- the update statement has one more parameter than the row: the id value for the where clause
  Statement -> [[SqlValue]] -> IO ()
executeMany Statement
stmt (forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\[SqlValue]
l SqlValue
x -> [SqlValue]
l forall a. [a] -> [a] -> [a]
++ [SqlValue
x]) [[SqlValue]]
rows [SqlValue]
eids)
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) forall a b. (a -> b) -> a -> b
$ forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | A function that deletes an entity from a database.
--   The function takes an HDBC connection and an entity as parameters.
delete :: forall a. (Entity a) => Conn -> a -> IO ()
delete :: forall a. Entity a => Conn -> a -> IO ()
delete Conn
conn a
entity = do
  SqlValue
eid <- forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
entity
  Integer
_rowCount <- forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run Conn
conn (forall a. Entity a => String
deleteStmtFor @a) [SqlValue
eid]
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) forall a b. (a -> b) -> a -> b
$ forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | set up a table for a given entity type. The table is dropped (if existing) and recreated.
--   The function takes an HDBC connection as parameter.
setupTableFor :: forall a. (Entity a) => Conn -> IO ()
setupTableFor :: forall a. Entity a => Conn -> IO ()
setupTableFor Conn
conn = do
  forall conn. IConnection conn => conn -> String -> IO ()
runRaw Conn
conn forall a b. (a -> b) -> a -> b
$ forall a. Entity a => String
dropTableStmtFor @a
  forall conn. IConnection conn => conn -> String -> IO ()
runRaw Conn
conn forall a b. (a -> b) -> a -> b
$ forall a. Entity a => Database -> String
createTableStmtFor @a (Conn -> Database
db Conn
conn)
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Conn -> Bool
implicitCommit Conn
conn) forall a b. (a -> b) -> a -> b
$ forall conn. IConnection conn => conn -> IO ()
commit Conn
conn

-- | Computes the EntityId of an entity.
--   The EntityId of an entity is a (typeRep, idValue) tuple.
--   The function takes an HDBC connection and an entity as parameters.
entityId :: forall a. (Entity a) => Conn -> a -> IO EntityId
entityId :: forall a. Entity a => Conn -> a -> IO EntityId
entityId Conn
conn a
x = do
  SqlValue
eid <- forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
x
  forall (m :: * -> *) a. Monad m => a -> m a
return (String
tyName, SqlValue
eid)
  where
    tyName :: String
tyName = forall {k} (a :: k). TypeInfo a -> String
constructorName (forall a.
(HasConstructor (Rep a), HasSelectors (Rep a), Generic a) =>
TypeInfo a
typeInfo @a)

-- | A function that returns the primary key value of an entity as a SqlValue.
--   The function takes an HDBC connection and an entity as parameters.
idValue :: forall a. (Entity a) => Conn -> a -> IO SqlValue
idValue :: forall a. Entity a => Conn -> a -> IO SqlValue
idValue Conn
conn a
x = do
  [SqlValue]
sqlValues <- forall a. Entity a => Conn -> a -> IO [SqlValue]
toRow Conn
conn a
x
  forall (m :: * -> *) a. Monad m => a -> m a
return ([SqlValue]
sqlValues forall a. [a] -> Int -> a
!! Int
idFieldIndex)
  where
    idFieldIndex :: Int
idFieldIndex = forall a. Entity a => String -> Int
fieldIndex @a (forall a. Entity a => String
idField @a)

-- | returns the index of a field of an entity.
--   The index is the position of the field in the list of fields of the entity.
--   If no such field exists, an error is thrown.
--   The function takes an field name as parameters, 
--   the type of the entity is determined by the context.
fieldIndex :: forall a. (Entity a) => String -> Int
fieldIndex :: forall a. Entity a => String -> Int
fieldIndex String
fieldName =
  forall a. String -> Maybe a -> a
expectJust
    (String
"Field " forall a. [a] -> [a] -> [a]
++ String
fieldName forall a. [a] -> [a] -> [a]
++ String
" is not present in type " forall a. [a] -> [a] -> [a]
++ forall {k} (a :: k). TypeInfo a -> String
constructorName TypeInfo a
ti)
    (forall a. Eq a => a -> [a] -> Maybe Int
elemIndex String
fieldName [String]
fieldList)
  where
    ti :: TypeInfo a
ti = forall a.
(HasConstructor (Rep a), HasSelectors (Rep a), Generic a) =>
TypeInfo a
typeInfo @a
    fieldList :: [String]
fieldList = forall {k} (a :: k). TypeInfo a -> [String]
fieldNames TypeInfo a
ti

expectJust :: String -> Maybe a -> a
expectJust :: forall a. String -> Maybe a -> a
expectJust String
_ (Just a
x)  = a
x
expectJust String
err Maybe a
Nothing = forall a. HasCallStack => String -> a
error (String
"expectJust " forall a. [a] -> [a] -> [a]
++ String
err)

-- | These instances are needed to make the Convertible type class work with Enum types out of the box.
--   This is needed because the Convertible type class is used to convert SqlValues to Haskell types.
instance {-# OVERLAPS #-} forall a. (Enum a) => Convertible SqlValue a where
  safeConvert :: SqlValue -> ConvertResult a
  safeConvert :: SqlValue -> ConvertResult a
safeConvert = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => Int -> a
toEnum forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Convertible SqlValue a => SqlValue -> a
fromSql

instance {-# OVERLAPS #-} forall a. (Enum a) => Convertible a SqlValue where
  safeConvert :: a -> ConvertResult SqlValue
  safeConvert :: a -> ConvertResult SqlValue
safeConvert = forall a b. b -> Either a b
Right forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Convertible a SqlValue => a -> SqlValue
toSql forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Enum a => a -> Int
fromEnum