{-# LANGUAGE BangPatterns #-}
module Math.NumberTheory.Factor (module Math.NumberTheory.Prime,
                                 pfactors, ppfactors, pfactorsTo, ppfactorsTo) where
import Control.Arrow (second, (&&&))
import Data.Either (lefts)
import Data.List as L
import Math.Core.Utils (multisetSumAsc)
import Math.NumberTheory.Prime
pfactors :: Integer -> [Integer]
pfactors n | n > 0 = pfactors' n $ takeWhile (< 10000) primes
           | n < 0 = -1 : pfactors' (-n) (takeWhile (< 10000) primes)
    where pfactors' n (d:ds) | n == 1 = []
                             | n < d*d = [n]
                             | r == 0 = d : pfactors' q (d:ds)
                             | otherwise = pfactors' n ds
                             where (q,r) = quotRem n d
          pfactors' n [] = pfactors'' n
          pfactors'' n = if isMillerRabinPrime n then [n]
                         else let d = findFactorParallelECM n 
                              in multisetSumAsc (pfactors'' d) (pfactors'' (n `div` d))
ppfactors :: Integer -> [(Integer,Int)]
ppfactors = map (head &&& length) . L.group . pfactors
pfactorsTo n = pfactorsTo' (1,[]) primes where
    pfactorsTo' (!m,!qs) ps@(ph:pt) | m' > n = [(m,qs)]
                                    | otherwise = pfactorsTo' (m',ph:qs) ps ++ pfactorsTo' (m,qs) pt
        where m' = m*ph
ppfactorsTo = map (second (map (head &&& length) . L.group)) . pfactorsTo
extendedEuclid a b
    | b == 0 = (1,0,a)
    | otherwise = let (q,r) = a `quotRem` b        
                      (s,t,d) = extendedEuclid b r 
                  in (t,s-q*t,d)                   
data EllipticCurve = EC Integer Integer Integer deriving (Eq, Show)
data EllipticCurvePt = Inf | P Integer Integer deriving (Eq, Show)
isEltEC _ Inf = True
isEltEC (EC n a b) (P x y) = (y*y - x*x*x - a*x - b) `mod` n == 0
ecAdd _ Inf pt = Right pt
ecAdd _ pt Inf = Right pt
ecAdd (EC n a b) (P x1 y1) (P x2 y2)
    | x1 /= x2 = let (_,v,d) = extendedEuclid n ((x1-x2) `mod` n)  
                     m = (y1-y2) * v `mod` n
                     x3 = (m*m - x1 - x2) `mod` n
                     y3 = (-y1 + m*(x1 - x3)) `mod` n
                  in if d == 1 then Right (P x3 y3) else Left d
    | x1 == x2 = if (y1 + y2) `mod` n == 0  
                 then Right Inf
                 else let (_,v,d) = extendedEuclid n ((2*y1) `mod` n)  
                          m = (3*x1*x1 + a) * v `mod` n
                          x3 = (m*m - 2*x1) `mod` n
                          y3 = (-y1 + m*(x1 - x3)) `mod` n
                      in if d == 1 then Right (P x3 y3) else Left d
ecSmult _ 0 _ = Right Inf
ecSmult ec k pt | k > 0 = ecSmult' k pt Inf
    where 
          ecSmult' 0 _ p = Right p
          ecSmult' i q p = let p' = if odd i then ecAdd ec p q else Right p
                               q' = ecAdd ec q q
                           in case (p',q') of
                              (Right p'', Right q'') -> ecSmult' (i `div` 2) q'' p''
                              (Left _, _) -> p'
                              (_, Left _) -> q'
discriminantEC a b = 4 * a * a * a + 27 * b * b
ecTrial ec@(EC n a b) ms pt
    | d == 1 = ecTrial' ms pt
    | otherwise = Left d
    where d = gcd n (discriminantEC a b)
          ecTrial' [] pt = Right pt
          ecTrial' (m:ms) pt = case ecSmult ec m pt of
                               Right pt' -> ecTrial' ms pt'
                               Left d -> Left d
l n = exp (sqrt (log n * log (log n)))
multipliers q = [p' | p <- takeWhile (<= b) primes, let p' = last (takeWhile (<= b) (powers p))]
    where b = round ((l q) ** (1/sqrt 2))
          powers x = iterate (*x) x
findFactorECM n | gcd n 6 == 1 =
    let ms = multipliers (sqrt $ fromInteger n)
    in head $ filter ( (/= 0) . (`mod` n) ) $
       lefts [ecTrial (EC n a 1) ms (P 0 1) | a <- [1..] ]
    
    
parallelInverse n as = if d == 1 then Right bs else Left $ head [d' | a <- as, let d' = gcd a n, d' /= 1]
    where c:cs = reverse $ scanl (\x y -> x*y `mod` n) 1 as
          ds = scanl (\x y -> x*y `mod` n) 1 (reverse as)
          (u,_,d) = extendedEuclid c n
          bs = reverse [ u*nota `mod` n | nota <- zipWith (*) cs ds]
parallelEcAdd n ecs ps1 ps2 =
    case parallelInverse n (zipWith f ps1 ps2) of
    Right invs -> Right [g ec p1 p2 inv | (ec,p1,p2,inv) <- L.zip4 ecs ps1 ps2 invs]
    Left d -> Left d
    where f Inf pt = 1
          f pt Inf = 1
          f (P x1 y1) (P x2 y2) | x1 /= x2 = x1-x2 
                                | x1 == x2 = 2*y1 
          
          g _ Inf pt _ = pt
          g _ pt Inf _ = pt
          g (EC n a b) (P x1 y1) (P x2 y2) inv
              | x1 /= x2 = let m = (y1-y2) * inv 
                               x3 = (m*m - x1 - x2) `mod` n
                               y3 = (-y1 + m*(x1 - x3)) `mod` n
                           in P x3 y3
              | x1 == x2 = if (y1 + y2) `elem` [0,n] 
                           then Inf
                           else let m = (3*x1*x1 + a) * inv 
                                    x3 = (m*m - 2*x1) `mod` n
                                    y3 = (-y1 + m*(x1 - x3)) `mod` n
                                 in P x3 y3
parallelEcSmult _ _ 0 pts = Right $ map (const Inf) pts
parallelEcSmult n ecs k pts | k > 0 = ecSmult' k pts (map (const Inf) pts)
    where 
          ecSmult' 0 _ ps = Right ps
          ecSmult' k qs ps = let ps' = if odd k then parallelEcAdd n ecs ps qs else Right ps
                                 qs' = parallelEcAdd n ecs qs qs
                             in case (ps',qs') of
                                (Right ps'', Right qs'') -> ecSmult' (k `div` 2) qs'' ps''
                                (Left _, _) -> ps'
                                (_, Left _) -> qs'
parallelEcTrial n ecs ms pts
    | all (==1) ds = ecTrial' ms pts
    | otherwise = Left $ head $ filter (/=1) ds
    where ds = [gcd n (discriminantEC a b) | EC n a b <- ecs]
          ecTrial' [] pts = Right pts
          ecTrial' (m:ms) pts = case parallelEcSmult n ecs m pts of
                                Right pts' -> ecTrial' ms pts'
                                Left d -> Left d
findFactorParallelECM n | gcd n 6 == 1 =
    let ms = multipliers (sqrt $ fromInteger n)
    in head $ filter ( (/= 0) . (`mod` n) ) $
       lefts [parallelEcTrial n [EC n (a+i) 1 | i <- [1..100]] ms (replicate 100 (P 0 1)) | a <- [0,100..] ]