{-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE ScopedTypeVariables #-} module Database.Mallard.Postgre ( HasPostgreConnection (..) , DigestMismatchException (..) , ensureMigratonSchema , getAppliedMigrations , applyMigration , applyMigrations , runTests , runTest ) where import Control.Exception import Control.Lens import Control.Monad.IO.Class import Control.Monad.State import Crypto.Hash import Data.Byteable import Data.ByteString (ByteString) import Data.Foldable import qualified Data.HashMap.Strict as Map import Data.Int import Data.Monoid import Data.String.Interpolation import qualified Data.Text.Encoding as T import Database.Mallard.Types import Database.Mallard.Validation import qualified Hasql.Decoders as D import qualified Hasql.Encoders as E import qualified Hasql.Pool as Pool import Hasql.Query import Hasql.Session import Hasql.Transaction (IsolationLevel (..), Mode (..), Transaction) import qualified Hasql.Transaction as HT import qualified Hasql.Transaction.Sessions as HT class HasPostgreConnection a where postgreConnection :: Lens' a Pool.Pool data MigrationSchemaVersion = NotInit | MigrationVersion Int64 deriving (Eq, Show) instance Ord MigrationSchemaVersion where compare NotInit NotInit = EQ compare NotInit (MigrationVersion _) = LT compare (MigrationVersion _) NotInit = GT compare (MigrationVersion a) (MigrationVersion b) = compare a b ensureMigratonSchema :: (MonadIO m, MonadState s m, HasPostgreConnection s) => m () ensureMigratonSchema = do mversion <- getMigrationSchemaVersion let toApply = case mversion of NotInit -> scriptsMigrationSchema MigrationVersion v -> drop (fromIntegral (v + 1)) scriptsMigrationSchema flip mapM_ toApply $ \a@(version,_) -> do runDB $ HT.transaction Serializable Write $ applyMigrationSchemaMigraiton a liftIO $ putStrLn $ "Migrator Version: " <> show version runDB :: (MonadIO m, MonadState s m, HasPostgreConnection s) => Session a -> m a runDB session = do pool <- fmap (^. postgreConnection) get res <- liftIO $ Pool.use pool session case res of Left err -> error $ show err Right value -> return value getAppliedMigrations :: (MonadIO m, MonadState s m, HasPostgreConnection s) => m MigrationTable getAppliedMigrations = runDB $ do lst <- query () (statement stmt encoder decoder True) return $ Map.fromList $ fmap (\m -> (m ^. migrationName, m)) lst where stmt = "SELECT name, description, requires, checksum, script_text FROM mallard.applied_migrations;" encoder = E.unit decoder = D.rowsList $ Migration <$> D.value (MigrationId <$> D.text) <*> D.value D.text <*> D.value (D.array (D.arrayDimension replicateM (D.arrayValue (MigrationId <$> D.text)))) <*> D.value valueDigest <*> D.value D.text valueDigest :: (HashAlgorithm a) => D.Value (Digest a) valueDigest = D.custom $ \_ bs -> case digestFromByteString bs of Nothing -> Left "ByteString was incorrect length for selected Digest type." Just v -> Right v applyMigrations :: (MonadIO m, MonadState s m, HasPostgreConnection s) => [Migration] -> m () applyMigrations = mapM_ applyMigration applyMigration :: (MonadIO m, MonadState s m, HasPostgreConnection s) => Migration -> m () applyMigration m = do runDB $ HT.transaction Serializable Write $ do HT.sql (T.encodeUtf8 (m ^. migrationScript)) HT.query m (statement stmt encoder decoder True) liftIO $ putStrLn $ "Applied migration: " <> show (m ^. migrationName) where stmt = "INSERT INTO mallard.applied_migrations (name, description, requires, checksum, script_text) VALUES ($1, $2, $3, $4, $5)" encoder = contramap (unMigrationId . _migrationName) (E.value E.text) <> contramap _migrationDescription (E.value E.text) <> contramap (fmap unMigrationId . _migrationRequires) (E.value (E.array (E.arrayDimension foldl' (E.arrayValue E.text)))) <> contramap (toBytes . _migrationChecksum) (E.value E.bytea) <> contramap _migrationScript (E.value E.text) decoder = D.unit runTests :: (MonadIO m, MonadState s m, HasPostgreConnection s) => [Test] -> m () runTests = mapM_ runTest runTest :: (MonadIO m, MonadState s m, HasPostgreConnection s) => Test -> m () runTest t = do runDB $ HT.transaction Serializable Write $ do HT.sql (T.encodeUtf8 (t ^. testScript)) HT.condemn applyMigrationSchemaMigraiton :: (Int64, ByteString) -> Transaction () applyMigrationSchemaMigraiton (version, script) = do HT.sql script HT.query version (statement stmt encoder decoder True) where stmt = "INSERT INTO mallard.migrator_version (version) VALUES ($1)" encoder = E.value E.int8 decoder = D.unit getMigrationSchemaVersion :: (MonadIO m, MonadState s m, HasPostgreConnection s) => m MigrationSchemaVersion getMigrationSchemaVersion = runDB $ do isInit <- isMigrationVersionZero if isInit then do version <- query () (statement stmt E.unit (D.maybeRow (D.value D.int8)) True) case version of Nothing -> return $ MigrationVersion 0 Just x -> return $ MigrationVersion x else return NotInit where stmt = "SELECT coalesce(max(version), 0) as max_version FROM mallard.migrator_version" isMigrationVersionZero :: Session Bool isMigrationVersionZero = do mtable <- query () (statement stmt E.unit (D.maybeRow (D.value D.text)) True) case mtable of Nothing -> return False Just _ -> return True where stmt = "SELECT table_name FROM information_schema.tables WHERE table_schema = 'mallard' AND table_name = 'migrator_version';" -- Exceptions data DigestSizeMismatchException = DigestSizeMismatchException deriving (Show) instance Exception DigestSizeMismatchException where displayException _ = [str| The size of the applied migration's checksum is not valid. This may imply the algorithm used by this tool has changed. |] -- TODO: Add ability to reset all checksums in migration table.