-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- This module defines comparison functions used in the subexpression recognizer.
--
-----------------------------------------------------------------------------

module Recognize.SubExpr.Compare
   ( pCompareBySimplify, pCompareByNormalize,
     pCompare, pSubstituteVars
   ) where

import           Control.Monad
import           Data.List                as L
import qualified Data.Map                 as M
import           Data.Maybe
import Domain.Math.Expr
import Ideas.Common.Library               hiding (option, simplify)
import Recognize.Data.Attribute
import Recognize.Data.Math
import Recognize.Expr.Functions         as F
import Recognize.Expr.Normalform
import Recognize.Parsing.Parse
import Recognize.Parsing.Derived
import Recognize.Strategy.Applications
import qualified Recognize.SubExpr.Functions as SF
import Recognize.SubExpr.SEParser
import Util.Cache

-- | Simplifies both expressions using the strategy module and then compares them.
--
-- We return (simplified expression #1, simplified expression #2, rewrites #1 - rewrites #2)
pCompareBySimplify :: Expr -> Expr -> SEParser (Expr, Expr, [Attribute])
pCompareBySimplify m e = do
  (sm,attr) <- pSimplify (nfComAssoc m)
  (se,attrs2) <- pSimplify (nfComAssoc e)
  _ <- gets precision
  b <- pCompareWith (nfComAssoc . nf4 1) sm se
  pLog ("pCompareBySimplify: " ++ show (m,sm,attr) ++ " " ++ show (e,se,attrs2) ++ " " ++ show b)
  guard b
  return (sm, se, attr L.\\ attrs2)

-- | Simplifies a single expression. Returns its simplified form and the used rewrites
pSimplify :: Expr -> SEParser (Expr, [Attribute])
pSimplify e = do
   -- Rewrite rules don't work on our custom vars. Using underSubst
   -- we substitute them with normal vars, rewrite the expression, and then undo the substitution
   (nfe, attr) <- maybeToParse $ SF.underSubst simplify $ SF.cleanExpr e
   pLog ("simplified: " ++ show nfe ++ " " ++ show attr)
   -- Substitute the custom vars with their corresponding values
   cnfe <- SF.cleanExpr <$> pSubstituteVars nfe
   return (cnfe, attr)

-- | Normalizes both expressions and compares them for equality.
--
-- Normalizing here entails sorting on commutativity and associativity, simplifying fractions and applying distributivity.
--
-- We return the first expression and rewrites (simpler fractions, distribution) of the first expression minus the rewrites of the second expression.
pCompareByNormalize :: Expr -> Expr -> SEParser (Expr, [Attribute])
pCompareByNormalize e1 e2 = do
  pLog $ "pCompareByNormalize " ++ show e1 ++ " " ++ show e2
  (ne1,attr1) <- pNormalize e1
  (ne2,attr2) <- pNormalize e2
  pLog ("Normalized: " ++ show ne1)
  pLog ("Normalized: " ++ show ne2)
  b <- pCompare ne1 ne2
  pLog $ "N: " ++ show ne1 ++ " | " ++ show ne2 ++ " " ++ show b
  guard b
  pLog "Normalize equal"
  return (e1, attr1 L.\\ attr2)

pNormalize :: Expr -> SEParser (Expr, [Attribute])
pNormalize e =
   maybeToParse (cachedNormalize e)

cachedNormalize :: Expr -> Maybe (Expr, [Attribute])
cachedNormalize = cached "cachedNormalize" $ \e -> do
   (nfe, attr) <- SF.underSubst normalize e
   return (nfComAssoc (SF.cleanExpr nfe), attr)

-- | Compare two expressions
--
-- Takes into account subexpression variables and magic types.
pCompare :: Expr -> Expr -> SEParser Bool
pCompare e1 e2 = do
  pLog ("pCompare: " ++ show e1 ++ " " ++ show e2)
  b <- pCompareWith id e1 e2
  pLog (show b)
  return b

pCompareWith :: (Expr -> Expr) -> Expr -> Expr -> SEParser Bool
pCompareWith f m e = isJust <$> option (pCompareExpr f (f $ SF.cleanExpr m) (f $ SF.cleanExpr e))

pCompareExpr :: (Expr -> Expr) -> Expr -> Expr -> SEParser ()
pCompareExpr _ (Nat n1) (Nat n2)       = guard $ n1 == n2
pCompareExpr _ (Var x) (Var y)         = guard $ x == y
pCompareExpr _ (Number n1) (Number n2) = guard $ n1 == n2
pCompareExpr f e1 e2
  | (SF.isMagicNat e1 && isNat e2) || (isNat e1 && SF.isMagicNat e2)     = return ()
  | (SF.isMagicNumber e1 && isNumber e2) || (isNumber e1 && SF.isMagicNumber e2) = return ()
  | (SF.isMagicVar e1 && F.isVar e2) || (F.isVar e1 && SF.isMagicVar e2) = return ()
  | SF.isVar e1 = subAndCompare e1 e2
  | SF.isVar e2 = subAndCompare e2 e1
  | isJust (isTimes e1)  = pCompareWithFunction (\e -> Just (timesSymbol,  snd $ from productView e)) f e1 e2
  | isJust (isDivide e1) = pCompareWithFunction (\e -> Just (divideSymbol, snd $ from productView e)) f e1 e2
  | isJust (isPlus e1)   = pCompareWithFunction (\e -> Just (plusSymbol,  from sumView e)) f e1 e2
  | isJust (isMinus e1)  = pCompareWithFunction (\e -> Just (minusSymbol, from sumView e)) f e1 e2
  | otherwise = pCompareWithFunction getFunction f e1 e2
  where
    subAndCompare :: Expr -> Expr -> SEParser ()
    subAndCompare e1 e2 = do
      e1' <- pSubstituteVars e1
      pCompareExpr f (f e1') e2
      -- Happens only if f e1' was equal to e2
      k <- getVarKey e1
      pUpdateVars k e2

-- Limited in the sense that only one (sub)expression may contain a magic natural number of variable
pCompareWithFunction :: (Expr -> Maybe (Symbol, [Expr])) -> (Expr -> Expr) -> Expr -> Expr -> SEParser ()
pCompareWithFunction fun f e1 e2 = do
  pLog ("pCompareExpr: " ++ show e1 ++ " " ++ show e2)
  (s1,xs) <- maybeToParse $ fun e1
  (s2,ys) <- maybeToParse $ fun e2
  let common = xs `intersect` ys
  let (diffX, diffY) = (xs \\ common, ys \\ common)
  guard (s1 == s2)
  let xIsMagic = all (\x -> SF.hasMagicNat x || SF.hasMagicNumber x || SF.hasMagicVar x || SF.hasVar x) diffX
  let yIsMagic = all (\x -> SF.hasMagicNat x || SF.hasMagicNumber x || SF.hasMagicVar x || SF.hasVar x) diffY
  -- Only difference is a magic natural number, so assume it is identity
  pLog ("P: " ++ show common)
  pLog ("P: " ++ show diffX)
  pLog ("P: " ++ show diffY)
  if length diffX == length diffY && (xIsMagic || yIsMagic)
    then mapM_ (uncurry (pCompareExpr f)) (zip diffX diffY)
  -- No difference, hence these must be equal
    else do
      guard $ null diffX && null diffY
      return ()

-- | Map all subexpression variables to the values they refer to
pSubstituteVars :: Expr -> SEParser Expr
pSubstituteVars e = do
  dic <- gets usedVariables
  maybeToParse $ SF.substituteAllIf SF.isVar dic e

-- | Update a subexpression variable with a new value
pUpdateVars :: String -> Expr -> SEParser ()
pUpdateVars k v = modify $ \st ->
   st { usedVariables = M.update (const $ Just v) k (usedVariables st) }