```{-|
Copyright  :  (C) 2015-2016, University of Twente,
2017     , QBayLogic B.V.
Maintainer :  Christiaan Baaij <christiaan.baaij@gmail.com>

= SOP: Sum-of-Products, sorta

The arithmetic operation for 'GHC.TypeLits.Nat' are, addition
(@'GHC.TypeLits.+'@), subtraction (@'GHC.TypeLits.-'@), multiplication
(@'GHC.TypeLits.*'@), and exponentiation (@'GHC.TypeLits.^'@). This means we
cannot write expressions in a canonical SOP normal form. We can get rid of
subtraction by working with integers, and translating @a - b@ to @a + (-1)*b@.
Exponentation cannot be getten rid of that way. So we define the following
grammar for our canonical SOP-like normal form of arithmetic expressions:

@
SOP      ::= Product \'+\' SOP | Product
Product  ::= Symbol \'*\' Product | Symbol
Symbol   ::= Integer
|  Var
|  Var \'^\' Product
|  SOP \'^\' ProductE

ProductE ::= SymbolE \'*\' ProductE | SymbolE
SymbolE  ::= Var
|  Var \'^\' Product
|  SOP \'^\' ProductE
@

So a valid SOP terms are:

@
x*y + y^2
(x+y)^(k*z)
@

, but,

@
(x*y)^2
@

is not, and should be:

@
x^2 * y^2
@

Exponents are thus not allowed to have products, so for example, the expression:

@
(x + 2)^(y + 2)
@

in valid SOP form is:

@
4*x*(2 + x)^y + 4*(2 + x)^y + (2 + x)^y*x^2
@

Also, exponents can only be integer values when the base is a variable. Although
not enforced by the grammar, the exponentials are flatted as far as possible in
SOP form. So:

@
(x^y)^z
@

is flattened to:

@
x^(y*z)
@
-}
module GHC.TypeLits.Normalise.SOP
( -- * SOP types
Symbol (..)
, Product (..)
, SOP (..)
-- * Simplification
, reduceExp
, mergeS
, mergeP
, mergeSOPMul
, normaliseExp
)
where

-- External
import Data.Either (partitionEithers)
import Data.List   (sort)

-- GHC API
import Outputable  (Outputable (..), (<+>), text, hcat, integer, punctuate)

data Symbol v c
= I Integer                 -- ^ Integer constant
| C c                       -- ^ Non-integer constant
| E (SOP v c) (Product v c) -- ^ Exponentiation
| V v                       -- ^ Variable
deriving (Eq,Ord)

newtype Product v c = P { unP :: [Symbol v c] }
deriving (Eq)

instance (Ord v, Ord c) => Ord (Product v c) where
compare (P [x])   (P [y])   = compare x y
compare (P [_])   (P (_:_)) = LT
compare (P (_:_)) (P [_])   = GT
compare (P xs)    (P ys)    = compare xs ys

newtype SOP v c = S { unS :: [Product v c] }
deriving (Ord)

instance (Eq v, Eq c) => Eq (SOP v c) where
(S []) == (S [P [I 0]]) = True
(S [P [I 0]]) == (S []) = True
(S ps1) == (S ps2)      = ps1 == ps2

instance (Outputable v, Outputable c) => Outputable (SOP v c) where
ppr = hcat . punctuate (text " + ") . map ppr . unS

instance (Outputable v, Outputable c) => Outputable (Product v c) where
ppr = hcat . punctuate (text " * ") . map ppr . unP

instance (Outputable v, Outputable c) => Outputable (Symbol v c) where
ppr (I i)   = integer i
ppr (C c)   = ppr c
ppr (V s)   = ppr s
ppr (E b e) = case (pprSimple b, pprSimple (S [e])) of
(bS,eS) -> bS <+> text "^" <+> eS
where
pprSimple (S [P [I i]]) = integer i
pprSimple (S [P [V v]]) = ppr v
pprSimple sop           = text "(" <+> ppr sop <+> text ")"

mergeWith :: (a -> a -> Either a a) -> [a] -> [a]
mergeWith _ []      = []
mergeWith op (f:fs) = case partitionEithers \$ map (`op` f) fs of
([],_)              -> f : mergeWith op fs
(updated,untouched) -> mergeWith op (updated ++ untouched)

-- | reduce exponentials
--
-- Performs the following rewrites:
--
-- @
-- x^0          ==>  1
-- 0^x          ==>  0
-- 2^3          ==>  8
-- (k ^ i) ^ j  ==>  k ^ (i * j)
-- @
reduceExp :: (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp (E _                 (P [(I 0)])) = I 1        -- x^0 ==> 1
reduceExp (E (S [P [I 0]])     _          ) = I 0        -- 0^x ==> 0
reduceExp (E (S [P [(I i)]])   (P [(I j)]))
| j >= 0                                  = I (i ^ j)  -- 2^3 ==> 8

-- (k ^ i) ^ j ==> k ^ (i * j)
reduceExp (E (S [P [(E k i)]]) j) = case normaliseExp k (S [e]) of
(S [P [s]]) -> s
_           -> E k e
where
e = P . sort . map reduceExp \$ mergeWith mergeS (unP i ++ unP j)

reduceExp s = s

-- | Merge two symbols of a Product term
--
-- Performs the following rewrites:
--
-- @
-- 8 * 7    ==>  56
-- 1 * x    ==>  x
-- x * 1    ==>  x
-- 0 * x    ==>  0
-- x * 0    ==>  0
-- x * x^4  ==>  x^5
-- x^4 * x  ==>  x^5
-- y*y      ==>  y^2
-- @
mergeS :: (Ord v, Ord c) => Symbol v c -> Symbol v c
-> Either (Symbol v c) (Symbol v c)
mergeS (I i) (I j) = Left (I (i * j)) -- 8 * 7 ==> 56
mergeS (I 1) r     = Left r           -- 1 * x ==> x
mergeS l     (I 1) = Left l           -- x * 1 ==> x
mergeS (I 0) _     = Left (I 0)       -- 0 * x ==> 0
mergeS _     (I 0) = Left (I 0)       -- x * 0 ==> 0

-- x * x^4 ==> x^5
mergeS s (E (S [P [s']]) (P [I i]))
| s == s'
= Left (E (S [P [s']]) (P [I (i + 1)]))

-- x^4 * x ==> x^5
mergeS (E (S [P [s']]) (P [I i])) s
| s == s'
= Left (E (S [P [s']]) (P [I (i + 1)]))

-- 4^x * 2^x ==> 8^x
mergeS (E (S [P [I i]]) p) (E (S [P [I j]]) p')
| p == p'
= Left (E (S [P [I (i*j)]]) p)

-- y*y ==> y^2
mergeS l r
| l == r
= case normaliseExp (S [P [l]]) (S [P [I 2]]) of
(S [P [e]]) -> Left  e
_           -> Right l

-- x^y * x^(-y) ==> 1
mergeS (E s1 (P p1)) (E s2 (P (I i:p2)))
| i == (-1)
, s1 == s2
, p1 == p2
= Left (I 1)

-- x^(-y) * x^y ==> 1
mergeS (E s1 (P (I i:p1))) (E s2 (P p2))
| i == (-1)
, s1 == s2
, p1 == p2
= Left (I 1)

mergeS l _ = Right l

-- | Merge two products of a SOP term
--
-- Performs the following rewrites:
--
-- @
-- 2xy + 3xy  ==>  5xy
-- 2xy + xy   ==>  3xy
-- xy + 2xy   ==>  3xy
-- xy + xy    ==>  2xy
-- @
mergeP :: (Eq v, Eq c) => Product v c -> Product v c
-> Either (Product v c) (Product v c)
-- 2xy + 3xy ==> 5xy
mergeP (P ((I i):is)) (P ((I j):js))
| is == js = Left . P \$ (I (i + j)) : is
-- 2xy + xy  ==> 3xy
mergeP (P ((I i):is)) (P js)
| is == js = Left . P \$ (I (i + 1)) : is
-- xy + 2xy  ==> 3xy
mergeP (P is) (P ((I j):js))
| is == js = Left . P \$ (I (j + 1)) : is
-- xy + xy ==> 2xy
mergeP (P is) (P js)
| is == js  = Left . P \$ (I 2) : is
| otherwise = Right \$ P is

-- | Expand or Simplify 'complex' exponentials
--
-- Performs the following rewrites:
--
-- @
-- b^1              ==>  b
-- 2^(y^2)          ==>  4^y
-- (x + 2)^2        ==>  x^2 + 4xy + 4
-- (x + 2)^(2x)     ==>  (x^2 + 4xy + 4)^x
-- (x + 2)^(y + 2)  ==>  4x(2 + x)^y + 4(2 + x)^y + (2 + x)^yx^2
-- @
normaliseExp :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
-- b^1 ==> b
normaliseExp b (S [P [I 1]]) = b

-- x^(2xy) ==> x^(2xy)
normaliseExp b@(S [P [V _]]) (S [e]) = S [P [E b e]]

-- 2^(y^2) ==> 4^y
normaliseExp b@(S [P [_]]) (S [e@(P [_])]) = S [P [reduceExp (E b e)]]

-- (x + 2)^2 ==> x^2 + 4xy + 4
normaliseExp b (S [P [(I i)]]) =
foldr1 mergeSOPMul (replicate (fromInteger i) b)

-- (x + 2)^(2x) ==> (x^2 + 4xy + 4)^x
normaliseExp b (S [P (e@(I i):es)]) | i >= 0 =
-- Without the "| i >= 0" guard, normaliseExp can loop with itself
-- for exponentials such as: 2^(n-k)
normaliseExp (normaliseExp b (S [P [e]])) (S [P es])

-- (x + 2)^(xy) ==> (x+2)^(xy)
normaliseExp b (S [e]) = S [P [reduceExp (E b e)]]

-- (x + 2)^(y + 2) ==> 4x(2 + x)^y + 4(2 + x)^y + (2 + x)^yx^2
normaliseExp b (S e) = foldr1 mergeSOPMul (map (normaliseExp b . S . (:[])) e)

zeroP :: Product v c -> Bool
zeroP (P ((I 0):_)) = True
zeroP _             = False

mkNonEmpty :: SOP v c -> SOP v c
mkNonEmpty (S []) = S [P [(I 0)]]
mkNonEmpty s      = s

-- | Simplifies SOP terms using
--
-- * 'mergeS'
-- * 'mergeP'
-- * 'reduceExp'
simplifySOP :: (Ord v, Ord c) => SOP v c -> SOP v c
simplifySOP = repeatF go
where
go = mkNonEmpty
. S
. sort . filter (not . zeroP)
. mergeWith mergeP
. map (P . sort . map reduceExp . mergeWith mergeS . unP)
. unS

repeatF f x =
let x' = f x
in  if x' == x
then x
else repeatF f x'
{-# INLINEABLE simplifySOP #-}

-- | Merge two SOP terms by additions
mergeSOPAdd :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPAdd (S sop1) (S sop2) = simplifySOP \$ S (sop1 ++ sop2)