{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE CPP #-}
module Math.NumberTheory.Moduli.Sqrt
(
sqrtsMod
, sqrtsModFactorisation
, sqrtsModPrimePower
, sqrtsModPrime
, Old.sqrtModP
, Old.sqrtModPList
, Old.sqrtModP'
, Old.tonelliShanks
, Old.sqrtModPP
, Old.sqrtModPPList
, Old.sqrtModF
, Old.sqrtModFList
) where
import Control.Monad (liftM2)
import Data.Bits
import Math.NumberTheory.Moduli.Chinese
import Math.NumberTheory.Moduli.Class (Mod, getVal, getMod, KnownNat)
import Math.NumberTheory.Moduli.Jacobi
import Math.NumberTheory.Powers.Modular (powMod)
import Math.NumberTheory.Primes.Types
import Math.NumberTheory.Primes.Sieve (sieveFrom)
import Math.NumberTheory.Primes (Prime, factorise)
import Math.NumberTheory.Utils (shiftToOddCount, splitOff, recipMod)
import Math.NumberTheory.Utils.FromIntegral
import qualified Math.NumberTheory.Moduli.SqrtOld as Old
sqrtsMod :: KnownNat m => Mod m -> [Mod m]
sqrtsMod a = map fromInteger $ sqrtsModFactorisation (getVal a) (factorise (getMod a))
sqrtsModFactorisation :: Integer -> [(Prime Integer, Word)] -> [Integer]
sqrtsModFactorisation _ [] = [0]
sqrtsModFactorisation n pps = map fst $ foldl1 (liftM2 comb) cs
where
ms :: [Integer]
ms = map (\(Prime p, pow) -> p ^ pow) pps
rs :: [[Integer]]
rs = map (\(p, pow) -> sqrtsModPrimePower n p pow) pps
cs :: [[(Integer, Integer)]]
cs = zipWith (\l m -> map (\x -> (x, m)) l) rs ms
comb t1@(_, m1) t2@(_, m2) = (chineseRemainder2 t1 t2, m1 * m2)
sqrtsModPrimePower :: Integer -> Prime Integer -> Word -> [Integer]
sqrtsModPrimePower nn p 1 = sqrtsModPrime nn p
sqrtsModPrimePower nn (Prime prime) expo = let primeExpo = prime ^ expo in
case splitOff prime (nn `mod` primeExpo) of
(_, 0) -> [0, prime ^ ((expo + 1) `quot` 2) .. primeExpo - 1]
(kk, n)
| odd kk -> []
| otherwise -> case (if prime == 2 then sqM2P n expo' else sqrtModPP' n prime expo') of
Nothing -> []
Just r -> let rr = r * prime ^ k in
if prime == 2 && k + 1 == t
then go rr os
else go rr os ++ go (primeExpo - rr) os
where
k = kk `quot` 2
t = (if prime == 2 then expo - k - 1 else expo - k) `max` ((expo + 1) `quot` 2)
expo' = expo - 2 * k
os = [0, prime ^ t .. primeExpo - 1]
go r rs = map (+ r) ps ++ map (+ (r - primeExpo)) qs
where
(ps, qs) = span (< primeExpo - r) rs
sqrtsModPrime :: Integer -> Prime Integer -> [Integer]
sqrtsModPrime n (Prime 2) = [n `mod` 2]
sqrtsModPrime n (Prime prime) = case jacobi n prime of
MinusOne -> []
Zero -> [0]
One -> let r = sqrtModP' (n `mod` prime) prime in [r, prime - r]
sqrtModP' :: Integer -> Integer -> Integer
sqrtModP' square prime
| prime == 2 = square
| rem4 prime == 3 = powMod square ((prime + 1) `quot` 4) prime
| square `mod` prime == prime - 1
= sqrtOfMinusOne prime
| otherwise = tonelliShanks square prime
sqrtOfMinusOne :: Integer -> Integer
sqrtOfMinusOne p
= head
$ filter (\n -> n /= 1 && n /= p - 1)
$ map (\n -> powMod n k p)
[2..p-2]
where
k = (p - 1) `quot` 4
tonelliShanks :: Integer -> Integer -> Integer
tonelliShanks square prime = loop rc t1 generator log2
where
(wordToInt -> log2,q) = shiftToOddCount (prime-1)
nonSquare = findNonSquare prime
generator = powMod nonSquare q prime
rc = powMod square ((q+1) `quot` 2) prime
t1 = powMod square q prime
msqr x = (x*x) `rem` prime
msquare 0 x = x
msquare k x = msquare (k-1) (msqr x)
findPeriod per 1 = per
findPeriod per x = findPeriod (per+1) (msqr x)
loop :: Integer -> Integer -> Integer -> Int -> Integer
loop !r t c m
| t == 1 = r
| otherwise = loop nextR nextT nextC nextM
where
nextM = findPeriod 0 t
b = msquare (m - 1 - nextM) c
nextR = (r*b) `rem` prime
nextC = msqr b
nextT = (t*nextC) `rem` prime
sqrtModPP' :: Integer -> Integer -> Word -> Maybe Integer
sqrtModPP' n prime expo = case sqrtsModPrime n (Prime prime) of
[] -> Nothing
r : _ -> fixup r
where
fixup r = let diff' = r*r-n
in if diff' == 0
then Just r
else case splitOff prime diff' of
(e,q) | expo <= e -> Just r
| otherwise -> fmap (\inv -> hoist inv r (q `mod` prime) (prime^e)) (recipMod (2*r) prime)
hoist inv root elim pp
| diff' == 0 = root'
| expo <= ex = root'
| otherwise = hoist inv root' (nelim `mod` prime) (prime^ex)
where
root' = (root + (inv*(prime-elim))*pp) `mod` (prime*pp)
diff' = root'*root' - n
(ex, nelim) = splitOff prime diff'
sqM2P :: Integer -> Word -> Maybe Integer
sqM2P n e
| e < 2 = Just (n `mod` 2)
| n' == 0 = Just 0
| odd k = Nothing
| otherwise = fmap ((`mod` mdl) . (`shiftL` wordToInt k2)) $ solve s e2
where
mdl = 1 `shiftL` wordToInt e
n' = n `mod` mdl
(k, s) = shiftToOddCount n'
k2 = k `quot` 2
e2 = e - k
solve _ 1 = Just 1
solve 1 _ = Just 1
solve r _
| rem4 r == 3 = Nothing
| rem8 r == 5 = Nothing
| otherwise = fixup r (fst $ shiftToOddCount (r-1))
where
fixup x pw
| pw >= e2 = Just x
| otherwise = fixup x' pw'
where
x' = x + (1 `shiftL` (wordToInt pw - 1))
d = x'*x' - r
pw' = if d == 0 then e2 else fst (shiftToOddCount d)
rem4 :: Integral a => a -> Int
rem4 n = fromIntegral n .&. 3
rem8 :: Integral a => a -> Int
rem8 n = fromIntegral n .&. 7
findNonSquare :: Integer -> Integer
findNonSquare n
| rem8 n == 5 || rem8 n == 3 = 2
| otherwise = search primelist
where
primelist = [3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,59,61,67]
++ map unPrime (sieveFrom (68 + n `rem` 4))
search (p:ps) = case jacobi p n of
MinusOne -> p
_ -> search ps
search _ = error "Should never have happened, prime list exhausted."