{-# LANGUAGE EmptyDataDecls    #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE MultiParamTypeClasses, ScopedTypeVariables #-}
-- | An ODBC backend for persistent.
module Database.Persist.MigratePostgres
    ( getMigrationStrategy 
    ) where

import Database.Persist.Sql
import Control.Monad.IO.Class (MonadIO (..))
import Control.Monad.Trans.Resource (runResourceT)
import qualified Data.Text as T
import Data.Text (pack,Text)

import Data.Either (partitionEithers)
import Control.Arrow
import Data.List (find, intercalate, sort, groupBy)
import Data.Function (on)
import Data.Conduit
import qualified Data.Conduit.List as CL
import Data.Maybe (mapMaybe)

import qualified Data.Text.Encoding as T

import Database.Persist.ODBCTypes
import Data.Acquire (with)

#if DEBUG
import Debug.Trace
tracex::String -> a -> a
tracex = trace
#else
tracex::String -> a -> a
tracex _ b = b
#endif

getMigrationStrategy :: DBType -> MigrationStrategy
getMigrationStrategy dbtype@Postgres {} = 
     MigrationStrategy
                          { dbmsLimitOffset=decorateSQLWithLimitOffset "LIMIT ALL" 
                           ,dbmsMigrate=migrate' 
                           ,dbmsInsertSql=insertSql' 
                           ,dbmsEscape=escape 
                           ,dbmsType=dbtype
                          } 
getMigrationStrategy dbtype = error $ "Postgres: calling with invalid dbtype " ++ show dbtype
                     
migrate' :: [EntityDef]
         -> (Text -> IO Statement)
         -> EntityDef
         -> IO (Either [Text] [(Bool, Text)])
migrate' allDefs getter val = fmap (fmap $ map showAlterDb) $ do
    let name = entityDB val
    old <- getColumns getter val
    case partitionEithers old of
        ([], old'') -> do
            let old' = partitionEithers old''
            let (newcols', udefs, fdefs) = mkColumns allDefs val
            let newcols = filter (not . safeToRemove val . cName) newcols'
            let udspair = map udToPair udefs
            --let composite = isJust $ entityPrimary val
            if null old
                then do
                    let idtxt = case entityPrimary val of
                                  Just pdef -> concat [" PRIMARY KEY (", intercalate "," $ map (T.unpack . escape . fieldDB) $ compositeFields pdef, ")"]
                                  Nothing   -> concat [T.unpack $ escape $ fieldDB $ entityId val
                                        , " SERIAL PRIMARY KEY UNIQUE"]
                    let addTable = AddTable $ concat
                            -- Lower case e: see Database.Persist.Sql.Migration
                            [ "CREATe TABLE "
                            , T.unpack $ escape name
                            , "("
                            , idtxt
                            , if null newcols then [] else ","
                            , intercalate "," $ map showColumn newcols
                            , ")"
                            ]
                    let uniques = flip concatMap udspair $ \(uname, ucols) ->
                            [AlterTable name $ AddUniqueConstraint uname ucols]
                        references = mapMaybe (\c@Column { cName=cname, cReference=Just (refTblName, _) } -> getAddReference allDefs name refTblName cname (cReference c)) $ filter (\c -> cReference c /= Nothing) newcols
                        foreignsAlt = map (\fdef -> let (childfields, parentfields) = unzip (map (\((_,b),(_,d)) -> (b,d)) (foreignFields fdef)) 
                                                    in AlterColumn name (foreignRefTableDBName fdef, AddReference (foreignConstraintNameDBName fdef) childfields parentfields)) fdefs
                    return $ Right $ addTable : uniques ++ references ++ foreignsAlt
                else do
                    let (acs, ats) = getAlters allDefs val (newcols, udspair) old'
                    let acs' = map (AlterColumn name) acs
                    let ats' = map (AlterTable name) ats
                    return $ Right $ acs' ++ ats'
        (errs, _) -> return $ Left errs

type SafeToRemove = Bool

data AlterColumn = Type SqlType | IsNull | NotNull | Add' Column | Drop SafeToRemove
                 | Default String | NoDefault | Update' String
                 | AddReference DBName [DBName] [DBName] | DropReference DBName
type AlterColumn' = (DBName, AlterColumn)

data AlterTable = AddUniqueConstraint DBName [DBName]
                | DropConstraint DBName

data AlterDB = AddTable String
             | AlterColumn DBName AlterColumn'
             | AlterTable DBName AlterTable

-- | Returns all of the columns in the given table currently in the database.
getColumns :: (Text -> IO Statement)
           -> EntityDef
           -> IO [Either Text (Either Column (DBName, [DBName]))]
getColumns getter def = do
    let sqlv=concat ["SELECT "
                          ,"column_name "
                          ,",is_nullable "
                          ,",udt_name "
                          ,",column_default "
                          ,",numeric_precision "
                          ,",numeric_scale "
                          ,"FROM information_schema.columns "
                          ,"WHERE table_catalog=current_database() "
                          ,"AND table_schema=current_schema() "
                          ,"AND table_name=? "
                          ,"AND column_name <> ?"]
  
    stmt <- getter $ pack sqlv
    let vals =
            [ PersistText $ unDBName $ entityDB def
            , PersistText $ unDBName $ fieldDB $ entityId def
            ]
    cs <- with (stmtQuery stmt vals) ($$ helperClmns)
    let sqlc=concat ["SELECT "
                          ,"c.constraint_name, "
                          ,"c.column_name "
                          ,"FROM information_schema.key_column_usage c, "
                          ,"information_schema.table_constraints k "
                          ,"WHERE c.table_catalog=current_database() "
                          ,"AND c.table_catalog=k.table_catalog "
                          ,"AND c.table_schema=current_schema() "
                          ,"AND c.table_schema=k.table_schema "
                          ,"AND c.table_name=? "
                          ,"AND c.table_name=k.table_name "
                          ,"AND c.column_name <> ? "
                          ,"AND c.ordinal_position=1 "
                          ,"AND c.constraint_name=k.constraint_name "
                          ,"AND k.constraint_type not in ('PRIMARY KEY','FOREIGN KEY') "
                          ,"ORDER BY c.constraint_name, c.column_name"]

    stmt' <- getter $ pack sqlc
        
    us <- with (stmtQuery stmt' vals) ($$ helperU)
    liftIO $ putStrLn $ "\n\ngetColumns cs="++show cs++"\n\nus="++show us
    return $ cs ++ us
  where
    getAll front = do
        x <- CL.head
        case x of
            Nothing -> return $ front []
            Just [PersistText con, PersistText col] -> getAll (front . (:) (con, col))
            Just [PersistByteString con, PersistByteString col] -> getAll (front . (:) (T.decodeUtf8 con, T.decodeUtf8 col)) 
            Just xx -> error $ "oops: unexpected datatype returned odbc postgres  xx="++show xx
    helperU = do
        rows <- getAll id
        return $ map (Right . Right . (DBName . fst . head &&& map (DBName . snd)))
               $ groupBy ((==) `on` fst) rows

    helperClmns = CL.mapM getIt =$ CL.consume
        where
          getIt = fmap (either Left (Right . Left)) .
                  liftIO .
                  getColumn getter (entityDB def)
{-
    helper = do
        x <- CL.head
        case x of
            Nothing -> tracex "getColumns helper Nothing!!!" $ return []
            Just x' -> do
                col <- liftIO $ getColumn getter (entityDB def) x'
                let col' = tracex ("getColumns helper: col="++show col) $ case col of
                            Left e -> Left e
                            Right c -> Right $ Left c
                cols <- helper
                return $ col' : cols
-}
-- | Check if a column name is listed as the "safe to remove" in the entity
-- list.
safeToRemove :: EntityDef -> DBName -> Bool
safeToRemove def (DBName colName)
    = any (elem "SafeToRemove" . fieldAttrs)
    $ filter ((== (DBName colName)) . fieldDB)
    $ entityFields def

getAlters :: [EntityDef]
          -> EntityDef
          -> ([Column], [(DBName, [DBName])])
          -> ([Column], [(DBName, [DBName])])
          -> ([AlterColumn'], [AlterTable])
getAlters allDefs def (c1, u1) (c2, u2) =
    (getAltersC c1 c2, getAltersU u1 u2)
  where
    getAltersC [] old = map (\x -> (cName x, Drop $ safeToRemove def $ cName x)) old
    getAltersC (new:news) old =
        let (alters, old') = findAlters allDefs (entityDB def) new old
         in alters ++ getAltersC news old'

    getAltersU :: [(DBName, [DBName])]
               -> [(DBName, [DBName])]
               -> [AlterTable]
    getAltersU [] old = map DropConstraint $ filter (not . isManual) $ map fst old
    getAltersU ((name, cols):news) old =
        case lookup name old of
            Nothing -> AddUniqueConstraint name cols : getAltersU news old
            Just ocols ->
                let old' = filter (\(x, _) -> x /= name) old
                 in if sort cols == sort ocols
                        then getAltersU news old'
                        else  DropConstraint name
                            : AddUniqueConstraint name cols
                            : getAltersU news old'

    -- Don't drop constraints which were manually added.
    isManual (DBName x) = "__manual_" `T.isPrefixOf` x

getColumn :: (Text -> IO Statement)
          -> DBName -> [PersistValue]
          -> IO (Either Text Column)
getColumn getter tname [PersistByteString x, PersistByteString y, PersistByteString z, d, npre, nscl] = do
    case d' of
        Left s -> return $ Left s
        Right d'' ->
            case getType (T.decodeUtf8 z) of
                Left s -> return $ Left s
                Right t -> do
                    let cname = DBName $ T.decodeUtf8 x
                    ref <- getRef cname
                    return $ Right Column
                        { cName = cname
                        , cNull = y == "YES"
                        , cSqlType = t
                        , cDefault = d''
                        , cDefaultConstraintName = Nothing
                        , cMaxLen = Nothing
                        , cReference = ref
                        }
  where
    getRef cname = do
        let sql = pack $ concat
                [ "SELECT "
                  ,"tc.table_name, "
                  ,"kcu.column_name, "
                  ,"ccu.table_name AS foreign_table_name, "
                  ,"ccu.column_name AS foreign_column_name, "
                  ,"kcu.ordinal_position "
                  ,"FROM "
                  ,"information_schema.table_constraints AS tc "
                  ,"JOIN information_schema.key_column_usage "
                  ,"AS kcu ON tc.constraint_name = kcu.constraint_name "
                  ,"JOIN information_schema.constraint_column_usage "
                  ,"AS ccu ON ccu.constraint_name = tc.constraint_name "
                  ,"WHERE constraint_type = 'FOREIGN KEY' "
                  ,"and tc.table_name=? "
                  ,"and tc.constraint_name=? "
                  ,"and tc.table_catalog=current_database() "
                  ,"AND tc.table_catalog=kcu.table_catalog "
                  ,"AND tc.table_catalog=ccu.table_catalog "
                  ,"AND tc.table_schema=current_schema() "
                  ,"AND tc.table_schema=kcu.table_schema "
                  ,"AND tc.table_schema=ccu.table_schema "
                  ,"order by tc.table_name,tc.constraint_name, kcu.ordinal_position " 
                  ]

        let ref = refName tname cname
        stmt <- getter sql
        with (stmtQuery stmt
                     [ PersistText $ unDBName tname
                     , PersistText $ unDBName ref
                     ]) ($$ do
            m <- CL.head

            return $ case m of
              Just [PersistText _table, PersistText _col, PersistText reftable, PersistText _refcol, PersistInt64 _pos] -> Just (DBName reftable, ref)
              Just [PersistByteString _table, PersistByteString _col, PersistByteString reftable, PersistByteString _refcol, PersistInt64 _pos] -> Just (DBName (T.decodeUtf8 reftable), ref)
              Nothing -> Nothing
              _ -> error $ "unexpected result found ["++ show m ++ "]" )
    d' = case d of
            PersistNull   -> Right Nothing
            PersistText t -> Right $ Just t
            PersistByteString bs -> Right $ Just $ T.decodeUtf8 bs
            _ -> Left $ pack $ "Invalid default column: " ++ show d
    getType "int4"        = Right $ SqlInt32
    getType "int8"        = Right $ SqlInt64
    getType "varchar"     = Right $ SqlString
    getType "date"        = Right $ SqlDay
    getType "bool"        = Right $ SqlBool
    getType "timestamp"   = Right $ SqlDayTime
--    getType "timestamptz" = Right $ SqlDayTimeZoned
    getType "float4"      = Right $ SqlReal
    getType "float8"      = Right $ SqlReal
    getType "bytea"       = Right $ SqlBlob
    getType "time"        = Right $ SqlTime
    getType "numeric"     = getNumeric npre nscl
    getType a             = Right $ SqlOther a

    getNumeric (PersistInt64 a) (PersistInt64 b) = Right $ SqlNumeric (fromIntegral a) (fromIntegral b)
    getNumeric a b = Left $ pack $ "Can not get numeric field precision, got: " ++ show a ++ " and " ++ show b ++ " as precision and scale"
getColumn _ a2 x =
    return $ Left $ pack $ "Invalid result from information_schema: " ++ show x ++ " a2[" ++ show a2 ++ "]"

findAlters :: [EntityDef] -> DBName -> Column -> [Column] -> ([AlterColumn'], [Column])
findAlters defs tablename col@(Column name isNull sqltype def _defConstraintName _maxLen ref) cols =
    tracex ("\n\n\nfindAlters tablename="++show tablename++ " name="++ show name++" col="++show col++"\ncols="++show cols++"\n\n\n") $ case filter ((name ==) . cName) cols of
        [] -> ([(name, Add' col)], cols)
        Column _ isNull' sqltype' def' defConstraintName' _maxLen' ref':_ ->
            let refDrop Nothing = []
                refDrop (Just (_, cname)) = tracex ("\n\n\n44444 findAlters dropping fkey defConstraintName'="++show defConstraintName' ++" name="++show name++" cname="++show cname++" tablename="++show tablename++"\n\n\n") $ 
                                             [(name, DropReference cname)]
                refAdd Nothing = []
                refAdd (Just (tname, a)) = tracex ("\n\n\n33333 findAlters adding fkey defConstraintName'="++show defConstraintName' ++" name="++show name++" tname="++show tname++" a="++show a++" tablename="++show tablename++"\n\n\n") $ 
                                           case find ((==tname) . entityDB) defs of
                                                Just refdef -> [(tname, AddReference a [name] [fieldDB $ entityId refdef])]
                                                Nothing -> error $ "could not find the entityDef for reftable[" ++ show tname ++ "]"
                modRef = tracex ("modType: sqltype[" ++ show sqltype ++ "] sqltype'[" ++ show sqltype' ++ "] name=" ++ show name) $ 
                    if fmap snd ref == fmap snd ref'
                        then []
                        else tracex ("\n\n\nmodRef findAlters drop/add cos ref doesnt match ref[" ++ show ref ++ "] ref'[" ++ show ref' ++ "] tablename="++show tablename++"\n\n\n") $ 
                              refDrop ref' ++ refAdd ref
                modNull = case (isNull, isNull') of
                            (True, False) -> [(name, IsNull)]
                            (False, True) ->
                                let up = case def of
                                            Nothing -> id
                                            Just s -> (:) (name, Update' $ T.unpack s)
                                 in up [(name, NotNull)]
                            _ -> []
                modType = tracex ("modType: sqltype[" ++ show sqltype ++ "] sqltype'[" ++ show sqltype' ++ "] name=" ++ show name) $ 
                          if sqltype == sqltype' then [] else [(name, Type sqltype)]
                modDef = tracex ("modDef col=" ++ show col ++ " def=" ++ show def ++ " def'=" ++ show def') $
                    if cmpdef def def'
                        then []
                        else case def of
                                Nothing -> [(name, NoDefault)]
                                Just s -> [(name, Default $ T.unpack s)]
             in (modRef ++ modDef ++ modNull ++ modType,
                 filter (\c -> cName c /= name) cols)

cmpdef::Maybe Text -> Maybe Text -> Bool
cmpdef Nothing Nothing = True
cmpdef (Just def) (Just def') | def==def' = True
                              | otherwise = 
        let (a,_)=T.breakOnEnd ":" def'
        in -- tracex ("cmpdef def[" ++ show def ++ "] def'[" ++ show def' ++ "] a["++show a++"]") $ 
           case T.stripSuffix "::" a of
              Just xs -> def==xs
              Nothing -> False
cmpdef _ _ = False

-- | Get the references to be added to a table for the given column.
getAddReference :: [EntityDef] -> DBName -> DBName -> DBName -> Maybe (DBName, DBName) -> Maybe AlterDB
getAddReference allDefs table reftable cname ref =
    case ref of
        Nothing -> Nothing
        Just (s, z) -> tracex ("\n\ngetaddreference table="++ show table++" reftable="++show reftable++" s="++show s++" z=" ++ show z++"\n\n") $ 
                       Just $ AlterColumn table (s, AddReference (refName table cname) [cname] [id_])
                          where
                            id_ = maybe (error $ "Could not find ID of entity " ++ show reftable)
                                        id $ do
                                          entDef <- find ((== reftable) . entityDB) allDefs
                                          return (fieldDB $ entityId entDef)
                          

showColumn :: Column -> String
showColumn (Column n nu sqlType' def _defConstraintName _maxLen _ref) = concat
    [ T.unpack $ escape n
    , " "
    , showSqlType sqlType' _maxLen
    , " "
    , if nu then "NULL" else "NOT NULL"
    , case def of
        Nothing -> ""
        Just s -> " DEFAULT " ++ T.unpack s
    ]

showSqlType :: SqlType -> Maybe Integer -> String
showSqlType SqlString Nothing = "VARCHAR"
showSqlType SqlString (Just len) = "VARCHAR(" ++ show len ++ ")"
showSqlType SqlInt32 _ = "INT4"
showSqlType SqlInt64 _ = "INT8"
showSqlType SqlReal _ = "DOUBLE PRECISION"
showSqlType (SqlNumeric s prec) _ = "NUMERIC(" ++ show s ++ "," ++ show prec ++ ")"
showSqlType SqlDay _ = "DATE"
showSqlType SqlTime _ = "TIME"
showSqlType SqlDayTime _ = "TIMESTAMP"
--showSqlType SqlDayTimeZoned _ = "TIMESTAMP WITH TIME ZONE"
showSqlType SqlBlob _ = "BYTEA"
showSqlType SqlBool _ = "BOOLEAN"
showSqlType (SqlOther t) _ = T.unpack t

showAlterDb :: AlterDB -> (Bool, Text)
showAlterDb (AddTable s) = (False, pack s)
showAlterDb (AlterColumn t (c, ac)) =
    (isUnsafe ac, pack $ showAlter t (c, ac))
  where
    isUnsafe (Drop safeToRemove) = not safeToRemove
    isUnsafe _ = False
showAlterDb (AlterTable t at) = (False, pack $ showAlterTable t at)

showAlterTable :: DBName -> AlterTable -> String
showAlterTable table (AddUniqueConstraint cname cols) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " ADD CONSTRAINT "
    , T.unpack $ escape cname
    , " UNIQUE("
    , intercalate "," $ map (T.unpack . escape) cols
    , ")"
    ]
showAlterTable table (DropConstraint cname) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " DROP CONSTRAINT "
    , T.unpack $ escape cname
    ]

showAlter :: DBName -> AlterColumn' -> String
showAlter table (n, Type t) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " TYPE "
        , showSqlType t Nothing
        ]
showAlter table (n, IsNull) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " DROP NOT NULL"
        ]
showAlter table (n, NotNull) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " SET NOT NULL"
        ]
showAlter table (_, Add' col) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ADD COLUMN "
        , showColumn col
        ]
showAlter table (n, Drop _) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " DROP COLUMN "
        , T.unpack $ escape n
        ]
showAlter table (n, Default s) =
    concat
        [ "ALTER TABLE "
        , T.unpack $ escape table
        , " ALTER COLUMN "
        , T.unpack $ escape n
        , " SET DEFAULT "
        , s
        ]
showAlter table (n, NoDefault) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " ALTER COLUMN "
    , T.unpack $ escape n
    , " DROP DEFAULT"
    ]
showAlter table (n, Update' s) = concat
    [ "UPDATE "
    , T.unpack $ escape table
    , " SET "
    , T.unpack $ escape n
    , "="
    , s
    , " WHERE "
    , T.unpack $ escape n
    , " IS NULL"
    ]
showAlter table (reftable, AddReference fkeyname t2 id2) = concat
    [ "ALTER TABLE "
    , T.unpack $ escape table
    , " ADD CONSTRAINT "
    , T.unpack $ escape fkeyname
    , " FOREIGN KEY("
    , T.unpack $ T.intercalate "," $ map escape t2
    , ") REFERENCES "
    , T.unpack $ escape reftable
    , "("
    , T.unpack $ T.intercalate "," $ map escape id2
    , ")"
    ]
showAlter table (_, DropReference cname) = concat
    [ "ALTER TABLE "
    , T.unpack (escape table)
    , " DROP CONSTRAINT "
    , T.unpack $ escape cname
    ]

escape :: DBName -> Text
escape (DBName s) =
    T.pack $ '"' : go (T.unpack s) ++ "\""
  where
    go "" = ""
    go ('"':xs) = "\"\"" ++ go xs
    go (x:xs) = x : go xs


refName :: DBName -> DBName -> DBName
refName (DBName table) (DBName column) =
    DBName $ T.concat [table, "_", column, "_fkey"]

udToPair :: UniqueDef -> (DBName, [DBName])
udToPair ud = (uniqueDBName ud, map snd $ uniqueFields ud)

insertSql' :: EntityDef -> [PersistValue] -> InsertSqlResult
insertSql' ent vals = tracex ("\n\n\nGBTEST " ++ show (entityFields ent) ++ "\n\n\n") $
  case entityPrimary ent of
    Just _pdef -> 
      ISRManyKeys sql vals
        where sql = pack $ concat
                [ "INSERT INTO "
                , T.unpack $ escape $ entityDB ent
                , "("
                , intercalate "," $ map (T.unpack . escape . fieldDB) $ entityFields ent
                , ") VALUES("
                , intercalate "," (map (const "?") $ entityFields ent)
                , ")"
                ]
    Nothing -> 
      ISRSingle $ pack $ concat
        [ "INSERT INTO "
        , T.unpack $ escape $ entityDB ent
        , "("
        , intercalate "," $ map (T.unpack . escape . fieldDB) $ entityFields ent
        , ") VALUES("
        , intercalate "," (map (const "?") $ entityFields ent)
        , ") RETURNING "
        , T.unpack $ escape $ fieldDB $ entityId ent
        ]