{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-warn-type-defaults #-}
module Math.NumberTheory.Curves.Montgomery
( Point
, pointX
, pointZ
, pointN
, pointA24
, SomePoint(..)
, newPoint
, add
, double
, multiply
) where
import Data.Proxy
import GHC.Exts
import GHC.Integer.Logarithms
import GHC.TypeNats.Compat
import Math.NumberTheory.Utils (recipMod)
data Point (a24 :: Nat) (n :: Nat) = Point
{ pointX :: !Integer
, pointZ :: !Integer
}
pointA24 :: forall a24 n. KnownNat a24 => Point a24 n -> Integer
pointA24 _ = toInteger $ natVal (Proxy :: Proxy a24)
pointN :: forall a24 n. KnownNat n => Point a24 n -> Integer
pointN _ = toInteger $ natVal (Proxy :: Proxy n)
instance KnownNat n => Eq (Point a24 n) where
Point _ 0 == Point _ 0 = True
Point _ 0 == _ = False
_ == Point _ 0 = False
p@(Point x1 z1) == Point x2 z2 = let n = pointN p in x1 * z2 `mod` n == x2 * z1 `mod` n
instance (KnownNat a24, KnownNat n) => Show (Point a24 n) where
show p = "(" ++ show (pointX p) ++ ", " ++ show (pointZ p) ++ ") (a24 "
++ show (pointA24 p) ++ ", mod "
++ show (pointN p) ++ ")"
data SomePoint where
SomePoint :: (KnownNat a24, KnownNat n) => Point a24 n -> SomePoint
instance Show SomePoint where
show (SomePoint p) = show p
newPoint :: Integer -> Integer -> Maybe SomePoint
newPoint s n = do
a24denRecip <- recipMod a24den n
a24 <- case a24num * a24denRecip `rem` n of
0 -> Nothing
1 -> Nothing
t -> Just t
SomeNat (_ :: Proxy a24Ty) <- if a24 < 0
then Nothing
else Just $ someNatVal $ fromInteger a24
SomeNat (_ :: Proxy nTy) <- if n < 0
then Nothing
else Just $ someNatVal $ fromInteger n
return $ SomePoint (Point x z :: Point a24Ty nTy)
where
u = s * s `rem` n - 5
v = 4 * s
d = v - u
x = u * u * u `mod` n
z = v * v * v `mod` n
a24num = d * d * d * (3 * u + v) `mod` n
a24den = 16 * x * v `rem` n
add :: KnownNat n => Point a24 n -> Point a24 n -> Point a24 n -> Point a24 n
add p0@(Point x0 z0) (Point x1 z1) (Point x2 z2) = Point x3 z3
where
n = pointN p0
a = (x1 - z1) * (x2 + z2) `rem` n
b = (x1 + z1) * (x2 - z2) `rem` n
apb = a + b
amb = a - b
c = apb * apb `rem` n
d = amb * amb `rem` n
x3 = c * z0 `mod` n
z3 = d * x0 `mod` n
double :: (KnownNat a24, KnownNat n) => Point a24 n -> Point a24 n
double p@(Point x z) = Point x' z'
where
n = pointN p
a24 = pointA24 p
r = x + z
s = x - z
u = r * r `rem` n
v = s * s `rem` n
t = u - v
x' = u * v `mod` n
z' = (v + a24 * t `rem` n) * t `mod` n
multiply :: (KnownNat a24, KnownNat n) => Word -> Point a24 n -> Point a24 n
multiply 0 _ = Point 0 0
multiply 1 p = p
multiply (W# w##) p =
case wordLog2# w## of
l# -> go (l# -# 1#) p (double p)
where
go 0# !p0 !p1 = case w## `and#` 1## of
0## -> double p0
_ -> add p p0 p1
go i# p0 p1 = case uncheckedShiftRL# w## i# `and#` 1## of
0## -> go (i# -# 1#) (double p0) (add p p0 p1)
_ -> go (i# -# 1#) (add p p0 p1) (double p1)