{-|
Module      : Crypto.PastaCurves.Fields (internal)
Description : Supports the instantiation of parameterized prime-modulus fields.
Copyright   : (c) Eric Schorn, 2022
Maintainer  : eric.schorn@nccgroup.com
Stability   : experimental
Portability : GHC
SPDX-License-Identifier: MIT

This internal module provides a (multi-use) field element template with an arbitrary 
prime modulus along with a variety of supporting functionality such as basic arithmetic,
multiplicative  inverse, square testing, square root, serialization and deserialization,
and hash2Field. The algorithms are NOT constant time; Safety and simplicity are the top 
priorities.
-}

{-# LANGUAGE CPP, DataKinds, DerivingStrategies, KindSignatures, NoImplicitPrelude #-}
{-# LANGUAGE ScopedTypeVariables, Trustworthy #-}

module Fields (Field(..), Fz(..), Num(..)) where

import Prelude hiding (concat, replicate)
import Crypto.Hash (Blake2b_512 (Blake2b_512), hashWith)
import Data.Bifunctor (bimap)
import Data.Bits ((.|.), shiftL, shiftR)
import Data.ByteArray (convert, length, xor)
import Data.ByteString (concat, foldl', pack, replicate)
import Data.ByteString.UTF8 (ByteString, fromString)
import Data.Char (chr)
import Data.Tuple (swap)
import Data.Typeable (Proxy (Proxy))
import GHC.Word (Word8)
import GHC.TypeLits (KnownNat, Nat, natVal)
import System.Random (Random(randomR), RandomGen)


-- | The `Fz (z :: Nat)` field element (template) type includes a parameterized modulus 
-- of @z@.
newtype Fz (z :: Nat) = Fz Integer deriving stock (Fz z -> Fz z -> Bool
forall (z :: Nat). Fz z -> Fz z -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Fz z -> Fz z -> Bool
$c/= :: forall (z :: Nat). Fz z -> Fz z -> Bool
== :: Fz z -> Fz z -> Bool
$c== :: forall (z :: Nat). Fz z -> Fz z -> Bool
Eq)


-- A CPP macro 'helper' to extract the modulus from (Fz z).
#define MOD natVal (Proxy :: Proxy z)


-- | The `Fz` type is an instance of the `Num` class.
instance KnownNat z => Num (Fz z) where

  fromInteger :: Integer -> Fz z
fromInteger Integer
a = forall (z :: Nat). Integer -> Fz z
Fz forall a b. (a -> b) -> a -> b
$ Integer
a forall a. Integral a => a -> a -> a
`mod` MOD
   
  + :: Fz z -> Fz z -> Fz z
(+) (Fz Integer
a) (Fz Integer
b) = forall a. Num a => Integer -> a
fromInteger (Integer
a forall a. Num a => a -> a -> a
+ Integer
b)
  
  (-) (Fz Integer
a) (Fz Integer
b) = forall a. Num a => Integer -> a
fromInteger (Integer
a forall a. Num a => a -> a -> a
- Integer
b)

  * :: Fz z -> Fz z -> Fz z
(*) (Fz Integer
a) (Fz Integer
b) = forall a. Num a => Integer -> a
fromInteger (Integer
a forall a. Num a => a -> a -> a
* Integer
b)

  abs :: Fz z -> Fz z
abs = forall a. HasCallStack => [Char] -> a
error [Char]
"abs: not implemented/needed"
  
  signum :: Fz z -> Fz z
signum = forall a. HasCallStack => [Char] -> a
error [Char]
"signum: not implemented/needed"


-- | The `Fz` type is an instance of the `Show` class with output in hexadecimal.
instance KnownNat z => Show (Fz z) where

  show :: Fz z -> [Char]
show (Fz Integer
a) = [Char]
"0x" forall a. [a] -> [a] -> [a]
++ [[Char]
"0123456789ABCDEF" forall a. [a] -> Int -> a
!! Int -> Int
nibble Int
n | Int
n <- [Int
e, Int
eforall a. Num a => a -> a -> a
-Int
1..Int
0]]
    where
      nibble :: Int -> Int
      nibble :: Int -> Int
nibble Int
n = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall a. Bits a => a -> Int -> a
shiftR Integer
a (Int
nforall a. Num a => a -> a -> a
*Int
4) forall a. Integral a => a -> a -> a
`mod` Integer
16
      e :: Int
e = ((Int
3 forall a. Num a => a -> a -> a
+ forall a. (a -> Bool) -> (a -> a) -> a -> a
until ((MOD <) . (2 ^)) (+ 1) 0) `divforall b c a. (b -> c) -> (a -> b) -> a -> c
` 4Integer
) forall a b. (Num a, Integral b) => a -> b -> a
- 1 :: Int


instance KnownNat z => Bounded (Fz z) where
  
  minBound :: Fz z
minBound = Fz z
0

  maxBound :: Fz z
maxBound = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ MOD - 1


-- | The `Field` class provides useful support functionality for field elements.
class (Num a, Eq a) => Field a where

  -- | The `fromBytesF` function is the primary deserialization constructor which 
  -- consumes a big-endian `ByteString` sized to minimally contain the modulus 
  -- and returns a field element. The deserialized integer must already be properly 
  -- reduced to reside within [0..modulus), otherwise Nothing is returned.
  fromBytesF :: ByteString  -> Maybe a

  -- | The `_fromBytesF` function is the secondary deserialization constructor which
  -- consumes an unconstrained big-endian `ByteString` and returns a internally 
  -- reduced field element. This function is useful for random testing and 
  -- hash2Field-style functions.
  _fromBytesF :: ByteString -> a

  -- | The `hash2Field` function provides intermediate functionality that is suitable
  -- for ultimately supporting the `Curves.hash2Curve` function. This function returns
  -- a 2-tuple of field elements.
  hash2Field :: ByteString -> String -> String -> (a, a)

  -- | The `inv0` function returns the multiplicative inverse as calculated by Fermat's
  -- Little Theorem (mapping 0 to 0).
  inv0 :: a -> a

  -- | The `isSqr` function indicates whether the operand has a square root.
  isSqr :: a -> Bool

  -- | The `rndF` function returns a random (invertible/non-zero) field element.
  rndF :: (RandomGen r) => r -> (r, a)

  -- | The `sgn0` function returns the least significant bit of the field element as an
  -- Integer.
  sgn0 :: a -> Integer

  -- | The `shiftR1` function shifts the field element one bit to the right, effectively 
  -- dividing it by two (and discarding the remainder).
  shiftR1 :: a -> a

  -- | The `Fields.sqrt` function implements the variable-time Tonelli-Shanks 
  -- algorithm to calculate the operand's square root. The function returns `Nothing`
  -- in the event of a problem (such as the operand not being a square, the modulus 
  -- is not prime, etc).
  sqrt :: a -> Maybe a

  -- | The `toBytesF` function serializes an element into a big-endian `ByteString` 
  -- sized to minimally contain the modulus.
  toBytesF :: a -> ByteString

  -- | The `toI` function returns the field element as a properly reduced Integer.
  toI :: a -> Integer


-- | The `Fz z` type is an instance of the `Field` class. Several functions are largely 
-- simple adapters to the more generic internal functions implemented further below.
instance KnownNat z => Field (Fz z) where

  -- Validated deserialization, returns a Maybe field element. Follows section 2.3.6
  -- of https://www.secg.org/sec1-v2.pdf
  -- If ByteString is not the correct length or integer >= modulus, return Nothing. 
  -- fromBytesF :: ByteString  -> Maybe a
  fromBytesF :: ByteString -> Maybe (Fz z)
fromBytesF ByteString
bytes | forall ba. ByteArrayAccess ba => ba -> Int
Data.ByteArray.length ByteString
bytes forall a. Eq a => a -> a -> Bool
/= Int
corLen Bool -> Bool -> Bool
|| Integer
integer forall a. Ord a => a -> a -> Bool
>= MOD = Nothing
                   | Bool
otherwise = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ forall a. Num a => Integer -> a
fromInteger Integer
integer
    where
      corLen :: Int
corLen = (Int
7 forall a. Num a => a -> a -> a
+ forall a. (a -> Bool) -> (a -> a) -> a -> a
until ((MOD <) . (2 ^)) (+ 1) 0) `divforall b c a. (b -> c) -> (a -> b) -> a -> c
` 8 :: Int  -- correct length
      integer :: Integer
integer = forall a. (a -> Word8 -> a) -> a -> ByteString -> a
foldl' (\Integer
a Word8
b -> Integer
a forall a. Bits a => a -> Int -> a
`shiftL` Int
8 forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Integer
0 ByteString
bytes :: Integer


  -- Unvalidated deserialization (no limits wrt modulus), returns reduced field element.
  -- _fromBytesF :: ByteString -> a
  _fromBytesF :: ByteString -> Fz z
_fromBytesF ByteString
bytes = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall a. (a -> Word8 -> a) -> a -> ByteString -> a
foldl' (\Integer
a Word8
b -> forall a. Bits a => a -> Int -> a
shiftL Integer
a Int
8 forall a. Bits a => a -> a -> a
.|. forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Integer
0 ByteString
bytes


  -- Field-level support for the hash2Curve function, returns a pair of field elements.
  -- The hash2field construction is per Zcash Pasta Curve (which is very similar but not 
  -- identical to the CFRG hash-to-curve specification). Fortuitously, cryptonite sets
  -- the hash personalization to all zeros, see https://github.com/haskell-crypto/cryptonite/issues/333
  -- Zcash/Pasta code https://github.com/zcash/pasta_curves/blob/main/src/hashtocurve.rs#L10
  -- CFRG scheme (for ref) https://www.ietf.org/archive/id/draft-irtf-cfrg-hash-to-curve-14.html#name-hash_to_field-implementatio
  -- Length of domain prefix and curve ID need to be less than 256 - 22 
  -- hash2Field :: ByteString -> String -> String -> (a, a)
  hash2Field :: ByteString -> [Char] -> [Char] -> (Fz z, Fz z)
hash2Field ByteString
msg [Char]
domPref [Char]
curveId 
    | Int
22 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [Char]
curveId forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [Char]
domPref forall a. Ord a => a -> a -> Bool
> Int
255 = forall a. HasCallStack => [Char] -> a
error [Char]
"strings too long"
    | Bool
otherwise = forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap forall a. Field a => ByteString -> a
_fromBytesF forall a. Field a => ByteString -> a
_fromBytesF (ByteString
digest1, ByteString
digest2)
    where
      -- Calculate reusable prefix and suffix
      prefix :: ByteString
prefix = Int -> Word8 -> ByteString
replicate Int
128 Word8
0
      suffix :: ByteString
suffix = [Char] -> ByteString
fromString ([Char]
domPref forall a. [a] -> [a] -> [a]
++ [Char]
"-" forall a. [a] -> [a] -> [a]
++ [Char]
curveId forall a. [a] -> [a] -> [a]
++ [Char]
"_XMD:BLAKE2b_SSWU_RO_" forall a. [a] -> [a] -> [a]
++
               [Int -> Char
chr (Int
22 forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [Char]
curveId forall a. Num a => a -> a -> a
+ forall (t :: * -> *) a. Foldable t => t a -> Int
Prelude.length [Char]
domPref)])
      -- A little helper function to hash ByteStrings
      mkDigest :: ByteString -> ByteString
      mkDigest :: ByteString -> ByteString
mkDigest ByteString
x = forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
convert forall a b. (a -> b) -> a -> b
$ forall ba alg.
(ByteArrayAccess ba, HashAlgorithm alg) =>
alg -> ba -> Digest alg
hashWith Blake2b_512
Blake2b_512 ByteString
x
      -- Hash the message along with prefix, suffix, etc 
      digest0 :: ByteString
digest0 = ByteString -> ByteString
mkDigest forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
concat [ByteString
prefix, ByteString
msg, [Word8] -> ByteString
pack [Word8
0,Word8
0x80,Word8
0], ByteString
suffix]
      -- Hash the hash again
      digest1 :: ByteString
digest1 = ByteString -> ByteString
mkDigest forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
concat [ByteString
digest0, [Word8] -> ByteString
pack [Word8
0x01], ByteString
suffix]
      -- Mix the two above hashes together via bytewise XOR, then hash that too
      mix :: ByteString
mix = forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
xor ByteString
digest0 ByteString
digest1 :: ByteString
      digest2 :: ByteString
digest2 = ByteString -> ByteString
mkDigest forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
concat [ByteString
mix, [Word8] -> ByteString
pack [Word8
0x02], ByteString
suffix]


  -- Multiplicative inverse, with 0 mapped to 0, via Fermat's Little Theorem
  -- inv0 :: a -> a
  inv0 :: Fz z -> Fz z
inv0 (Fz Integer
a) = forall (z :: Nat). Integer -> Fz z
Fz forall a b. (a -> b) -> a -> b
$ Integer -> Integer -> Integer -> Integer
_powMod Integer
a (MOD - 2) (MOD)


  -- Determines if the operand has a square root. Uses helper functions with Integers
  -- isSqr :: a -> Bool
  isSqr :: Fz z -> Bool
isSqr (Fz Integer
a) = Integer -> Integer -> Bool
_isSqr Integer
a (MOD)


  -- The `rndF` function returns a random (invertible/non-zero) field element.
  rndF :: forall r. RandomGen r => r -> (r, Fz z)
rndF r
rndGen = forall a. Num a => Integer -> a
fromInteger forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. (a, b) -> (b, a)
swap (forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Integer
1, MOD - 1) rndGen)


  -- Returns the least significant bit of the field element as an Integer
  -- sgn0 :: a -> Integer
  sgn0 :: Fz z -> Integer
sgn0 (Fz Integer
a) = Integer
a forall a. Integral a => a -> a -> a
`mod` Integer
2


  -- Shift right by 1 (divides the element by 2, discarding the remainder)
  -- shiftR1 :: a -> a
  shiftR1 :: Fz z -> Fz z
shiftR1 (Fz Integer
a) = forall (z :: Nat). Integer -> Fz z
Fz forall a b. (a -> b) -> a -> b
$ Integer
a forall a. Integral a => a -> a -> a
`div` Integer
2


  -- Returns square root as Maybe field element. If any problems, returns Nothing.
  -- sqrt :: a -> Maybe a
  sqrt :: Fz z -> Maybe (Fz z)
sqrt (Fz Integer
a) = forall a. Num a => Integer -> a
fromInteger forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Integer
-> Integer -> Integer -> Integer -> Integer -> Maybe Integer
_sqrtVt Integer
a (MOD) s p c  -- Use helper function
    where  
      -- rewrite (modulus - 1) as p * 2**s 
      s :: Integer
s = forall a. (a -> Bool) -> (a -> a) -> a -> a
until ((forall a. Eq a => a -> a -> Bool
/= Integer
0) forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((MOD -1) forall {k} (t :: k). Proxy t
`rem`) . (2 ^)) (+ Integer
1) 0 - 1 :: Integer
      p :: Integer
p = (MOD - 1) `div` (2 ^ s)
      -- Find first non-square and use that to prepare \'fountain of fixes\'
      z :: Integer
z = forall a. [a] -> a
head ([Integer
x | Integer
x <- [Integer
1..], Bool -> Bool
not (Integer -> Integer -> Bool
_isSqr Integer
x (MOD))] ++ [0])
      c :: Integer
c = Integer -> Integer -> Integer -> Integer
_powMod Integer
z Integer
p (MOD)


  -- Deserialization. Follows section 2.3.7 of https://www.secg.org/sec1-v2.pdf
  -- toBytesF :: a -> ByteString
  toBytesF :: Fz z -> ByteString
toBytesF (Fz Integer
a) = [Word8] -> ByteString
pack forall a b. (a -> b) -> a -> b
$ forall a. [a] -> [a]
reverse [Word8]
res
    where
      corLen :: Int
corLen = forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ (Integer
7 forall a. Num a => a -> a -> a
+ forall a. (a -> Bool) -> (a -> a) -> a -> a
until ((MOD <) . (2 ^)) (+ 1) 0) `divforall b c a. (b -> c) -> (a -> b) -> a -> c
` 8 :: Int
      res :: [Word8]
res = [forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. Bits a => a -> Int -> a
shiftR Integer
a (Int
8forall a. Num a => a -> a -> a
*Int
b)) | Int
b <- [Int
0..(Int
corLen forall a. Num a => a -> a -> a
- Int
1)]] :: [Word8]


  -- Returns the element as an Integer
  -- toI :: a -> Integer 
  toI :: Fz z -> Integer
toI (Fz Integer
a) = Integer
a


-- Complex/common support functions operating on integers rather than field elements

-- | Modular exponentiation.
-- _powMod :: operand -> exponent -> modulus
_powMod :: Integer -> Integer -> Integer -> Integer
_powMod :: Integer -> Integer -> Integer -> Integer
_powMod Integer
_ Integer
e Integer
q | Integer
e forall a. Ord a => a -> a -> Bool
< Integer
0 Bool -> Bool -> Bool
|| Integer
q forall a. Ord a => a -> a -> Bool
< Integer
2 = forall a. HasCallStack => [Char] -> a
error [Char]
"Invalid exponent/modulus"
_powMod Integer
_ Integer
0 Integer
_ = Integer
1
_powMod Integer
a Integer
1 Integer
_ = Integer
a
_powMod Integer
a Integer
e Integer
q | forall a. Integral a => a -> Bool
even Integer
e = Integer -> Integer -> Integer -> Integer
_powMod (Integer
a forall a. Num a => a -> a -> a
* Integer
a forall a. Integral a => a -> a -> a
`mod` Integer
q) (forall a. Bits a => a -> Int -> a
shiftR Integer
e Int
1) Integer
q
              | Bool
otherwise = Integer
a forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer -> Integer
_powMod Integer
a (Integer
e forall a. Num a => a -> a -> a
- Integer
1) Integer
q forall a. Integral a => a -> a -> a
`mod` Integer
q


-- | Is operand a square via Legendre symbol.
-- isSqr :: operand -> modulus
_isSqr :: Integer -> Integer -> Bool
_isSqr :: Integer -> Integer -> Bool
_isSqr Integer
a Integer
q = (Integer
legendreSymbol forall a. Eq a => a -> a -> Bool
== Integer
0) Bool -> Bool -> Bool
|| (Integer
legendreSymbol forall a. Eq a => a -> a -> Bool
== Integer
1)
  where legendreSymbol :: Integer
legendreSymbol = Integer -> Integer -> Integer -> Integer
_powMod Integer
a ((Integer
q forall a. Num a => a -> a -> a
- Integer
1) forall a. Integral a => a -> a -> a
`div` Integer
2) Integer
q


-- | Variable-time Tonelli-Shanks algorithm. Works with any prime modulus.
-- _sqrtVt :: operand -> modulus -> \'s\' -> \'p\' -> nonSquare
_sqrtVt :: Integer -> Integer -> Integer -> Integer -> Integer -> Maybe Integer
_sqrtVt :: Integer
-> Integer -> Integer -> Integer -> Integer -> Maybe Integer
_sqrtVt Integer
0 Integer
_ Integer
_ Integer
_ Integer
_ = forall a. a -> Maybe a
Just Integer
0
_sqrtVt Integer
a Integer
q Integer
_ Integer
_ Integer
_ | Bool -> Bool
not (Integer -> Integer -> Bool
_isSqr Integer
a Integer
q) = forall a. Maybe a
Nothing  -- Not truly necessary
_sqrtVt Integer
_ Integer
_ Integer
_ Integer
_ Integer
0 = forall a. Maybe a
Nothing  -- covers the bases
_sqrtVt Integer
a Integer
q Integer
s Integer
p Integer
c = forall a. a -> Maybe a
Just Integer
result
  where
    t :: Integer
t = Integer -> Integer -> Integer -> Integer
_powMod Integer
a Integer
p Integer
q
    r :: Integer
r = Integer -> Integer -> Integer -> Integer
_powMod Integer
a ((Integer
p forall a. Num a => a -> a -> a
+ Integer
1) forall a. Integral a => a -> a -> a
`div` Integer
2) Integer
q
    result :: Integer
result = Integer -> Integer -> Integer -> Integer -> Integer
loopy Integer
t Integer
r Integer
c Integer
s  -- recursively iterate the function below
    loopy :: Integer -> Integer -> Integer -> Integer -> Integer
    loopy :: Integer -> Integer -> Integer -> Integer -> Integer
loopy Integer
tt  Integer
_  Integer
_ Integer
ss | Integer
tt forall a. Eq a => a -> a -> Bool
== Integer
0 Bool -> Bool -> Bool
|| Integer
ss forall a. Eq a => a -> a -> Bool
== Integer
0 = Integer
0
    loopy  Integer
1 Integer
rr  Integer
_  Integer
_ = Integer
rr
    loopy Integer
tt Integer
rr Integer
cc Integer
ss = Integer -> Integer -> Integer -> Integer -> Integer
loopy Integer
t_n Integer
r_n Integer
c_n Integer
s_n  -- read _n as _next
      where
        s_n :: Integer
s_n = forall a. [a] -> a
head ([Integer
i | Integer
i <- [Integer
1..(Integer
ss forall a. Num a => a -> a -> a
- Integer
1)], Integer -> Integer -> Integer -> Integer
_powMod Integer
tt (Integer
2 forall a b. (Num a, Integral b) => a -> b -> a
^ Integer
i) Integer
q forall a. Eq a => a -> a -> Bool
== Integer
1] forall a. [a] -> [a] -> [a]
++ [Integer
0]) -- ++0 avoids empty
        ff :: Integer
ff = Integer -> Integer -> Integer -> Integer
_powMod Integer
cc (Integer
2 forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
ss forall a. Num a => a -> a -> a
- Integer
s_n forall a. Num a => a -> a -> a
- Integer
1)) Integer
q
        r_n :: Integer
r_n = Integer
rr forall a. Num a => a -> a -> a
* Integer
ff forall a. Integral a => a -> a -> a
`mod` Integer
q
        t_n :: Integer
t_n = (Integer
tt forall a. Num a => a -> a -> a
* Integer -> Integer -> Integer -> Integer
_powMod Integer
ff Integer
2 Integer
q) forall a. Integral a => a -> a -> a
`mod` Integer
q
        c_n :: Integer
c_n = Integer -> Integer -> Integer -> Integer
_powMod Integer
cc (Integer
2 forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
ss forall a. Num a => a -> a -> a
- Integer
s_n)) Integer
q