{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module ZkFold.Base.Algebra.EllipticCurve.Class where
import Data.Functor ((<&>))
import Data.Kind (Type)
import Numeric.Natural (Natural)
import Prelude hiding (Num (..), sum, (/), (^))
import Test.QuickCheck hiding (scale)
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Data.ByteString
data Point curve = Point { forall {k} (curve :: k). Point curve -> BaseField curve
_x :: (BaseField curve), forall {k} (curve :: k). Point curve -> BaseField curve
_y :: (BaseField curve) } | Inf
class EllipticCurve curve where
type BaseField curve :: Type
type ScalarField curve :: Type
inf :: Point curve
gen :: Point curve
add :: Point curve -> Point curve -> Point curve
mul :: ScalarField curve -> Point curve -> Point curve
instance (EllipticCurve curve, Show (BaseField curve)) => Show (Point curve) where
show :: Point curve -> String
show Point curve
Inf = String
"Inf"
show (Point BaseField curve
x BaseField curve
y) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
y String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
instance (EllipticCurve curve, Eq (BaseField curve)) => Eq (Point curve) where
Point curve
Inf == :: Point curve -> Point curve -> Bool
== Point curve
Inf = Bool
True
Point curve
Inf == Point curve
_ = Bool
False
Point curve
_ == Point curve
Inf = Bool
False
Point BaseField curve
x1 BaseField curve
y1 == Point BaseField curve
x2 BaseField curve
y2 = BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2 Bool -> Bool -> Bool
&& BaseField curve
y1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
y2
instance EllipticCurve curve => AdditiveSemigroup (Point curve) where
+ :: Point curve -> Point curve -> Point curve
(+) = Point curve -> Point curve -> Point curve
forall k (curve :: k).
EllipticCurve curve =>
Point curve -> Point curve -> Point curve
add
instance EllipticCurve curve => Scale Natural (Point curve) where
scale :: Natural -> Point curve -> Point curve
scale = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale
instance EllipticCurve curve => AdditiveMonoid (Point curve) where
zero :: Point curve
zero = Point curve
forall {k} (curve :: k). Point curve
Inf
instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => Scale Integer (Point curve) where
scale :: Integer -> Point curve -> Point curve
scale = Integer -> Point curve -> Point curve
forall a. AdditiveGroup a => Integer -> a -> a
intScale
instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => AdditiveGroup (Point curve) where
negate :: Point curve -> Point curve
negate = Point curve -> Point curve
forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate
instance (EllipticCurve curve, Binary (BaseField curve)) => Binary (Point curve) where
put :: Point curve -> Put
put Point curve
Inf = Word8 -> Put
putWord8 Word8
0
put (Point BaseField curve
x BaseField curve
y) = Word8 -> Put
putWord8 Word8
1 Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> BaseField curve -> Put
forall t. Binary t => t -> Put
put BaseField curve
x Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> BaseField curve -> Put
forall t. Binary t => t -> Put
put BaseField curve
y
get :: Get (Point curve)
get = do
Word8
flag <- Get Word8
getWord8
if Word8
flag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 then Point curve -> Get (Point curve)
forall a. a -> Get a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Point curve
forall {k} (curve :: k). Point curve
Inf
else if Word8
flag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
1 then BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point (BaseField curve -> BaseField curve -> Point curve)
-> Get (BaseField curve) -> Get (BaseField curve -> Point curve)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (BaseField curve)
forall t. Binary t => Get t
get Get (BaseField curve -> Point curve)
-> Get (BaseField curve) -> Get (Point curve)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Get (BaseField curve)
forall t. Binary t => Get t
get
else String -> Get (Point curve)
forall a. String -> Get a
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String
"Binary (Point curve): unexpected flag " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word8 -> String
forall a. Show a => a -> String
show Word8
flag)
instance (EllipticCurve curve, Arbitrary (ScalarField curve)) => Arbitrary (Point curve) where
arbitrary :: Gen (Point curve)
arbitrary = Gen (ScalarField curve)
forall a. Arbitrary a => Gen a
arbitrary Gen (ScalarField curve)
-> (ScalarField curve -> Point curve) -> Gen (Point curve)
forall (f :: Type -> Type) a b. Functor f => f a -> (a -> b) -> f b
<&> (ScalarField curve -> Point curve -> Point curve
forall {k} (curve :: k).
EllipticCurve curve =>
ScalarField curve -> Point curve -> Point curve
`mul` Point curve
forall k (curve :: k). EllipticCurve curve => Point curve
gen)
class (EllipticCurve curve1, EllipticCurve curve2, ScalarField curve1 ~ ScalarField curve2,
Eq t, MultiplicativeGroup t, Exponent t (ScalarField curve1)) => Pairing curve1 curve2 t | curve1 curve2 -> t where
pairing :: Point curve1 -> Point curve2 -> t
pointAdd
:: Field (BaseField curve)
=> Eq (BaseField curve)
=> Point curve
-> Point curve
-> Point curve
pointAdd :: forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p Point curve
Inf = Point curve
p
pointAdd Point curve
Inf Point curve
q = Point curve
q
pointAdd (Point BaseField curve
x1 BaseField curve
y1) (Point BaseField curve
x2 BaseField curve
y2)
| BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2 = Point curve
forall {k} (curve :: k). Point curve
Inf
| Bool
otherwise = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x3 BaseField curve
y3
where
slope :: BaseField curve
slope = (BaseField curve
y1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y2) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2)
x3 :: BaseField curve
x3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2
y3 :: BaseField curve
y3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x3) BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y1
pointDouble
:: Field (BaseField curve)
=> Point curve -> Point curve
pointDouble :: forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
Inf = Point curve
forall {k} (curve :: k). Point curve
Inf
pointDouble (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x' BaseField curve
y'
where
slope :: BaseField curve
slope = (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
y BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
y)
x' :: BaseField curve
x' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x
y' :: BaseField curve
y' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x') BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y
addPoints
:: EllipticCurve curve
=> Field (BaseField curve)
=> Eq (BaseField curve)
=> Point curve
-> Point curve
-> Point curve
addPoints :: forall {k} (curve :: k).
(EllipticCurve curve, Field (BaseField curve),
Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
addPoints Point curve
p1 Point curve
p2
| Point curve
p1 Point curve -> Point curve -> Bool
forall a. Eq a => a -> a -> Bool
== Point curve
p2 = Point curve -> Point curve
forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
p1
| Bool
otherwise = Point curve -> Point curve -> Point curve
forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p1 Point curve
p2
pointNegate
:: AdditiveGroup (BaseField curve)
=> Point curve -> Point curve
pointNegate :: forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate Point curve
Inf = Point curve
forall {k} (curve :: k). Point curve
Inf
pointNegate (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x (BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y)
pointMul
:: forall curve
. EllipticCurve curve
=> BinaryExpansion (ScalarField curve)
=> Eq (ScalarField curve)
=> ScalarField curve
-> Point curve
-> Point curve
pointMul :: forall {k} (curve :: k).
(EllipticCurve curve, BinaryExpansion (ScalarField curve),
Eq (ScalarField curve)) =>
ScalarField curve -> Point curve -> Point curve
pointMul = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale (Natural -> Point curve -> Point curve)
-> (ScalarField curve -> Natural)
-> ScalarField curve
-> Point curve
-> Point curve
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Natural] -> Natural
forall a. BinaryExpansion a => [a] -> a
fromBinary ([Natural] -> Natural)
-> (ScalarField curve -> [Natural]) -> ScalarField curve -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ScalarField curve] -> [Natural]
forall a b. (Semiring a, Eq a, Semiring b) => [a] -> [b]
castBits ([ScalarField curve] -> [Natural])
-> (ScalarField curve -> [ScalarField curve])
-> ScalarField curve
-> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarField curve -> [ScalarField curve]
forall a. BinaryExpansion a => a -> [a]
binaryExpansion