{-# LANGUAGE OverloadedStrings #-}
module Database.Schema.Migrations.Backend.HDBC
    ( hdbcBackend
    )
where

import Database.HDBC
  ( quickQuery'
  , fromSql
  , toSql
  , IConnection(getTables, run, runRaw)
  , commit
  , rollback
  , disconnect
  )

import Database.Schema.Migrations.Backend
    ( Backend(..)
    , rootMigrationName
    )
import Database.Schema.Migrations.Migration
    ( Migration(..)
    , newMigration
    )

import Data.Text ( Text )
import Data.String.Conversions ( cs, (<>) )

import Control.Applicative ( (<$>) )
import Data.Time.Clock (getCurrentTime)

migrationTableName :: Text
migrationTableName :: Text
migrationTableName = Text
"installed_migrations"

createSql :: Text
createSql :: Text
createSql = Text
"CREATE TABLE " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
migrationTableName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" (migration_id TEXT)"

revertSql :: Text
revertSql :: Text
revertSql = Text
"DROP TABLE " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
migrationTableName

-- |General Backend constructor for all HDBC connection implementations.
hdbcBackend :: (IConnection conn) => conn -> Backend
hdbcBackend :: conn -> Backend
hdbcBackend conn
conn =
    Backend :: IO Migration
-> IO Bool
-> (Migration -> IO ())
-> (Migration -> IO ())
-> IO [Text]
-> IO ()
-> IO ()
-> IO ()
-> Backend
Backend { isBootstrapped :: IO Bool
isBootstrapped = String -> [String] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
elem (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs Text
migrationTableName) ([String] -> Bool) -> IO [String] -> IO Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> conn -> IO [String]
forall conn. IConnection conn => conn -> IO [String]
getTables conn
conn
            , getBootstrapMigration :: IO Migration
getBootstrapMigration =
                  do
                    UTCTime
ts <- IO UTCTime
getCurrentTime
                    Migration -> IO Migration
forall (m :: * -> *) a. Monad m => a -> m a
return (Migration -> IO Migration) -> Migration -> IO Migration
forall a b. (a -> b) -> a -> b
$ (Text -> Migration
newMigration Text
rootMigrationName)
                        { mApply :: Text
mApply = Text
createSql
                        , mRevert :: Maybe Text
mRevert = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
revertSql
                        , mDesc :: Maybe Text
mDesc = Text -> Maybe Text
forall a. a -> Maybe a
Just Text
"Migration table installation"
                        , mTimestamp :: Maybe UTCTime
mTimestamp = UTCTime -> Maybe UTCTime
forall a. a -> Maybe a
Just UTCTime
ts
                        }

            , applyMigration :: Migration -> IO ()
applyMigration = \Migration
m -> do
                conn -> String -> IO ()
forall conn. IConnection conn => conn -> String -> IO ()
runRaw conn
conn (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Migration -> Text
mApply Migration
m)
                Integer
_ <- conn -> String -> [SqlValue] -> IO Integer
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run conn
conn (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text
"INSERT INTO " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
migrationTableName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>
                          Text
" (migration_id) VALUES (?)") [Text -> SqlValue
forall a. Convertible a SqlValue => a -> SqlValue
toSql (Text -> SqlValue) -> Text -> SqlValue
forall a b. (a -> b) -> a -> b
$ Migration -> Text
mId Migration
m]
                () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

            , revertMigration :: Migration -> IO ()
revertMigration = \Migration
m -> do
                  case Migration -> Maybe Text
mRevert Migration
m of
                    Maybe Text
Nothing -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    Just Text
query -> conn -> String -> IO ()
forall conn. IConnection conn => conn -> String -> IO ()
runRaw conn
conn (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs Text
query)
                  -- Remove migration from installed_migrations in either case.
                  Integer
_ <- conn -> String -> [SqlValue] -> IO Integer
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO Integer
run conn
conn (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text
"DELETE FROM " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
migrationTableName Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<>
                            Text
" WHERE migration_id = ?") [Text -> SqlValue
forall a. Convertible a SqlValue => a -> SqlValue
toSql (Text -> SqlValue) -> Text -> SqlValue
forall a b. (a -> b) -> a -> b
$ Migration -> Text
mId Migration
m]
                  () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

            , getMigrations :: IO [Text]
getMigrations = do
                [[SqlValue]]
results <- conn -> String -> [SqlValue] -> IO [[SqlValue]]
forall conn.
IConnection conn =>
conn -> String -> [SqlValue] -> IO [[SqlValue]]
quickQuery' conn
conn (Text -> String
forall a b. ConvertibleStrings a b => a -> b
cs (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Text
"SELECT migration_id FROM " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
migrationTableName) []
                [Text] -> IO [Text]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Text] -> IO [Text]) -> [Text] -> IO [Text]
forall a b. (a -> b) -> a -> b
$ ([SqlValue] -> Text) -> [[SqlValue]] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (SqlValue -> Text
forall a. Convertible SqlValue a => SqlValue -> a
fromSql (SqlValue -> Text)
-> ([SqlValue] -> SqlValue) -> [SqlValue] -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SqlValue] -> SqlValue
forall a. [a] -> a
head) [[SqlValue]]
results

            , commitBackend :: IO ()
commitBackend = conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
commit conn
conn

            , rollbackBackend :: IO ()
rollbackBackend = conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
rollback conn
conn

            , disconnectBackend :: IO ()
disconnectBackend = conn -> IO ()
forall conn. IConnection conn => conn -> IO ()
disconnect conn
conn
            }