{-# LANGUAGE PatternSynonyms, ViewPatterns #-}
module Math.NumberTheory.Canon.AurifCyclo (
aurCandDec,
aurDec,
applyCycloPair, applyCycloPairWithMap, CycloPair, Poly,
cyclo, cycloWithMap,
cycloDivSet, cycloDivSetWithMap,
chineseAurif, chineseAurifWithMap,
crCycloAurifApply, applyCrCycloPair, divvy,
CycloMap, getIntegerBasedCycloMap, showCyclo, crCycloInitMap,
multPoly, divPoly, addPoly,
CR_, CanonRep_, CanonElement_
)
where
import Math.NumberTheory.Canon.Internals
import Math.NumberTheory.Moduli.Jacobi (JacobiSymbol(..), jacobi)
import Data.Array (array, (!), Array(), elems)
import GHC.Real (numerator, denominator)
import Data.List (sort, sortBy, (\\))
import qualified Data.Map as M
cr2 :: CR_
cr2 = fst $ crFromI 2
crCycloAurifApply :: Bool -> CR_ -> CR_ -> CR_ -> Integer -> CycloMap -> (CR_, CycloMap)
crCycloAurifApply b x y g gi m
| (crPrime g) && not (g == cr2 && b)
= eA ([term1, termNM1], m)
| not b = eA (cycA grtx grty g)
| b && not gpwr2 = eA (cycA (oddRoot x) (-1 * oddRoot y) odd')
| otherwise = (fst $ crSimpleApply op x y, m)
where op = if b then (+) else (-)
((gp, _):gs) = g
gpwr2 = gp == 2 && gs == []
gth_root v = crToI $ crRoot v gi
grtx = gth_root x
grty = gth_root y
term1 = integerApply op (crRoot x gi) (crRoot y gi)
termNM1 = div (integerApply op x y) term1
cycA x' y' n = (sort ia, m')
where (ia, m') = applyCrCycloPair x' y' n m
eA (a,mp) = (foldr1 crMult $ map (\v' -> fst $ crFromI v') v, m')
where (v, m') = case aurCandDecI x y gi g b of
Nothing -> auL a mp
Just (a1, a2) -> auL (divvy a a1 a2) mp
auL al ma = case c of
Nothing -> (al, mp')
Just (a3, a4) -> (divvy al a3 a4, mp')
where (c, mp') = chineseAurifCr x y b ma
odd' | gp == 2 = tail g
| otherwise = g
oddRoot v = crToI $ crRoot v (crToI odd')
aurCandDec :: Integer -> Integer -> Bool -> Maybe (Integer, Integer)
aurCandDec xi yi b = f (fst $ crFromI xi) (fst $ crFromI yi)
where f xp yp = aurCandDecI x y n (fst $ crFromI n) b
where n = gcd (crMaxRoot $ crAbs x) (crMaxRoot $ crAbs y)
gxy = crGCD xp yp
(x, y) = (crDivStrict xp gxy, crDivStrict yp gxy)
aurCandDecI :: CR_ -> CR_ -> Integer -> CR_ -> Bool -> Maybe (Integer, Integer)
aurCandDecI x y n cr b| (nm4 == 1 && b) || (nm4 /= 1 && not b) ||
(xdg == x && ydg == y) || (m /= 0)
= Nothing
| otherwise = case aurDecI n' cr' of
Nothing -> Nothing
Just (gamma, delta) -> apply gamma delta
where
(n', cr') | x /= cr1 && y /= cr1 = (n, cr)
| otherwise = (gcd1i, gcd1)
where x1 = if y /= cr1 then y else x
gcd1 = crRadical $ crGCD x1 cr
gcd1i = crToI gcd1
nm4 = mod n' 4
divTry a = case crDiv a (crExp cr' n' False) of
Left _ -> a
Right quotient -> quotient
xdg = divTry x
ydg = divTry y
mrGCD = gcd (crMaxRoot $ crAbs xdg) (crMaxRoot $ crAbs ydg)
m = mod mrGCD (2*n')
(x', ml) | ydg /= y = ( (crDivRational ydg x), if (not b) then (-1) else 1)
| otherwise = ( (crDivRational xdg y), 1);
xrtn = crMult cr' (crRoot x' n')
xrtnr = crToRational xrtn
sqrtnxr = crToRational $ crRoot (crMult cr' xrtn) 2
apply gm dt = Just (ml * numerator f1, numerator f2)
where f1 = c - sqrtnxr * d
f2 = c + sqrtnxr * d
c = aA gm xrtnr
d = aA dt xrtnr
aA a x'' = f (elems a) 1 0
where f [] _ a' = a'
f (c:cs) m' a' = f cs (m'*x'') (a' + (toRational c)*m')
aurDec :: Integer -> Maybe (Array Integer Integer, Array Integer Integer)
aurDec n | n <= 1 = error "aurifDecomp: n must be greater than 1"
| otherwise = aurDecI n (fst $ crFromI n)
aurDecI :: Integer -> CR_ -> Maybe (Array Integer Integer, Array Integer Integer)
aurDecI n cr | crHasSquare cr || n < 2 || n' < 2
= Nothing
| otherwise = Just (gamma, delta)
where nm4 = mod n 4
n' = if (nm4 == 1) then n else (2*n)
d = div (totient n') 2
dm2 = mod d 2
mg | dm2 == 1 = div (d-1) 2
| otherwise = div d 2
md | dm2 == 1 = div (d-1) 2
| otherwise = div (d-2) 2
q = array (1, d) ([(i, f i) | i <- [1..d]])
where f i | mod i 2 == 1 = convJacobi $ jacobi n i
| otherwise = eQ
where eQ | ff == True = moeb cr' * (totient g) * (cos' $ (n-1)*i)
| otherwise = error "Moebius fcn can't be called if number not totally factored"
(cr', ff)= crFromI $ div n' g
g = gcd n' i
moeb crm | crHasSquare crm = 0
| mod (length crm) 2 == 1 = -1
| otherwise = 1
cos' c | m8 == 2 || m8 == 6 = 0
| m8 == 4 = -1
| m8 == 0 = 1
| otherwise = error "Logic error: bad/odd value passed to cos'"
where m8 = mod c 8
gamma = array (0, d) ([(0,1)] ++ [(i, gf i) | i <- [1..d]])
where gf k | k > mg = gamma!(d-k)
| otherwise = div gTerm (2*k)
where gTerm = sum $ map f [0..k-1]
where f j = n * q!(2*k-2*j-1) * delta!j - q!(2*k-2*j) * gamma!j
delta = array (0, d-1) ([(0,1)] ++ [(i, df i) | i <- [1..d-1]])
where df k | k > md = delta!(d-k-1)
| otherwise = div dTerm (2*k+1)
where dTerm = gamma!k + sum (map f [0..k-1])
where f j = q!(2*k-2*j+1) * gamma!j - q!(2*k-2*j) * delta!j
divvy :: [Integer] -> Integer -> Integer -> [Integer]
divvy a x y = d (sortBy rev a) (abs x) (abs y)
where rev a' b' = if (a' > b') then LT else GT
d [] x' y' | x' == 1 && y' == 1 = []
| abs x' == 1 && abs y' == 1 = [x' * y']
| otherwise = error "Empty list passed as first param but x' and y' weren't both 1"
d (c:cs) x' y' | x' == 1 && y' == 1 = c:cs
| otherwise = v ++ d cs (div x' gnx) (div y' gny)
where v = filter (>1) $ [div q gny, gnx, gny]
gnx = gcd c x'
q = div c gnx
gny = gcd q y'
type Poly = [Integer]
type CycloPair = (Integer, Poly)
type CycloMapInternal = M.Map CR_ CycloPair
newtype CycloMap = MakeCM CycloMapInternal deriving (Eq, Show)
fromCM :: CycloMap -> CycloMapInternal
fromCM (MakeCM cm) = cm
getIntegerBasedCycloMap :: CycloMap -> M.Map Integer CycloPair
getIntegerBasedCycloMap cm = M.mapKeys crToI (fromCM cm)
crCycloInitMap :: CycloMap
crCycloInitMap = MakeCM $ M.insert cr1 (1, gen_xNm1 1) M.empty
cmLookup :: CR_ -> CycloMap -> Maybe CycloPair
cmLookup c m = M.lookup c (fromCM m)
cmInsert :: CR_ -> CycloPair -> CycloMap -> CycloMap
cmInsert c p m = MakeCM $ M.insert c p (fromCM m)
cyclo :: Integer -> (CycloPair, CycloMap)
cyclo n = crCyclo (fst $ crFromI n) crCycloInitMap
cycloWithMap :: Integer -> CycloMap -> (CycloPair, CycloMap)
cycloWithMap n m = crCyclo (fst $ crFromI n) m
cycloDivSet :: Integer -> CycloMap
cycloDivSet n = fst $ crCycloDivSet (fst $ crFromI n) crCycloInitMap
cycloDivSetWithMap :: Integer -> CycloMap -> (CycloMap, CycloMap)
cycloDivSetWithMap n m = crCycloDivSet (fst $ crFromI n) m
crCyclo :: CR_ -> CycloMap -> (CycloPair, CycloMap)
crCyclo cr m | crPositive cr = ((crToI $ crDivStrict cr r, p), m')
| otherwise = error "crCyclo: Positive integer needed"
where r = crRadical cr
((_,p), m') = crCycloRad r m
crCycloDivSet :: CR_ -> CycloMap -> (CycloMap, CycloMap)
crCycloDivSet cr m | crPositive cr = m2
| otherwise = error "crCycloDivSet: Positive integer needed"
where (_,m2) = c
crd = crDivisors cr
c = case cmLookup cr m of
Nothing -> c'
Just p -> (p, (mf, m))
where mf = MakeCM $ M.fromList $ filter (\(n,_) -> elem n crd) $ M.toList $ fromCM m
c' | r == cr = (pr, (pm, pm))
| otherwise = (cm, (mm, mm))
where (cm, mm) = mfn sqFulDivs pm
r = crRadical cr
(pr, pm) = crCycloRad r m
sqFulDivs = crd \\ crDivisors r
mfn [] _ = error "Logic Error in mfn: Empty list is forbidden"
mfn (n:ns) mp | ns == [] = cp
| otherwise = mfn ns mp'
where cp@(_, mp') = crCycloAll n mp
crCycloRad :: CR_ -> CycloMap -> (CycloPair, CycloMap)
crCycloRad cr m = case cmLookup cr m of
Nothing -> c'
Just p -> (p, m)
where c' | cs == [] = (cycpr, cmInsert cr cycpr m)
| otherwise = (cyc_n, cmInsert cr cyc_n mp)
where (_ : cs) = cr
r = fromInteger $ crToI $ crRadical cr
cycpr = genPrimePoly r
xNm1 = gen_xNm1 r
(cPrd, mp) = mf (init $ crDivisors cr)
cyc_n = (1, divPoly xNm1 cPrd)
mf (n:ns) = f ns p m'
where ((_,p), m') = crCycloRad n m
f (n':ns') p' mp = f ns' (multPoly p' p'') m''
where ((_,p''), m'') = crCycloRad n' mp
f _ p' mp = (p', mp)
mf [] = error "Logic error: Blank list can't be passed to mf aka crCycloMemoFold"
crCycloAll :: CR_ -> CycloMap -> (CycloPair, CycloMap)
crCycloAll cr m | p == cr1 = case cmLookup cr m of
Nothing -> error "Logic error: Radical value not found for crCycloAll"
Just cb -> (cb, m)
| otherwise = (crp, cmInsert cr crp md)
where (p, d) = crPullSq
((i,y), md) = case cmLookup d m of
Nothing -> crCycloAll d m
Just c -> (c, m)
crp = ((fst $ head p) * i, y)
crPullSq = f [] cr
where f h [] = (cr1, h)
f h (c@(cp,ce):cs) | ce > 1 = ([(cp, 1)], h ++ (cp, ce-1):cs)
| otherwise = f (h ++ [c]) cs
applyCrCycloPair :: Integer -> Integer -> CR_ -> CycloMap -> ([Integer], CycloMap)
applyCrCycloPair l r cr m = (applyCrCycloPairI l r cr (M.elems $ fromCM md), mn)
where (md, mn) = crCycloDivSet cr m
applyCrCycloPairI :: Integer -> Integer -> CR_ -> [CycloPair] -> [Integer]
applyCrCycloPairI l r cr cds = map applyPoly cds
where nd = crTotient cr
pA v = a where a = array (0,nd) ([(0,1)] ++ [(i, v*a!(i-1)) | i <- [1..nd]])
lpa = pA l
rpa = pA r
applyPoly (m,p) = foldr1 (+) (map f $ zip p [0..])
where f (a, b) | a == 0 = 0
| otherwise = a * lpa!(m*b) * rpa!(m*(maxExp - b))
maxExp = toInteger $ length p - 1
applyCycloPair :: Integer -> Integer -> Integer -> [Integer]
applyCycloPair x y e = fst $ applyCycloPairWithMap x y e crCycloInitMap
applyCycloPairWithMap :: Integer -> Integer -> Integer -> CycloMap -> ([Integer], CycloMap)
applyCycloPairWithMap x y e m = applyCrCycloPair x y (fst $ crFromI e) m
showCyclo :: CR_ -> CycloMap -> [Char]
showCyclo n m = p $ snd $ fst $ crCyclo n m
where p (c:cs) = show c ++ (p' cs (1 :: Int))
p _ = []
p' (c:cs) s | c == 0 = r
| otherwise = (if c > 0 then " + " else " - ") ++ (if ac == 1 then "" else show ac) ++
"x" ++ (if s == 1 then "" else "^" ++ show s) ++ r
where r = p' cs (s+1)
ac = abs c
p' _ _ = []
crSquareFlag :: CR_ -> Bool
crSquareFlag = all (\(_, ce) -> mod ce 2 == 0)
chineseAurif :: Integer -> Integer -> Bool -> Maybe (Integer, Integer)
chineseAurif x y b = fst $ chineseAurifWithMap x y b crCycloInitMap
chineseAurifWithMap :: Integer -> Integer -> Bool -> CycloMap -> (Maybe (Integer, Integer), CycloMap)
chineseAurifWithMap x y b m = chineseAurifCr (fst $ crFromI x) (fst $ crFromI y) b m
chineseAurifCr :: CR_ -> CR_ -> Bool -> CycloMap -> (Maybe (Integer, Integer), CycloMap)
chineseAurifCr xp yp b m = case c of
Nothing -> chineseAurifI mbyx n myx (crToI myx) b m'
r -> (r, m')
where (c, m') = chineseAurifI mbxy n mxy (crToI mxy) b m
gcdxy = crGCD xp yp
(x, y) = (crDivStrict xp gcdxy, crDivStrict yp gcdxy)
n = gcd (crMaxRoot $ crAbs x) (crMaxRoot $ crAbs y)
ncr = fst $ crFromI n
mbxy = crRoot (crDivRational x y) n
mxy = crGCD (crNumer mbxy) ncr
mbyx = crRecip mbxy
myx = crGCD (crNumer mbyx) ncr
chineseAurifI :: CR_ -> Integer -> CR_ -> Integer -> Bool -> CycloMap -> (Maybe (Integer, Integer), CycloMap)
chineseAurifI mbcr n mcr m b mp | mod n 2 == 0 || mod m 2 == 0 ||
m < 3 || km /= 0 ||
(mm4 == 1 && b) ||
(mm4 == 3 && not b) ||
mbdm == cr0 || not (crSquareFlag mbdm)
= (Nothing, mp)
| otherwise = case cv - (gd1 * gd2) of
0 -> (Just (gd1, gd2), mp')
_ -> (Nothing, mp)
where mm4 = mod m 4
e = toRational $ if (mm4 == 3) then (-1) else mm4
(k, km) = quotRem n m
mbdm = case crDiv mbcr mcr of
Left _ -> cr0
Right q -> q
r = crToRational $ crRoot (crMult mbcr mcr) 2
mb = crToRational mbcr
jR c = toRational $ convJacobi $ jacobi c m
eM = e * mb
v1 = (toRational m) * mb^(div (k * (m + 1)) 2)
v2 = t * s
where t = (jR 2) * r * (mb ^ (div (k-1) 2))
s = sum $ map (\c -> (jR c) * eM^(k*c))
$ filter (\c -> gcd c m == 1) [1..m]
ncr = fst $ crFromI n
cv = head $ applyCrCycloPairI (numerator eM) (denominator eM) ncr [cp]
(cp, mp') = crCyclo ncr mp
gd1 = gcd cv (numerator $ v1 - v2)
gd2 = gcd cv (numerator $ v1 + v2)
convJacobi :: JacobiSymbol -> Integer
convJacobi j = case j of
MinusOne -> -1
Zero -> 0
One -> 1
gen_xNm1 :: Int -> Poly
gen_xNm1 r = -1 : (replicate (r-1) 0) ++ [1]
genPrimePoly :: Int -> (Integer, Poly)
genPrimePoly r = (1, replicate r 1)
multPoly :: Num a => [a] -> [a] -> [a]
multPoly [] _ = []
multPoly (p:p1) p2 = let pTimesP2 = multiplyBy p p2
xTimesP1Timesp2 = multiplyByX $ multPoly p1 p2
in addPoly pTimesP2 xTimesP1Timesp2
where multiplyBy a p' = map (a*) p'
multiplyByX p' = 0:p'
divPoly :: Integral a => [a] -> [a] -> [a]
divPoly p1 p2 = go [] p1 (length p1 - length p2)
where go q u n
| n < 0 = q
| otherwise = go (q0:q) u' (n-1)
where q0 = div (head u) (head p2)
u' = tail (addPoly u (map (\t -> -1 * t * q0) p2))
addPoly :: Num a => [a] -> [a] -> [a]
addPoly p1 p2 = if (length p1 >= length p2) then (add' p1 p2) else (add' p2 p1)
where add' p1' p2' = zipWith (+) p1' (p2' ++ repeat 0)