{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE TypeSynonymInstances  #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE PolyKinds             #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# OPTIONS_GHC -Wno-orphans       #-}
module Data.HDiff.Patch.Merge where
import Data.Functor.Sum
import qualified Data.Map as M
import Control.Monad.State
import Control.Monad.Writer hiding (Sum)
import Control.Monad.Except
import Generics.MRSOP.Util
import Generics.MRSOP.Base
import Generics.MRSOP.Holes
import Data.Exists
import Data.HDiff.Patch
import Data.HDiff.Change
import Data.HDiff.Change.Apply
import Data.HDiff.Change.Thinning
import Data.HDiff.MetaVar
data Conflict :: (kon -> *) -> [[[Atom kon]]] -> Atom kon -> * where
  Conflict :: String
           -> RawPatch ki codes at
           -> RawPatch ki codes at
           -> Conflict ki codes at
type PatchC ki codes ix
  = Holes ki codes (Sum (Conflict ki codes) (CChange ki codes)) ('I ix)
noConflicts :: PatchC ki codes ix -> Maybe (Patch ki codes ix)
noConflicts = holesMapM rmvInL
  where
    rmvInL (InL _) = Nothing
    rmvInL (InR x) = Just x
getConflicts :: (ShowHO ki) => PatchC ki codes ix -> [String]
getConflicts = snd . runWriter . holesMapM go
  where
    go x@(InL (Conflict str _ _)) = tell [str] >> return x
    go x                          = return x
(//) :: ( Applicable ki codes (Holes2 ki codes)
        , EqHO ki , ShowHO ki
        , HasDatatypeInfo ki fam codes
        )
     => Patch ki codes ix
     -> Patch ki codes ix
     -> PatchC ki codes ix
p // q = holesJoin $ holesMap (uncurry' reconcile)
                   $ holesLCP p
                   $ q `withFreshNamesFrom` p
reconcile :: forall ki codes fam at
           . ( Applicable ki codes (Holes2 ki codes) , EqHO ki , ShowHO ki
             , HasDatatypeInfo ki fam codes
             )
          => RawPatch ki codes at
          -> RawPatch ki codes at
          -> Holes ki codes (Sum (Conflict ki codes) (CChange ki codes)) at
reconcile p q
  
  | patchEq p q = Hole' $ InR $ makeCopyFrom (distrCChange p)
  
  | otherwise    =
    
    let sp = holesJoin $ holesMap (uncurry' holesLCP . unCMatch) p
        sq = holesJoin $ holesMap (uncurry' holesLCP . unCMatch) q 
     in case process sp sq of
          CantReconcile err -> Hole' $ InL $ Conflict err p q
          ReturnNominator   -> holesMap InR p
          InstDenominator v -> Hole' $
            case runExcept $ transport (scIns sq) v of
              Left err -> InL $ Conflict (show err) p q
              Right r  -> case utx22change r of
                            Nothing  -> InL $ Conflict "chg" p q
                            Just res -> InR res
data ProcessOutcome ki codes
  = ReturnNominator
  | InstDenominator (Subst ki codes (Holes2 ki codes))
  | CantReconcile String
rawCpy :: M.Map Int Int
       -> Holes2 ki codes at
       -> Bool
rawCpy ar (Hole' v1 :*: Hole' v2) = metavarGet v1 == metavarGet v2
                                 && M.lookup (metavarGet v1) ar == Just 1
rawCpy _  _                       = False
simpleCopy :: Holes2 ki codes at -> Bool
simpleCopy (Hole' v1 :*: Hole' v2) = metavarGet v1 == metavarGet v2
simpleCopy _ = False
isLocalIns :: Holes2 ki codes at -> Bool
isLocalIns (Hole _ _ :*: HPeel _ _ _) = True
isLocalIns _                          = False
arityMap :: Holes ki codes (MetaVarIK ki) at -> M.Map Int Int
arityMap = go . holesGetHolesAnnWith' metavarGet
  where
    go []     = M.empty
    go (v:vs) = M.alter (Just . maybe 1 (+1)) v (go vs)
process :: (Applicable ki codes (Holes2 ki codes) , EqHO ki , ShowHO ki)
        => HolesHoles2 ki codes at -> HolesHoles2 ki codes at
        -> ProcessOutcome ki codes
process sp sq =
  case and <$> mapM (exElim $ uncurry' step1) phiD of
    Nothing    -> CantReconcile "p1n"
    Just True  -> if any (exElim $ uncurry' insins) phiID
                  then CantReconcile "p1ii"
                  else ReturnNominator
    Just False ->
      let partial = runState (runExceptT $ mapM_ (exElim $ uncurry' step2) phiID) M.empty
       in case partial of
            (Left err  , _) -> CantReconcile $ "p2n: " ++ err
            (Right ()  , s) -> InstDenominator s
  where
    (delsp :*: _) = utx2distr sp
    phiD  = holesGetHolesAnnWith' Exists $ holesLCP delsp sq
    phiID = holesGetHolesAnnWith' Exists $ holesLCP sp sq
    
    
    
    varmap = arityMap (snd' (utx2distr sq))
    
    
    
    step1 :: Holes ki codes (MetaVarIK ki) at -> HolesHoles2 ki codes at
          -> Maybe Bool
    
    
    
    step1 (HOpq _ _) (Hole _ chg)
      | simpleCopy chg = Just True
      | otherwise      = Nothing
    
    
    step1 (Hole _ _) _   = Just True
    
    
    step1 _ (Hole _ chg) = Just $ rawCpy varmap chg
    
    step1 _ _ = Just False
    
    
    
    step2 :: (Applicable ki codes (Holes2 ki codes) , EqHO ki , ShowHO ki)
          => HolesHoles2 ki codes at -> HolesHoles2 ki codes at
          -> ExceptT String (State (Subst ki codes (Holes2 ki codes))) ()
    step2 pp qq = do
      s <- lift get
      let del = scDel qq
      case thinUTx2 (utx2distr pp) del of
        Left e    -> throwError ("th: " ++ show e)
        Right pp0 -> do
          let pp' = uncurry' holesLCP pp0
          case runExcept (pmatch' s del pp') of
            Left  e  -> throwError (show e)
            Right s' -> put s' >> return ()
    insins :: HolesHoles2 ki codes at -> HolesHoles2 ki codes at -> Bool
    insins (Hole _ pp) (Hole _ qq) = isLocalIns pp && isLocalIns qq
    insins _ _ = False