{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
module ZkFold.Base.Algebra.EllipticCurve.Class where
import Control.DeepSeq (NFData)
import Data.Functor ((<&>))
import Data.Kind (Type)
import GHC.Generics (Generic)
import Prelude hiding (Num (..), sum, (/), (^))
import qualified Prelude as P
import Test.QuickCheck hiding (scale)
import ZkFold.Base.Algebra.Basic.Class
import ZkFold.Base.Algebra.Basic.Number
data Point curve = Point { forall {k} (curve :: k). Point curve -> BaseField curve
_x :: BaseField curve, forall {k} (curve :: k). Point curve -> BaseField curve
_y :: BaseField curve } | Inf
deriving ((forall x. Point curve -> Rep (Point curve) x)
-> (forall x. Rep (Point curve) x -> Point curve)
-> Generic (Point curve)
forall x. Rep (Point curve) x -> Point curve
forall x. Point curve -> Rep (Point curve) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall k (curve :: k) x. Rep (Point curve) x -> Point curve
forall k (curve :: k) x. Point curve -> Rep (Point curve) x
$cfrom :: forall k (curve :: k) x. Point curve -> Rep (Point curve) x
from :: forall x. Point curve -> Rep (Point curve) x
$cto :: forall k (curve :: k) x. Rep (Point curve) x -> Point curve
to :: forall x. Rep (Point curve) x -> Point curve
Generic)
deriving instance NFData (BaseField curve) => NFData (Point curve)
class EllipticCurve curve where
type BaseField curve :: Type
type ScalarField curve :: Type
inf :: Point curve
gen :: Point curve
add :: Point curve -> Point curve -> Point curve
mul :: ScalarField curve -> Point curve -> Point curve
instance (EllipticCurve curve, Show (BaseField curve)) => Show (Point curve) where
show :: Point curve -> String
show Point curve
Inf = String
"Inf"
show (Point BaseField curve
x BaseField curve
y) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
y String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
instance (EllipticCurve curve, Eq (BaseField curve)) => Eq (Point curve) where
Point curve
Inf == :: Point curve -> Point curve -> Bool
== Point curve
Inf = Bool
True
Point curve
Inf == Point curve
_ = Bool
False
Point curve
_ == Point curve
Inf = Bool
False
Point BaseField curve
x1 BaseField curve
y1 == Point BaseField curve
x2 BaseField curve
y2 = BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2 Bool -> Bool -> Bool
&& BaseField curve
y1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
y2
instance EllipticCurve curve => AdditiveSemigroup (Point curve) where
+ :: Point curve -> Point curve -> Point curve
(+) = Point curve -> Point curve -> Point curve
forall k (curve :: k).
EllipticCurve curve =>
Point curve -> Point curve -> Point curve
add
instance {-# OVERLAPPABLE #-}
( EllipticCurve curve
, Eq s
, BinaryExpansion s
, Bits s ~ [s]
) => Scale s (Point curve) where
scale :: s -> Point curve -> Point curve
scale = s -> Point curve -> Point curve
forall {k} (curve :: k) s.
(EllipticCurve curve, BinaryExpansion s, Bits s ~ [s], Eq s) =>
s -> Point curve -> Point curve
pointMul
instance EllipticCurve curve => Scale Natural (Point curve) where
scale :: Natural -> Point curve -> Point curve
scale = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale
instance EllipticCurve curve => AdditiveMonoid (Point curve) where
zero :: Point curve
zero = Point curve
forall {k} (curve :: k). Point curve
Inf
instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => Scale Integer (Point curve) where
scale :: Integer -> Point curve -> Point curve
scale = Integer -> Point curve -> Point curve
forall a. AdditiveGroup a => Integer -> a -> a
intScale
instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => AdditiveGroup (Point curve) where
negate :: Point curve -> Point curve
negate = Point curve -> Point curve
forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate
instance (EllipticCurve curve, Arbitrary (ScalarField curve)) => Arbitrary (Point curve) where
arbitrary :: Gen (Point curve)
arbitrary = Gen (ScalarField curve)
forall a. Arbitrary a => Gen a
arbitrary Gen (ScalarField curve)
-> (ScalarField curve -> Point curve) -> Gen (Point curve)
forall (f :: Type -> Type) a b. Functor f => f a -> (a -> b) -> f b
<&> (ScalarField curve -> Point curve -> Point curve
forall {k} (curve :: k).
EllipticCurve curve =>
ScalarField curve -> Point curve -> Point curve
`mul` Point curve
forall k (curve :: k). EllipticCurve curve => Point curve
gen)
class (EllipticCurve curve1, EllipticCurve curve2, ScalarField curve1 ~ ScalarField curve2,
Eq (TargetGroup curve1 curve2), MultiplicativeGroup (TargetGroup curve1 curve2),
Exponent (TargetGroup curve1 curve2) (ScalarField curve1)) => Pairing curve1 curve2 where
type TargetGroup curve1 curve2 :: Type
pairing :: Point curve1 -> Point curve2 -> TargetGroup curve1 curve2
pointAdd
:: Field (BaseField curve)
=> Eq (BaseField curve)
=> Point curve
-> Point curve
-> Point curve
pointAdd :: forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p Point curve
Inf = Point curve
p
pointAdd Point curve
Inf Point curve
q = Point curve
q
pointAdd (Point BaseField curve
x1 BaseField curve
y1) (Point BaseField curve
x2 BaseField curve
y2)
| BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2 = Point curve
forall {k} (curve :: k). Point curve
Inf
| Bool
otherwise = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x3 BaseField curve
y3
where
slope :: BaseField curve
slope = (BaseField curve
y1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y2) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2)
x3 :: BaseField curve
x3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2
y3 :: BaseField curve
y3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x3) BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y1
pointDouble
:: Field (BaseField curve)
=> Point curve -> Point curve
pointDouble :: forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
Inf = Point curve
forall {k} (curve :: k). Point curve
Inf
pointDouble (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x' BaseField curve
y'
where
slope :: BaseField curve
slope = (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
y BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
y)
x' :: BaseField curve
x' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x
y' :: BaseField curve
y' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x') BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y
addPoints
:: EllipticCurve curve
=> Field (BaseField curve)
=> Eq (BaseField curve)
=> Point curve
-> Point curve
-> Point curve
addPoints :: forall {k} (curve :: k).
(EllipticCurve curve, Field (BaseField curve),
Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
addPoints Point curve
p1 Point curve
p2
| Point curve
p1 Point curve -> Point curve -> Bool
forall a. Eq a => a -> a -> Bool
== Point curve
p2 = Point curve -> Point curve
forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
p1
| Bool
otherwise = Point curve -> Point curve -> Point curve
forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p1 Point curve
p2
pointNegate
:: AdditiveGroup (BaseField curve)
=> Point curve -> Point curve
pointNegate :: forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate Point curve
Inf = Point curve
forall {k} (curve :: k). Point curve
Inf
pointNegate (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x (BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y)
pointMul
:: forall curve s
. EllipticCurve curve
=> BinaryExpansion (s)
=> Bits s ~ [s]
=> Eq s
=> s
-> Point curve
-> Point curve
pointMul :: forall {k} (curve :: k) s.
(EllipticCurve curve, BinaryExpansion s, Bits s ~ [s], Eq s) =>
s -> Point curve -> Point curve
pointMul = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale (Natural -> Point curve -> Point curve)
-> (s -> Natural) -> s -> Point curve -> Point curve
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Natural] -> Natural
Bits Natural -> Natural
forall a. BinaryExpansion a => Bits a -> a
fromBinary ([Natural] -> Natural) -> (s -> [Natural]) -> s -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [s] -> [Natural]
forall a b. (Semiring a, Eq a, Semiring b) => [a] -> [b]
castBits ([s] -> [Natural]) -> (s -> [s]) -> s -> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> [s]
s -> Bits s
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion
class EllipticCurve curve => StandardEllipticCurve curve where
aParameter :: BaseField curve
bParameter :: BaseField curve
data PointCompressed curve = PointCompressed (BaseField curve) Bool | InfCompressed
instance Show (BaseField curve) => Show (PointCompressed curve) where
show :: PointCompressed curve -> String
show PointCompressed curve
InfCompressed = String
"InfCompressed"
show (PointCompressed BaseField curve
x Bool
bigY) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Bool -> String
forall a. Show a => a -> String
show Bool
bigY String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
deriving instance Eq (BaseField curve) => Eq (PointCompressed curve)
instance (Arbitrary (Point curve), AdditiveGroup (BaseField curve), Ord (BaseField curve)
) => Arbitrary (PointCompressed curve) where
arbitrary :: Gen (PointCompressed curve)
arbitrary = Point curve -> PointCompressed curve
forall {k} (curve :: k).
(AdditiveGroup (BaseField curve), Ord (BaseField curve)) =>
Point curve -> PointCompressed curve
compress (Point curve -> PointCompressed curve)
-> Gen (Point curve) -> Gen (PointCompressed curve)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Point curve)
forall a. Arbitrary a => Gen a
arbitrary
compress
:: ( AdditiveGroup (BaseField curve)
, Ord (BaseField curve)
)
=> Point curve -> PointCompressed curve
compress :: forall {k} (curve :: k).
(AdditiveGroup (BaseField curve), Ord (BaseField curve)) =>
Point curve -> PointCompressed curve
compress = \case
Point curve
Inf -> PointCompressed curve
forall {k} (curve :: k). PointCompressed curve
InfCompressed
Point BaseField curve
x BaseField curve
y -> BaseField curve -> Bool -> PointCompressed curve
forall {k} (curve :: k).
BaseField curve -> Bool -> PointCompressed curve
PointCompressed BaseField curve
x (BaseField curve
y BaseField curve -> BaseField curve -> Bool
forall a. Ord a => a -> a -> Bool
> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y)
decompress
:: forall curve .
( StandardEllipticCurve curve
, FiniteField (BaseField curve)
, Ord (BaseField curve)
)
=> PointCompressed curve -> Point curve
decompress :: forall {k} (curve :: k).
(StandardEllipticCurve curve, FiniteField (BaseField curve),
Ord (BaseField curve)) =>
PointCompressed curve -> Point curve
decompress = \case
PointCompressed curve
InfCompressed -> Point curve
forall {k} (curve :: k). Point curve
Inf
PointCompressed BaseField curve
x Bool
bigY ->
let a :: BaseField curve
a = forall (curve :: k). StandardEllipticCurve curve => BaseField curve
forall {k} (curve :: k).
StandardEllipticCurve curve =>
BaseField curve
aParameter @curve
b :: BaseField curve
b = forall (curve :: k). StandardEllipticCurve curve => BaseField curve
forall {k} (curve :: k).
StandardEllipticCurve curve =>
BaseField curve
bParameter @curve
p :: Natural
p = forall a. Finite a => Natural
order @(BaseField curve)
sqrt_ :: BaseField curve -> BaseField curve
sqrt_ BaseField curve
z = BaseField curve
z BaseField curve -> Natural -> BaseField curve
forall a b. Exponent a b => a -> b -> a
^ ((Natural
p Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
1) Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`P.div` Natural
2)
y' :: BaseField curve
y' = BaseField curve -> BaseField curve
sqrt_ (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
a BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
b)
y :: BaseField curve
y = (if Bool
bigY then [BaseField curve] -> BaseField curve
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum else [BaseField curve] -> BaseField curve
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
minimum) [BaseField curve
y', BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y']
in
BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x BaseField curve
y