--Copyright 2001, 2002, 2003 David J. Sankel
--
--This file is part of rsa-haskell.
--rsa-haskell is free software; you can redistribute it and/or modify
--it under the terms of the GNU General Public License as published by
--the Free Software Foundation; either version 2 of the License, or
--(at your option) any later version.
--
--rsa-haskell is distributed in the hope that it will be useful,
--but WITHOUT ANY WARRANTY; without even the implied warranty of
--MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
--GNU General Public License for more details.
--
--You should have received a copy of the GNU General Public License
--along with rsa-haskell; if not, write to the Free Software
--Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

module Codec.Encryption.RSA.NumberTheory(
inverse, extEuclGcd, simplePrimalityTest, getPrime, pg, isPrime,
rabinMillerPrimalityTest, expmod, factor, testInverse, primes, (/|),
randomOctet
)    where


import System.Random(getStdRandom,randomR)
--The following line is required for ghc optomized implementation
--  (see comments beginning with GHC):
-- import Bits(setBit)
import Data.List(elemIndex)
import Data.Maybe(fromJust)
import Data.Char(chr,ord)
import Data.Bits(xor)

--Precondition: the integer is >= 0
randomOctet :: Int -> IO( String )
randomOctet :: Int -> IO String
randomOctet Int
n
  | Int
n forall a. Ord a => a -> a -> Bool
< Int
0 = forall a. HasCallStack => String -> a
error String
"randomOctet argument doesn't meet preconditions"
  | Bool
otherwise = (forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
n forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom (forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR( Int
0,Int
255) ))
                  forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall a b. (a -> b) -> [a] -> [b]
map Int -> Char
chr) )

--Returns a list [r_1,r_2,r_3,r_4, . . ., r_n ] where
--  a = p_1^r_1 * p_2^r_2 * p_3^r_3 * . . . * p_n^r_n
factor :: Integer -> [Int]
factor :: Integer -> [Int]
factor = Integer -> [Int]
factor_1

--An implimentation of factor
factor_1 :: Integer -> [Int]
factor_1 :: Integer -> [Int]
factor_1 Integer
a = forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
== Int
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (\Integer
x -> Integer -> Integer -> Int
largestPower Integer
x Integer
a) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
<= Integer
a ) forall a b. (a -> b) -> a -> b
$ [Integer]
primes

--Another implimentation of factor
factor_2 :: Integer -> [Integer]
factor_2 :: Integer -> [Integer]
factor_2 Integer
a =
  let
    p :: [Integer]
p = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. (Integral a, Num b) => a -> b
fromIntegral) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
dropWhile (forall a. Eq a => a -> a -> Bool
== Int
0)
      forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. [a] -> [a]
reverse forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (\Integer
x -> Integer -> Integer -> Int
largestPower Integer
x Integer
a)
          forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
<= Integer
a forall a. Integral a => a -> a -> a
`div` Integer
2) forall a b. (a -> b) -> a -> b
$ [Integer]
primes
  in
    if (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Integer]
p forall a. Eq a => a -> a -> Bool
== Int
0)
    then (forall a. Int -> [a] -> [a]
take ((forall a b. (Integral a, Num b) => a -> b
fromIntegral forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Integer
a [Integer]
primes)forall a. Num a => a -> a -> a
-Int
1) (forall a. a -> [a]
repeat Integer
0))
          forall a. [a] -> [a] -> [a]
++ [Integer
1]
    else [Integer]
p

--Find the inverse of x (mod n)
inverse :: Integer -> Integer -> Integer
inverse :: Integer -> Integer -> Integer
inverse Integer
x Integer
n = (forall a b. (a, b) -> a
fst (Integer -> Integer -> (Integer, Integer)
extEuclGcd Integer
x Integer
n)) forall a. Integral a => a -> a -> a
`mod` Integer
n

testInverse :: Integer ->Integer -> Bool
testInverse :: Integer -> Integer -> Bool
testInverse Integer
a Integer
b = ((Integer -> Integer -> Integer
inverse Integer
a Integer
b)forall a. Num a => a -> a -> a
*Integer
a) forall a. Integral a => a -> a -> a
`mod` Integer
b forall a. Eq a => a -> a -> Bool
== Integer
1

--Extended Eucildean algorithm
--Returns (x,y) where gcd(a,b) = xa + yb
extEuclGcd :: Integer -> Integer -> (Integer,Integer)
extEuclGcd :: Integer -> Integer -> (Integer, Integer)
extEuclGcd Integer
a Integer
b = Integer
-> Integer
-> (Integer, Integer)
-> (Integer, Integer)
-> (Integer, Integer)
extEuclGcd_iter Integer
a Integer
b (Integer
1,Integer
0) (Integer
0,Integer
1)

extEuclGcd_iter :: Integer -> Integer
  -> (Integer,Integer) -> (Integer,Integer) -> (Integer,Integer)
extEuclGcd_iter :: Integer
-> Integer
-> (Integer, Integer)
-> (Integer, Integer)
-> (Integer, Integer)
extEuclGcd_iter Integer
a Integer
b (Integer
c1,Integer
c2) (Integer
d1,Integer
d2)
  |  (Integer
a forall a. Ord a => a -> a -> Bool
> Integer
b) Bool -> Bool -> Bool
&& (Integer
r1 forall a. Eq a => a -> a -> Bool
== Integer
0)  = (Integer
d1,Integer
d2)
  |  (Integer
a forall a. Ord a => a -> a -> Bool
> Integer
b) Bool -> Bool -> Bool
&& (Integer
r1 forall a. Eq a => a -> a -> Bool
/= Integer
0)  = Integer
-> Integer
-> (Integer, Integer)
-> (Integer, Integer)
-> (Integer, Integer)
extEuclGcd_iter
    (Integer
a forall a. Num a => a -> a -> a
- (Integer
q1forall a. Num a => a -> a -> a
*Integer
b)) Integer
b (Integer
c1 forall a. Num a => a -> a -> a
- (Integer
q1forall a. Num a => a -> a -> a
*Integer
d1), Integer
c2 forall a. Num a => a -> a -> a
- (Integer
q1forall a. Num a => a -> a -> a
*Integer
d2)) (Integer
d1,Integer
d2)
  |  (Integer
a forall a. Ord a => a -> a -> Bool
<= Integer
b) Bool -> Bool -> Bool
&& (Integer
r2 forall a. Eq a => a -> a -> Bool
== Integer
0) = (Integer
c1,Integer
c2)
  |  (Integer
a forall a. Ord a => a -> a -> Bool
<= Integer
b) Bool -> Bool -> Bool
&& (Integer
r2 forall a. Eq a => a -> a -> Bool
/= Integer
0) = Integer
-> Integer
-> (Integer, Integer)
-> (Integer, Integer)
-> (Integer, Integer)
extEuclGcd_iter
    Integer
a (Integer
b forall a. Num a => a -> a -> a
- (Integer
q2forall a. Num a => a -> a -> a
*Integer
a)) (Integer
c1,Integer
c2) ( Integer
d1 forall a. Num a => a -> a -> a
- (Integer
q2forall a. Num a => a -> a -> a
*Integer
c1), Integer
d2forall a. Num a => a -> a -> a
- (Integer
q2forall a. Num a => a -> a -> a
*Integer
c2))
      where
        q1 :: Integer
q1 = Integer
a forall a. Integral a => a -> a -> a
`div` Integer
b
        q2 :: Integer
q2 = Integer
b forall a. Integral a => a -> a -> a
`div` Integer
a
        r1 :: Integer
r1 = Integer
a forall a. Integral a => a -> a -> a
`mod` Integer
b
        r2 :: Integer
r2 = Integer
b forall a. Integral a => a -> a -> a
`mod` Integer
a

-- This will return a random Integer of n bits.  The highest order bit
-- will always be 1.

-- GHC optomized implementation
-- getNumber :: Int -> IO Integer
-- getNumber n = do
--                  i <- getStdRandom ( randomR (0, a-1 ) )
--                  return (setBit i (n-1))
--               where
--                   a = (2^n) ::Integer

--This is the portable version
getNumber :: Int -> IO Integer
getNumber :: Int -> IO Integer
getNumber Int
n = do
                 Integer
i <- forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom ( forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Integer
0, Integer
aforall a. Num a => a -> a -> a
-Integer
1 ) )
                 forall (m :: * -> *) a. Monad m => a -> m a
return (Integer
iforall a. Num a => a -> a -> a
+(Integer
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
nforall a. Num a => a -> a -> a
-Int
1)))
              where
                  a :: Integer
a = (Integer
2forall a b. (Num a, Integral b) => a -> b -> a
^(Int
nforall a. Num a => a -> a -> a
-Int
1)) ::Integer

--Returns a probable prime number of nBits bits

-- GHC optomized implementation
-- getPrime  :: Int -> IO Integer
-- getPrime nBits = do
--                 r <- getNumber nBits
--                 let p = (setBit r 0) --Make it odd for speed
--                 pIsPrime <- isPrime p
--                 if( pIsPrime )
--                    then return p
--                    else getPrime nBits

--This is the portable version
getPrime  :: Int -> IO Integer
getPrime :: Int -> IO Integer
getPrime Int
nBits = do
                Integer
r <- Int -> IO Integer
getNumber Int
nBits
                let p :: Integer
p = if( Integer
2 Integer -> Integer -> Bool
/| Integer
r ) then Integer
r else Integer
rforall a. Num a => a -> a -> a
+Integer
1
                Bool
pIsPrime <- Integer -> IO Bool
isPrime Integer
p
                if( Bool
pIsPrime )
                   then forall (m :: * -> *) a. Monad m => a -> m a
return Integer
p
                   else Int -> IO Integer
getPrime Int
nBits

--Prime Generate:
--Generates a prime p | minimum <= p <= maximum and gcd p e  == 1
pg :: Integer -> Integer -> Integer -> IO(Integer)
pg :: Integer -> Integer -> Integer -> IO Integer
pg Integer
minimum Integer
maximum Integer
e = do
  Integer
p <- forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom( forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR( Integer
minimum, Integer
maximum ) )
  Bool
pIsPrime <- Integer -> IO Bool
isPrime Integer
p
  if( Bool
pIsPrime Bool -> Bool -> Bool
&& (forall a. Integral a => a -> a -> a
gcd Integer
p Integer
e) forall a. Eq a => a -> a -> Bool
== Integer
1 )
    then forall (m :: * -> *) a. Monad m => a -> m a
return Integer
p
    else Integer -> Integer -> Integer -> IO Integer
pg Integer
minimum Integer
maximum Integer
e

isPrime :: Integer -> IO Bool
isPrime :: Integer -> IO Bool
isPrime Integer
a
  | (Integer
a forall a. Ord a => a -> a -> Bool
<= Integer
1)    = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
  | (Integer
a forall a. Ord a => a -> a -> Bool
<= Integer
2000) = forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Bool
simplePrimalityTest Integer
a)
  | Bool
otherwise   = if (Integer -> Bool
simplePrimalityTest Integer
a)
                    then do --Do this 5 times for saftey
                      [Bool]
test <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Integer -> IO Bool
rabinMillerPrimalityTest forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
5 forall a b. (a -> b) -> a -> b
$ forall a. a -> [a]
repeat Integer
a
                      forall (m :: * -> *) a. Monad m => a -> m a
return (forall (t :: * -> *). Foldable t => t Bool -> Bool
and [Bool]
test)
                    else forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

simplePrimalityTest :: Integer -> Bool
simplePrimalityTest :: Integer -> Bool
simplePrimalityTest Integer
a = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Bool -> Bool -> Bool
(&&) Bool
True (forall a b. (a -> b) -> [a] -> [b]
map (Integer -> Integer -> Bool
/| Integer
a)(forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
<Integer
it) [Integer]
primes))
  where it :: Integer
it = forall a. Ord a => a -> a -> a
min Integer
2000 Integer
a

--returns greatest z where x^z | y
largestPower :: Integer -> Integer -> Int
largestPower :: Integer -> Integer -> Int
largestPower Integer
x Integer
y = forall a. HasCallStack => Maybe a -> a
fromJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Eq a => a -> [a] -> Maybe Int
elemIndex Bool
False
  forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> b) -> [a] -> [b]
map (\Integer
b -> (Integer
y forall a. Integral a => a -> a -> a
`mod` Integer
xforall a b. (Num a, Integral b) => a -> b -> a
^Integer
b) forall a. Eq a => a -> a -> Bool
== Integer
0) forall a b. (a -> b) -> a -> b
$ [Integer
1..]

rabinMillerPrimalityTest :: Integer -> IO Bool
rabinMillerPrimalityTest :: Integer -> IO Bool
rabinMillerPrimalityTest Integer
p = Integer -> Integer -> Integer -> IO Bool
rabinMillerPrimalityTest_iter_1 Integer
p Integer
b Integer
m
                                 where
                                   b :: Integer
b = forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Int
largestPower Integer
2 (Integer
pforall a. Num a => a -> a -> a
-Integer
1)
                                   m :: Integer
m = (Integer
pforall a. Num a => a -> a -> a
-Integer
1) forall a. Integral a => a -> a -> a
`div` (Integer
2forall a b. (Num a, Integral b) => a -> b -> a
^Integer
b)

--The ?prime? Number -> The amount of iterations -> b -> m
rabinMillerPrimalityTest_iter_1 :: Integer -> Integer -> Integer -> IO Bool
rabinMillerPrimalityTest_iter_1 :: Integer -> Integer -> Integer -> IO Bool
rabinMillerPrimalityTest_iter_1 Integer
p Integer
b Integer
m =
              do
                Integer
a <- forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom ( forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Integer
0, Integer
2000 ) )
                forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Integer -> Integer -> Integer -> Bool
rabinMillerPrimalityTest_iter_2 Integer
p Integer
b Integer
0 (Integer -> Integer -> Integer -> Integer
expmod Integer
a Integer
m Integer
p))

rabinMillerPrimalityTest_iter_2 :: Integer -> Integer -> Integer -> Integer
  -> Bool
rabinMillerPrimalityTest_iter_2 :: Integer -> Integer -> Integer -> Integer -> Bool
rabinMillerPrimalityTest_iter_2 Integer
p Integer
b Integer
j Integer
z
  | (Integer
z forall a. Eq a => a -> a -> Bool
== Integer
1)   Bool -> Bool -> Bool
|| (Integer
z forall a. Eq a => a -> a -> Bool
== Integer
pforall a. Num a => a -> a -> a
-Integer
1)       = Bool
True
  | (Integer
j forall a. Ord a => a -> a -> Bool
> Integer
0)    Bool -> Bool -> Bool
&& (Integer
z forall a. Eq a => a -> a -> Bool
== Integer
1)         = Bool
False
  | (Integer
jforall a. Num a => a -> a -> a
+Integer
1 forall a. Ord a => a -> a -> Bool
< Integer
b)  Bool -> Bool -> Bool
&& (Integer
z forall a. Eq a => a -> a -> Bool
/= Integer
pforall a. Num a => a -> a -> a
-Integer
1)       =
    (Integer -> Integer -> Integer -> Integer -> Bool
rabinMillerPrimalityTest_iter_2 Integer
p Integer
b (Integer
jforall a. Num a => a -> a -> a
+Integer
1) ((Integer
zforall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) forall a. Integral a => a -> a -> a
`mod` Integer
p ))
  | Integer
z forall a. Eq a => a -> a -> Bool
== Integer
p forall a. Num a => a -> a -> a
- Integer
1                     = Bool
True
  | (Integer
jforall a. Num a => a -> a -> a
+Integer
1 forall a. Eq a => a -> a -> Bool
== Integer
b) Bool -> Bool -> Bool
&& (Integer
z forall a. Eq a => a -> a -> Bool
/= Integer
pforall a. Num a => a -> a -> a
-Integer
1)       = Bool
False

--a^x (mod m)
expmod :: Integer -> Integer -> Integer -> Integer
expmod :: Integer -> Integer -> Integer -> Integer
expmod Integer
a Integer
x Integer
m |  Integer
x forall a. Eq a => a -> a -> Bool
== Integer
0    = Integer
1
             |  Integer
x forall a. Eq a => a -> a -> Bool
== Integer
1    = Integer
a forall a. Integral a => a -> a -> a
`mod` Integer
m
             |  forall a. Integral a => a -> Bool
even Integer
x    = let p :: Integer
p = (Integer -> Integer -> Integer -> Integer
expmod Integer
a (Integer
x forall a. Integral a => a -> a -> a
`div` Integer
2) Integer
m) forall a. Integral a => a -> a -> a
`mod` Integer
m
                            in  (Integer
pforall a b. (Num a, Integral b) => a -> b -> a
^Integer
2) forall a. Integral a => a -> a -> a
`mod` Integer
m
             |  Bool
otherwise = (Integer
a forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer -> Integer
expmod Integer
a (Integer
xforall a. Num a => a -> a -> a
-Integer
1) Integer
m) forall a. Integral a => a -> a -> a
`mod` Integer
m

--Largest x where x^2 < i
intSqrt :: Integer -> Integer
intSqrt :: Integer -> Integer
intSqrt Integer
i = forall a b. (RealFrac a, Integral b) => a -> b
floor (forall a. Floating a => a -> a
sqrt (forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
i ) )

--The doesn't divide function
(/|) :: Integer -> Integer -> Bool
Integer
a /| :: Integer -> Integer -> Bool
/| Integer
b = Integer
b forall a. Integral a => a -> a -> a
`mod` Integer
a forall a. Eq a => a -> a -> Bool
/= Integer
0

--List of primes
primes :: [Integer]
primes :: [Integer]
primes = Integer
2forall a. a -> [a] -> [a]
:[Integer
x | Integer
x <- [Integer
3,Integer
5..], forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Bool -> Bool -> Bool
(&&) Bool
True
          ( forall a b. (a -> b) -> [a] -> [b]
map ( Integer -> Integer -> Bool
/| Integer
x ) (forall a. (a -> Bool) -> [a] -> [a]
takeWhile (forall a. Ord a => a -> a -> Bool
<=(Integer -> Integer
intSqrt Integer
x)) [Integer]
primes ) ) ]