-- |
-- Module      : Data.Express.Map
-- Copyright   : (c) 2019-2024 Rudy Matela
-- License     : 3-Clause BSD  (see the file LICENSE)
-- Maintainer  : Rudy Matela <rudy@matela.com.br>
--
-- Utilities for mapping or transforming 'Expr's.
module Data.Express.Map
  ( mapValues
  , mapVars
  , mapConsts
  , mapSubexprs
  , (//-)
  , (//)
  , renameVarsBy
  )
where

import Data.Express.Core
import Data.Express.Utils.List
import Data.Maybe (fromMaybe)

-- | /O(n*m)/.
-- Applies a function to all terminal values in an expression.
-- (cf. '//-')
--
-- Given that:
--
-- > > let zero  = val (0 :: Int)
-- > > let one   = val (1 :: Int)
-- > > let two   = val (2 :: Int)
-- > > let three = val (3 :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :$ xx :$ yy
-- > > let intToZero e = if typ e == typ zero then zero else e
--
-- Then:
--
-- > > one -+- (two -+- three)
-- > 1 + (2 + 3) :: Int
--
-- > > mapValues intToZero $ one -+- (two -+- three)
-- > 0 + (0 + 0) :: Integer
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapValues :: (Expr -> Expr) -> Expr -> Expr
mapValues :: (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f  =  Expr -> Expr
m
  where
  m :: Expr -> Expr
m (Expr
e1 :$ Expr
e2)  =  Expr -> Expr
m Expr
e1 Expr -> Expr -> Expr
:$ Expr -> Expr
m Expr
e2
  m Expr
e           =  Expr -> Expr
f Expr
e

-- | /O(n*m)/.
-- Applies a function to all variables in an expression.
--
-- Given that:
--
-- > > let primeify e = if isVar e
-- > |                  then case e of (Value n d) -> Value (n ++ "'") d
-- > |                  else e
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :$ xx :$ yy
--
-- Then:
--
-- > > xx -+- yy
-- > x + y :: Int
--
-- > > primeify xx
-- > x' :: Int
--
-- > > mapVars primeify $ xx -+- yy
-- > x' + y' :: Int
--
-- > > mapVars (primeify . primeify) $ xx -+- yy
-- > x'' + y'' :: Int
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapVars :: (Expr -> Expr) -> Expr -> Expr
mapVars :: (Expr -> Expr) -> Expr -> Expr
mapVars Expr -> Expr
f  =  (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f'
  where
  f' :: Expr -> Expr
f' Expr
e  =  if Expr -> Bool
isVar Expr
e
           then Expr -> Expr
f Expr
e
           else Expr
e

-- | /O(n*m)/.
-- Applies a function to all terminal constants in an expression.
--
-- Given that:
--
-- > > let one   = val (1 :: Int)
-- > > let two   = val (2 :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :$ xx :$ yy
-- > > let intToZero e = if typ e == typ zero then zero else e
--
-- Then:
--
-- > > one -+- (two -+- xx)
-- > 1 + (2 + x) :: Int
--
-- > > mapConsts intToZero (one -+- (two -+- xx))
-- > 0 + (0 + x) :: Integer
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapConsts :: (Expr -> Expr) -> Expr -> Expr
mapConsts :: (Expr -> Expr) -> Expr -> Expr
mapConsts Expr -> Expr
f  =  (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f'
  where
  f' :: Expr -> Expr
f' Expr
e  =  if Expr -> Bool
isConst Expr
e
           then Expr -> Expr
f Expr
e
           else Expr
e

-- | /O(n*m)/.
-- Substitute subexpressions of an expression using the given function.
-- Outer expressions have more precedence than inner expressions.
-- (cf. '//')
--
-- With:
--
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let zz = var "z" (undefined :: Int)
-- > > let plus = value "+" ((+) :: Int->Int->Int)
-- > > let times = value "*" ((*) :: Int->Int->Int)
-- > > let xx -+- yy = plus :$ xx :$ yy
-- > > let xx -*- yy = times :$ xx :$ yy
--
-- > > let pluswap (o :$ xx :$ yy) | o == plus = Just $ o :$ yy :$ xx
-- > |     pluswap _                           = Nothing
--
-- Then:
--
-- > > mapSubexprs pluswap $ (xx -*- yy) -+- (yy -*- zz)
-- > y * z + x * y :: Int
--
-- > > mapSubexprs pluswap $ (xx -+- yy) -*- (yy -+- zz)
-- > (y + x) * (z + y) :: Int
--
-- Substitutions do not stack, in other words
-- a replaced expression or its subexpressions are not further replaced:
--
-- > > mapSubexprs pluswap $ (xx -+- yy) -+- (yy -+- zz)
-- > (y + z) + (x + y) :: Int
--
-- Given that the argument function is /O(m)/, this function is /O(n*m)/.
mapSubexprs :: (Expr -> Maybe Expr) -> Expr -> Expr
mapSubexprs :: (Expr -> Maybe Expr) -> Expr -> Expr
mapSubexprs Expr -> Maybe Expr
f  =  Expr -> Expr
m
  where
  m :: Expr -> Expr
m Expr
e  =  Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe Expr
e' (Expr -> Maybe Expr
f Expr
e)
    where
    e' :: Expr
e'  =  case Expr
e of
           Expr
e1 :$ Expr
e2 -> Expr -> Expr
m Expr
e1 Expr -> Expr -> Expr
:$ Expr -> Expr
m Expr
e2
           Expr
e -> Expr
e

-- | /O(n*m)/.
-- Substitute occurrences of values in an expression
-- from the given list of substitutions.
-- (cf. 'mapValues')
--
-- Given that:
--
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let zz = var "z" (undefined :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :$ xx :$ yy
--
-- Then:
--
-- > > ((xx -+- yy) -+- (yy -+- zz)) //- [(xx, yy), (zz, yy)]
-- > (y + y) + (y + y) :: Int
--
-- > > ((xx -+- yy) -+- (yy -+- zz)) //- [(yy, yy -+- zz)]
-- > (x + (y + z)) + ((y + z) + z) :: Int
--
-- This function does not work for substituting non-terminal subexpressions:
--
-- > > (xx -+- yy) //- [(xx -+- yy, zz)]
-- > x + y :: Int
--
-- Please use the slower '//' if you want the above replacement to work.
--
-- Replacement happens only once:
--
-- > > xx //- [(xx,yy), (yy,zz)]
-- > y :: Int
--
-- Given that the argument list has length /m/,
-- this function is /O(n*m)/.
(//-) :: Expr -> [(Expr,Expr)] -> Expr
Expr
e //- :: Expr -> [(Expr, Expr)] -> Expr
//- [(Expr, Expr)]
s  =  (Expr -> Expr) -> Expr -> Expr
mapValues (Expr -> [(Expr, Expr)] -> Expr
forall a. Eq a => a -> [(a, a)] -> a
`lookupId` [(Expr, Expr)]
s) Expr
e

-- | /O(n*n*m)/.
-- Substitute subexpressions in an expression
-- from the given list of substitutions.
-- (cf. 'mapSubexprs').
--
-- Please consider using '//-' if you are replacing just terminal values
-- as it is faster.
--
-- Given that:
--
-- > > let xx = var "x" (undefined :: Int)
-- > > let yy = var "y" (undefined :: Int)
-- > > let zz = var "z" (undefined :: Int)
-- > > let xx -+- yy = value "+" ((+) :: Int->Int->Int) :$ xx :$ yy
--
-- Then:
--
-- > > ((xx -+- yy) -+- (yy -+- zz)) // [(xx -+- yy, yy), (yy -+- zz, yy)]
-- > y + y :: Int
--
-- > > ((xx -+- yy) -+- zz) // [(xx -+- yy, zz), (zz, xx -+- yy)]
-- > z + (x + y) :: Int
--
-- Replacement happens only once with outer expressions
-- having more precedence than inner expressions.
--
-- > > (xx -+- yy) // [(yy,xx), (xx -+- yy,zz), (zz,xx)]
-- > z :: Int
--
-- Given that the argument list has length /m/, this function is /O(n*n*m)/.
-- Remember that since /n/ is the size of an expression,
-- comparing two expressions is /O(n)/ in the worst case,
-- and we may need to compare with /n/ subexpressions in the worst case.
(//) :: Expr -> [(Expr,Expr)] -> Expr
Expr
e // :: Expr -> [(Expr, Expr)] -> Expr
// [(Expr, Expr)]
s  =  (Expr -> Maybe Expr) -> Expr -> Expr
mapSubexprs (Expr -> [(Expr, Expr)] -> Maybe Expr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
`lookup` [(Expr, Expr)]
s) Expr
e

-- | Rename variables in an 'Expr'.
--
-- > > renameVarsBy (++ "'") (xx -+- yy)
-- > x' + y' :: Int
--
-- > > renameVarsBy (++ "'") (yy -+- (zz -+- xx))
-- > (y' + (z' + x')) :: Int
--
-- > > renameVarsBy (++ "1") (abs' xx)
-- > abs x1 :: Int
--
-- > > renameVarsBy (++ "2") $ abs' (xx -+- yy)
-- > abs (x2 + y2) :: Int
--
-- NOTE: this will affect holes!
renameVarsBy :: (String -> String) -> Expr -> Expr
renameVarsBy :: (String -> String) -> Expr -> Expr
renameVarsBy String -> String
f = (Expr -> Expr) -> Expr -> Expr
mapValues Expr -> Expr
f'
  where
  f' :: Expr -> Expr
f' (Value (Char
'_':String
n) Dynamic
t)  =  String -> Dynamic -> Expr
Value (Char
'_'Char -> String -> String
forall a. a -> [a] -> [a]
:String -> String
f String
n) Dynamic
t
  f' Expr
e  =  Expr
e