module GHC.TypeLits.Normalise.SOP
(
Symbol (..)
, Product (..)
, SOP (..)
, reduceExp
, mergeS
, mergeP
, mergeSOPAdd
, mergeSOPMul
, normaliseExp
)
where
import Data.Either (partitionEithers)
import Data.List (sort)
import Outputable (Outputable (..), (<+>), text, hcat, integer, punctuate)
data Symbol v c
= I Integer
| C c
| E (SOP v c) (Product v c)
| V v
deriving (Eq,Ord)
newtype Product v c = P { unP :: [Symbol v c] }
deriving (Eq)
instance (Ord v, Ord c) => Ord (Product v c) where
compare (P [x]) (P [y]) = compare x y
compare (P [_]) (P (_:_)) = LT
compare (P (_:_)) (P [_]) = GT
compare (P xs) (P ys) = compare xs ys
newtype SOP v c = S { unS :: [Product v c] }
deriving (Ord)
instance (Eq v, Eq c) => Eq (SOP v c) where
(S []) == (S [P [I 0]]) = True
(S [P [I 0]]) == (S []) = True
(S ps1) == (S ps2) = ps1 == ps2
instance (Outputable v, Outputable c) => Outputable (SOP v c) where
ppr = hcat . punctuate (text " + ") . map ppr . unS
instance (Outputable v, Outputable c) => Outputable (Product v c) where
ppr = hcat . punctuate (text " * ") . map ppr . unP
instance (Outputable v, Outputable c) => Outputable (Symbol v c) where
ppr (I i) = integer i
ppr (C c) = ppr c
ppr (V s) = ppr s
ppr (E b e) = case (pprSimple b, pprSimple (S [e])) of
(bS,eS) -> bS <+> text "^" <+> eS
where
pprSimple (S [P [I i]]) = integer i
pprSimple (S [P [V v]]) = ppr v
pprSimple sop = text "(" <+> ppr sop <+> text ")"
mergeWith :: (a -> a -> Either a a) -> [a] -> [a]
mergeWith _ [] = []
mergeWith op (f:fs) = case partitionEithers $ map (`op` f) fs of
([],_) -> f : mergeWith op fs
(updated,untouched) -> mergeWith op (updated ++ untouched)
reduceExp :: (Ord v, Ord c) => Symbol v c -> Symbol v c
reduceExp (E _ (P [(I 0)])) = I 1
reduceExp (E (S [P [I 0]]) _ ) = I 0
reduceExp (E (S [P [(I i)]]) (P [(I j)]))
| j >= 0 = I (i ^ j)
reduceExp (E (S [P [(E k i)]]) j) = case normaliseExp k (S [e]) of
(S [P [s]]) -> s
_ -> E k e
where
e = P . sort . map reduceExp $ mergeWith mergeS (unP i ++ unP j)
reduceExp s = s
mergeS :: (Ord v, Ord c) => Symbol v c -> Symbol v c
-> Either (Symbol v c) (Symbol v c)
mergeS (I i) (I j) = Left (I (i * j))
mergeS (I 1) r = Left r
mergeS l (I 1) = Left l
mergeS (I 0) _ = Left (I 0)
mergeS _ (I 0) = Left (I 0)
mergeS s (E (S [P [s']]) (P [I i]))
| s == s'
= Left (E (S [P [s']]) (P [I (i + 1)]))
mergeS (E (S [P [s']]) (P [I i])) s
| s == s'
= Left (E (S [P [s']]) (P [I (i + 1)]))
mergeS (E (S [P [I i]]) p) (E (S [P [I j]]) p')
| p == p'
= Left (E (S [P [I (i*j)]]) p)
mergeS l r
| l == r
= case normaliseExp (S [P [l]]) (S [P [I 2]]) of
(S [P [e]]) -> Left e
_ -> Right l
mergeS (E s1 (P p1)) (E s2 (P (I i:p2)))
| i == (-1)
, s1 == s2
, p1 == p2
= Left (I 1)
mergeS (E s1 (P (I i:p1))) (E s2 (P p2))
| i == (-1)
, s1 == s2
, p1 == p2
= Left (I 1)
mergeS l _ = Right l
mergeP :: (Eq v, Eq c) => Product v c -> Product v c
-> Either (Product v c) (Product v c)
mergeP (P ((I i):is)) (P ((I j):js))
| is == js = Left . P $ (I (i + j)) : is
mergeP (P ((I i):is)) (P js)
| is == js = Left . P $ (I (i + 1)) : is
mergeP (P is) (P ((I j):js))
| is == js = Left . P $ (I (j + 1)) : is
mergeP (P is) (P js)
| is == js = Left . P $ (I 2) : is
| otherwise = Right $ P is
normaliseExp :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
normaliseExp b (S [P [I 1]]) = b
normaliseExp b@(S [P [V _]]) (S [e]) = S [P [E b e]]
normaliseExp b@(S [P [_]]) (S [e@(P [_])]) = S [P [reduceExp (E b e)]]
normaliseExp b (S [P [(I i)]]) =
foldr1 mergeSOPMul (replicate (fromInteger i) b)
normaliseExp b (S [P (e@(I i):es)]) | i >= 0 =
normaliseExp (normaliseExp b (S [P [e]])) (S [P es])
normaliseExp b (S [e]) = S [P [reduceExp (E b e)]]
normaliseExp b (S e) = foldr1 mergeSOPMul (map (normaliseExp b . S . (:[])) e)
zeroP :: Product v c -> Bool
zeroP (P ((I 0):_)) = True
zeroP _ = False
mkNonEmpty :: SOP v c -> SOP v c
mkNonEmpty (S []) = S [P [(I 0)]]
mkNonEmpty s = s
simplifySOP :: (Ord v, Ord c) => SOP v c -> SOP v c
simplifySOP = repeatF go
where
go = mkNonEmpty
. S
. sort . filter (not . zeroP)
. mergeWith mergeP
. map (P . sort . map reduceExp . mergeWith mergeS . unP)
. unS
repeatF f x =
let x' = f x
in if x' == x
then x
else repeatF f x'
{-# INLINEABLE simplifySOP #-}
mergeSOPAdd :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPAdd (S sop1) (S sop2) = simplifySOP $ S (sop1 ++ sop2)
{-# INLINEABLE mergeSOPAdd #-}
mergeSOPMul :: (Ord v, Ord c) => SOP v c -> SOP v c -> SOP v c
mergeSOPMul (S sop1) (S sop2)
= simplifySOP
. S
$ concatMap (zipWith (\p1 p2 -> P (unP p1 ++ unP p2)) sop1 . repeat) sop2
{-# INLINEABLE mergeSOPMul #-}