```-- |
-- Module:      Cantor.Huge
-- Copyright:   (c) 2020 Andrew Lelechenko
-- Licence:     MIT
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>

{-# LANGUAGE LambdaCase #-}

module Cantor.Huge
( Huge
, pow
, evalWith
) where

import Prelude hiding ((^^))
import Control.Exception
import Math.NumberTheory.Logarithms
import Math.NumberTheory.Roots
import Numeric.Natural

-- | Lazy huge numbers with an efficient 'Ord' instance.
data Huge
= Nat Natural
| Mul Huge Huge
| Pow Huge Huge

instance Show Huge where
show = \case
Nat n -> show n
Add x y -> "(" ++ show x ++ " + " ++ show y ++ ")"
Mul x y -> "(" ++ show x ++ " * " ++ show y ++ ")"
Pow x y -> "(" ++ show x ++ " ^ " ++ show y ++ ")"

instance Num Huge where
(*) = mul
abs = id
signum = const 1
negate = throw Underflow
fromInteger = Nat . fromInteger

{-# RULES "Huge/pow" forall x p. x ^ p = x `pow` p #-}

add :: Huge -> Huge -> Huge
add (Nat 0) y = y
add x (Nat 0) = x
-- add (Nat x) (Nat y) = Nat \$ x + y

mul :: Huge -> Huge -> Huge
mul (Nat 0) _ = Nat 0
mul _ (Nat 0) = Nat 0
mul (Nat 1) y = y
mul x (Nat 1) = x
-- mul (Nat x) (Nat y) = Nat \$ x * y
mul x y = Mul x y

-- | Exponentiation.
pow :: Huge -> Huge -> Huge
pow _ (Nat 0) = Nat 1
pow (Nat 0) _ = Nat 0
pow x (Nat 1) = x
pow (Nat 1) _ = Nat 1
pow x y = Pow x y

-- | Convert 'Huge' to another numeric type,
-- using provided function for exponentiation.
evalWith :: Num a => (a -> a -> a) -> Huge -> a
evalWith (^^) = go
where
go = \case
Nat n   -> fromIntegral n
Add x y -> go x +  go y
Mul x y -> go x *  go y
Pow x y -> go x ^^ go y

-- | Simply 'evalWith' '(^)'.
eval :: Huge -> Natural
eval = evalWith (^)

instance Eq Huge where
x == y = x `compare` y == EQ

instance Ord Huge where
x `compare` y = x `compareHuge` y

-- Assuming the second argument has been constructed
-- using smart constructors.
compareNat :: Natural -> Huge -> Ordering
compareNat m = go
where
go = \case
Nat n -> m `compare` n
| Nat n <- x -> if m < n then LT else (m - n) `compareNat` y
| Nat n <- y -> if m < n then LT else (m - n) `compareNat` x
| go x == LT -> LT
| go y == LT -> LT
| x <= y     -> (m - eval x) `compareNat` y
| otherwise  -> (m - eval y) `compareNat` x
Mul x y
| Nat n <- x -> if m < n then LT else unwrap quotPerf m n y
| Nat n <- y -> if m < n then LT else unwrap quotPerf m n x
| go x /= GT -> LT
| go y /= GT -> LT
| x <= y     -> unwrap quotPerf m (eval x) y
| otherwise  -> unwrap quotPerf m (eval y) x
Pow x y
| Nat n <- x -> if m < n then LT else unwrap logPerf  m n y
| Nat n <- y -> if m < n then LT else unwrap rootPerf m n x
| go x /= GT -> LT
| go y /= GT -> LT
| x <= y     -> unwrap logPerf  m (eval x) y
| otherwise  -> unwrap rootPerf m (eval y) x

data Perfectness = Perfect | Imperfect
deriving (Eq, Ord, Show)

unwrap
:: (Natural -> Natural -> (Natural, Perfectness))
-> Natural
-> Natural
-> Huge
-> Ordering
unwrap f m n y = case m `f` n of
(q, r) -> q `compareNat` y <> (r `compare` Perfect)

quotPerf :: Natural -> Natural -> (Natural, Perfectness)
quotPerf m x = (q, r)
where
q = m `quot` x
r = if q * x == m then Perfect else Imperfect

rootPerf :: Natural -> Natural -> (Natural, Perfectness)
rootPerf m x = (q, r)
where
q = integerRoot x m
r = if q ^ x == m then Perfect else Imperfect

logPerf :: Natural -> Natural -> (Natural, Perfectness)
logPerf m x = (fromIntegral q, r)
where
q = naturalLogBase x m
r = if x ^ q == m then Perfect else Imperfect

inverse :: Ordering -> Ordering
inverse = \case
LT -> GT
EQ -> EQ
GT -> LT

-- Assuming both arguments have been constructed
-- using smart constructors.
compareHuge :: Huge -> Huge -> Ordering
Nat m   `compareHuge` z       = compareNat m z
z       `compareHuge` Nat m   = inverse \$ compareNat m z
Add x y `compareHuge` Mul u v = compareAscNodes Add Mul x y u v
Add x y `compareHuge` Pow u v = compareAscNodes Add Pow x y u v
Mul x y `compareHuge` Add u v = inverse \$ compareAscNodes Add Mul u v x y
Mul x y `compareHuge` Mul u v = compareMulMul x y u v
Mul x y `compareHuge` Pow u v = compareAscNodes Mul Pow x y u v
Pow x y `compareHuge` Add u v = inverse \$ compareAscNodes Add Pow u v x y
Pow x y `compareHuge` Mul u v = inverse \$ compareAscNodes Mul Pow u v x y
Pow x y `compareHuge` Pow u v = comparePowPow x y u v

-- Compare Add vs. Mul, Add vs. Pow or Mul vs. Pow,
-- but not vice versa.
compareAscNodes
:: (Huge -> Huge -> Huge)
-> (Huge -> Huge -> Huge)
-> Huge
-> Huge
-> Huge
-> Huge
-> Ordering
compareAscNodes fxy fuv x y u v =
case (x `compare` u, x `compare` v, y `compare` u, y `compare` v) of
(LT,  _,  _, LT) -> LT
( _, LT, LT,  _) -> LT

(GT, GT,  _,  _) -> uvSimpler
(GT,  _,  _, GT) -> uvSimpler
( _, GT, GT,  _) -> uvSimpler
( _,  _, GT, GT) -> uvSimpler

(LT,  _, LT,  _) -> xySimpler
(LT,  _, EQ,  _) -> xySimpler
(EQ,  _, LT,  _) -> xySimpler
(EQ,  _, EQ,  _) -> xySimpler

( _, LT,  _, LT) -> xySimpler
( _, LT,  _, EQ) -> xySimpler
( _, EQ,  _, LT) -> xySimpler
( _, EQ,  _, EQ) -> xySimpler
where
uvSimpler = inverse \$ compareNat (eval (fuv u v)) (fxy x y)
xySimpler = compareNat (eval (fxy x y)) (fuv u v)

case (x `compare` u, x `compare` v, y `compare` u, y `compare` v) of
(EQ,  _,  _, yv) -> yv
( _, EQ, yu,  _) -> yu
( _, xv, EQ,  _) -> xv
(xu,  _,  _, EQ) -> xu

(GT,  _,  _, GT) -> GT
( _, GT, GT,  _) -> GT
(LT,  _,  _, LT) -> LT
( _, LT, LT,  _) -> LT

-- x > u > y, x > v > y
(GT, GT, LT, LT)
| u <= v    -> x `compare` Add (Nat (eval u - eval y)) v
| otherwise -> x `compare` Add u (Nat (eval v - eval y))
-- y > u > x, y > v > x
(LT, LT, GT, GT)
| u <= v    -> y `compare` Add (Nat (eval u - eval x)) v
| otherwise -> y `compare` Add u (Nat (eval v - eval x))
-- u > x > v, u > y > v
(LT, GT, LT, GT)
| x <= y    -> Add (Nat (eval x - eval v)) y `compare` u
| otherwise -> Add x (Nat (eval y - eval v)) `compare` u
-- v > x > u, v > y > u
(GT, LT, GT, LT)
| x <= y    -> Add (Nat (eval x - eval u)) y `compare` v
| otherwise -> Add x (Nat (eval y - eval u)) `compare` v

compareMulMul :: Huge -> Huge -> Huge -> Huge -> Ordering
compareMulMul x y u v =
case (x `compare` u, x `compare` v, y `compare` u, y `compare` v) of
(EQ,  _,  _, yv) -> yv
( _, EQ, yu,  _) -> yu
( _, xv, EQ,  _) -> xv
(xu,  _,  _, EQ) -> xu

(GT,  _,  _, GT) -> GT
( _, GT, GT,  _) -> GT
(LT,  _,  _, LT) -> LT
( _, LT, LT,  _) -> LT

(GT, GT, LT, LT) -> uvSimpler
(LT, LT, GT, GT) -> uvSimpler
(LT, GT, LT, GT) -> xySimpler
(GT, LT, GT, LT) -> xySimpler
where
uvSimpler = inverse \$ compareNat (eval (Mul u v)) (Mul x y)
xySimpler = compareNat (eval (Mul x y)) (Mul u v)

comparePowPow :: Huge -> Huge -> Huge -> Huge -> Ordering
comparePowPow x y u v = case (x `compare` u, y `compare` v) of
(EQ, yv) -> yv
(xu, EQ) -> xu
(LT, LT) -> LT
(GT, GT) -> GT
(LT, GT) -> inverse \$ compareNat (eval (Pow u v)) (Pow x y)
(GT, LT) -> compareNat (eval (Pow x y)) (Pow u v)
```