{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Math.NumberTheory.Moduli.Equations
( solveLinear
, solveQuadratic
) where
import GHC.Integer.GMP.Internals
import Math.NumberTheory.Moduli.Chinese
import Math.NumberTheory.Moduli.Class
import Math.NumberTheory.Moduli.Sqrt
import Math.NumberTheory.Primes
import Math.NumberTheory.Utils (recipMod)
solveLinear
:: KnownNat m
=> Mod m
-> Mod m
-> [Mod m]
solveLinear a b = map fromInteger $ solveLinear' (getMod a) (getVal a) (getVal b)
solveLinear' :: Integer -> Integer -> Integer -> [Integer]
solveLinear' m a b = case solveLinearCoprime m' (a `quot` d) (b `quot` d) of
Nothing -> []
Just x -> map (\i -> x + m' * i) [0 .. d - 1]
where
d = m `gcd` a `gcd` b
m' = m `quot` d
solveLinearCoprime :: Integer -> Integer -> Integer -> Maybe Integer
solveLinearCoprime 1 _ _ = Just 0
solveLinearCoprime m a b = (\a1 -> negate b * a1 `mod` m) <$> recipMod a m
solveQuadratic
:: KnownNat m
=> Mod m
-> Mod m
-> Mod m
-> [Mod m]
solveQuadratic a b c
= map fromInteger
$ fst
$ combine
$ map (\(p, n) -> (solveQuadraticPrimePower a' b' c' p n, unPrime p ^ n))
$ factorise
$ getMod a
where
a' = getVal a
b' = getVal b
c' = getVal c
combine :: [([Integer], Integer)] -> ([Integer], Integer)
combine = foldl
(\(xs, xm) (ys, ym) -> ([ chineseRemainder2 (x, xm) (y, ym) | x <- xs, y <- ys ], xm * ym))
([0], 1)
solveQuadraticPrimePower
:: Integer
-> Integer
-> Integer
-> Prime Integer
-> Word
-> [Integer]
solveQuadraticPrimePower a b c p = go
where
go :: Word -> [Integer]
go 0 = [0]
go 1 = solveQuadraticPrime a b c p
go k = concatMap (liftRoot k) (go (k - 1))
liftRoot :: Word -> Integer -> [Integer]
liftRoot k r = case recipMod (2 * a * r + b) pk of
Nothing -> case fr of
0 -> map (\i -> r + pk `quot` p' * i) [0 .. p' - 1]
_ -> []
Just invDeriv -> [(r - fr * invDeriv) `mod` pk]
where
pk = p' ^ k
fr = (a * r * r + b * r + c) `mod` pk
p' :: Integer
p' = unPrime p
solveQuadraticPrime
:: Integer
-> Integer
-> Integer
-> Prime Integer
-> [Integer]
solveQuadraticPrime a b c (unPrime -> 2 :: Integer)
= case (even c, even (a + b)) of
(True, True) -> [0, 1]
(True, _) -> [0]
(_, False) -> [1]
_ -> []
solveQuadraticPrime a b c p
| a `mod` p' == 0
= solveLinear' p' b c
| otherwise
= map (\n -> n * recipModInteger (2 * a) p' `mod` p')
$ map (subtract b)
$ sqrtsModPrime (b * b - 4 * a * c) p
where
p' :: Integer
p' = unPrime p