{-| Module : Database.Persist.Migration.Core Maintainer : Brandon Chinn Stability : experimental Portability : portable Defines a migration framework for the persistent library. -} {-# LANGUAGE DuplicateRecordFields #-} {-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RankNTypes #-} {-# LANGUAGE RecordWildCards #-} {-# OPTIONS_GHC -fno-warn-redundant-constraints #-} module Database.Persist.Migration.Core ( MigrateSettings(..) , defaultSettings , validateMigration , runMigration , getMigration ) where import Control.Monad (unless, when) import Control.Monad.IO.Class (MonadIO(..)) import Control.Monad.Reader (mapReaderT) import Data.List (nub) import Data.Maybe (fromMaybe) import qualified Data.Text as Text import Data.Time.Clock (getCurrentTime) import Database.Persist.Migration.Backend (MigrateBackend(..)) import Database.Persist.Migration.Operation ( Migration , MigrationPath(..) , Operation(..) , Version , opPath , validateOperation ) import Database.Persist.Migration.Operation.Types (Column(..), ColumnProp(..), TableConstraint(..)) import Database.Persist.Migration.Utils.Plan (getPath) import Database.Persist.Migration.Utils.Sql (MigrateSql, executeSql) import Database.Persist.Sql (PersistValue(..), Single(..), SqlPersistT, rawExecute, rawSql) import Database.Persist.Types (SqlType(..)) -- | Get the current version of the database, or Nothing if none exists. getCurrVersion :: MonadIO m => MigrateBackend -> SqlPersistT m (Maybe Version) getCurrVersion backend = do -- create the persistent_migration table if it doesn't already exist mapReaderT liftIO (getMigrationSql backend migrationSchema) >>= mapM_ executeSql extractVersion <$> rawSql queryVersion [] where migrationSchema = CreateTable { name = "persistent_migration" , schema = [ Column "id" SqlInt32 [NotNull, AutoIncrement] , Column "version" SqlInt32 [NotNull] , Column "label" SqlString [] , Column "timestamp" SqlDayTime [NotNull] ] , constraints = [ PrimaryKey ["id"] ] } queryVersion = "SELECT version FROM persistent_migration ORDER BY timestamp DESC LIMIT 1" extractVersion = \case [] -> Nothing [Single v] -> Just v _ -> error "Invalid response from the database." -- | Get the list of operations to run, given the current state of the database. getOperations :: Migration -> Maybe Version -> Either (Version, Version) [Operation] getOperations migration mVersion = case getPath edges start end of Just path -> Right $ concat path Nothing -> Left (start, end) where edges = map (\(path := ops) -> (path, ops)) migration start = fromMaybe (getFirstVersion migration) mVersion end = getLatestVersion migration -- | Get the first version in the given migration. getFirstVersion :: Migration -> Version getFirstVersion = minimum . map (fst . opPath) -- | Get the most up-to-date version in the given migration. getLatestVersion :: Migration -> Version getLatestVersion = maximum . map (snd . opPath) {- Migration plan and execution -} -- | Settings to customize migration steps. newtype MigrateSettings = MigrateSettings { versionToLabel :: Version -> Maybe String -- ^ A function to optionally label certain versions } -- | Default migration settings. defaultSettings :: MigrateSettings defaultSettings = MigrateSettings { versionToLabel = const Nothing } -- | Validate the given migration. validateMigration :: Migration -> Either String () validateMigration migration = do unless (allIncreasing opVersions) $ Left "Operation versions must be monotonically increasing" when (hasDuplicates opVersions) $ Left "There may only be one operation per pair of versions" where opVersions = map opPath migration allIncreasing = all (uncurry (<)) hasDuplicates l = length (nub l) < length l -- | Run the given migration. After successful completion, saves the migration to the database. runMigration :: MonadIO m => MigrateBackend -> MigrateSettings -> Migration -> SqlPersistT m () runMigration backend settings@MigrateSettings{..} migration = do getMigration backend settings migration >>= mapM_ executeSql now <- liftIO getCurrentTime let version = getLatestVersion migration rawExecute "INSERT INTO persistent_migration(version, label, timestamp) VALUES (?, ?, ?)" [ PersistInt64 $ fromIntegral version , PersistText $ Text.pack $ fromMaybe (show version) $ versionToLabel version , PersistUTCTime now ] -- | Get the SQL queries for the given migration. getMigration :: MonadIO m => MigrateBackend -> MigrateSettings -> Migration -> SqlPersistT m [MigrateSql] getMigration backend _ migration = do either fail return $ validateMigration migration currVersion <- getCurrVersion backend operations <- either badPath return $ getOperations migration currVersion either fail return $ mapM_ validateOperation operations concatMapM (mapReaderT liftIO . getMigrationSql backend) operations where badPath (start, end) = fail $ "Could not find path: " ++ show start ++ " ~> " ++ show end -- Utilities concatMapM f = fmap concat . mapM f