{-# LANGUAGE CPP, DataKinds, DerivingStrategies, FlexibleInstances, PolyKinds #-}
{-# LANGUAGE MultiParamTypeClasses, NoImplicitPrelude, Safe, ScopedTypeVariables #-}
module Curves (Curve(..), CurvePt(..), Point(..)) where
import Prelude hiding (drop, length, sqrt)
import Control.Monad (mfilter)
import Data.ByteString (ByteString, cons, drop, index, length, pack)
import Data.Maybe (fromJust)
import Data.Typeable (Proxy (Proxy))
import GHC.TypeLits (Nat, KnownNat, natVal)
import Fields (Field (..))
data Point (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f =
Projective {forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_x :: f, forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_y :: f, forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z :: f} deriving stock (Int -> Point a b baseX baseY f -> ShowS
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Int -> Point a b baseX baseY f -> ShowS
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
[Point a b baseX baseY f] -> ShowS
forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Point a b baseX baseY f -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Point a b baseX baseY f] -> ShowS
$cshowList :: forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
[Point a b baseX baseY f] -> ShowS
show :: Point a b baseX baseY f -> String
$cshow :: forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Point a b baseX baseY f -> String
showsPrec :: Int -> Point a b baseX baseY f -> ShowS
$cshowsPrec :: forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Show f =>
Int -> Point a b baseX baseY f -> ShowS
Show)
#define A natVal (Proxy :: Proxy a)
#define B natVal (Proxy :: Proxy b)
#define BASE_X natVal (Proxy :: Proxy baseX)
#define BASE_Y natVal (Proxy :: Proxy baseY)
instance (Field f, KnownNat a, KnownNat b, KnownNat baseX, KnownNat baseY) =>
Eq (Point a b baseX baseY f) where
== :: Point a b baseX baseY f -> Point a b baseX baseY f -> Bool
(==) (Projective f
x1 f
y1 f
z1) (Projective f
x2 f
y2 f
z2) =
(f
x1 forall a. Num a => a -> a -> a
* f
z2 forall a. Eq a => a -> a -> Bool
== f
x2 forall a. Num a => a -> a -> a
* f
z1) Bool -> Bool -> Bool
&& (f
y1 forall a. Num a => a -> a -> a
* f
z2 forall a. Eq a => a -> a -> Bool
== f
y2 forall a. Num a => a -> a -> a
* f
z1)
class CurvePt a where
base :: a
fromBytesC :: ByteString -> Maybe a
isOnCurve :: a -> Bool
negatePt :: a -> a
neutral :: a
pointAdd :: a -> a -> a
toBytesC :: a -> ByteString
instance (Field f, KnownNat a, KnownNat b, KnownNat baseX, KnownNat baseY) =>
CurvePt (Point a b baseX baseY f) where
base :: Point a b baseX baseY f
base = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective (forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall (n :: Nat) (proxy :: Nat -> *).
KnownNat n =>
proxy n -> Integer
BASE_X) (fromInteger $ BASE_Y) 1
fromBytesC :: ByteString -> Maybe (Point a b baseX baseY f)
fromBytesC ByteString
bytes
| ByteString -> Int
length ByteString
bytes forall a. Eq a => a -> a -> Bool
== Int
1 Bool -> Bool -> Bool
&& HasCallStack => ByteString -> Int -> Word8
index ByteString
bytes Int
0 forall a. Eq a => a -> a -> Bool
== Word8
0 = forall a. a -> Maybe a
Just forall a. CurvePt a => a
neutral
| ByteString -> Int
length ByteString
bytes forall a. Eq a => a -> a -> Bool
== Int
corLen Bool -> Bool -> Bool
&& (HasCallStack => ByteString -> Int -> Word8
index ByteString
bytes Int
0 forall a. Eq a => a -> a -> Bool
== Word8
0x2 Bool -> Bool -> Bool
|| HasCallStack => ByteString -> Int -> Word8
index ByteString
bytes Int
0 forall a. Eq a => a -> a -> Bool
== Word8
0x03) = Maybe (Point a b baseX baseY f)
result
where
corLen :: Int
corLen = Int
1 forall a. Num a => a -> a -> a
+ ByteString -> Int
length (forall a. Field a => a -> ByteString
toBytesF (forall a. Num a => Integer -> a
fromInteger (A) :: f))
x :: Maybe f
x = forall a. Field a => ByteString -> Maybe a
fromBytesF (Int -> ByteString -> ByteString
drop Int
1 ByteString
bytes) :: Maybe f
sgn0y :: Integer
sgn0y = if HasCallStack => ByteString -> Int -> Word8
index ByteString
bytes Int
0 forall a. Eq a => a -> a -> Bool
== Word8
0x02 then Integer
0 else Integer
1 :: Integer
alpha :: Maybe f
alpha = (\f
t -> f
t forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) forall a. Num a => a -> a -> a
+ ((forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ A) :: f) * t + ((fromInteger $ B) :: forall a. Num a => a -> a -> a
f)) <$> x
beta :: Maybe f
beta = Maybe f
alpha forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Field a => a -> Maybe a
sqrt
y :: Maybe f
y = (\f
t -> if forall a. Field a => a -> Integer
sgn0 f
t forall a. Eq a => a -> a -> Bool
== Integer
sgn0y then f
t else forall a. Num a => a -> a
negate f
t) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe f
beta
proposed :: Maybe (Point a b baseX baseY f)
proposed = (forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe f
x forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Maybe f
y forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. a -> Maybe a
Just f
1) :: Maybe (Point a b baseX baseY f)
result :: Maybe (Point a b baseX baseY f)
result = forall (m :: * -> *) a. MonadPlus m => (a -> Bool) -> m a -> m a
mfilter forall a. CurvePt a => a -> Bool
isOnCurve Maybe (Point a b baseX baseY f)
proposed
fromBytesC ByteString
_ = forall a. Maybe a
Nothing
isOnCurve :: Point a b baseX baseY f -> Bool
isOnCurve (Projective f
x f
y f
z) = f
z forall a. Num a => a -> a -> a
* f
y forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer) forall a. Eq a => a -> a -> Bool
== f
x forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) forall a. Num a => a -> a -> a
+
forall a. Num a => Integer -> a
fromInteger (A) * x * z ^ (2 :: Integer) + fromInteger (B) * z ^ forall a. Num a => a -> a -> a
(3 :: Integer)
negatePt :: Point a b baseX baseY f -> Point a b baseX baseY f
negatePt (Projective f
x f
y f
z) = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective f
x (- f
y) f
z
neutral :: Point a b baseX baseY f
neutral = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective f
0 f
1 f
0
pointAdd :: Point a b baseX baseY f
-> Point a b baseX baseY f -> Point a b baseX baseY f
pointAdd (Projective f
x1 f
y1 f
z1) (Projective f
x2 f
y2 f
z2) = Point a b baseX baseY f
result
where
m0 :: f
m0 = f
x1 forall a. Num a => a -> a -> a
* f
x2
m1 :: f
m1 = f
y1 forall a. Num a => a -> a -> a
* f
y2
m2 :: f
m2 = f
z1 forall a. Num a => a -> a -> a
* f
z2
m3 :: f
m3 = (f
x1 forall a. Num a => a -> a -> a
+ f
y1) forall a. Num a => a -> a -> a
* (f
x2 forall a. Num a => a -> a -> a
+ f
y2)
m4 :: f
m4 = (f
x1 forall a. Num a => a -> a -> a
+ f
z1) forall a. Num a => a -> a -> a
* (f
x2 forall a. Num a => a -> a -> a
+ f
z2)
m5 :: f
m5 = (f
y1 forall a. Num a => a -> a -> a
+ f
z1) forall a. Num a => a -> a -> a
* (f
y2 forall a. Num a => a -> a -> a
+ f
z2)
m6 :: f
m6 = ((forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ A) :: f) * (- m0 - m2 + m4)
m7 :: f
m7 = ((forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ Integer
3 forall a. Num a => a -> a -> a
* B) :: f) * m2
m8 :: f
m8 = (f
m1 forall a. Num a => a -> a -> a
- f
m6 forall a. Num a => a -> a -> a
- f
m7) forall a. Num a => a -> a -> a
* (f
m1 forall a. Num a => a -> a -> a
+ f
m6 forall a. Num a => a -> a -> a
+ f
m7)
m9 :: f
m9 = ((forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ A) :: f) * m2
m10 :: f
m10 = ((forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ Integer
3 forall a. Num a => a -> a -> a
* B) :: f) * (- m0 - m2 + m4)
m11 :: f
m11 = ((forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ A) :: f) * (m0 - m9)
m12 :: f
m12 = (f
m0 forall a. Num a => a -> a -> a
* f
3 forall a. Num a => a -> a -> a
+ f
m9) forall a. Num a => a -> a -> a
* (f
m10 forall a. Num a => a -> a -> a
+ f
m11)
m13 :: f
m13 = (- f
m1 forall a. Num a => a -> a -> a
- f
m2 forall a. Num a => a -> a -> a
+ f
m5) forall a. Num a => a -> a -> a
* (f
m10 forall a. Num a => a -> a -> a
+ f
m11)
m14 :: f
m14 = (- f
m0 forall a. Num a => a -> a -> a
- f
m1 forall a. Num a => a -> a -> a
+ f
m3) forall a. Num a => a -> a -> a
* (f
m1 forall a. Num a => a -> a -> a
- f
m6 forall a. Num a => a -> a -> a
- f
m7)
m15 :: f
m15 = (- f
m0 forall a. Num a => a -> a -> a
- f
m1 forall a. Num a => a -> a -> a
+ f
m3) forall a. Num a => a -> a -> a
* (f
m0 forall a. Num a => a -> a -> a
* f
3 forall a. Num a => a -> a -> a
+ f
m9)
m16 :: f
m16 = (- f
m1 forall a. Num a => a -> a -> a
- f
m2 forall a. Num a => a -> a -> a
+ f
m5) forall a. Num a => a -> a -> a
* (f
m1 forall a. Num a => a -> a -> a
+ f
m6 forall a. Num a => a -> a -> a
+ f
m7)
result :: Point a b baseX baseY f
result = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective (-f
m13 forall a. Num a => a -> a -> a
+ f
m14) (f
m8 forall a. Num a => a -> a -> a
+ f
m12) (f
m15 forall a. Num a => a -> a -> a
+ f
m16) :: Point a b baseX baseY f
toBytesC :: Point a b baseX baseY f -> ByteString
toBytesC Point a b baseX baseY f
pt
| Bool -> Bool
not forall a b. (a -> b) -> a -> b
$ forall a. CurvePt a => a -> Bool
isOnCurve Point a b baseX baseY f
pt = forall a. HasCallStack => String -> a
error String
"trying to serialize point not on curve"
| forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z Point a b baseX baseY f
pt forall a. Eq a => a -> a -> Bool
== f
0 = [Word8] -> ByteString
pack [Word8
0]
| forall a. Field a => a -> Integer
sgn0 f
y forall a. Eq a => a -> a -> Bool
== Integer
0 = Word8 -> ByteString -> ByteString
cons Word8
0x02 (forall a. Field a => a -> ByteString
toBytesF f
x)
| Bool
otherwise = Word8 -> ByteString -> ByteString
cons Word8
0x03 (forall a. Field a => a -> ByteString
toBytesF f
x)
where
x :: f
x = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_x Point a b baseX baseY f
pt forall a. Num a => a -> a -> a
* forall a. Field a => a -> a
inv0 (forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z Point a b baseX baseY f
pt)
y :: f
y = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_y Point a b baseX baseY f
pt forall a. Num a => a -> a -> a
* forall a. Field a => a -> a
inv0 (forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
Point a b baseX baseY f -> f
_z Point a b baseX baseY f
pt)
class (CurvePt a, Field b) => Curve a b where
pointMul :: b -> a -> a
mapToCurveSimpleSwu :: b -> b -> a
instance (Field f1, Field f2, KnownNat a, KnownNat b, KnownNat baseX, KnownNat baseY) =>
Curve (Point a b baseX baseY f1) f2 where
pointMul :: f2 -> Point a b baseX baseY f1 -> Point a b baseX baseY f1
pointMul f2
s Point a b baseX baseY f1
pt = f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' f2
s Point a b baseX baseY f1
pt forall a. CurvePt a => a
neutral
where
pointMul' :: f2 -> Point a b baseX baseY f1 -> Point a b baseX baseY f1 -> Point a b baseX baseY f1
pointMul' :: f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' f2
scalar Point a b baseX baseY f1
p1 Point a b baseX baseY f1
accum
| f2
scalar forall a. Eq a => a -> a -> Bool
== f2
0 = Point a b baseX baseY f1
accum
| forall a. Field a => a -> Integer
sgn0 f2
scalar forall a. Eq a => a -> a -> Bool
/= Integer
0 = f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' (forall a. Field a => a -> a
shiftR1 f2
scalar) Point a b baseX baseY f1
doublePt (forall a. CurvePt a => a -> a -> a
pointAdd Point a b baseX baseY f1
accum Point a b baseX baseY f1
p1)
| forall a. Field a => a -> Integer
sgn0 f2
scalar forall a. Eq a => a -> a -> Bool
== Integer
0 = f2
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
-> Point a b baseX baseY f1
pointMul' (forall a. Field a => a -> a
shiftR1 f2
scalar) Point a b baseX baseY f1
doublePt Point a b baseX baseY f1
accum
| Bool
otherwise = forall a. HasCallStack => String -> a
error String
"pointMul' pattern match fail (should never happen)"
where
doublePt :: Point a b baseX baseY f1
doublePt = forall a. CurvePt a => a -> a -> a
pointAdd Point a b baseX baseY f1
p1 Point a b baseX baseY f1
p1
mapToCurveSimpleSwu :: f2 -> f2 -> Point a b baseX baseY f1
mapToCurveSimpleSwu f2
fu f2
fz = if A * B /= 0 then result else error "Curve params A*B must not be zero"
where
u :: f1
u = (forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall a. Field a => a -> Integer
toI f2
fu) :: f1
z :: f1
z = (forall a. Num a => Integer -> a
fromInteger forall a b. (a -> b) -> a -> b
$ forall a. Field a => a -> Integer
toI f2
fz) :: f1
tv1 :: f1
tv1 = forall a. Field a => a -> a
inv0 (f1
z forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer) forall a. Num a => a -> a -> a
* f1
u forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
4 :: Integer) forall a. Num a => a -> a -> a
+ f1
z forall a. Num a => a -> a -> a
* f1
u forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer))
x1a :: f1
x1a = (forall a. Num a => Integer -> a
fromInteger ((-Integer
1) forall a. Num a => a -> a -> a
* B) * inv0 (fromInteger (A))forall a. Num a => a -> a -> a
) * (1 + tv1)
x1 :: f1
x1 = if forall a. Field a => a -> Integer
toI f1
tv1 forall a. Eq a => a -> a -> Bool
== Integer
0 then forall a. Num a => Integer -> a
fromInteger (B) * inv0 (z * fromInteger forall a. Num a => a -> a -> a
(A)) else x1a
gx1 :: f1
gx1 = f1
x1 forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) forall a. Num a => a -> a -> a
+ forall a. Num a => Integer -> a
fromInteger (A) * x1 + fromInteger (B)
x2 :: f1
x2 = f1
z forall a. Num a => a -> a -> a
* f1
u forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
2 :: Integer) forall a. Num a => a -> a -> a
* f1
x1
gx2 :: f1
gx2 = f1
x2 forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
3 :: Integer) forall a. Num a => a -> a -> a
+ forall a. Num a => Integer -> a
fromInteger (A) * x2 + fromInteger (B)
(f1
x, f1
ya) = if forall a. Field a => a -> Bool
isSqr f1
gx1 then (f1
x1, forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall a. Field a => a -> Maybe a
sqrt f1
gx1) else (f1
x2, forall a. HasCallStack => Maybe a -> a
fromJust forall a b. (a -> b) -> a -> b
$ forall a. Field a => a -> Maybe a
sqrt f1
gx2)
y :: f1
y = if forall a. Field a => a -> Integer
sgn0 f1
u forall a. Eq a => a -> a -> Bool
/= forall a. Field a => a -> Integer
sgn0 f1
ya then -f1
ya else f1
ya
result :: Point a b baseX baseY f1
result = forall (a :: Nat) (b :: Nat) (baseX :: Nat) (baseY :: Nat) f.
f -> f -> f -> Point a b baseX baseY f
Projective f1
x f1
y f1
1 :: Point a b baseX baseY f1