module Database.PostgreSQL.Migrations (
defaultMain
, connectEnv
, runSqlFile
, Migration, migrate
, column
, create_table
, add_column
, create_index
, create_unique_index
, drop_table
, drop_column
, drop_index
, rename_column
, change_column
, create_table_stmt, add_column_stmt, create_index_stmt
, drop_table_stmt, drop_column_stmt, drop_index_stmt
, rename_column_stmt, change_column_stmt
) where
import Control.Monad
import Control.Monad.Reader
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import Data.Int
import Data.Maybe
import Database.PostgreSQL.Simple hiding (connect)
import Database.PostgreSQL.Simple.Internal (exec)
import Database.PostgreSQL.Simple.Types
import System.Environment
import System.Exit
import Database.PostgreSQL.Escape
connectEnv :: IO Connection
connectEnv = do
psqlStr <- getEnvironment >>=
return . (fromMaybe "") . (lookup "DATABASE_URL")
connectPostgreSQL $ S8.pack psqlStr
type Migration = ReaderT Connection IO
migrate :: Migration a -> Connection -> IO ()
migrate = (void .) . runReaderT
executeQuery_ :: Query -> Migration Int64
executeQuery_ q = ask >>= \conn -> liftIO $ execute_ conn q
runSqlFile :: FilePath -> Migration ()
runSqlFile sqlFile = void $ do
conn <- ask
liftIO $ do
rawSql <- S.readFile sqlFile
exec conn rawSql
column :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
column name def = S8.concat [quoteIdent name, " ", def]
create_table :: S8.ByteString
-> [S8.ByteString]
-> Migration Int64
create_table = (executeQuery_ .) . create_table_stmt
create_table_stmt :: S8.ByteString
-> [S8.ByteString]
-> Query
create_table_stmt tableName colDefs = Query $ S8.concat $
[ "create table "
, quoteIdent tableName
, " ("] ++ (S8.intercalate ", " colDefs):([");"])
drop_table :: S8.ByteString -> Migration Int64
drop_table = executeQuery_ . drop_table_stmt
drop_table_stmt :: S8.ByteString -> Query
drop_table_stmt tableName = Query $ S8.concat
[ "drop table ", quoteIdent tableName, ";"]
add_column :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
-> Migration Int64
add_column = ((executeQuery_ .) .) . add_column_stmt
add_column_stmt :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
-> Query
add_column_stmt tableName colName colDef = Query $ S8.concat
[ "alter table ", quoteIdent tableName, " add ", column colName colDef, ";"]
drop_column :: S8.ByteString
-> S8.ByteString
-> Migration Int64
drop_column = (executeQuery_ .) . drop_column_stmt
drop_column_stmt :: S8.ByteString
-> S8.ByteString
-> Query
drop_column_stmt tableName colName = Query $ S8.concat
["alter table ", quoteIdent tableName, " drop ", quoteIdent colName, ";"]
rename_column :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
-> Migration Int64
rename_column = ((executeQuery_ .) .) . rename_column_stmt
rename_column_stmt :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
-> Query
rename_column_stmt tableName colName colNameNew = Query $ S8.concat
[ "alter table ", quoteIdent tableName, " rename "
, quoteIdent colName, " to ", quoteIdent colNameNew, ";"]
change_column :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
-> Migration Int64
change_column = ((executeQuery_ .) .) . change_column_stmt
change_column_stmt :: S8.ByteString
-> S8.ByteString
-> S8.ByteString
-> Query
change_column_stmt tableName colName action = Query $ S8.concat
[ "alter table ", quoteIdent tableName, " alter "
, quoteIdent colName, " ", action, ";"]
data CmdArgs = CmdArgs { cmd :: String
, cmdVersion :: String
, cmdCommit :: Bool }
create_index :: S8.ByteString
-> S8.ByteString
-> [S8.ByteString]
-> Migration Int64
create_index = ((executeQuery_ .) .) . (create_index_stmt False)
create_unique_index :: S8.ByteString
-> S8.ByteString
-> [S8.ByteString]
-> Migration Int64
create_unique_index = ((executeQuery_ .) .) . (create_index_stmt True)
create_index_stmt :: Bool
-> S8.ByteString
-> S8.ByteString
-> [S8.ByteString]
-> Query
create_index_stmt unq indexName tableName colNames = Query $ S8.concat
[ "create", unique, " index ", quoteIdent indexName, " on "
, quoteIdent tableName, " (", cols, ")", ";" ]
where cols = S8.intercalate ", " $ map quoteIdent colNames
unique = if unq then " unique" else ""
drop_index :: S8.ByteString
-> Migration Int64
drop_index = executeQuery_ . drop_index_stmt
drop_index_stmt :: S8.ByteString
-> Query
drop_index_stmt indexName = Query $ S8.concat
[ "drop index ", quoteIdent indexName, ";" ]
parseCmdArgs :: [String] -> Maybe CmdArgs
parseCmdArgs args = do
mycmd <- listToMaybe args
let args0 = tail args
myversion <- listToMaybe args0
return $ go (CmdArgs mycmd myversion False) $ tail args0
where go res [] = res
go res (arg:as) =
let newRes = case arg of
"--with-db-commit" -> res { cmdCommit = True }
_ -> res
in go newRes as
defaultMain :: (Connection -> IO ())
-> (Connection -> IO ())
-> IO ()
defaultMain up down = do
(Just cmdArgs) <- getArgs >>= return . parseCmdArgs
case cmd cmdArgs of
"up" -> do
conn <- connectEnv
res <- query_ conn
"select version from schema_migrations order by version desc limit 1"
let currentVersion = case res of
[] -> ""
(Only v):_ -> v
let version = cmdVersion cmdArgs
if currentVersion < version then do
begin conn
up conn
void $ execute conn "insert into schema_migrations values(?)"
(Only version)
if cmdCommit cmdArgs then
commit conn
else rollback conn
else exitWith $ ExitFailure 1
"down" -> do
conn <- connectEnv
res <- query_ conn
"select version from schema_migrations order by version desc limit 1"
let currentVersion = case res of
[] -> ""
(Only v):_ -> v
let version = cmdVersion cmdArgs
if currentVersion == version then do
begin conn
down conn
void $ execute conn "delete from schema_migrations where version = ?"
(Only version)
if cmdCommit cmdArgs then
commit conn
else rollback conn
else
exitWith $ ExitFailure 1
_ -> exitWith $ ExitFailure 1