module Database.PostgreSQL.PQTypes.Checks ( -- * Checks checkDatabase , checkDatabaseAllowUnknownTables , createTable , createDomain -- * Options , ExtrasOptions(..) -- * Migrations , migrateDatabase ) where import Control.Applicative ((<$>)) import Control.Monad.Catch import Control.Monad.Reader import Data.Int import Data.Function (on) import Data.Maybe import Data.Monoid import Data.Monoid.Utils import Data.Ord (comparing) import qualified Data.String import Data.Text (Text) import Database.PostgreSQL.PQTypes hiding (def) import Log import Prelude import TextShow import qualified Data.List as L import qualified Data.Text as T import Database.PostgreSQL.PQTypes.ExtrasOptions import Database.PostgreSQL.PQTypes.Checks.Util import Database.PostgreSQL.PQTypes.Migrate import Database.PostgreSQL.PQTypes.Model import Database.PostgreSQL.PQTypes.SQL.Builder import Database.PostgreSQL.PQTypes.Versions headExc :: String -> [a] -> a headExc s [] = error s headExc _ (x:_) = x ---------------------------------------- -- | Run migrations and check the database structure. migrateDatabase :: (MonadDB m, MonadLog m, MonadThrow m) => ExtrasOptions -> [Extension] -> [Domain] -> [Table] -> [Migration m] -> m () migrateDatabase options@ExtrasOptions{..} extensions domains tables migrations = do setDBTimeZoneToUTC mapM_ checkExtension extensions tablesWithVersions <- getTableVersions (tableVersions : tables) -- 'checkDBConsistency' also performs migrations. checkDBConsistency options domains tablesWithVersions migrations resultCheck =<< checkDomainsStructure domains resultCheck =<< checkDBStructure options tablesWithVersions resultCheck =<< checkTablesWereDropped migrations resultCheck =<< checkUnknownTables tables resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables) -- everything is OK, commit changes commit -- | Run checks on the database structure and whether the database -- needs to be migrated. Will do a full check of DB structure. checkDatabase :: forall m . (MonadDB m, MonadLog m, MonadThrow m) => ExtrasOptions -> [Domain] -> [Table] -> m () checkDatabase options = checkDatabase_ options False -- | Same as 'checkDatabase', but will not failed if there are -- additional tables in database. checkDatabaseAllowUnknownTables :: forall m . (MonadDB m, MonadLog m, MonadThrow m) => ExtrasOptions -> [Domain] -> [Table] -> m () checkDatabaseAllowUnknownTables options = checkDatabase_ options True checkDatabase_ :: forall m . (MonadDB m, MonadLog m, MonadThrow m) => ExtrasOptions -> Bool -> [Domain] -> [Table] -> m () checkDatabase_ options allowUnknownTables domains tables = do tablesWithVersions <- getTableVersions (tableVersions : tables) resultCheck $ checkVersions tablesWithVersions resultCheck =<< checkDomainsStructure domains resultCheck =<< checkDBStructure options tablesWithVersions when (not $ allowUnknownTables) $ do resultCheck =<< checkUnknownTables tables resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables) -- Check initial setups only after database structure is considered -- consistent as before that some of the checks may fail internally. resultCheck =<< checkInitialSetups tables where checkVersions :: [(Table, Int32)] -> ValidationResult checkVersions vs = mconcat . map checkVersion $ vs checkVersion :: (Table, Int32) -> ValidationResult checkVersion (t@Table{..}, v) | tblVersion `elem` tblAcceptedDbVersions = validationError $ "Table '" <> tblNameText t <> "' has its current table version in accepted db versions" | tblVersion == v || v `elem` tblAcceptedDbVersions = mempty | v == 0 = validationError $ "Table '" <> tblNameText t <> "' must be created" | otherwise = validationError $ "Table '" <> tblNameText t <> "' must be migrated" <+> showt v <+> "->" <+> showt tblVersion checkInitialSetups :: [Table] -> m ValidationResult checkInitialSetups tbls = liftM mconcat . mapM checkInitialSetup' $ tbls checkInitialSetup' :: Table -> m ValidationResult checkInitialSetup' t@Table{..} = case tblInitialSetup of Nothing -> return mempty Just is -> checkInitialSetup is >>= \case True -> return mempty False -> return . validationError $ "Initial setup for table '" <> tblNameText t <> "' is not valid" -- | Return SQL fragment of current catalog within quotes currentCatalog :: (MonadDB m, MonadThrow m) => m (RawSQL ()) currentCatalog = do runSQL_ "SELECT current_catalog::text" dbname <- fetchOne runIdentity return $ unsafeSQL $ "\"" ++ dbname ++ "\"" -- | Check for a given extension. We need to read from 'pg_extension' -- table as Amazon RDS limits usage of 'CREATE EXTENSION IF NOT EXISTS'. checkExtension :: (MonadDB m, MonadLog m, MonadThrow m) => Extension -> m () checkExtension (Extension extension) = do logInfo_ $ "Checking for extension '" <> txtExtension <> "'" extensionExists <- runQuery01 . sqlSelect "pg_extension" $ do sqlResult "TRUE" sqlWhereEq "extname" $ unRawSQL extension if not extensionExists then do logInfo_ $ "Creating extension '" <> txtExtension <> "'" runSQL_ $ "CREATE EXTENSION IF NOT EXISTS" <+> raw extension else logInfo_ $ "Extension '" <> txtExtension <> "' exists" where txtExtension = unRawSQL extension -- | Check whether the database returns timestamps in UTC, and set the -- timezone to UTC if it doesn't. setDBTimeZoneToUTC :: (MonadDB m, MonadLog m, MonadThrow m) => m () setDBTimeZoneToUTC = do runSQL_ "SHOW timezone" timezone :: String <- fetchOne runIdentity when (timezone /= "UTC") $ do dbname <- currentCatalog logInfo_ $ "Setting '" <> unRawSQL dbname <> "' database to return timestamps in UTC" runQuery_ $ "ALTER DATABASE" <+> dbname <+> "SET TIMEZONE = 'UTC'" -- | Get the names of all user-defined tables that actually exist in -- the DB. getDBTableNames :: (MonadDB m) => m [Text] getDBTableNames = do runQuery_ $ sqlSelect "information_schema.tables" $ do sqlResult "table_name::text" sqlWhere "table_name <> 'table_versions'" sqlWhere "table_type = 'BASE TABLE'" sqlWhereExists $ sqlSelect "unnest(current_schemas(false)) as cs" $ do sqlResult "TRUE" sqlWhere "cs = table_schema" dbTableNames <- fetchMany runIdentity return dbTableNames -- | Check that there's a 1-to-1 correspondence between the list of -- 'Table's and what's actually in the database. checkUnknownTables :: (MonadDB m, MonadLog m) => [Table] -> m ValidationResult checkUnknownTables tables = do dbTableNames <- getDBTableNames let tableNames = map (unRawSQL . tblName) tables absent = dbTableNames L.\\ tableNames notPresent = tableNames L.\\ dbTableNames if (not . null $ absent) || (not . null $ notPresent) then do mapM_ (logInfo_ . (<+>) "Unknown table:") absent mapM_ (logInfo_ . (<+>) "Table not present in the database:") notPresent return $ (validateIsNull "Unknown tables:" absent) <> (validateIsNull "Tables not present in the database:" notPresent) else return mempty validateIsNull :: Text -> [Text] -> ValidationResult validateIsNull _ [] = mempty validateIsNull msg ts = validationError $ msg <+> T.intercalate ", " ts -- | Check that there's a 1-to-1 correspondence between the list of -- 'Table's and what's actually in the table 'table_versions'. checkExistenceOfVersionsForTables :: (MonadDB m, MonadLog m) => [Table] -> m ValidationResult checkExistenceOfVersionsForTables tables = do runQuery_ $ sqlSelect "table_versions" $ do sqlResult "name::text" (existingTableNames :: [Text]) <- fetchMany runIdentity let tableNames = map (unRawSQL . tblName) tables absent = existingTableNames L.\\ tableNames notPresent = tableNames L.\\ existingTableNames if (not . null $ absent) || (not . null $ notPresent) then do mapM_ (logInfo_ . (<+>) "Unknown entry in 'table_versions':") absent mapM_ (logInfo_ . (<+>) "Table not present in the 'table_versions':") notPresent return $ (validateIsNull "Unknown entry in table_versions':" absent ) <> (validateIsNull "Tables not present in the 'table_versions':" notPresent) else return mempty checkDomainsStructure :: (MonadDB m, MonadThrow m) => [Domain] -> m ValidationResult checkDomainsStructure defs = fmap mconcat . forM defs $ \def -> do runQuery_ . sqlSelect "pg_catalog.pg_type t1" $ do sqlResult "t1.typname::text" -- name sqlResult "(SELECT pg_catalog.format_type(t2.oid, t2.typtypmod) FROM pg_catalog.pg_type t2 WHERE t2.oid = t1.typbasetype)" -- type sqlResult "NOT t1.typnotnull" -- nullable sqlResult "t1.typdefault" -- default value sqlResult "ARRAY(SELECT c.conname::text FROM pg_catalog.pg_constraint c WHERE c.contypid = t1.oid ORDER by c.oid)" -- constraint names sqlResult "ARRAY(SELECT regexp_replace(pg_get_constraintdef(c.oid, true), 'CHECK \\((.*)\\)', '\\1') FROM pg_catalog.pg_constraint c WHERE c.contypid = t1.oid ORDER by c.oid)" -- constraint definitions sqlResult "ARRAY(SELECT c.convalidated FROM pg_catalog.pg_constraint c WHERE c.contypid = t1.oid ORDER by c.oid)" -- are constraints validated? sqlWhereEq "t1.typname" $ unRawSQL $ domName def mdom <- fetchMaybe $ \(dname, dtype, nullable, defval, cnames, conds, valids) -> Domain { domName = unsafeSQL dname , domType = dtype , domNullable = nullable , domDefault = unsafeSQL <$> defval , domChecks = mkChecks $ zipWith3 (\cname cond validated -> Check { chkName = unsafeSQL cname , chkCondition = unsafeSQL cond , chkValidated = validated }) (unArray1 cnames) (unArray1 conds) (unArray1 valids) } return $ case mdom of Just dom | dom /= def -> topMessage "domain" (unRawSQL $ domName dom) $ mconcat [ compareAttr dom def "name" domName , compareAttr dom def "type" domType , compareAttr dom def "nullable" domNullable , compareAttr dom def "default" domDefault , compareAttr dom def "checks" domChecks ] | otherwise -> mempty Nothing -> validationError $ "Domain '" <> unRawSQL (domName def) <> "' doesn't exist in the database" where compareAttr :: (Eq a, Show a) => Domain -> Domain -> Text -> (Domain -> a) -> ValidationResult compareAttr dom def attrname attr | attr dom == attr def = mempty | otherwise = validationError $ "Attribute '" <> attrname <> "' does not match (database:" <+> T.pack (show $ attr dom) <> ", definition:" <+> T.pack (show $ attr def) <> ")" -- | Check that the tables that must have been dropped are actually -- missing from the DB. checkTablesWereDropped :: (MonadDB m, MonadThrow m) => [Migration m] -> m ValidationResult checkTablesWereDropped mgrs = do let droppedTableNames = [ mgrTableName mgr | mgr <- mgrs, isDropTableMigration mgr ] fmap mconcat . forM droppedTableNames $ \tblName -> do mver <- checkTableVersion (T.unpack . unRawSQL $ tblName) return $ if isNothing mver then mempty else validationError $ "The table '" <> unRawSQL tblName <> "' that must have been dropped" <> " is still present in the database." -- | Checks whether the database is consistent. checkDBStructure :: forall m. (MonadDB m, MonadThrow m) => ExtrasOptions -> [(Table, Int32)] -> m ValidationResult checkDBStructure options tables = fmap mconcat . forM tables $ \(table, version) -> do result <- topMessage "table" (tblNameText table) <$> checkTableStructure table -- If one of the accepted versions defined for the table is the current table -- version in the database, show inconsistencies as info messages only. return $ if version `elem` tblAcceptedDbVersions table then validationErrorsToInfos result else result where checkTableStructure :: Table -> m ValidationResult checkTableStructure table@Table{..} = do -- get table description from pg_catalog as describeTable -- mechanism from HDBC doesn't give accurate results runQuery_ $ sqlSelect "pg_catalog.pg_attribute a" $ do sqlResult "a.attname::text" sqlResult "pg_catalog.format_type(a.atttypid, a.atttypmod)" sqlResult "NOT a.attnotnull" sqlResult . parenthesize . toSQLCommand $ sqlSelect "pg_catalog.pg_attrdef d" $ do sqlResult "pg_catalog.pg_get_expr(d.adbin, d.adrelid)" sqlWhere "d.adrelid = a.attrelid" sqlWhere "d.adnum = a.attnum" sqlWhere "a.atthasdef" sqlWhere "a.attnum > 0" sqlWhere "NOT a.attisdropped" sqlWhereEqSql "a.attrelid" $ sqlGetTableID table sqlOrderBy "a.attnum" desc <- fetchMany fetchTableColumn -- get info about constraints from pg_catalog pk <- sqlGetPrimaryKey table runQuery_ $ sqlGetChecks table checks <- fetchMany fetchTableCheck runQuery_ $ sqlGetIndexes table indexes <- fetchMany fetchTableIndex runQuery_ $ sqlGetForeignKeys table fkeys <- fetchMany fetchForeignKey return $ mconcat [ checkColumns 1 tblColumns desc , checkPrimaryKey tblPrimaryKey pk , checkChecks tblChecks checks , checkIndexes tblIndexes indexes , checkForeignKeys tblForeignKeys fkeys ] where fetchTableColumn :: (String, ColumnType, Bool, Maybe String) -> TableColumn fetchTableColumn (name, ctype, nullable, mdefault) = TableColumn { colName = unsafeSQL name , colType = ctype , colNullable = nullable , colDefault = unsafeSQL `liftM` mdefault } checkColumns :: Int -> [TableColumn] -> [TableColumn] -> ValidationResult checkColumns _ [] [] = mempty checkColumns _ rest [] = validationError $ tableHasLess "columns" rest checkColumns _ [] rest = validationError $ tableHasMore "columns" rest checkColumns !n (d:defs) (c:cols) = mconcat [ validateNames $ colName d == colName c -- bigserial == bigint + autoincrement and there is no -- distinction between them after table is created. , validateTypes $ colType d == colType c || (colType d == BigSerialT && colType c == BigIntT) -- there is a problem with default values determined by sequences as -- they're implicitely specified by db, so let's omit them in such case , validateDefaults $ colDefault d == colDefault c || (colDefault d == Nothing && ((T.isPrefixOf "nextval('" . unRawSQL) `liftM` colDefault c) == Just True) , validateNullables $ colNullable d == colNullable c , checkColumns (n+1) defs cols ] where validateNames True = mempty validateNames False = validationError $ errorMsg ("no. " <> showt n) "names" (unRawSQL . colName) validateTypes True = mempty validateTypes False = validationError $ errorMsg cname "types" (T.pack . show . colType) <+> sqlHint ("TYPE" <+> columnTypeToSQL (colType d)) validateNullables True = mempty validateNullables False = validationError $ errorMsg cname "nullables" (showt . colNullable) <+> sqlHint ((if colNullable d then "DROP" else "SET") <+> "NOT NULL") validateDefaults True = mempty validateDefaults False = validationError $ (errorMsg cname "defaults" (showt . fmap unRawSQL . colDefault)) <+> sqlHint set_default where set_default = case colDefault d of Just v -> "SET DEFAULT" <+> v Nothing -> "DROP DEFAULT" cname = unRawSQL $ colName d errorMsg ident attr f = "Column '" <> ident <> "' differs in" <+> attr <+> "(table:" <+> f c <> ", definition:" <+> f d <> ")." sqlHint sql = "(HINT: SQL for making the change is: ALTER TABLE" <+> tblNameText table <+> "ALTER COLUMN" <+> unRawSQL (colName d) <+> unRawSQL sql <> ")" checkPrimaryKey :: Maybe PrimaryKey -> Maybe (PrimaryKey, RawSQL ()) -> ValidationResult checkPrimaryKey mdef mpk = mconcat [ checkEquality "PRIMARY KEY" def (map fst pk) , checkNames (const (pkName tblName)) pk , if (eoEnforcePKs options) then checkPKPresence tblName mdef mpk else mempty ] where def = maybeToList mdef pk = maybeToList mpk checkChecks :: [Check] -> [Check] -> ValidationResult checkChecks defs checks = mapValidationResult id mapErrs (checkEquality "CHECKs" defs checks) where mapErrs [] = [] mapErrs errmsgs = errmsgs <> [ " (HINT: If checks are equal modulo number of \ \ parentheses/whitespaces used in conditions, \ \ just copy and paste expected output into source code)" ] checkIndexes :: [TableIndex] -> [(TableIndex, RawSQL ())] -> ValidationResult checkIndexes defs indexes = mconcat [ checkEquality "INDEXes" defs (map fst indexes) , checkNames (indexName tblName) indexes ] checkForeignKeys :: [ForeignKey] -> [(ForeignKey, RawSQL ())] -> ValidationResult checkForeignKeys defs fkeys = mconcat [ checkEquality "FOREIGN KEYs" defs (map fst fkeys) , checkNames (fkName tblName) fkeys ] -- | Checks whether database is consistent, performing migrations if -- necessary. Requires all table names to be in lower case. -- -- The migrations list must have the following properties: -- * consecutive 'mgrFrom' numbers -- * no duplicates -- * all 'mgrFrom' are less than table version number of the table in -- the 'tables' list checkDBConsistency :: forall m. (MonadDB m, MonadLog m, MonadThrow m) => ExtrasOptions -> [Domain] -> [(Table, Int32)] -> [Migration m] -> m () checkDBConsistency options domains tablesWithVersions migrations = do -- Check the validity of the migrations list. validateMigrations validateDropTableMigrations -- Load version numbers of the tables that actually exist in the DB. dbTablesWithVersions <- getDBTableVersions if all ((==) 0 . snd) tablesWithVersions -- No tables are present, create everything from scratch. then do createDBSchema initializeDB -- Migration mode. else do -- Additional validity checks for the migrations list. validateMigrationsAgainstDB [ (tblName table, tblVersion table, actualVer) | (table, actualVer) <- tablesWithVersions ] validateDropTableMigrationsAgainstDB dbTablesWithVersions -- Run migrations, if necessary. runMigrations dbTablesWithVersions where tables = map fst tablesWithVersions errorInvalidMigrations :: [RawSQL ()] -> a errorInvalidMigrations tblNames = error $ "checkDBConsistency: invalid migrations for tables" <+> (L.intercalate ", " $ map (T.unpack . unRawSQL) tblNames) checkMigrationsListValidity :: Table -> [Int32] -> [Int32] -> m () checkMigrationsListValidity table presentMigrationVersions expectedMigrationVersions = do when (presentMigrationVersions /= expectedMigrationVersions) $ do logAttention "Migrations are invalid" $ object [ "table" .= tblNameText table , "migration_versions" .= presentMigrationVersions , "expected_migration_versions" .= expectedMigrationVersions ] errorInvalidMigrations [tblName $ table] validateMigrations :: m () validateMigrations = forM_ tables $ \table -> do let presentMigrationVersions = [ mgrFrom | Migration{..} <- migrations , mgrTableName == tblName table ] expectedMigrationVersions = reverse $ take (length presentMigrationVersions) $ reverse [0 .. tblVersion table - 1] checkMigrationsListValidity table presentMigrationVersions expectedMigrationVersions validateDropTableMigrations :: m () validateDropTableMigrations = do let droppedTableNames = [ mgrTableName $ mgr | mgr <- migrations , isDropTableMigration mgr ] tableNames = [ tblName tbl | tbl <- tables ] -- Check that the intersection between the 'tables' list and -- dropped tables is empty. let intersection = L.intersect droppedTableNames tableNames when (not . null $ intersection) $ do logAttention ("The intersection between tables " <> "and dropped tables is not empty") $ object [ "intersection" .= map unRawSQL intersection ] errorInvalidMigrations [ tblName tbl | tbl <- tables , tblName tbl `elem` intersection ] -- Check that if a list of migrations for a given table has a -- drop table migration, it is unique and is the last migration -- in the list. let migrationsByTable = L.groupBy ((==) `on` mgrTableName) migrations dropMigrationLists = [ mgrs | mgrs <- migrationsByTable , any isDropTableMigration mgrs ] invalidMigrationLists = [ mgrs | mgrs <- dropMigrationLists , (not . isDropTableMigration . last $ mgrs) || (length . filter isDropTableMigration $ mgrs) > 1 ] when (not . null $ invalidMigrationLists) $ do let tablesWithInvalidMigrationLists = [ mgrTableName mgr | mgrs <- invalidMigrationLists , let mgr = head mgrs ] logAttention ("Migration lists for some tables contain " <> "either multiple drop table migrations or " <> "a drop table migration in non-tail position.") $ object [ "tables" .= [ unRawSQL tblName | tblName <- tablesWithInvalidMigrationLists ] ] errorInvalidMigrations tablesWithInvalidMigrationLists createDBSchema :: m () createDBSchema = do logInfo_ "Creating domains..." mapM_ createDomain domains -- Create all tables with no constraints first to allow cyclic references. logInfo_ "Creating tables..." mapM_ (createTable False) tables logInfo_ "Creating table constraints..." mapM_ createTableConstraints tables logInfo_ "Done." initializeDB :: m () initializeDB = do logInfo_ "Running initial setup for tables..." forM_ tables $ \t -> case tblInitialSetup t of Nothing -> return () Just tis -> do logInfo_ $ "Initializing" <+> tblNameText t <> "..." initialSetup tis logInfo_ "Done." -- | Input is a list of (table name, expected version, actual version) triples. validateMigrationsAgainstDB :: [(RawSQL (), Int32, Int32)] -> m () validateMigrationsAgainstDB tablesWithVersions_ = forM_ tablesWithVersions_ $ \(tableName, expectedVer, actualVer) -> when (expectedVer /= actualVer) $ case [ m | m@Migration{..} <- migrations , mgrTableName == tableName ] of [] -> error $ "checkDBConsistency: no migrations found for table '" ++ (T.unpack . unRawSQL $ tableName) ++ "', cannot migrate " ++ show actualVer ++ " -> " ++ show expectedVer (m:_) | mgrFrom m > actualVer -> error $ "checkDBConsistency: earliest migration for table '" ++ (T.unpack . unRawSQL $ tableName) ++ "' is from version " ++ show (mgrFrom m) ++ ", cannot migrate " ++ show actualVer ++ " -> " ++ show expectedVer | otherwise -> return () validateDropTableMigrationsAgainstDB :: [(Text, Int32)] -> m () validateDropTableMigrationsAgainstDB dbTablesWithVersions = do let dbTablesToDropWithVersions = [ (tblName, mgrFrom mgr, fromJust mver) | mgr <- migrations , isDropTableMigration mgr , let tblName = mgrTableName mgr , let mver = lookup (unRawSQL tblName) $ dbTablesWithVersions , isJust mver ] forM_ dbTablesToDropWithVersions $ \(tblName, fromVer, ver) -> when (fromVer /= ver) $ -- In case when the table we're going to drop is an old -- version, check that there are migrations that bring it to a new one. validateMigrationsAgainstDB [(tblName, fromVer, ver)] findMigrationsToRun :: [(Text, Int32)] -> [Migration m] findMigrationsToRun dbTablesWithVersions = let tableNamesToDrop = [ mgrTableName mgr | mgr <- migrations , isDropTableMigration mgr ] droppedEventually :: Migration m -> Bool droppedEventually mgr = mgrTableName mgr `elem` tableNamesToDrop lookupVer :: Migration m -> Maybe Int32 lookupVer mgr = lookup (unRawSQL $ mgrTableName mgr) dbTablesWithVersions tableDoesNotExist = isNothing . lookupVer -- The idea here is that we find the first migration we need -- to run and then just run all migrations in order after -- that one. migrationsToRun' = dropWhile (\mgr -> case lookupVer mgr of -- Table doesn't exist in the DB. If it's a create -- table migration and we're not going to drop the -- table afterwards, this is our starting point. Nothing -> not $ (mgrFrom mgr == 0) && (not . droppedEventually $ mgr) -- Table exists in the DB. Run only those migrations -- that have mgrFrom >= table version in the DB. Just ver -> not $ mgrFrom mgr >= ver) migrations -- Special case: also include migrations for tables that do -- not exist in the DB and ARE going to be dropped if they -- come as a consecutive list before the starting point that -- we've found. -- -- Case in point: createTable t, doSomethingTo t, -- doSomethingTo t1, dropTable t. l = length migrationsToRun' initialMigrations = drop l $ reverse migrations additionalMigrations = takeWhile (\mgr -> droppedEventually mgr && tableDoesNotExist mgr) initialMigrations migrationsToRun = (reverse additionalMigrations) ++ migrationsToRun' in migrationsToRun runMigration :: (Migration m) -> m () runMigration Migration{..} = do case mgrAction of StandardMigration mgrDo -> do logInfo_ $ arrListTable mgrTableName <> showt mgrFrom <+> "->" <+> showt (succ mgrFrom) mgrDo runQuery_ $ sqlUpdate "table_versions" $ do sqlSet "version" (succ mgrFrom) sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName) DropTableMigration mgrDropTableMode -> do logInfo_ $ arrListTable mgrTableName <> "drop table" runQuery_ $ sqlDropTable mgrTableName mgrDropTableMode runQuery_ $ sqlDelete "table_versions" $ do sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName) runMigrations :: [(Text, Int32)] -> m () runMigrations dbTablesWithVersions = do let migrationsToRun = findMigrationsToRun dbTablesWithVersions validateMigrationsToRun migrationsToRun dbTablesWithVersions when (not . null $ migrationsToRun) $ do logInfo_ "Running migrations..." forM_ migrationsToRun $ \mgr -> do runMigration mgr when (eoForceCommit options) $ do logInfo_ $ "Committing migration changes..." commit logInfo_ $ "Committing migration changes done." logInfo_ "!IMPORTANT! Database has been permanently changed" logInfo_ "Running migrations... done." validateMigrationsToRun :: [Migration m] -> [(Text, Int32)] -> m () validateMigrationsToRun migrationsToRun dbTablesWithVersions = do let migrationsToRunGrouped :: [[Migration m]] migrationsToRunGrouped = L.groupBy ((==) `on` mgrTableName) . L.sortBy (comparing mgrTableName) $ -- NB: stable sort migrationsToRun loc_common = "Database.PostgreSQL.PQTypes.Checks." ++ "checkDBConsistency.validateMigrationsToRun" lookupDBTableVer :: [Migration m] -> Maybe Int32 lookupDBTableVer mgrGroup = lookup (unRawSQL . mgrTableName . headExc head_err $ mgrGroup) dbTablesWithVersions where head_err = loc_common ++ ".lookupDBTableVer: broken invariant" groupsWithWrongDBTableVersions :: [([Migration m], Int32)] groupsWithWrongDBTableVersions = [ (mgrGroup, dbTableVer) | mgrGroup <- migrationsToRunGrouped , let dbTableVer = fromMaybe 0 $ lookupDBTableVer mgrGroup , dbTableVer /= (mgrFrom . headExc head_err $ mgrGroup) ] where head_err = loc_common ++ ".groupsWithWrongDBTableVersions: broken invariant" mgrGroupsNotInDB :: [[Migration m]] mgrGroupsNotInDB = [ mgrGroup | mgrGroup <- migrationsToRunGrouped , isNothing $ lookupDBTableVer mgrGroup ] groupsStartingWithDropTable :: [[Migration m]] groupsStartingWithDropTable = [ mgrGroup | mgrGroup <- mgrGroupsNotInDB , isDropTableMigration . headExc head_err $ mgrGroup ] where head_err = loc_common ++ ".groupsStartingWithDropTable: broken invariant" groupsNotStartingWithCreateTable :: [[Migration m]] groupsNotStartingWithCreateTable = [ mgrGroup | mgrGroup <- mgrGroupsNotInDB , mgrFrom (headExc head_err mgrGroup) /= 0 ] where head_err = loc_common ++ ".groupsNotStartingWithCreateTable: broken invariant" tblNames :: [[Migration m]] -> [RawSQL ()] tblNames grps = [ mgrTableName . headExc head_err $ grp | grp <- grps ] where head_err = loc_common ++ ".tblNames: broken invariant" when (not . null $ groupsWithWrongDBTableVersions) $ do let tnms = tblNames . map fst $ groupsWithWrongDBTableVersions logAttention ("There are migration chains selected for execution " <> "that expect a different starting table version number " <> "from the one in the database. " <> "This likely means that the order of migrations is wrong.") $ object [ "tables" .= map unRawSQL tnms ] errorInvalidMigrations tnms when (not . null $ groupsStartingWithDropTable) $ do let tnms = tblNames groupsStartingWithDropTable logAttention "There are drop table migrations for non-existing tables." $ object [ "tables" .= map unRawSQL tnms ] errorInvalidMigrations tnms -- NB: the following check can break if we allow renaming tables. when (not . null $ groupsNotStartingWithCreateTable) $ do let tnms = tblNames groupsNotStartingWithCreateTable logAttention ("Some tables haven't been created yet, but" <> "their migration lists don't start with a create table migration.") $ object [ "tables" .= map unRawSQL tnms ] errorInvalidMigrations tnms -- | Associate each table in the list with its version as it exists in -- the DB, or 0 if it's missing from the DB. getTableVersions :: (MonadDB m, MonadThrow m) => [Table] -> m [(Table, Int32)] getTableVersions tbls = sequence [ (\mver -> (tbl, fromMaybe 0 mver)) <$> checkTableVersion (tblNameString tbl) | tbl <- tbls ] -- | Like 'getTableVersions', but for all user-defined tables that -- actually exist in the DB. getDBTableVersions :: (MonadDB m, MonadThrow m) => m [(Text, Int32)] getDBTableVersions = do dbTableNames <- getDBTableNames sequence [ (\mver -> (name, fromMaybe 0 mver)) <$> checkTableVersion (T.unpack name) | name <- dbTableNames ] -- | Check whether the table exists in the DB, and return 'Just' its -- version if it does, or 'Nothing' if it doesn't. checkTableVersion :: (MonadDB m, MonadThrow m) => String -> m (Maybe Int32) checkTableVersion tblName = do doesExist <- runQuery01 . sqlSelect "pg_catalog.pg_class c" $ do sqlResult "TRUE" sqlLeftJoinOn "pg_catalog.pg_namespace n" "n.oid = c.relnamespace" sqlWhereEq "c.relname" $ tblName sqlWhere "pg_catalog.pg_table_is_visible(c.oid)" if doesExist then do runQuery_ $ "SELECT version FROM table_versions WHERE name =" tblName mver <- fetchMaybe runIdentity case mver of Just ver -> return $ Just ver Nothing -> error $ "checkTableVersion: table '" ++ tblName ++ "' is present in the database, " ++ "but there is no corresponding version info in 'table_versions'." else do return Nothing -- *** TABLE STRUCTURE *** sqlGetTableID :: Table -> SQL sqlGetTableID table = parenthesize . toSQLCommand $ sqlSelect "pg_catalog.pg_class c" $ do sqlResult "c.oid" sqlLeftJoinOn "pg_catalog.pg_namespace n" "n.oid = c.relnamespace" sqlWhereEq "c.relname" $ tblNameString table sqlWhere "pg_catalog.pg_table_is_visible(c.oid)" -- *** PRIMARY KEY *** sqlGetPrimaryKey :: (MonadDB m, MonadThrow m) => Table -> m (Maybe (PrimaryKey, RawSQL ())) sqlGetPrimaryKey table = do (mColumnNumbers :: Maybe [Int16]) <- do runQuery_ . sqlSelect "pg_catalog.pg_constraint" $ do sqlResult "conkey" sqlWhereEqSql "conrelid" (sqlGetTableID table) sqlWhereEq "contype" 'p' fetchMaybe $ unArray1 . runIdentity case mColumnNumbers of Nothing -> do return Nothing Just columnNumbers -> do columnNames <- do forM columnNumbers $ \k -> do runQuery_ . sqlSelect "pk_columns" $ do sqlWith "key_series" . sqlSelect "pg_constraint as c2" $ do sqlResult "unnest(c2.conkey) as k" sqlWhereEqSql "c2.conrelid" $ sqlGetTableID table sqlWhereEq "c2.contype" 'p' sqlWith "pk_columns" . sqlSelect "key_series" $ do sqlJoinOn "pg_catalog.pg_attribute as a" "a.attnum = key_series.k" sqlResult "a.attname::text as column_name" sqlResult "key_series.k as column_order" sqlWhereEqSql "a.attrelid" $ sqlGetTableID table sqlResult "pk_columns.column_name" sqlWhereEq "pk_columns.column_order" k fetchOne (\(Identity t) -> t :: String) runQuery_ . sqlSelect "pg_catalog.pg_constraint as c" $ do sqlWhereEq "c.contype" 'p' sqlWhereEqSql "c.conrelid" $ sqlGetTableID table sqlResult "c.conname::text" sqlResult $ Data.String.fromString ("array['" <> (mintercalate "', '" columnNames) <> "']::text[]") join <$> fetchMaybe fetchPrimaryKey fetchPrimaryKey :: (String, Array1 String) -> Maybe (PrimaryKey, RawSQL ()) fetchPrimaryKey (name, Array1 columns) = (, unsafeSQL name) <$> (pkOnColumns $ map unsafeSQL columns) -- *** CHECKS *** sqlGetChecks :: Table -> SQL sqlGetChecks table = toSQLCommand . sqlSelect "pg_catalog.pg_constraint c" $ do sqlResult "c.conname::text" sqlResult "regexp_replace(pg_get_constraintdef(c.oid, true), 'CHECK \\((.*)\\)', '\\1') AS body" -- check body sqlResult "c.convalidated" -- validated? sqlWhereEq "c.contype" 'c' sqlWhereEqSql "c.conrelid" $ sqlGetTableID table fetchTableCheck :: (String, String, Bool) -> Check fetchTableCheck (name, condition, validated) = Check { chkName = unsafeSQL name , chkCondition = unsafeSQL condition , chkValidated = validated } -- *** INDEXES *** sqlGetIndexes :: Table -> SQL sqlGetIndexes table = toSQLCommand . sqlSelect "pg_catalog.pg_class c" $ do sqlResult "c.relname::text" -- index name sqlResult $ "ARRAY(" <> selectCoordinates <> ")" -- array of index coordinates sqlResult "am.amname::text" -- the method used (btree, gin etc) sqlResult "i.indisunique" -- is it unique? sqlResult "i.indisvalid" -- is it valid? -- if partial, get constraint def sqlResult "pg_catalog.pg_get_expr(i.indpred, i.indrelid, true)" sqlJoinOn "pg_catalog.pg_index i" "c.oid = i.indexrelid" sqlJoinOn "pg_catalog.pg_am am" "c.relam = am.oid" sqlLeftJoinOn "pg_catalog.pg_constraint r" "r.conrelid = i.indrelid AND r.conindid = i.indexrelid" sqlWhereEqSql "i.indrelid" $ sqlGetTableID table sqlWhereIsNULL "r.contype" -- fetch only "pure" indexes where -- Get all coordinates of the index. selectCoordinates = smconcat [ "WITH RECURSIVE coordinates(k, name) AS (" , " VALUES (0, NULL)" , " UNION ALL" , " SELECT k+1, pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true)" , " FROM coordinates" , " WHERE pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true) != ''" , ")" , "SELECT name FROM coordinates WHERE k > 0" ] fetchTableIndex :: (String, Array1 String, String, Bool, Bool, Maybe String) -> (TableIndex, RawSQL ()) fetchTableIndex (name, Array1 columns, method, unique, valid, mconstraint) = (TableIndex { idxColumns = map unsafeSQL columns , idxMethod = read method , idxUnique = unique , idxValid = valid , idxWhere = unsafeSQL `liftM` mconstraint }, unsafeSQL name) -- *** FOREIGN KEYS *** sqlGetForeignKeys :: Table -> SQL sqlGetForeignKeys table = toSQLCommand . sqlSelect "pg_catalog.pg_constraint r" $ do sqlResult "r.conname::text" -- fk name sqlResult $ "ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN (" <> unnestWithOrdinality "r.conkey" <> ") conkeys ON (a.attnum = conkeys.item) WHERE a.attrelid = r.conrelid ORDER BY conkeys.n)" -- constrained columns sqlResult "c.relname::text" -- referenced table sqlResult $ "ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN (" <> unnestWithOrdinality "r.confkey" <> ") confkeys ON (a.attnum = confkeys.item) WHERE a.attrelid = r.confrelid ORDER BY confkeys.n)" -- referenced columns sqlResult "r.confupdtype" -- on update sqlResult "r.confdeltype" -- on delete sqlResult "r.condeferrable" -- deferrable? sqlResult "r.condeferred" -- initially deferred? sqlResult "r.convalidated" -- validated? sqlJoinOn "pg_catalog.pg_class c" "c.oid = r.confrelid" sqlWhereEqSql "r.conrelid" $ sqlGetTableID table sqlWhereEq "r.contype" 'f' where unnestWithOrdinality :: RawSQL () -> SQL unnestWithOrdinality arr = "SELECT n, " <> raw arr <> "[n] AS item FROM generate_subscripts(" <> raw arr <> ", 1) AS n" fetchForeignKey :: (String, Array1 String, String, Array1 String, Char, Char, Bool, Bool, Bool) -> (ForeignKey, RawSQL ()) fetchForeignKey ( name, Array1 columns, reftable, Array1 refcolumns , on_update, on_delete, deferrable, deferred, validated ) = (ForeignKey { fkColumns = map unsafeSQL columns , fkRefTable = unsafeSQL reftable , fkRefColumns = map unsafeSQL refcolumns , fkOnUpdate = charToForeignKeyAction on_update , fkOnDelete = charToForeignKeyAction on_delete , fkDeferrable = deferrable , fkDeferred = deferred , fkValidated = validated }, unsafeSQL name) where charToForeignKeyAction c = case c of 'a' -> ForeignKeyNoAction 'r' -> ForeignKeyRestrict 'c' -> ForeignKeyCascade 'n' -> ForeignKeySetNull 'd' -> ForeignKeySetDefault _ -> error $ "fetchForeignKey: invalid foreign key action code: " ++ show c