{-# LANGUAGE BangPatterns #-}
-- |
-- Module      : Crypto.Number.ModArithmetic
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : Good

module Crypto.Number.ModArithmetic
    (
    -- * Exponentiation
      expSafe
    , expFast
    -- * Inverse computing
    , inverse
    , inverseCoprimes
    , inverseFermat
    -- * Squares
    , jacobi
    , squareRoot
    ) where

import Control.Exception (throw, Exception)
import Crypto.Number.Basic
import Crypto.Number.Compat

-- | Raised when two numbers are supposed to be coprimes but are not.
data CoprimesAssertionError = CoprimesAssertionError
    deriving (Int -> CoprimesAssertionError -> ShowS
[CoprimesAssertionError] -> ShowS
CoprimesAssertionError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CoprimesAssertionError] -> ShowS
$cshowList :: [CoprimesAssertionError] -> ShowS
show :: CoprimesAssertionError -> String
$cshow :: CoprimesAssertionError -> String
showsPrec :: Int -> CoprimesAssertionError -> ShowS
$cshowsPrec :: Int -> CoprimesAssertionError -> ShowS
Show)

instance Exception CoprimesAssertionError

-- | Compute the modular exponentiation of base^exponent using
-- algorithms design to avoid side channels and timing measurement
--
-- Modulo need to be odd otherwise the normal fast modular exponentiation
-- is used.
--
-- When used with integer-simple, this function is not different
-- from expFast, and thus provide the same unstudied and dubious
-- timing and side channels claims.
--
-- Before GHC 8.4.2, powModSecInteger is missing from integer-gmp,
-- so expSafe has the same security as expFast.
expSafe :: Integer -- ^ base
        -> Integer -- ^ exponent
        -> Integer -- ^ modulo
        -> Integer -- ^ result
expSafe :: Integer -> Integer -> Integer -> Integer
expSafe Integer
b Integer
e Integer
m
    | forall a. Integral a => a -> Bool
odd Integer
m     = Integer -> Integer -> Integer -> GmpSupported Integer
gmpPowModSecInteger Integer
b Integer
e Integer
m forall a. GmpSupported a -> a -> a
`onGmpUnsupported`
                  (Integer -> Integer -> Integer -> GmpSupported Integer
gmpPowModInteger Integer
b Integer
e Integer
m   forall a. GmpSupported a -> a -> a
`onGmpUnsupported`
                  Integer -> Integer -> Integer -> Integer
exponentiation Integer
b Integer
e Integer
m)
    | Bool
otherwise = Integer -> Integer -> Integer -> GmpSupported Integer
gmpPowModInteger Integer
b Integer
e Integer
m    forall a. GmpSupported a -> a -> a
`onGmpUnsupported`
                  Integer -> Integer -> Integer -> Integer
exponentiation Integer
b Integer
e Integer
m

-- | Compute the modular exponentiation of base^exponent using
-- the fastest algorithm without any consideration for
-- hiding parameters.
--
-- Use this function when all the parameters are public,
-- otherwise 'expSafe' should be preferred.
expFast :: Integer -- ^ base
        -> Integer -- ^ exponent
        -> Integer -- ^ modulo
        -> Integer -- ^ result
expFast :: Integer -> Integer -> Integer -> Integer
expFast Integer
b Integer
e Integer
m = Integer -> Integer -> Integer -> GmpSupported Integer
gmpPowModInteger Integer
b Integer
e Integer
m forall a. GmpSupported a -> a -> a
`onGmpUnsupported` Integer -> Integer -> Integer -> Integer
exponentiation Integer
b Integer
e Integer
m

-- | @exponentiation@ computes modular exponentiation as /b^e mod m/
-- using repetitive squaring.
exponentiation :: Integer -> Integer -> Integer -> Integer
exponentiation :: Integer -> Integer -> Integer -> Integer
exponentiation Integer
b Integer
e Integer
m
    | Integer
b forall a. Eq a => a -> a -> Bool
== Integer
1    = Integer
b
    | Integer
e forall a. Eq a => a -> a -> Bool
== Integer
0    = Integer
1
    | Integer
e forall a. Eq a => a -> a -> Bool
== Integer
1    = Integer
b forall a. Integral a => a -> a -> a
`mod` Integer
m
    | forall a. Integral a => a -> Bool
even Integer
e    = let p :: Integer
p = Integer -> Integer -> Integer -> Integer
exponentiation Integer
b (Integer
e forall a. Integral a => a -> a -> a
`div` Integer
2) Integer
m forall a. Integral a => a -> a -> a
`mod` Integer
m
                   in (Integer
pforall a b. (Num a, Integral b) => a -> b -> a
^(Integer
2::Integer)) forall a. Integral a => a -> a -> a
`mod` Integer
m
    | Bool
otherwise = (Integer
b forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer -> Integer
exponentiation Integer
b (Integer
eforall a. Num a => a -> a -> a
-Integer
1) Integer
m) forall a. Integral a => a -> a -> a
`mod` Integer
m

-- | @inverse@ computes the modular inverse as in /g^(-1) mod m/.
inverse :: Integer -> Integer -> Maybe Integer
inverse :: Integer -> Integer -> Maybe Integer
inverse Integer
g Integer
m = Integer -> Integer -> GmpSupported (Maybe Integer)
gmpInverse Integer
g Integer
m forall a. GmpSupported a -> a -> a
`onGmpUnsupported` Maybe Integer
v
  where
    v :: Maybe Integer
v
        | Integer
d forall a. Ord a => a -> a -> Bool
> Integer
1     = forall a. Maybe a
Nothing
        | Bool
otherwise = forall a. a -> Maybe a
Just (Integer
x forall a. Integral a => a -> a -> a
`mod` Integer
m)
    (Integer
x,Integer
_,Integer
d) = Integer -> Integer -> (Integer, Integer, Integer)
gcde Integer
g Integer
m

-- | Compute the modular inverse of two coprime numbers.
-- This is equivalent to inverse except that the result
-- is known to exists.
--
-- If the numbers are not defined as coprime, this function
-- will raise a 'CoprimesAssertionError'.
inverseCoprimes :: Integer -> Integer -> Integer
inverseCoprimes :: Integer -> Integer -> Integer
inverseCoprimes Integer
g Integer
m =
    case Integer -> Integer -> Maybe Integer
inverse Integer
g Integer
m of
        Maybe Integer
Nothing -> forall a e. Exception e => e -> a
throw CoprimesAssertionError
CoprimesAssertionError
        Just Integer
i  -> Integer
i

-- | Computes the Jacobi symbol (a/n).
-- 0 ≤ a < n; n ≥ 3 and odd.
--
-- The Legendre and Jacobi symbols are indistinguishable exactly when the
-- lower argument is an odd prime, in which case they have the same value.
--
-- See algorithm 2.149 in "Handbook of Applied Cryptography" by Alfred J. Menezes et al.
jacobi :: Integer -> Integer -> Maybe Integer
jacobi :: Integer -> Integer -> Maybe Integer
jacobi Integer
a Integer
n
    | Integer
n forall a. Ord a => a -> a -> Bool
< Integer
3 Bool -> Bool -> Bool
|| forall a. Integral a => a -> Bool
even Integer
n  = forall a. Maybe a
Nothing
    | Integer
a forall a. Eq a => a -> a -> Bool
== Integer
0 Bool -> Bool -> Bool
|| Integer
a forall a. Eq a => a -> a -> Bool
== Integer
1 = forall a. a -> Maybe a
Just Integer
a
    | Integer
n forall a. Ord a => a -> a -> Bool
<= Integer
a           = Integer -> Integer -> Maybe Integer
jacobi (Integer
a forall a. Integral a => a -> a -> a
`mod` Integer
n) Integer
n
    | Integer
a forall a. Ord a => a -> a -> Bool
< Integer
0            =
      let b :: Integer
b = if Integer
n forall a. Integral a => a -> a -> a
`mod` Integer
4 forall a. Eq a => a -> a -> Bool
== Integer
1 then Integer
1 else -Integer
1
       in forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*Integer
b) (Integer -> Integer -> Maybe Integer
jacobi (-Integer
a) Integer
n)
    | Bool
otherwise        =
      let (Int
e, Integer
a1) = Integer -> (Int, Integer)
asPowerOf2AndOdd Integer
a
          nMod8 :: Integer
nMod8   = Integer
n forall a. Integral a => a -> a -> a
`mod` Integer
8
          nMod4 :: Integer
nMod4   = Integer
n forall a. Integral a => a -> a -> a
`mod` Integer
4
          a1Mod4 :: Integer
a1Mod4  = Integer
a1 forall a. Integral a => a -> a -> a
`mod` Integer
4
          s' :: Integer
s'      = if forall a. Integral a => a -> Bool
even Int
e Bool -> Bool -> Bool
|| Integer
nMod8 forall a. Eq a => a -> a -> Bool
== Integer
1 Bool -> Bool -> Bool
|| Integer
nMod8 forall a. Eq a => a -> a -> Bool
== Integer
7 then Integer
1 else -Integer
1
          s :: Integer
s       = if Integer
nMod4 forall a. Eq a => a -> a -> Bool
== Integer
3 Bool -> Bool -> Bool
&& Integer
a1Mod4 forall a. Eq a => a -> a -> Bool
== Integer
3 then -Integer
s' else Integer
s'
          n1 :: Integer
n1      = Integer
n forall a. Integral a => a -> a -> a
`mod` Integer
a1
       in if Integer
a1 forall a. Eq a => a -> a -> Bool
== Integer
1 then forall a. a -> Maybe a
Just Integer
s
          else forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall a. Num a => a -> a -> a
*Integer
s) (Integer -> Integer -> Maybe Integer
jacobi Integer
n1 Integer
a1)

-- | Modular inverse using Fermat's little theorem.  This works only when
-- the modulus is prime but avoids side channels like in 'expSafe'.
inverseFermat :: Integer -> Integer -> Integer
inverseFermat :: Integer -> Integer -> Integer
inverseFermat Integer
g Integer
p = Integer -> Integer -> Integer -> Integer
expSafe Integer
g (Integer
p forall a. Num a => a -> a -> a
- Integer
2) Integer
p

-- | Raised when the assumption about the modulus is invalid.
data ModulusAssertionError = ModulusAssertionError
    deriving (Int -> ModulusAssertionError -> ShowS
[ModulusAssertionError] -> ShowS
ModulusAssertionError -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ModulusAssertionError] -> ShowS
$cshowList :: [ModulusAssertionError] -> ShowS
show :: ModulusAssertionError -> String
$cshow :: ModulusAssertionError -> String
showsPrec :: Int -> ModulusAssertionError -> ShowS
$cshowsPrec :: Int -> ModulusAssertionError -> ShowS
Show)

instance Exception ModulusAssertionError

-- | Modular square root of @g@ modulo a prime @p@.
--
-- If the modulus is found not to be prime, the function will raise a
-- 'ModulusAssertionError'.
--
-- This implementation is variable time and should be used with public
-- parameters only.
squareRoot :: Integer -> Integer -> Maybe Integer
squareRoot :: Integer -> Integer -> Maybe Integer
squareRoot Integer
p
    | Integer
p forall a. Ord a => a -> a -> Bool
< Integer
2     = forall a e. Exception e => e -> a
throw ModulusAssertionError
ModulusAssertionError
    | Bool
otherwise =
        case Integer
p forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
8 of
           (Integer
v, Integer
3) -> Integer -> Integer -> Maybe Integer
method1 (Integer
2 forall a. Num a => a -> a -> a
* Integer
v forall a. Num a => a -> a -> a
+ Integer
1)
           (Integer
v, Integer
7) -> Integer -> Integer -> Maybe Integer
method1 (Integer
2 forall a. Num a => a -> a -> a
* Integer
v forall a. Num a => a -> a -> a
+ Integer
2)
           (Integer
u, Integer
5) -> Integer -> Integer -> Maybe Integer
method2 Integer
u
           (Integer
_, Integer
1) -> Integer -> Integer -> Maybe Integer
tonelliShanks Integer
p
           (Integer
0, Integer
2) -> \Integer
a -> forall a. a -> Maybe a
Just (if forall a. Integral a => a -> Bool
even Integer
a then Integer
0 else Integer
1)
           (Integer, Integer)
_      -> forall a e. Exception e => e -> a
throw ModulusAssertionError
ModulusAssertionError

  where
    Integer
x eqMod :: Integer -> Integer -> Bool
`eqMod` Integer
y = (Integer
x forall a. Num a => a -> a -> a
- Integer
y) forall a. Integral a => a -> a -> a
`mod` Integer
p forall a. Eq a => a -> a -> Bool
== Integer
0

    validate :: Integer -> Integer -> Maybe Integer
validate Integer
g Integer
y | (Integer
y forall a. Num a => a -> a -> a
* Integer
y) Integer -> Integer -> Bool
`eqMod` Integer
g = forall a. a -> Maybe a
Just Integer
y
                 | Bool
otherwise         = forall a. Maybe a
Nothing

    -- p == 4u + 3 and u' == u + 1
    method1 :: Integer -> Integer -> Maybe Integer
method1 Integer
u' Integer
g =
        let y :: Integer
y = Integer -> Integer -> Integer -> Integer
expFast Integer
g Integer
u' Integer
p
         in Integer -> Integer -> Maybe Integer
validate Integer
g Integer
y

    -- p == 8u + 5
    method2 :: Integer -> Integer -> Maybe Integer
method2 Integer
u Integer
g =
        let gamma :: Integer
gamma = Integer -> Integer -> Integer -> Integer
expFast (Integer
2 forall a. Num a => a -> a -> a
* Integer
g) Integer
u Integer
p
            g_gamma :: Integer
g_gamma = Integer
g forall a. Num a => a -> a -> a
* Integer
gamma
            i :: Integer
i = (Integer
2 forall a. Num a => a -> a -> a
* Integer
g_gamma forall a. Num a => a -> a -> a
* Integer
gamma) forall a. Integral a => a -> a -> a
`mod` Integer
p
            y :: Integer
y = (Integer
g_gamma forall a. Num a => a -> a -> a
* (Integer
i forall a. Num a => a -> a -> a
- Integer
1)) forall a. Integral a => a -> a -> a
`mod` Integer
p
         in Integer -> Integer -> Maybe Integer
validate Integer
g Integer
y

tonelliShanks :: Integer -> Integer -> Maybe Integer
tonelliShanks :: Integer -> Integer -> Maybe Integer
tonelliShanks Integer
p Integer
a
    | Integer
aa forall a. Eq a => a -> a -> Bool
== Integer
0   = forall a. a -> Maybe a
Just Integer
0
    | Bool
otherwise =
        case Integer -> Integer -> Integer -> Integer
expFast Integer
aa Integer
p2 Integer
p of
            Integer
b | Integer
b forall a. Eq a => a -> a -> Bool
== Integer
p1   -> forall a. Maybe a
Nothing
              | Integer
b forall a. Eq a => a -> a -> Bool
== Integer
1    -> forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall {t}.
(Eq t, Num t) =>
Integer -> Integer -> Integer -> t -> Integer
go (Integer -> Integer -> Integer -> Integer
expFast Integer
aa ((Integer
s forall a. Num a => a -> a -> a
+ Integer
1) forall a. Integral a => a -> a -> a
`div` Integer
2) Integer
p)
                                       (Integer -> Integer -> Integer -> Integer
expFast Integer
aa Integer
s Integer
p)
                                       (Integer -> Integer -> Integer -> Integer
expFast Integer
n  Integer
s Integer
p)
                                       Int
e
              | Bool
otherwise -> forall a e. Exception e => e -> a
throw ModulusAssertionError
ModulusAssertionError
  where
    aa :: Integer
aa = Integer
a forall a. Integral a => a -> a -> a
`mod` Integer
p
    p1 :: Integer
p1 = Integer
p forall a. Num a => a -> a -> a
- Integer
1
    p2 :: Integer
p2 = Integer
p1 forall a. Integral a => a -> a -> a
`div` Integer
2
    n :: Integer
n  = Integer -> Integer
findN Integer
2

    Integer
x mul :: Integer -> Integer -> Integer
`mul` Integer
y = (Integer
x forall a. Num a => a -> a -> a
* Integer
y) forall a. Integral a => a -> a -> a
`mod` Integer
p

    pow2m :: t -> Integer -> Integer
pow2m t
0 Integer
x = Integer
x
    pow2m t
i Integer
x = t -> Integer -> Integer
pow2m (t
i forall a. Num a => a -> a -> a
- t
1) (Integer
x Integer -> Integer -> Integer
`mul` Integer
x)

    (Int
e, Integer
s) = Integer -> (Int, Integer)
asPowerOf2AndOdd Integer
p1

    -- find a quadratic non-residue
    findN :: Integer -> Integer
findN Integer
i
        | Integer -> Integer -> Integer -> Integer
expFast Integer
i Integer
p2 Integer
p forall a. Eq a => a -> a -> Bool
== Integer
p1 = Integer
i
        | Bool
otherwise            = Integer -> Integer
findN (Integer
i forall a. Num a => a -> a -> a
+ Integer
1)

    -- find m such that b^(2^m) == 1 (mod p)
    findM :: Integer -> t -> t
findM Integer
b t
i
        | Integer
b forall a. Eq a => a -> a -> Bool
== Integer
1    = t
i
        | Bool
otherwise = Integer -> t -> t
findM (Integer
b Integer -> Integer -> Integer
`mul` Integer
b) (t
i forall a. Num a => a -> a -> a
+ t
1)

    go :: Integer -> Integer -> Integer -> t -> Integer
go !Integer
x Integer
b Integer
g !t
r
        | Integer
b forall a. Eq a => a -> a -> Bool
== Integer
1    = Integer
x
        | Bool
otherwise =
            let r' :: t
r' = forall {t}. Num t => Integer -> t -> t
findM Integer
b t
0
                z :: Integer
z = forall {t}. (Eq t, Num t) => t -> Integer -> Integer
pow2m (t
r forall a. Num a => a -> a -> a
- t
r' forall a. Num a => a -> a -> a
- t
1) Integer
g
                x' :: Integer
x' = Integer
x Integer -> Integer -> Integer
`mul` Integer
z
                b' :: Integer
b' = Integer
b Integer -> Integer -> Integer
`mul` Integer
g'
                g' :: Integer
g' = Integer
z Integer -> Integer -> Integer
`mul` Integer
z
             in Integer -> Integer -> Integer -> t -> Integer
go Integer
x' Integer
b' Integer
g' t
r'