module Database.PostgreSQL.PQTypes.Checks (
  
    checkDatabase
  , checkDatabaseAllowUnknownTables
  , createTable
  , createDomain
  
  , ExtrasOptions(..)
  
  , migrateDatabase
  ) where
import Control.Arrow ((&&&))
import Control.Applicative ((<$>))
import Control.Monad.Catch
import Control.Monad.Reader
import Data.Int
import Data.Function (on)
import Data.Maybe
import Data.Monoid
import Data.Monoid.Utils
import Data.Ord (comparing)
import qualified Data.String
import Data.Text (Text)
import Database.PostgreSQL.PQTypes hiding (def)
import GHC.Stack (HasCallStack)
import Log
import Prelude
import TextShow
import qualified Data.List as L
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Text as T
import Database.PostgreSQL.PQTypes.ExtrasOptions
import Database.PostgreSQL.PQTypes.Checks.Util
import Database.PostgreSQL.PQTypes.Migrate
import Database.PostgreSQL.PQTypes.Model
import Database.PostgreSQL.PQTypes.SQL.Builder
import Database.PostgreSQL.PQTypes.Versions
headExc :: String -> [a] -> a
headExc s []    = error s
headExc _ (x:_) = x
migrateDatabase
  :: (MonadDB m, MonadLog m, MonadMask m)
  => ExtrasOptions
  -> [Extension]
  -> [CompositeType]
  -> [Domain]
  -> [Table]
  -> [Migration m]
  -> m ()
migrateDatabase options@ExtrasOptions{..}
  extensions composites domains tables migrations = do
  setDBTimeZoneToUTC
  mapM_ checkExtension extensions
  tablesWithVersions <- getTableVersions (tableVersions : tables)
  
  checkDBConsistency options domains tablesWithVersions migrations
  resultCheck =<< checkCompositesStructure True composites
  resultCheck =<< checkDomainsStructure domains
  resultCheck =<< checkDBStructure options tablesWithVersions
  resultCheck =<< checkTablesWereDropped migrations
  resultCheck =<< checkUnknownTables tables
  resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables)
  
  commit
checkDatabase
  :: forall m . (MonadDB m, MonadLog m, MonadThrow m)
  => ExtrasOptions -> [CompositeType] -> [Domain] -> [Table] -> m ()
checkDatabase options = checkDatabase_ options False
checkDatabaseAllowUnknownTables
  :: forall m . (MonadDB m, MonadLog m, MonadThrow m)
  => ExtrasOptions -> [CompositeType] -> [Domain] -> [Table] -> m ()
checkDatabaseAllowUnknownTables options = checkDatabase_ options True
checkDatabase_
  :: forall m . (MonadDB m, MonadLog m, MonadThrow m)
  => ExtrasOptions -> Bool -> [CompositeType] -> [Domain] -> [Table] -> m ()
checkDatabase_ options allowUnknownTables composites domains tables = do
  tablesWithVersions <- getTableVersions (tableVersions : tables)
  resultCheck $ checkVersions tablesWithVersions
  resultCheck =<< checkCompositesStructure False composites
  resultCheck =<< checkDomainsStructure domains
  resultCheck =<< checkDBStructure options tablesWithVersions
  when (not $ allowUnknownTables) $ do
    resultCheck =<< checkUnknownTables tables
    resultCheck =<< checkExistenceOfVersionsForTables (tableVersions : tables)
  
  
  resultCheck =<< checkInitialSetups tables
  where
    checkVersions :: [(Table, Int32)] -> ValidationResult
    checkVersions vs = mconcat . map checkVersion $ vs
    checkVersion :: (Table, Int32) -> ValidationResult
    checkVersion (t@Table{..}, v)
      | tblVersion `elem` tblAcceptedDbVersions
                  = validationError $
                    "Table '" <> tblNameText t <>
                    "' has its current table version in accepted db versions"
      | tblVersion == v || v `elem` tblAcceptedDbVersions
                  = mempty
      | v == 0    = validationError $
                    "Table '" <> tblNameText t <> "' must be created"
      | otherwise = validationError $
                    "Table '" <> tblNameText t
                    <> "' must be migrated" <+> showt v <+> "->"
                    <+> showt tblVersion
    checkInitialSetups :: [Table] -> m ValidationResult
    checkInitialSetups tbls =
      liftM mconcat . mapM checkInitialSetup' $ tbls
    checkInitialSetup' :: Table -> m ValidationResult
    checkInitialSetup' t@Table{..} = case tblInitialSetup of
      Nothing -> return mempty
      Just is -> checkInitialSetup is >>= \case
        True  -> return mempty
        False -> return . validationError $ "Initial setup for table '"
                 <> tblNameText t <> "' is not valid"
currentCatalog :: (MonadDB m, MonadThrow m) => m (RawSQL ())
currentCatalog = do
  runSQL_ "SELECT current_catalog::text"
  dbname <- fetchOne runIdentity
  return $ unsafeSQL $ "\"" ++ dbname ++ "\""
checkExtension :: (MonadDB m, MonadLog m, MonadThrow m) => Extension -> m ()
checkExtension (Extension extension) = do
  logInfo_ $ "Checking for extension '" <> txtExtension <> "'"
  extensionExists <- runQuery01 . sqlSelect "pg_extension" $ do
    sqlResult "TRUE"
    sqlWhereEq "extname" $ unRawSQL extension
  if not extensionExists
    then do
      logInfo_ $ "Creating extension '" <> txtExtension <> "'"
      runSQL_ $ "CREATE EXTENSION IF NOT EXISTS" <+> raw extension
    else logInfo_ $ "Extension '" <> txtExtension <> "' exists"
  where
    txtExtension = unRawSQL extension
setDBTimeZoneToUTC :: (MonadDB m, MonadLog m, MonadThrow m) => m ()
setDBTimeZoneToUTC = do
  runSQL_ "SHOW timezone"
  timezone :: String <- fetchOne runIdentity
  when (timezone /= "UTC") $ do
    dbname <- currentCatalog
    logInfo_ $ "Setting '" <> unRawSQL dbname
      <> "' database to return timestamps in UTC"
    runQuery_ $ "ALTER DATABASE" <+> dbname <+> "SET TIMEZONE = 'UTC'"
getDBTableNames :: (MonadDB m) => m [Text]
getDBTableNames = do
  runQuery_ $ sqlSelect "information_schema.tables" $ do
    sqlResult "table_name::text"
    sqlWhere "table_name <> 'table_versions'"
    sqlWhere "table_type = 'BASE TABLE'"
    sqlWhereExists $ sqlSelect "unnest(current_schemas(false)) as cs" $ do
      sqlResult "TRUE"
      sqlWhere "cs = table_schema"
  dbTableNames <- fetchMany runIdentity
  return dbTableNames
checkUnknownTables :: (MonadDB m, MonadLog m) => [Table] -> m ValidationResult
checkUnknownTables tables = do
  dbTableNames  <- getDBTableNames
  let tableNames = map (unRawSQL . tblName) tables
      absent     = dbTableNames L.\\ tableNames
      notPresent = tableNames   L.\\ dbTableNames
  if (not . null $ absent) || (not . null $ notPresent)
    then do
    mapM_ (logInfo_ . (<+>) "Unknown table:") absent
    mapM_ (logInfo_ . (<+>) "Table not present in the database:") notPresent
    return $
      (validateIsNull "Unknown tables:" absent) <>
      (validateIsNull "Tables not present in the database:" notPresent)
    else return mempty
validateIsNull :: Text -> [Text] -> ValidationResult
validateIsNull _   [] = mempty
validateIsNull msg ts = validationError $ msg <+> T.intercalate ", " ts
checkExistenceOfVersionsForTables
  :: (MonadDB m, MonadLog m)
  => [Table] -> m ValidationResult
checkExistenceOfVersionsForTables tables = do
  runQuery_ $ sqlSelect "table_versions" $ do
    sqlResult "name::text"
  (existingTableNames :: [Text]) <- fetchMany runIdentity
  let tableNames = map (unRawSQL . tblName) tables
      absent     = existingTableNames L.\\ tableNames
      notPresent = tableNames   L.\\ existingTableNames
  if (not . null $ absent) || (not . null $ notPresent)
    then do
    mapM_ (logInfo_ . (<+>) "Unknown entry in 'table_versions':") absent
    mapM_ (logInfo_ . (<+>) "Table not present in the 'table_versions':")
      notPresent
    return $
      (validateIsNull "Unknown entry in table_versions':"  absent ) <>
      (validateIsNull "Tables not present in the 'table_versions':" notPresent)
    else return mempty
checkDomainsStructure :: (MonadDB m, MonadThrow m)
                      => [Domain] -> m ValidationResult
checkDomainsStructure defs = fmap mconcat . forM defs $ \def -> do
  runQuery_ . sqlSelect "pg_catalog.pg_type t1" $ do
    sqlResult "t1.typname::text" 
    sqlResult "(SELECT pg_catalog.format_type(t2.oid, t2.typtypmod) \
              \FROM pg_catalog.pg_type t2 \
              \WHERE t2.oid = t1.typbasetype)" 
    sqlResult "NOT t1.typnotnull" 
    sqlResult "t1.typdefault" 
    sqlResult "ARRAY(SELECT c.conname::text FROM pg_catalog.pg_constraint c \
              \WHERE c.contypid = t1.oid ORDER by c.oid)" 
    sqlResult "ARRAY(SELECT regexp_replace(pg_get_constraintdef(c.oid, true), '\
              \CHECK \\((.*)\\)', '\\1') FROM pg_catalog.pg_constraint c \
              \WHERE c.contypid = t1.oid \
              \ORDER by c.oid)" 
    sqlResult "ARRAY(SELECT c.convalidated FROM pg_catalog.pg_constraint c \
              \WHERE c.contypid = t1.oid \
              \ORDER by c.oid)" 
    sqlWhereEq "t1.typname" $ unRawSQL $ domName def
  mdom <- fetchMaybe $
    \(dname, dtype, nullable, defval, cnames, conds, valids) ->
      Domain
      { domName = unsafeSQL dname
      , domType = dtype
      , domNullable = nullable
      , domDefault = unsafeSQL <$> defval
      , domChecks =
          mkChecks $ zipWith3
          (\cname cond validated ->
             Check
             { chkName = unsafeSQL cname
             , chkCondition = unsafeSQL cond
             , chkValidated = validated
             }) (unArray1 cnames) (unArray1 conds) (unArray1 valids)
      }
  return $ case mdom of
    Just dom
      | dom /= def -> topMessage "domain" (unRawSQL $ domName dom) $ mconcat [
          compareAttr dom def "name" domName
        , compareAttr dom def "type" domType
        , compareAttr dom def "nullable" domNullable
        , compareAttr dom def "default" domDefault
        , compareAttr dom def "checks" domChecks
        ]
      | otherwise -> mempty
    Nothing -> validationError $ "Domain '" <> unRawSQL (domName def)
               <> "' doesn't exist in the database"
  where
    compareAttr :: (Eq a, Show a)
                => Domain -> Domain -> Text -> (Domain -> a) -> ValidationResult
    compareAttr dom def attrname attr
      | attr dom == attr def = mempty
      | otherwise            = validationError $
        "Attribute '" <> attrname
        <> "' does not match (database:" <+> T.pack (show $ attr dom)
        <> ", definition:" <+> T.pack (show $ attr def) <> ")"
checkTablesWereDropped :: (MonadDB m, MonadThrow m) =>
                          [Migration m] -> m ValidationResult
checkTablesWereDropped mgrs = do
  let droppedTableNames = [ mgrTableName mgr
                          | mgr <- mgrs, isDropTableMigration mgr ]
  fmap mconcat . forM droppedTableNames $
    \tblName -> do
      mver <- checkTableVersion (T.unpack . unRawSQL $ tblName)
      return $ if isNothing mver
               then mempty
               else validationError $ "The table '" <> unRawSQL tblName
                    <> "' that must have been dropped"
                    <> " is still present in the database."
checkCompositesStructure
  :: MonadDB m
  => Bool 
  -> [CompositeType]
  -> m ValidationResult
checkCompositesStructure createTypes compositeList = getDBCompositeTypes >>= \case
  [] | createTypes -> do 
                         
    mapM_ (runQuery_ . sqlCreateComposite) compositeList
    return mempty
  dbCompositeTypes -> pure $ mconcat
    [ checkNotPresentComposites
    , checkDatabaseComposites
    ]
    where
      compositeMap = M.fromList $
        map ((unRawSQL . ctName) &&& ctColumns) compositeList
      checkNotPresentComposites =
        let notPresent = S.toList $ M.keysSet compositeMap
              S.\\ S.fromList (map (unRawSQL . ctName) dbCompositeTypes)
        in validateIsNull "Composite types not present in the database:" notPresent
      checkDatabaseComposites = mconcat . (`map` dbCompositeTypes) $ \dbComposite ->
        let cname = unRawSQL $ ctName dbComposite
        in case cname `M.lookup` compositeMap of
          Just columns -> topMessage "composite type" cname $
            checkColumns 1 columns (ctColumns dbComposite)
          Nothing -> validationError $ "Composite type '" <> T.pack (show dbComposite)
            <> "' from the database doesn't have a corresponding code definition"
        where
          checkColumns
            :: Int -> [CompositeColumn] -> [CompositeColumn] -> ValidationResult
          checkColumns _ [] [] = mempty
          checkColumns _ rest [] = validationError $
            objectHasLess "Composite type" "columns" rest
          checkColumns _ [] rest = validationError $
            objectHasMore "Composite type" "columns" rest
          checkColumns !n (d:defs) (c:cols) = mconcat [
              validateNames $ ccName d == ccName c
            , validateTypes $ ccType d == ccType c
            , checkColumns (n+1) defs cols
            ]
            where
              validateNames True  = mempty
              validateNames False = validationError $
                errorMsg ("no. " <> showt n) "names" (unRawSQL . ccName)
              validateTypes True  = mempty
              validateTypes False = validationError $
                errorMsg (unRawSQL $ ccName d) "types" (T.pack . show . ccType)
              errorMsg ident attr f =
                "Column '" <> ident <> "' differs in"
                <+> attr <+> "(database:" <+> f c <> ", definition:" <+> f d <> ")."
checkDBStructure
  :: forall m. (MonadDB m, MonadThrow m)
  => ExtrasOptions
  -> [(Table, Int32)]
  -> m ValidationResult
checkDBStructure options tables = fmap mconcat .
                                  forM tables $ \(table, version) ->
  do
  result <- topMessage "table" (tblNameText table) <$> checkTableStructure table
  
  
  return $ if version `elem` tblAcceptedDbVersions table
           then validationErrorsToInfos result
           else result
  where
    checkTableStructure :: Table -> m ValidationResult
    checkTableStructure table@Table{..} = do
      
      
      runQuery_ $ sqlSelect "pg_catalog.pg_attribute a" $ do
        sqlResult "a.attname::text"
        sqlResult "pg_catalog.format_type(a.atttypid, a.atttypmod)"
        sqlResult "NOT a.attnotnull"
        sqlResult . parenthesize . toSQLCommand $
          sqlSelect "pg_catalog.pg_attrdef d" $ do
            sqlResult "pg_catalog.pg_get_expr(d.adbin, d.adrelid)"
            sqlWhere "d.adrelid = a.attrelid"
            sqlWhere "d.adnum = a.attnum"
            sqlWhere "a.atthasdef"
        sqlWhere "a.attnum > 0"
        sqlWhere "NOT a.attisdropped"
        sqlWhereEqSql "a.attrelid" $ sqlGetTableID table
        sqlOrderBy "a.attnum"
      desc <- fetchMany fetchTableColumn
      
      pk <- sqlGetPrimaryKey table
      runQuery_ $ sqlGetChecks table
      checks <- fetchMany fetchTableCheck
      runQuery_ $ sqlGetIndexes table
      indexes <- fetchMany fetchTableIndex
      runQuery_ $ sqlGetForeignKeys table
      fkeys <- fetchMany fetchForeignKey
      return $ mconcat [
          checkColumns 1 tblColumns desc
        , checkPrimaryKey tblPrimaryKey pk
        , checkChecks tblChecks checks
        , checkIndexes tblIndexes indexes
        , checkForeignKeys tblForeignKeys fkeys
        ]
      where
        fetchTableColumn
          :: (String, ColumnType, Bool, Maybe String) -> TableColumn
        fetchTableColumn (name, ctype, nullable, mdefault) = TableColumn {
            colName = unsafeSQL name
          , colType = ctype
          , colNullable = nullable
          , colDefault = unsafeSQL `liftM` mdefault
          }
        checkColumns
          :: Int -> [TableColumn] -> [TableColumn] -> ValidationResult
        checkColumns _ [] [] = mempty
        checkColumns _ rest [] = validationError $
          objectHasLess "Table" "columns" rest
        checkColumns _ [] rest = validationError $
          objectHasMore "Table" "columns" rest
        checkColumns !n (d:defs) (c:cols) = mconcat [
            validateNames $ colName d == colName c
          
          
          , validateTypes $ colType d == colType c ||
            (colType d == BigSerialT && colType c == BigIntT)
          
          
          
          , validateDefaults $ colDefault d == colDefault c ||
            (colDefault d == Nothing
             && ((T.isPrefixOf "nextval('" . unRawSQL) `liftM` colDefault c)
                == Just True)
          , validateNullables $ colNullable d == colNullable c
          , checkColumns (n+1) defs cols
          ]
          where
            validateNames True  = mempty
            validateNames False = validationError $
              errorMsg ("no. " <> showt n) "names" (unRawSQL . colName)
            validateTypes True  = mempty
            validateTypes False = validationError $
              errorMsg cname "types" (T.pack . show . colType)
              <+> sqlHint ("TYPE" <+> columnTypeToSQL (colType d))
            validateNullables True  = mempty
            validateNullables False = validationError $
              errorMsg cname "nullables" (showt . colNullable)
              <+> sqlHint ((if colNullable d then "DROP" else "SET")
                            <+> "NOT NULL")
            validateDefaults True  = mempty
            validateDefaults False = validationError $
              (errorMsg cname "defaults" (showt . fmap unRawSQL . colDefault))
              <+> sqlHint set_default
              where
                set_default = case colDefault d of
                  Just v  -> "SET DEFAULT" <+> v
                  Nothing -> "DROP DEFAULT"
            cname = unRawSQL $ colName d
            errorMsg ident attr f =
              "Column '" <> ident <> "' differs in"
              <+> attr <+> "(table:" <+> f c <> ", definition:" <+> f d <> ")."
            sqlHint sql =
              "(HINT: SQL for making the change is: ALTER TABLE"
              <+> tblNameText table <+> "ALTER COLUMN" <+> unRawSQL (colName d)
              <+> unRawSQL sql <> ")"
        checkPrimaryKey :: Maybe PrimaryKey -> Maybe (PrimaryKey, RawSQL ())
                        -> ValidationResult
        checkPrimaryKey mdef mpk = mconcat [
            checkEquality "PRIMARY KEY" def (map fst pk)
          , checkNames (const (pkName tblName)) pk
          , if (eoEnforcePKs options)
            then checkPKPresence tblName mdef mpk
            else mempty
          ]
          where
            def = maybeToList mdef
            pk = maybeToList mpk
        checkChecks :: [Check] -> [Check] -> ValidationResult
        checkChecks defs checks =
          mapValidationResult id mapErrs (checkEquality "CHECKs" defs checks)
          where
            mapErrs []      = []
            mapErrs errmsgs = errmsgs <>
              [ " (HINT: If checks are equal modulo number of \
                \ parentheses/whitespaces used in conditions, \
                \ just copy and paste expected output into source code)"
              ]
        checkIndexes :: [TableIndex] -> [(TableIndex, RawSQL ())]
                     -> ValidationResult
        checkIndexes defs indexes = mconcat [
            checkEquality "INDEXes" defs (map fst indexes)
          , checkNames (indexName tblName) indexes
          ]
        checkForeignKeys :: [ForeignKey] -> [(ForeignKey, RawSQL ())]
                         -> ValidationResult
        checkForeignKeys defs fkeys = mconcat [
            checkEquality "FOREIGN KEYs" defs (map fst fkeys)
          , checkNames (fkName tblName) fkeys
          ]
checkDBConsistency
  :: forall m. (MonadDB m, MonadLog m, MonadMask m)
  => ExtrasOptions -> [Domain] -> [(Table, Int32)] -> [Migration m]
  -> m ()
checkDBConsistency options domains tablesWithVersions migrations = do
  autoTransaction <- tsAutoTransaction <$> getTransactionSettings
  unless autoTransaction $ do
    error "checkDBConsistency: tsAutoTransaction setting needs to be True"
  
  validateMigrations
  validateDropTableMigrations
  
  dbTablesWithVersions <- getDBTableVersions
  if all ((==) 0 . snd) tablesWithVersions
    
    then do
      createDBSchema
      initializeDB
    
    else do
      
      validateMigrationsAgainstDB [ (tblName table, tblVersion table, actualVer)
                                  | (table, actualVer) <- tablesWithVersions ]
      validateDropTableMigrationsAgainstDB dbTablesWithVersions
      
      runMigrations dbTablesWithVersions
  where
    tables = map fst tablesWithVersions
    errorInvalidMigrations :: HasCallStack => [RawSQL ()] -> a
    errorInvalidMigrations tblNames =
      error $ "checkDBConsistency: invalid migrations for tables"
              <+> (L.intercalate ", " $ map (T.unpack . unRawSQL) tblNames)
    checkMigrationsListValidity :: Table -> [Int32] -> [Int32] -> m ()
    checkMigrationsListValidity table presentMigrationVersions
      expectedMigrationVersions = do
      when (presentMigrationVersions /= expectedMigrationVersions) $ do
        logAttention "Migrations are invalid" $ object [
            "table"                       .= tblNameText table
          , "migration_versions"          .= presentMigrationVersions
          , "expected_migration_versions" .= expectedMigrationVersions
          ]
        errorInvalidMigrations [tblName $ table]
    validateMigrations :: m ()
    validateMigrations = forM_ tables $ \table -> do
      let presentMigrationVersions
            = [ mgrFrom | Migration{..} <- migrations
                        , mgrTableName == tblName table ]
          expectedMigrationVersions
            = reverse $ take (length presentMigrationVersions) $
              reverse  [0 .. tblVersion table - 1]
      checkMigrationsListValidity table presentMigrationVersions
        expectedMigrationVersions
    validateDropTableMigrations :: m ()
    validateDropTableMigrations = do
      let droppedTableNames =
            [ mgrTableName $ mgr | mgr <- migrations
                                 , isDropTableMigration mgr ]
          tableNames =
            [ tblName tbl | tbl <- tables ]
      
      
      let intersection = L.intersect droppedTableNames tableNames
      when (not . null $ intersection) $ do
          logAttention ("The intersection between tables "
                        <> "and dropped tables is not empty")
            $ object
            [ "intersection" .= map unRawSQL intersection ]
          errorInvalidMigrations [ tblName tbl
                                 | tbl <- tables
                                 , tblName tbl `elem` intersection ]
      
      
      
      let migrationsByTable     = L.groupBy ((==) `on` mgrTableName)
                                  migrations
          dropMigrationLists    = [ mgrs | mgrs <- migrationsByTable
                                         , any isDropTableMigration mgrs ]
          invalidMigrationLists =
            [ mgrs | mgrs <- dropMigrationLists
                   , (not . isDropTableMigration . last $ mgrs) ||
                     (length . filter isDropTableMigration $ mgrs) > 1 ]
      when (not . null $ invalidMigrationLists) $ do
        let tablesWithInvalidMigrationLists =
              [ mgrTableName mgr | mgrs <- invalidMigrationLists
                                 , let mgr = head mgrs ]
        logAttention ("Migration lists for some tables contain "
                      <> "either multiple drop table migrations or "
                      <> "a drop table migration in non-tail position.")
          $ object [ "tables" .=
                     [ unRawSQL tblName
                     | tblName <- tablesWithInvalidMigrationLists ] ]
        errorInvalidMigrations tablesWithInvalidMigrationLists
    createDBSchema :: m ()
    createDBSchema = do
      logInfo_ "Creating domains..."
      mapM_ createDomain domains
      
      logInfo_ "Creating tables..."
      mapM_ (createTable False) tables
      logInfo_ "Creating table constraints..."
      mapM_ createTableConstraints tables
      logInfo_ "Done."
    initializeDB :: m ()
    initializeDB = do
      logInfo_ "Running initial setup for tables..."
      forM_ tables $ \t -> case tblInitialSetup t of
        Nothing -> return ()
        Just tis -> do
          logInfo_ $ "Initializing" <+> tblNameText t <> "..."
          initialSetup tis
      logInfo_ "Done."
    
    
    validateMigrationsAgainstDB :: [(RawSQL (), Int32, Int32)] -> m ()
    validateMigrationsAgainstDB tablesWithVersions_
      = forM_ tablesWithVersions_ $ \(tableName, expectedVer, actualVer) ->
        when (expectedVer /= actualVer) $
        case [ m | m@Migration{..} <- migrations
                 , mgrTableName == tableName ] of
          [] ->
            error $ "checkDBConsistency: no migrations found for table '"
              ++ (T.unpack . unRawSQL $ tableName) ++ "', cannot migrate "
              ++ show actualVer ++ " -> " ++ show expectedVer
          (m:_) | mgrFrom m > actualVer ->
                  error $ "checkDBConsistency: earliest migration for table '"
                    ++ (T.unpack . unRawSQL $ tableName) ++ "' is from version "
                    ++ show (mgrFrom m) ++ ", cannot migrate "
                    ++ show actualVer ++ " -> " ++ show expectedVer
                | otherwise -> return ()
    validateDropTableMigrationsAgainstDB :: [(Text, Int32)] -> m ()
    validateDropTableMigrationsAgainstDB dbTablesWithVersions = do
      let dbTablesToDropWithVersions =
            [ (tblName, mgrFrom mgr, fromJust mver)
            | mgr <- migrations
            , isDropTableMigration mgr
            , let tblName = mgrTableName mgr
            , let mver = lookup (unRawSQL tblName) $ dbTablesWithVersions
            , isJust mver ]
      forM_ dbTablesToDropWithVersions $ \(tblName, fromVer, ver) ->
        when (fromVer /= ver) $
          
          
          
          validateMigrationsAgainstDB [(tblName, fromVer, ver)]
    findMigrationsToRun :: [(Text, Int32)] -> [Migration m]
    findMigrationsToRun dbTablesWithVersions =
      let tableNamesToDrop = [ mgrTableName mgr | mgr <- migrations
                                                , isDropTableMigration mgr ]
          droppedEventually :: Migration m -> Bool
          droppedEventually mgr = mgrTableName mgr `elem` tableNamesToDrop
          lookupVer :: Migration m -> Maybe Int32
          lookupVer mgr = lookup (unRawSQL $ mgrTableName mgr)
                          dbTablesWithVersions
          tableDoesNotExist = isNothing . lookupVer
          
          
          
          migrationsToRun' = dropWhile
            (\mgr ->
               case lookupVer mgr of
                 
                 
                 
                 Nothing -> not $
                            (mgrFrom mgr == 0) &&
                            (not . droppedEventually $ mgr)
                 
                 
                 Just ver -> not $
                             mgrFrom mgr >= ver)
            migrations
          
          
          
          
          
          
          
          
          
          
          l                     = length migrationsToRun'
          initialMigrations     = drop l $ reverse migrations
          additionalMigrations' = takeWhile
            (\mgr -> droppedEventually mgr && tableDoesNotExist mgr)
            initialMigrations
          
          
          
          additionalMigrations  =
            let ret  = reverse additionalMigrations'
                grps = L.groupBy ((==) `on` mgrTableName) ret
            in if any ((/=) 0 . mgrFrom . head) grps
               then []
               else ret
          
          
          migrationsToRun       = if not . null $ migrationsToRun'
                                  then additionalMigrations ++ migrationsToRun'
                                  else []
      in migrationsToRun
    runMigration :: (Migration m) -> m ()
    runMigration Migration{..} = do
      case mgrAction of
        StandardMigration mgrDo -> do
          logMigration
          mgrDo
          updateTableVersion
        DropTableMigration mgrDropTableMode -> do
          logInfo_ $ arrListTable mgrTableName <> "drop table"
          runQuery_ $ sqlDropTable mgrTableName
            mgrDropTableMode
          runQuery_ $ sqlDelete "table_versions" $ do
            sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName)
        CreateIndexConcurrentlyMigration tname idx -> do
          logMigration
          
          
          
          
          
          runQuery_ $ "DROP INDEX IF EXISTS" <+> indexName tname idx
          
          
          
          
          
          runSQL_ "COMMIT"
          runQuery_ (sqlCreateIndexConcurrently tname idx) `finally` begin
          updateTableVersion
      where
        logMigration = do
          logInfo_ $ arrListTable mgrTableName
            <> showt mgrFrom <+> "->" <+> showt (succ mgrFrom)
        updateTableVersion = do
          runQuery_ $ sqlUpdate "table_versions" $ do
            sqlSet "version"  (succ mgrFrom)
            sqlWhereEq "name" (T.unpack . unRawSQL $ mgrTableName)
    runMigrations :: [(Text, Int32)] -> m ()
    runMigrations dbTablesWithVersions = do
      let migrationsToRun = findMigrationsToRun dbTablesWithVersions
      validateMigrationsToRun migrationsToRun dbTablesWithVersions
      when (not . null $ migrationsToRun) $ do
        logInfo_ "Running migrations..."
        forM_ migrationsToRun $ \mgr -> do
          runMigration mgr
          when (eoForceCommit options) $ do
            logInfo_ $ "Committing migration changes..."
            commit
            logInfo_ $ "Committing migration changes done."
            logInfo_ "!IMPORTANT! Database has been permanently changed"
        logInfo_ "Running migrations... done."
    validateMigrationsToRun :: [Migration m] -> [(Text, Int32)] -> m ()
    validateMigrationsToRun migrationsToRun dbTablesWithVersions = do
      let migrationsToRunGrouped :: [[Migration m]]
          migrationsToRunGrouped =
            L.groupBy ((==) `on` mgrTableName) .
            L.sortBy (comparing mgrTableName) $ 
            migrationsToRun
          loc_common = "Database.PostgreSQL.PQTypes.Checks."
            ++ "checkDBConsistency.validateMigrationsToRun"
          lookupDBTableVer :: [Migration m] -> Maybe Int32
          lookupDBTableVer mgrGroup =
            lookup (unRawSQL . mgrTableName . headExc head_err
                    $ mgrGroup) dbTablesWithVersions
            where
              head_err = loc_common ++ ".lookupDBTableVer: broken invariant"
          groupsWithWrongDBTableVersions :: [([Migration m], Int32)]
          groupsWithWrongDBTableVersions =
            [ (mgrGroup, dbTableVer)
            | mgrGroup <- migrationsToRunGrouped
            , let dbTableVer = fromMaybe 0 $ lookupDBTableVer mgrGroup
            , dbTableVer /= (mgrFrom . headExc head_err $ mgrGroup)
            ]
            where
              head_err = loc_common
                ++ ".groupsWithWrongDBTableVersions: broken invariant"
          mgrGroupsNotInDB :: [[Migration m]]
          mgrGroupsNotInDB =
            [ mgrGroup
            | mgrGroup <- migrationsToRunGrouped
            , isNothing $ lookupDBTableVer mgrGroup
            ]
          groupsStartingWithDropTable :: [[Migration m]]
          groupsStartingWithDropTable =
            [ mgrGroup
            | mgrGroup <- mgrGroupsNotInDB
            , isDropTableMigration . headExc head_err $ mgrGroup
            ]
            where
              head_err = loc_common
                ++ ".groupsStartingWithDropTable: broken invariant"
          groupsNotStartingWithCreateTable :: [[Migration m]]
          groupsNotStartingWithCreateTable =
            [ mgrGroup
            | mgrGroup <- mgrGroupsNotInDB
            , mgrFrom (headExc head_err mgrGroup) /= 0
            ]
            where
              head_err = loc_common
                ++ ".groupsNotStartingWithCreateTable: broken invariant"
          tblNames :: [[Migration m]] -> [RawSQL ()]
          tblNames grps =
            [ mgrTableName . headExc head_err $ grp | grp <- grps ]
            where
              head_err = loc_common ++ ".tblNames: broken invariant"
      when (not . null $ groupsWithWrongDBTableVersions) $ do
        let tnms = tblNames . map fst $ groupsWithWrongDBTableVersions
        logAttention
          ("There are migration chains selected for execution "
            <> "that expect a different starting table version number "
            <> "from the one in the database. "
            <> "This likely means that the order of migrations is wrong.")
          $ object [ "tables" .= map unRawSQL tnms ]
        errorInvalidMigrations tnms
      when (not . null $ groupsStartingWithDropTable) $ do
        let tnms = tblNames groupsStartingWithDropTable
        logAttention "There are drop table migrations for non-existing tables."
          $ object [ "tables" .= map unRawSQL tnms ]
        errorInvalidMigrations tnms
      
      when (not . null $ groupsNotStartingWithCreateTable) $ do
        let tnms = tblNames groupsNotStartingWithCreateTable
        logAttention
          ("Some tables haven't been created yet, but" <>
            "their migration lists don't start with a create table migration.")
          $ object [ "tables" .=  map unRawSQL tnms ]
        errorInvalidMigrations tnms
getTableVersions :: (MonadDB m, MonadThrow m) => [Table] -> m [(Table, Int32)]
getTableVersions tbls =
  sequence
  [ (\mver -> (tbl, fromMaybe 0 mver)) <$> checkTableVersion (tblNameString tbl)
  | tbl <- tbls ]
getDBTableVersions :: (MonadDB m, MonadThrow m) => m [(Text, Int32)]
getDBTableVersions = do
  dbTableNames <- getDBTableNames
  sequence
    [ (\mver -> (name, fromMaybe 0 mver)) <$> checkTableVersion (T.unpack name)
    | name <- dbTableNames ]
checkTableVersion :: (MonadDB m, MonadThrow m) => String -> m (Maybe Int32)
checkTableVersion tblName = do
  doesExist <- runQuery01 . sqlSelect "pg_catalog.pg_class c" $ do
    sqlResult "TRUE"
    sqlLeftJoinOn "pg_catalog.pg_namespace n" "n.oid = c.relnamespace"
    sqlWhereEq "c.relname" $ tblName
    sqlWhere "pg_catalog.pg_table_is_visible(c.oid)"
  if doesExist
    then do
      runQuery_ $ "SELECT version FROM table_versions WHERE name ="
        <?> tblName
      mver <- fetchMaybe runIdentity
      case mver of
        Just ver -> return $ Just ver
        Nothing  -> error $ "checkTableVersion: table '"
          ++ tblName
          ++ "' is present in the database, "
          ++ "but there is no corresponding version info in 'table_versions'."
    else do
      return Nothing
sqlGetTableID :: Table -> SQL
sqlGetTableID table = parenthesize . toSQLCommand $
  sqlSelect "pg_catalog.pg_class c" $ do
    sqlResult "c.oid"
    sqlLeftJoinOn "pg_catalog.pg_namespace n" "n.oid = c.relnamespace"
    sqlWhereEq "c.relname" $ tblNameString table
    sqlWhere "pg_catalog.pg_table_is_visible(c.oid)"
sqlGetPrimaryKey
  :: (MonadDB m, MonadThrow m)
  => Table -> m (Maybe (PrimaryKey, RawSQL ()))
sqlGetPrimaryKey table = do
  (mColumnNumbers :: Maybe [Int16]) <- do
    runQuery_ . sqlSelect "pg_catalog.pg_constraint" $ do
      sqlResult "conkey"
      sqlWhereEqSql "conrelid" (sqlGetTableID table)
      sqlWhereEq "contype" 'p'
    fetchMaybe $ unArray1 . runIdentity
  case mColumnNumbers of
    Nothing -> do return Nothing
    Just columnNumbers -> do
      columnNames <- do
        forM columnNumbers $ \k -> do
          runQuery_ . sqlSelect "pk_columns" $ do
            sqlWith "key_series" . sqlSelect "pg_constraint as c2" $ do
              sqlResult "unnest(c2.conkey) as k"
              sqlWhereEqSql "c2.conrelid" $ sqlGetTableID table
              sqlWhereEq "c2.contype" 'p'
            sqlWith "pk_columns" . sqlSelect "key_series" $ do
              sqlJoinOn  "pg_catalog.pg_attribute as a" "a.attnum = key_series.k"
              sqlResult "a.attname::text as column_name"
              sqlResult "key_series.k as column_order"
              sqlWhereEqSql "a.attrelid" $ sqlGetTableID table
            sqlResult "pk_columns.column_name"
            sqlWhereEq "pk_columns.column_order" k
          fetchOne (\(Identity t) -> t :: String)
      runQuery_ . sqlSelect "pg_catalog.pg_constraint as c" $ do
        sqlWhereEq "c.contype" 'p'
        sqlWhereEqSql "c.conrelid" $ sqlGetTableID table
        sqlResult "c.conname::text"
        sqlResult $ Data.String.fromString
          ("array['" <> (mintercalate "', '" columnNames) <> "']::text[]")
      join <$> fetchMaybe fetchPrimaryKey
fetchPrimaryKey :: (String, Array1 String) -> Maybe (PrimaryKey, RawSQL ())
fetchPrimaryKey (name, Array1 columns) = (, unsafeSQL name)
  <$> (pkOnColumns $ map unsafeSQL columns)
sqlGetChecks :: Table -> SQL
sqlGetChecks table = toSQLCommand . sqlSelect "pg_catalog.pg_constraint c" $ do
  sqlResult "c.conname::text"
  sqlResult "regexp_replace(pg_get_constraintdef(c.oid, true), \
            \'CHECK \\((.*)\\)', '\\1') AS body" 
  sqlResult "c.convalidated" 
  sqlWhereEq "c.contype" 'c'
  sqlWhereEqSql "c.conrelid" $ sqlGetTableID table
fetchTableCheck :: (String, String, Bool) -> Check
fetchTableCheck (name, condition, validated) = Check {
  chkName = unsafeSQL name
, chkCondition = unsafeSQL condition
, chkValidated = validated
}
sqlGetIndexes :: Table -> SQL
sqlGetIndexes table = toSQLCommand . sqlSelect "pg_catalog.pg_class c" $ do
  sqlResult "c.relname::text" 
  sqlResult $ "ARRAY(" <> selectCoordinates <> ")" 
  sqlResult "am.amname::text" 
  sqlResult "i.indisunique" 
  sqlResult "i.indisvalid"  
  
  sqlResult "pg_catalog.pg_get_expr(i.indpred, i.indrelid, true)"
  sqlJoinOn "pg_catalog.pg_index i" "c.oid = i.indexrelid"
  sqlJoinOn "pg_catalog.pg_am am" "c.relam = am.oid"
  sqlLeftJoinOn "pg_catalog.pg_constraint r"
    "r.conrelid = i.indrelid AND r.conindid = i.indexrelid"
  sqlWhereEqSql "i.indrelid" $ sqlGetTableID table
  sqlWhereIsNULL "r.contype" 
  where
    
    selectCoordinates = smconcat [
        "WITH RECURSIVE coordinates(k, name) AS ("
      , "  VALUES (0, NULL)"
      , "  UNION ALL"
      , "    SELECT k+1, pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true)"
      , "      FROM coordinates"
      , "     WHERE pg_catalog.pg_get_indexdef(i.indexrelid, k+1, true) != ''"
      , ")"
      , "SELECT name FROM coordinates WHERE k > 0"
      ]
fetchTableIndex :: (String, Array1 String, String, Bool, Bool, Maybe String)
                -> (TableIndex, RawSQL ())
fetchTableIndex (name, Array1 columns, method, unique, valid, mconstraint) =
  (TableIndex
   { idxColumns = map unsafeSQL columns
   , idxMethod = read method
   , idxUnique = unique
   , idxValid = valid
   , idxWhere = unsafeSQL `liftM` mconstraint
   }
  , unsafeSQL name)
sqlGetForeignKeys :: Table -> SQL
sqlGetForeignKeys table = toSQLCommand
                          . sqlSelect "pg_catalog.pg_constraint r" $ do
  sqlResult "r.conname::text" 
  sqlResult $
    "ARRAY(SELECT a.attname::text FROM pg_catalog.pg_attribute a JOIN ("
    <> unnestWithOrdinality "r.conkey"
    <> ") conkeys ON (a.attnum = conkeys.item) \
       \WHERE a.attrelid = r.conrelid \
       \ORDER BY conkeys.n)" 
  sqlResult "c.relname::text" 
  sqlResult $ "ARRAY(SELECT a.attname::text \
              \FROM pg_catalog.pg_attribute a JOIN ("
    <> unnestWithOrdinality "r.confkey"
    <> ") confkeys ON (a.attnum = confkeys.item) \
       \WHERE a.attrelid = r.confrelid \
       \ORDER BY confkeys.n)" 
  sqlResult "r.confupdtype" 
  sqlResult "r.confdeltype" 
  sqlResult "r.condeferrable" 
  sqlResult "r.condeferred" 
  sqlResult "r.convalidated" 
  sqlJoinOn "pg_catalog.pg_class c" "c.oid = r.confrelid"
  sqlWhereEqSql "r.conrelid" $ sqlGetTableID table
  sqlWhereEq "r.contype" 'f'
  where
    unnestWithOrdinality :: RawSQL () -> SQL
    unnestWithOrdinality arr =
      "SELECT n, " <> raw arr
      <> "[n] AS item FROM generate_subscripts(" <> raw arr <> ", 1) AS n"
fetchForeignKey ::
  (String, Array1 String, String, Array1 String, Char, Char, Bool, Bool, Bool)
  -> (ForeignKey, RawSQL ())
fetchForeignKey
  ( name, Array1 columns, reftable, Array1 refcolumns
  , on_update, on_delete, deferrable, deferred, validated ) = (ForeignKey {
  fkColumns = map unsafeSQL columns
, fkRefTable = unsafeSQL reftable
, fkRefColumns = map unsafeSQL refcolumns
, fkOnUpdate = charToForeignKeyAction on_update
, fkOnDelete = charToForeignKeyAction on_delete
, fkDeferrable = deferrable
, fkDeferred = deferred
, fkValidated = validated
}, unsafeSQL name)
  where
    charToForeignKeyAction c = case c of
      'a' -> ForeignKeyNoAction
      'r' -> ForeignKeyRestrict
      'c' -> ForeignKeyCascade
      'n' -> ForeignKeySetNull
      'd' -> ForeignKeySetDefault
      _   -> error $ "fetchForeignKey: invalid foreign key action code: "
                     ++ show c