-- | This helper module is intended for use by the backend creators
module Database.Groundhog.Generic
  ( migrateRecursively
  , createMigration
  , executeMigration
  , executeMigrationUnsafe
  , runMigration
  , runMigrationUnsafe
  , printMigration
  , getEntityName
  , mergeMigrations
  , silentMigrationLogger
  , defaultMigrationLogger
  , defaultSelect
  , defaultSelectAll
  ) where

import Database.Groundhog.Core

import Control.Monad(liftM, forM_)
import Control.Monad.Trans.State
import Control.Monad.Trans.Class(lift)
import Control.Monad.IO.Class (MonadIO (..))
import Data.Enumerator(Iteratee(..), run, (==<<))
import Data.Enumerator.List(consume)
import Data.Either(partitionEithers)
import Data.List(intercalate)
import qualified Data.Map as Map

-- | Create migration for a given entity and all entities it depends on.
-- The stateful Map is used to avoid duplicate migrations when an entity type
-- occurs several times in a datatype
migrateRecursively :: (Monad m, PersistEntity e) => 
     (EntityDef -> m SingleMigration)          -- ^ migrate entity
  -> (Int -> [NamedType] -> m SingleMigration) -- ^ migrate tuple
  -> (NamedType -> m SingleMigration)          -- ^ migrate list
  -> e                                         -- ^ initial entity
  -> StateT NamedMigrations m ()
migrateRecursively migE migT migL = go . namedType where
  go w = case getType w of
    (DbList t)     -> f (getName w) (migL t) (go t)
    (DbTuple n ts) -> f (getName w) (migT n ts) (mapM_ go ts)
    (DbEntity e) -> f (getName w) (migE e) (mapM_ go (allSubtypes e))
    (DbMaybe t)    -> go t
    _              -> return ()    -- ordinary types need not migration
  f name mig cont = do
    v <- gets (Map.lookup name)
    case v of
      Nothing -> lift mig >>= modify.Map.insert name >> cont
      _ -> return ()
  allSubtypes = map snd . concatMap constrParams . constructors

getCorrectMigrations :: NamedMigrations -> [(Bool, String)]
getCorrectMigrations = either (error.unlines) id . mergeMigrations . Map.elems

-- | Produce the migrations but not execute them. Fails when an unsafe migration occurs.
createMigration :: PersistBackend m => Migration m -> m NamedMigrations
createMigration m = liftM snd $ runStateT m Map.empty

-- | Execute the migrations and log them. 
executeMigration :: (PersistBackend m, MonadIO m) => (String -> IO ()) -> NamedMigrations -> m ()
executeMigration logger m = do
  let migs = getCorrectMigrations m
  let unsafe = map snd $ filter fst migs
  if null unsafe
    then mapM_ (executeMigrate logger.snd) migs
    else error $ concat
            [ "\n\nDatabase migration: manual intervention required.\n"
            , "The following actions are considered unsafe:\n\n"
            , unlines $ map (\s -> "    " ++ s ++ ";") unsafe
            ]

-- | Execute migrations and log them. Executes the unsafe migrations without warnings
executeMigrationUnsafe :: (PersistBackend m, MonadIO m) => (String -> IO ()) -> NamedMigrations -> m ()
executeMigrationUnsafe logger = mapM_ (executeMigrate logger.snd) . getCorrectMigrations

-- | Pretty print the migrations
printMigration :: MonadIO m => NamedMigrations -> m ()
printMigration migs = liftIO $ do
  let kv = Map.assocs migs
  forM_ kv $ \(k, v) -> do
    putStrLn $ "Datatype " ++ k ++ ":"
    case v of
      Left errors -> mapM_ (putStrLn . ("\tError:\t" ++)) errors
      Right sqls  -> do
        let showSql (isUnsafe, sql) = (if isUnsafe then "Unsafe:\t" else "Safe:\t") ++ sql
        mapM_ (putStrLn . ("\t" ++).showSql) sqls

-- | Run migrations and log them. Fails when an unsafe migration occurs.
runMigration :: (PersistBackend m, MonadIO m) => (String -> IO ()) -> Migration m -> m ()
runMigration logger m = createMigration m >>= executeMigration logger

-- | Run migrations and log them. Executes the unsafe migrations without warnings
runMigrationUnsafe :: (PersistBackend m, MonadIO m) => (String -> IO ()) -> Migration m -> m ()
runMigrationUnsafe logger m = createMigration m >>= executeMigrationUnsafe logger

executeMigrate :: (PersistBackend m, MonadIO m) => (String -> IO ()) -> String -> m ()
executeMigrate logger query = do
  liftIO $ logger query
  executeRaw False query []
  return ()

-- | No-op
silentMigrationLogger :: String -> IO ()
silentMigrationLogger _ = return ()

-- | Prints the queries to stdout
defaultMigrationLogger :: String -> IO ()
defaultMigrationLogger query = putStrLn $ "Migrating: " ++ query

-- | Joins the migrations. The result is either all error messages or all queries
mergeMigrations :: [SingleMigration] -> SingleMigration
mergeMigrations ms =
  let (errors, statements) = partitionEithers ms
  in if null errors
       then Right (concat statements)
       else Left  (concat errors)

-- | Get full entity name with the names of its parameters.
--
-- @ getEntityName (entityDef v) == persistName v @
getEntityName :: EntityDef -> String
getEntityName e = intercalate "$" $ entityName e:map getName (typeParams e)

-- | Call 'selectEnum' but return the result as a list
defaultSelect :: (PersistBackend m, PersistEntity v, Constructor c) => Cond v c -> [Order v c] -> Int -> Int -> m [(Key v, v)]
defaultSelect cond ord off lim = do
    res <- run $ selectEnum cond ord off lim ==<< consume
    case res of
        Left e -> error $ show e
        Right x -> return x

-- | Call 'selectAllEnum' but return the result as a list
defaultSelectAll :: (PersistBackend m, PersistEntity v) => m [(Key v, v)]
defaultSelectAll = do
    res <- run $ Iteratee (runIteratee consume >>= runIteratee . selectAllEnum)
    case res of
        Left e -> error $ show e
        Right x -> return x