{-# LANGUAGE FlexibleContexts, ExistentialQuantification, ScopedTypeVariables, MultiParamTypeClasses, FlexibleInstances #-}

-- | This helper module is intended for use by the backend creators
module Database.Groundhog.Generic
  -- * Migration
  , executeMigration
  , executeMigrationSilent
  , executeMigrationUnsafe
  , runMigration
  , runMigrationSilent
  , runMigrationUnsafe
  , getQueries
  , printMigration
  , mergeMigrations
  -- * Helper functions for defining *PersistValue instances
  , primToPersistValue
  , primFromPersistValue
  , primToPurePersistValues
  , primFromPurePersistValues
  , primToSinglePersistValue
  , primFromSinglePersistValue
  , pureToPersistValue
  , pureFromPersistValue
  , singleToPersistValue
  , singleFromPersistValue
  , toSinglePersistValueUnique
  , fromSinglePersistValueUnique
  , toPersistValuesUnique
  , fromPersistValuesUnique
  , toSinglePersistValueAutoKey
  , fromSinglePersistValueAutoKey
  , failMessage
  , failMessageNamed
  -- * Other
  , bracket
  , finally
  , onException
  , PSFieldDef(..)
  , applyDbTypeSettings
  , findOne
  , replaceOne
  , matchElements
  , haveSameElems
  , phantomDb
  , getDefaultAutoKeyType
  , getUniqueFields
  , isSimple
  , firstRow
  , streamToList
  , mapStream
  , joinStreams
  , deleteByKey
  ) where

import Database.Groundhog.Core

import Control.Applicative ((<|>))
import Control.Monad (liftM, forM_, unless)
import Control.Monad.Trans.Reader (ReaderT(..), runReaderT, ask)
import Control.Monad.Trans.State (StateT(..))
import Control.Monad.Trans.Control (MonadBaseControl, control, restoreM)
import qualified Control.Exception as E
import Control.Monad.IO.Class (MonadIO (..))
import Data.Acquire (with)
import Data.Acquire.Internal (Acquire(..), Allocated(..), ReleaseType(..))
import Data.Either (partitionEithers)
import Data.Function (on)
import Data.IORef
import Data.List (partition, sortBy)
import qualified Data.Map as Map
import System.IO (hPutStrLn, stderr)

-- | Produce the migrations but not execute them. Fails when an unsafe migration occurs.
createMigration :: Monad m => Migration m -> m NamedMigrations
createMigration m = liftM snd $ runStateT m Map.empty

-- | Returns either a list of errors in migration or a list of queries
getQueries :: Bool -- ^ True - support unsafe queries
             -> SingleMigration -> Either [String] [String]
getQueries _ (Left errs) = Left errs
getQueries runUnsafe (Right migs) = (if runUnsafe || null unsafe
  then Right $ map (\(_, _, query) -> query) migs'
  else Left $
    [ "Database migration: manual intervention required."
    , "The following actions are considered unsafe:"
    ] ++ map (\(_, _, query) -> query) unsafe) where
  migs' = sortBy (compare `on` \(_, i, _) -> i) migs
  unsafe = filter (\(isUnsafe, _, _) -> isUnsafe) migs'

executeMigration' :: (PersistBackend m, MonadIO m) => Bool -> Bool -> NamedMigrations -> m ()
executeMigration' runUnsafe silent m = do
  let migs = getQueries runUnsafe $ mergeMigrations $ Map.elems m
  case migs of
    Left errs -> fail $ unlines errs
    Right qs -> forM_  qs $ \q -> do
      unless silent $ liftIO $ hPutStrLn stderr $ "Migrating: " ++ q
      executeRaw False q []

-- | Execute the migrations with printing to stderr. Fails when an unsafe migration occurs.
executeMigration :: (PersistBackend m, MonadIO m) => NamedMigrations -> m ()
executeMigration = executeMigration' False False

-- | Execute the migrations. Fails when an unsafe migration occurs.
executeMigrationSilent :: (PersistBackend m, MonadIO m) => NamedMigrations -> m ()
executeMigrationSilent = executeMigration' False True

-- | Execute migrations. Executes the unsafe migrations without warnings and prints them to stderr
executeMigrationUnsafe :: (PersistBackend m, MonadIO m) => NamedMigrations -> m ()
executeMigrationUnsafe = executeMigration' True False

-- | Pretty print the migrations
printMigration :: MonadIO m => NamedMigrations -> m ()
printMigration migs = liftIO $ forM_ (Map.assocs migs) $ \(k, v) -> do
  putStrLn $ "Datatype " ++ k ++ ":"
  case v of
    Left errors -> mapM_ (putStrLn . ("\tError:\t" ++)) errors
    Right sqls  -> do
      let showSql (isUnsafe, _, sql) = (if isUnsafe then "Unsafe:\t" else "Safe:\t") ++ sql
      mapM_ (putStrLn . ("\t" ++) . showSql) sqls

-- | Creates migrations and executes them with printing to stderr. Fails when an unsafe migration occurs.
-- > runMigration m = createMigration m >>= executeMigration
runMigration :: (PersistBackend m, MonadIO m) => Migration m -> m ()
runMigration m = createMigration m >>= executeMigration

-- | Creates migrations and silently executes them. Fails when an unsafe migration occurs.
-- > runMigration m = createMigration m >>= executeMigrationSilent
runMigrationSilent :: (PersistBackend m, MonadIO m) => Migration m -> m ()
runMigrationSilent m = createMigration m >>= executeMigrationSilent

-- | Creates migrations and executes them with printing to stderr. Executes the unsafe migrations without warnings
-- > runMigrationUnsafe m = createMigration m >>= executeMigrationUnsafe
runMigrationUnsafe :: (PersistBackend m, MonadIO m) => Migration m -> m ()
runMigrationUnsafe m = createMigration m >>= executeMigrationUnsafe

-- | Joins the migrations. The result is either all error messages or all queries
mergeMigrations :: [SingleMigration] -> SingleMigration
mergeMigrations ms = case partitionEithers ms of
  ([], statements) -> Right $ concat statements
  (errors, _)      -> Left  $ concat errors

failMessage :: PersistField a => a -> [PersistValue] -> String
failMessage a = failMessageNamed (persistName a)

failMessageNamed :: String -> [PersistValue] -> String
failMessageNamed name xs = "Invalid list for " ++ name ++ ": " ++ show xs

finally :: MonadBaseControl IO m
        => m a -- ^ computation to run first
        -> m b -- ^ computation to run afterward (even if an exception was raised)
        -> m a
finally a sequel = control $ \runInIO ->
                     E.finally (runInIO a)
                               (runInIO sequel)

bracket :: MonadBaseControl IO m
        => m a        -- ^ computation to run first ("acquire resource")
        -> (a -> m b) -- ^ computation to run last ("release resource")
        -> (a -> m c) -- ^ computation to run in-between
        -> m c
bracket before after thing = control $ \runInIO ->
                     E.bracket (runInIO before) (\st -> runInIO $ restoreM st >>= after) (\st -> runInIO $ restoreM st >>= thing)

onException :: MonadBaseControl IO m
        => m a
        -> m b
        -> m a
onException io what = control $ \runInIO -> E.onException (runInIO io) (runInIO what)

data PSFieldDef str = PSFieldDef {
    psFieldName :: str  -- ^ name in the record, bar
  , psDbFieldName :: Maybe str -- ^ column name, SQLbar
  , psDbTypeName :: Maybe str -- ^ column type, inet, NUMERIC(5, 2), VARCHAR(50), etc.
  , psExprName :: Maybe str -- ^ name of constructor in the Field GADT, BarField
  , psEmbeddedDef :: Maybe [PSFieldDef str]
  , psDefaultValue :: Maybe str -- ^ default value in the database
  , psReferenceParent :: Maybe (Maybe ((Maybe str, str), [str]), Maybe ReferenceActionType, Maybe ReferenceActionType)
  , psFieldConverter :: Maybe str -- ^ name of a pair of functions
} deriving (Eq, Show)

applyDbTypeSettings :: PSFieldDef String -> DbType -> DbType
applyDbTypeSettings (PSFieldDef _ _ dbTypeName _ Nothing def psRef _) typ = case typ of
  DbTypePrimitive t nullable def' ref -> DbTypePrimitive (maybe t (\typeName -> DbOther $ OtherTypeDef [Left typeName]) dbTypeName) nullable (def <|> def') (applyReferencesSettings psRef ref)
  DbEmbedded emb ref -> DbEmbedded emb (applyReferencesSettings psRef ref)
  t -> t
applyDbTypeSettings (PSFieldDef _ _ _ _ (Just subs) _ psRef _) typ = (case typ of
  DbEmbedded (EmbeddedDef _ fields) ref -> DbEmbedded (uncurry EmbeddedDef $ go subs fields) (applyReferencesSettings psRef ref)
  t -> error $ "applyDbTypeSettings: expected DbEmbedded, got " ++ show t) where
  go [] fs = (False, fs)
  go st [] = error $ "applyDbTypeSettings: embedded datatype does not have expected fields: " ++ show st
  go st (field@(fName, fType):fs) = case partition ((== fName) . psFieldName) st of
    ([fDef], rest) -> result where
      (flag, fields') = go rest fs
      result = case psDbFieldName fDef of
        Nothing -> (flag, (fName, applyDbTypeSettings fDef fType):fields')
        Just name' -> (True, (name', applyDbTypeSettings fDef fType):fields')
    _ -> let (flag, fields') = go st fs in (flag, field:fields')

applyReferencesSettings :: Maybe (Maybe ((Maybe String, String), [String]), Maybe ReferenceActionType, Maybe ReferenceActionType) -> Maybe ParentTableReference -> Maybe ParentTableReference
applyReferencesSettings Nothing ref = ref
applyReferencesSettings (Just (parent, onDel, onUpd)) (Just (parent', onDel', onUpd')) = Just (maybe parent' Right parent, onDel <|> onDel', onUpd <|> onUpd')
applyReferencesSettings (Just (Just parent, onDel, onUpd)) Nothing = Just (Right parent, onDel, onUpd)
applyReferencesSettings _ Nothing = error $ "applyReferencesSettings: expected type with reference, got Nothing"

primToPersistValue :: (PersistBackend m, PrimitivePersistField a) => a -> m ([PersistValue] -> [PersistValue])
primToPersistValue a = return (toPrimitivePersistValue a:)

primFromPersistValue :: (PersistBackend m, PrimitivePersistField a) => [PersistValue] -> m (a, [PersistValue])
primFromPersistValue (x:xs) = return (fromPrimitivePersistValue x, xs)
primFromPersistValue xs = (\a -> fail (failMessage a xs) >> return (a, xs)) undefined

primToPurePersistValues :: PrimitivePersistField a => a -> ([PersistValue] -> [PersistValue])
primToPurePersistValues a = (toPrimitivePersistValue a:)

primFromPurePersistValues :: PrimitivePersistField a => [PersistValue] -> (a, [PersistValue])
primFromPurePersistValues (x:xs) = (fromPrimitivePersistValue x, xs)
primFromPurePersistValues xs = (\a -> error (failMessage a xs) `asTypeOf` (a, xs)) undefined

primToSinglePersistValue :: (PersistBackend m, PrimitivePersistField a) => a -> m PersistValue
primToSinglePersistValue a = return (toPrimitivePersistValue a)

primFromSinglePersistValue :: (PersistBackend m, PrimitivePersistField a) => PersistValue -> m a
primFromSinglePersistValue a = return (fromPrimitivePersistValue a)

pureToPersistValue :: (PersistBackend m, PurePersistField a) => a -> m ([PersistValue] -> [PersistValue])
pureToPersistValue a = return (toPurePersistValues a)

pureFromPersistValue :: (PersistBackend m, PurePersistField a) => [PersistValue] -> m (a, [PersistValue])
pureFromPersistValue xs = return (fromPurePersistValues xs)

singleToPersistValue :: (PersistBackend m, SinglePersistField a) => a -> m ([PersistValue] -> [PersistValue])
singleToPersistValue a = toSinglePersistValue a >>= \x -> return (x:)

singleFromPersistValue :: (PersistBackend m, SinglePersistField a) => [PersistValue] -> m (a, [PersistValue])
singleFromPersistValue (x:xs) = fromSinglePersistValue x >>= \a -> return (a, xs)
singleFromPersistValue xs = (\a -> fail (failMessage a xs) >> return (a, xs)) undefined

toSinglePersistValueUnique :: forall m v u . (PersistBackend m, PersistEntity v, IsUniqueKey (Key v (Unique u)), PrimitivePersistField (Key v (Unique u)))
                           => u (UniqueMarker v) -> v -> m PersistValue
toSinglePersistValueUnique u v = insertBy u v >> primToSinglePersistValue (extractUnique v :: Key v (Unique u))

fromSinglePersistValueUnique :: forall m v u . (PersistBackend m, PersistEntity v, IsUniqueKey (Key v (Unique u)), PrimitivePersistField (Key v (Unique u)))
                             => u (UniqueMarker v) -> PersistValue -> m v
fromSinglePersistValueUnique _ x = getBy (fromPrimitivePersistValue x :: Key v (Unique u)) >>= maybe (fail $ "No data with id " ++ show x) return

toPersistValuesUnique :: forall m v u . (PersistBackend m, PersistEntity v, IsUniqueKey (Key v (Unique u)))
                      => u (UniqueMarker v) -> v -> m ([PersistValue] -> [PersistValue])
toPersistValuesUnique u v = insertBy u v >> toPersistValues (extractUnique v :: Key v (Unique u))

fromPersistValuesUnique :: forall m v u . (PersistBackend m, PersistEntity v, IsUniqueKey (Key v (Unique u)))
                        => u (UniqueMarker v) -> [PersistValue] -> m (v, [PersistValue])
fromPersistValuesUnique _ xs = fromPersistValues xs >>= \(k, xs') -> getBy (k :: Key v (Unique u)) >>= maybe (fail $ "No data with id " ++ show xs) (\v -> return (v, xs'))

toSinglePersistValueAutoKey :: forall m v . (PersistBackend m, PersistEntity v, PrimitivePersistField (AutoKey v))
                            => v -> m PersistValue
toSinglePersistValueAutoKey a = insertByAll a >>= primToSinglePersistValue . either id id

fromSinglePersistValueAutoKey :: forall m v . (PersistBackend m, PersistEntity v, PrimitivePersistField (Key v BackendSpecific))
                              => PersistValue -> m v
fromSinglePersistValueAutoKey x = get (fromPrimitivePersistValue x :: Key v BackendSpecific) >>= maybe (fail $ "No data with id " ++ show x) return

replaceOne :: (Eq x, Show x) => String -> (a -> x) -> (b -> x) -> (a -> b -> b) -> a -> [b] -> [b]
replaceOne what getter1 getter2 apply a bs = case filter ((getter1 a ==) . getter2) bs of
  [_] -> map (\b -> if getter1 a == getter2 b then apply a b else b) bs
  []  -> error $ "Not found " ++ what ++ " with name " ++ show (getter1 a)
  _   -> error $ "Found more than one " ++ what ++ " with name " ++ show (getter1 a)

findOne :: (Eq x, Show x) => String -> (a -> x) -> x -> [a] -> a
findOne what getter x as = case filter ((x ==) . getter) as of
  [a] -> a
  []  -> error $ "Not found " ++ what ++ " with name " ++ show x
  _   -> error $ "Found more than one " ++ what ++ " with name " ++ show x

-- | Returns only old elements, only new elements, and matched pairs (old, new).
-- The new ones exist only in datatype, the old are present only in DB, match is typically by name (the properties of the matched elements may differ).
matchElements :: Show a => (a -> b -> Bool) -> [a] -> [b] -> ([a], [b], [(a, b)])
matchElements eq oldElems newElems = foldr f (oldElems, [], []) newElems where
  f new (olds, news, matches) = case partition (`eq` new) olds of
    ([], rest) -> (rest, new:news, matches)
    ([old], rest) -> (rest, news, (old, new):matches)
    (xs, _) -> error $ "matchElements: more than one element matched " ++ show xs

haveSameElems :: Show a => (a -> b -> Bool) -> [a] -> [b] -> Bool
haveSameElems p xs ys = case matchElements p xs ys of
  ([], [], _) -> True
  _           -> False

phantomDb :: PersistBackend m => m (proxy (Conn m))
phantomDb = return $ error "phantomDb"

getDefaultAutoKeyType :: DbDescriptor db => proxy db -> DbTypePrimitive
getDefaultAutoKeyType proxy = case dbType proxy ((undefined :: proxy db -> AutoKeyType db) proxy) of
  DbTypePrimitive t _ _ _ -> t
  t -> error $ "getDefaultAutoKeyType: unexpected key type " ++ show t

firstRow :: MonadIO m => RowStream a -> m (Maybe a)
firstRow s = liftIO $ with s id

streamToList :: MonadIO m => RowStream a -> m [a]
streamToList s = liftIO $ with s go where
  go next = next >>= maybe (return []) (\a -> liftM (a:) (go next))

mapStream :: PersistBackendConn conn => (a -> Action conn b) -> RowStream a -> Action conn (RowStream b)
mapStream f s = do
  conn <- ask
  let apply next = next >>= \a -> case a of
        Nothing -> return Nothing
        Just a' -> liftM Just $ runReaderT (f a') conn
  return $ fmap apply s

joinStreams :: [Action conn (RowStream a)] -> Action conn (RowStream a)
joinStreams streams = do
  conn <- ask
  var <- liftIO $ newIORef $ ((return Nothing, const $ return ()), streams)
  return $ Acquire $ \restore -> do
    let joinedNext = do
          ((next, close), queue) <- readIORef var
          val <- next
          case val of
            Nothing -> case queue of
              [] -> return Nothing
              (makeStream:queue') -> do
                close ReleaseNormal
                Acquire f <- runReaderT makeStream conn
                Allocated next' close' <- f restore
                writeIORef var ((next', close'), queue')
            Just a -> return $ Just a
        joinedClose typ = readIORef var >>= \((_, close),_) -> close typ
    return $ Allocated joinedNext joinedClose

getUniqueFields :: UniqueDef' str (Either field str) -> [field]
getUniqueFields (UniqueDef _ _ uFields) = map (either id (error "A unique key may not contain expressions")) uFields

isSimple :: [ConstructorDef] -> Bool
isSimple [_] = True
isSimple _   = False

{-# DEPRECATED deleteByKey "Use deleteBy instead" #-}
deleteByKey :: (PersistBackend m, PersistEntity v, PrimitivePersistField (Key v BackendSpecific)) => Key v BackendSpecific -> m ()
deleteByKey = deleteBy