{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# OPTIONS_GHC -fno-warn-redundant-constraints #-}
module Database.Persist.Migration.Core
( MigrateSettings(..)
, defaultSettings
, validateMigration
, runMigration
, getMigration
) where
import Control.Monad (unless, when)
import Control.Monad.IO.Class (MonadIO(..))
import Control.Monad.Reader (mapReaderT)
import Data.List (nub)
import Data.Maybe (fromMaybe)
import qualified Data.Text as Text
import Data.Time.Clock (getCurrentTime)
import Database.Persist.Migration.Backend (MigrateBackend(..))
import Database.Persist.Migration.Operation
( Migration
, MigrationPath(..)
, Operation(..)
, Version
, opPath
, validateOperation
)
import Database.Persist.Migration.Operation.Types
(Column(..), ColumnProp(..), TableConstraint(..))
import Database.Persist.Migration.Utils.Plan (getPath)
import Database.Persist.Migration.Utils.Sql (MigrateSql, executeSql)
import Database.Persist.Sql
(PersistValue(..), Single(..), SqlPersistT, rawExecute, rawSql)
import Database.Persist.Types (SqlType(..))
getCurrVersion :: MonadIO m => MigrateBackend -> SqlPersistT m (Maybe Version)
getCurrVersion backend = do
mapReaderT liftIO (getMigrationSql backend migrationSchema) >>= mapM_ executeSql
extractVersion <$> rawSql queryVersion []
where
migrationSchema = CreateTable
{ name = "persistent_migration"
, schema =
[ Column "id" SqlInt32 [NotNull, AutoIncrement]
, Column "version" SqlInt32 [NotNull]
, Column "label" SqlString []
, Column "timestamp" SqlDayTime [NotNull]
]
, constraints =
[ PrimaryKey ["id"]
]
}
queryVersion = "SELECT version FROM persistent_migration ORDER BY timestamp DESC LIMIT 1"
extractVersion = \case
[] -> Nothing
[Single v] -> Just v
_ -> error "Invalid response from the database."
getOperations :: Migration -> Maybe Version -> Either (Version, Version) [Operation]
getOperations migration mVersion = case getPath edges start end of
Just path -> Right $ concat path
Nothing -> Left (start, end)
where
edges = map (\(path := ops) -> (path, ops)) migration
start = fromMaybe (getFirstVersion migration) mVersion
end = getLatestVersion migration
getFirstVersion :: Migration -> Version
getFirstVersion = minimum . map (fst . opPath)
getLatestVersion :: Migration -> Version
getLatestVersion = maximum . map (snd . opPath)
newtype MigrateSettings = MigrateSettings
{ versionToLabel :: Version -> Maybe String
}
defaultSettings :: MigrateSettings
defaultSettings = MigrateSettings
{ versionToLabel = const Nothing
}
validateMigration :: Migration -> Either String ()
validateMigration migration = do
unless (allIncreasing opVersions) $
Left "Operation versions must be monotonically increasing"
when (hasDuplicates opVersions) $
Left "There may only be one operation per pair of versions"
where
opVersions = map opPath migration
allIncreasing = all (uncurry (<))
hasDuplicates l = length (nub l) < length l
runMigration :: MonadIO m => MigrateBackend -> MigrateSettings -> Migration -> SqlPersistT m ()
runMigration backend settings@MigrateSettings{..} migration = do
getMigration backend settings migration >>= mapM_ executeSql
now <- liftIO getCurrentTime
let version = getLatestVersion migration
rawExecute "INSERT INTO persistent_migration(version, label, timestamp) VALUES (?, ?, ?)"
[ PersistInt64 $ fromIntegral version
, PersistText $ Text.pack $ fromMaybe (show version) $ versionToLabel version
, PersistUTCTime now
]
getMigration :: MonadIO m
=> MigrateBackend
-> MigrateSettings
-> Migration
-> SqlPersistT m [MigrateSql]
getMigration backend _ migration = do
either fail return $ validateMigration migration
currVersion <- getCurrVersion backend
operations <- either badPath return $ getOperations migration currVersion
either fail return $ mapM_ validateOperation operations
concatMapM (mapReaderT liftIO . getMigrationSql backend) operations
where
badPath (start, end) = fail $ "Could not find path: " ++ show start ++ " ~> " ++ show end
concatMapM f = fmap concat . mapM f