{-# LANGUAGE FlexibleContexts, MultiParamTypeClasses #-}
-----------------------------------------------------------------------------
-- Copyright 2019, Advise-Me project team. This file is distributed under 
-- the terms of the Apache License 2.0. For more information, see the files
-- "LICENSE.txt" and "NOTICE.txt", which are included in the distribution.
-----------------------------------------------------------------------------
-- |
-- Maintainer  :  bastiaan.heeren@ou.nl
-- Stability   :  provisional
-- Portability :  portable (depends on ghc)
--
-- Common functions on expressions or math types.
-- 
-----------------------------------------------------------------------------

module Recognize.Expr.Functions where

import Control.Monad
import Data.Char                                (toLower)
import Data.Function                            (on)
import Data.Functor.Identity                    (runIdentity)
import Data.Generics.Str                        (strStructure)
import Data.Maybe
import Domain.Math.Data.Relation
import Domain.Math.Expr.Data
import Ideas.Common.Rewriting hiding (hasVar)
import Ideas.Utils.Uniplate
import Recognize.Data.Math
import Recognize.Expr.Normalform
import Recognize.Expr.Symbols

-- | If there is a math that is a chained equation then split it up in a list of math types.
unchainAll :: [Math] -> ([Math], Bool)
unchainAll ms = (concat mss, or bs)
  where
    (mss, bs) = unzip (map f ms)

    f m = fromMaybe ([m], False) $ do
      Sym s xs <- getExpr m
      guard (s == chainedEqSymbol)
      let (eqs, b) = unchain xs
      return (map mk eqs, b)
      where
      mk eq = m {getResult = Right (toExpr eq)}

-- | From a list of expressions create a list of equations
unchain :: [Expr] -> ([Equation Expr], Bool)
unchain xs = (eqs, or bs)
  where
    (eqs, bs) = unzip (zipWith f xs (drop 1 xs))

    f x y =
      case getLeft y of
          Just ly | nf x /= nf y -> (x :==: ly, True)
          _ -> (x :==: y, False)

-- | gets the first found var if present
getVar :: Expr -> Maybe Expr
getVar = listToMaybe . mapMaybe f . universe
  where
    f e =
      case e of
        Var x -> Just $ Var x
        _ -> Nothing

-- | gets the string of the first found var if present
getVarS :: Expr -> Maybe String
getVarS e = do
  ve <- getVar e
  case ve of
    Var s -> return s
    _ -> Nothing

-- | Get the strings of all present vars in the expression
vars :: Expr -> [String]
vars e = concatMap f (universe e)
  where
    f (Var s) = [s]
    f _ = []

-- | Is a natural number
isNat :: Expr -> Bool
isNat (Nat _) = True
isNat _ = False

-- | Is a variable
isVar :: Expr -> Bool
isVar (Var _) = True
isVar _ = False

-- | Has a variable
hasVar :: Expr -> Bool
hasVar (Var _) = True
hasVar e = case getFunction e of
  Nothing -> False
  Just (_,xs) -> any hasVar xs

-- | Is a number
isNumber :: Expr -> Bool
isNumber (Number _) = True
isNumber _ = False

-- | Is a natural number, variable or number
isAtom :: Expr -> Bool
isAtom e = isNat e || isVar e || isNumber e

-- | Is a division
isDiv :: Expr -> Bool
isDiv (_ :/: _) = True
isDiv _ = False

-- | Has some expression as a subexpression
--
-- Check "Recognize.SubExpr" for more complicated cases
hasExpr :: Expr -> Expr -> Bool
hasExpr key e = key == e || case getFunction e of
  Nothing -> False
  Just (_,xs) -> any (hasExpr key) xs

-- | Given a list of expression and some target expression.
--
-- Return the expression that is closest to the target expression.
closestInList :: [Expr] -> Expr -> Maybe Expr
closestInList [] _ = Nothing
closestInList (x:xs) a =
  case closestInList xs a of
      Nothing -> return x
      Just y -- need a way to use abs on expressions
        | x > a && y > a && x < y -> return x
        | x > a && y > a -> return y
        | x < a && y < a && x > y -> return x
        | x < a && y < a -> return y
        | x > a && nf (x - a) < nf (a - y) -> return x
        | x > a -> return y
        | x < a && nf (a - x) < nf (y - a) -> return x
        | x < a -> return y
        | x == a -> return x
        | otherwise -> return y

-- | If the given expression is a binary operator
-- then replace its operator with plus,minus,times and division
-- otherwise return an empty list
changeOp :: Expr -> [Expr]
changeOp e =
    case children e of
      [x, y] -> map (\f -> f x y) bins
      _ -> []
  where
    bins = [(+), (-), (*), (/)]

-- | Determines whether two expressions share the same structure
-- without taking atomic values in consideration.  eg:  a + 6 == 4 + 2
equivalentStructure :: Expr -> Expr -> Bool
equivalentStructure a b = and $ zipWith f (universe a) (universe b)
  where
  f :: Expr -> Expr -> Bool
  f x y | isAtom x && isAtom y = True
        | isAtom x || isAtom y = False
        | otherwise = ((==) `on` (fst.runIdentity.getFunction)) x y

-- | Generate the atoms that are not equal between the two expressions
changeSet :: Expr -> Expr -> [(Expr, Expr)]
changeSet a b = foldl (\r (x,y) -> if isAtom x && isAtom y && x /= y then (x,y) : r else r) [] (zip (universe a) (universe b))

-- | Get the left or right argument of the direct children of the expression
getLeft, getRight :: Expr -> Maybe Expr
getLeft e = case children e of
                  x:_ -> Just x
                  _   -> Nothing
getRight e = case children e of
                  _:x:_ -> Just x
                  _     -> Nothing

-- | Traverse as far left as possible and return the found atom
getMostLeft :: Expr -> Maybe Expr
getMostLeft e = case getLeft e of
  Nothing -> Nothing
  Just x -> msum [getMostLeft x, Just x]

-- | Replace the left or right argument of the direct children of the expression
replaceLeft, replaceRight :: Expr -> Expr -> Expr
replaceLeft new e =
    let (str, f) = uniplate e
    in case strStructure str of
          (_:rest, g) -> f (g (new:rest))
          _ -> e
replaceRight new e =
    let (str, f) = uniplate e
    in case strStructure str of
          (x:_:rest, g) -> f (g (x:new:rest))
          _ -> e

-- | Round and simplify an expression to a specified precision
roundNumber :: Int -> Expr -> Expr
roundNumber d e@(Number _) = nf4 d e
roundNumber _ e = e

-- | Normalize the argument of a normalize expression
normalizeIfNF :: Expr -> Expr
normalizeIfNF e@(Sym s [e'])
  | isNormalformSymbol s = nf e'
  | otherwise = e
normalizeIfNF e = e