{-# LANGUAGE AllowAmbiguousTypes  #-}
{-# LANGUAGE DeriveAnyClass       #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Algebra.EllipticCurve.Class where

import           Control.DeepSeq                  (NFData)
import           Data.Functor                     ((<&>))
import           Data.Kind                        (Type)
import           GHC.Generics                     (Generic)
import           Prelude                          hiding (Num (..), sum, (/), (^))
import qualified Prelude                          as P
import           Test.QuickCheck                  hiding (scale)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number

data Point curve = Point { forall {k} (curve :: k). Point curve -> BaseField curve
_x :: BaseField curve, forall {k} (curve :: k). Point curve -> BaseField curve
_y :: BaseField curve } | Inf
    deriving ((forall x. Point curve -> Rep (Point curve) x)
-> (forall x. Rep (Point curve) x -> Point curve)
-> Generic (Point curve)
forall x. Rep (Point curve) x -> Point curve
forall x. Point curve -> Rep (Point curve) x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
forall k (curve :: k) x. Rep (Point curve) x -> Point curve
forall k (curve :: k) x. Point curve -> Rep (Point curve) x
$cfrom :: forall k (curve :: k) x. Point curve -> Rep (Point curve) x
from :: forall x. Point curve -> Rep (Point curve) x
$cto :: forall k (curve :: k) x. Rep (Point curve) x -> Point curve
to :: forall x. Rep (Point curve) x -> Point curve
Generic)

deriving instance NFData (BaseField curve) => NFData (Point curve)

class EllipticCurve curve where

    type BaseField curve :: Type
    type ScalarField curve :: Type

    inf :: Point curve

    gen :: Point curve

    add :: Point curve -> Point curve -> Point curve

    mul :: ScalarField curve -> Point curve -> Point curve

instance (EllipticCurve curve, Show (BaseField curve)) => Show (Point curve) where
    show :: Point curve -> String
show Point curve
Inf         = String
"Inf"
    show (Point BaseField curve
x BaseField curve
y) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
y String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

instance (EllipticCurve curve, Eq (BaseField curve)) => Eq (Point curve) where
    Point curve
Inf         == :: Point curve -> Point curve -> Bool
== Point curve
Inf         = Bool
True
    Point curve
Inf         == Point curve
_           = Bool
False
    Point curve
_           == Point curve
Inf         = Bool
False
    Point BaseField curve
x1 BaseField curve
y1 == Point BaseField curve
x2 BaseField curve
y2 = BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2 Bool -> Bool -> Bool
&& BaseField curve
y1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
y2

instance EllipticCurve curve => AdditiveSemigroup (Point curve) where
    + :: Point curve -> Point curve -> Point curve
(+) = Point curve -> Point curve -> Point curve
forall k (curve :: k).
EllipticCurve curve =>
Point curve -> Point curve -> Point curve
add

instance {-# OVERLAPPABLE #-}
    ( EllipticCurve curve
    , Eq s
    , BinaryExpansion s
    , Bits s ~ [s]
    ) => Scale s (Point curve) where
    scale :: s -> Point curve -> Point curve
scale = s -> Point curve -> Point curve
forall {k} (curve :: k) s.
(EllipticCurve curve, BinaryExpansion s, Bits s ~ [s], Eq s) =>
s -> Point curve -> Point curve
pointMul

instance EllipticCurve curve => Scale Natural (Point curve) where
    scale :: Natural -> Point curve -> Point curve
scale = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale

instance EllipticCurve curve => AdditiveMonoid (Point curve) where
    zero :: Point curve
zero = Point curve
forall {k} (curve :: k). Point curve
Inf

instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => Scale Integer (Point curve) where
    scale :: Integer -> Point curve -> Point curve
scale = Integer -> Point curve -> Point curve
forall a. AdditiveGroup a => Integer -> a -> a
intScale

instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => AdditiveGroup (Point curve) where
    negate :: Point curve -> Point curve
negate = Point curve -> Point curve
forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate

instance (EllipticCurve curve, Arbitrary (ScalarField curve)) => Arbitrary (Point curve) where
    arbitrary :: Gen (Point curve)
arbitrary = Gen (ScalarField curve)
forall a. Arbitrary a => Gen a
arbitrary Gen (ScalarField curve)
-> (ScalarField curve -> Point curve) -> Gen (Point curve)
forall (f :: Type -> Type) a b. Functor f => f a -> (a -> b) -> f b
<&> (ScalarField curve -> Point curve -> Point curve
forall {k} (curve :: k).
EllipticCurve curve =>
ScalarField curve -> Point curve -> Point curve
`mul` Point curve
forall k (curve :: k). EllipticCurve curve => Point curve
gen)

class (EllipticCurve curve1, EllipticCurve curve2, ScalarField curve1 ~ ScalarField curve2,
        Eq (TargetGroup curve1 curve2), MultiplicativeGroup (TargetGroup curve1 curve2),
        Exponent (TargetGroup curve1 curve2) (ScalarField curve1)) => Pairing curve1 curve2 where
    type TargetGroup curve1 curve2 :: Type
    pairing :: Point curve1 -> Point curve2 -> TargetGroup curve1 curve2

pointAdd
    :: Field (BaseField curve)
    => Eq (BaseField curve)
    => Point curve
    -> Point curve
    -> Point curve
pointAdd :: forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p   Point curve
Inf     = Point curve
p
pointAdd Point curve
Inf Point curve
q       = Point curve
q
pointAdd (Point BaseField curve
x1 BaseField curve
y1) (Point BaseField curve
x2 BaseField curve
y2)
  | BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2  = Point curve
forall {k} (curve :: k). Point curve
Inf
  | Bool
otherwise = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x3 BaseField curve
y3
  where
    slope :: BaseField curve
slope  = (BaseField curve
y1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y2) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2)
    x3 :: BaseField curve
x3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2
    y3 :: BaseField curve
y3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x3) BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y1

pointDouble
    :: Field (BaseField curve)
    => Point curve -> Point curve
pointDouble :: forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
Inf = Point curve
forall {k} (curve :: k). Point curve
Inf
pointDouble (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x' BaseField curve
y'
  where
    slope :: BaseField curve
slope = (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
y BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
y)
    x' :: BaseField curve
x' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x
    y' :: BaseField curve
y' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x') BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y

addPoints
    :: EllipticCurve curve
    => Field (BaseField curve)
    => Eq (BaseField curve)
    => Point curve
    -> Point curve
    -> Point curve
addPoints :: forall {k} (curve :: k).
(EllipticCurve curve, Field (BaseField curve),
 Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
addPoints Point curve
p1 Point curve
p2
    | Point curve
p1 Point curve -> Point curve -> Bool
forall a. Eq a => a -> a -> Bool
== Point curve
p2  = Point curve -> Point curve
forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
p1
    | Bool
otherwise = Point curve -> Point curve -> Point curve
forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p1 Point curve
p2

pointNegate
    :: AdditiveGroup (BaseField curve)
    => Point curve -> Point curve
pointNegate :: forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate Point curve
Inf         = Point curve
forall {k} (curve :: k). Point curve
Inf
pointNegate (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x (BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y)

pointMul
    :: forall curve s
    .  EllipticCurve curve
    => BinaryExpansion (s)
    => Bits s ~ [s]
    => Eq s
    => s
    -> Point curve
    -> Point curve
pointMul :: forall {k} (curve :: k) s.
(EllipticCurve curve, BinaryExpansion s, Bits s ~ [s], Eq s) =>
s -> Point curve -> Point curve
pointMul = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale (Natural -> Point curve -> Point curve)
-> (s -> Natural) -> s -> Point curve -> Point curve
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Natural] -> Natural
Bits Natural -> Natural
forall a. BinaryExpansion a => Bits a -> a
fromBinary ([Natural] -> Natural) -> (s -> [Natural]) -> s -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [s] -> [Natural]
forall a b. (Semiring a, Eq a, Semiring b) => [a] -> [b]
castBits ([s] -> [Natural]) -> (s -> [s]) -> s -> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s -> [s]
s -> Bits s
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion

-- An elliptic curve in standard form, y^2 = x^3 + a * x + b
class EllipticCurve curve => StandardEllipticCurve curve where
    aParameter :: BaseField curve
    bParameter :: BaseField curve

data PointCompressed curve = PointCompressed (BaseField curve) Bool | InfCompressed

instance Show (BaseField curve) => Show (PointCompressed curve) where
    show :: PointCompressed curve -> String
show PointCompressed curve
InfCompressed            = String
"InfCompressed"
    show (PointCompressed BaseField curve
x Bool
bigY) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Bool -> String
forall a. Show a => a -> String
show Bool
bigY String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

deriving instance Eq (BaseField curve) => Eq (PointCompressed curve)

instance (Arbitrary (Point curve), AdditiveGroup (BaseField curve), Ord (BaseField curve)
        ) => Arbitrary (PointCompressed curve) where
    arbitrary :: Gen (PointCompressed curve)
arbitrary = Point curve -> PointCompressed curve
forall {k} (curve :: k).
(AdditiveGroup (BaseField curve), Ord (BaseField curve)) =>
Point curve -> PointCompressed curve
compress (Point curve -> PointCompressed curve)
-> Gen (Point curve) -> Gen (PointCompressed curve)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen (Point curve)
forall a. Arbitrary a => Gen a
arbitrary

compress
  :: ( AdditiveGroup (BaseField curve)
     , Ord (BaseField curve)
     )
  => Point curve -> PointCompressed curve
compress :: forall {k} (curve :: k).
(AdditiveGroup (BaseField curve), Ord (BaseField curve)) =>
Point curve -> PointCompressed curve
compress = \case
  Point curve
Inf -> PointCompressed curve
forall {k} (curve :: k). PointCompressed curve
InfCompressed
  Point BaseField curve
x BaseField curve
y -> BaseField curve -> Bool -> PointCompressed curve
forall {k} (curve :: k).
BaseField curve -> Bool -> PointCompressed curve
PointCompressed BaseField curve
x (BaseField curve
y BaseField curve -> BaseField curve -> Bool
forall a. Ord a => a -> a -> Bool
> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y)

decompress
  :: forall curve .
     ( StandardEllipticCurve curve
     , FiniteField (BaseField curve)
     , Ord (BaseField curve)
     )
  => PointCompressed curve -> Point curve
decompress :: forall {k} (curve :: k).
(StandardEllipticCurve curve, FiniteField (BaseField curve),
 Ord (BaseField curve)) =>
PointCompressed curve -> Point curve
decompress = \case
  PointCompressed curve
InfCompressed -> Point curve
forall {k} (curve :: k). Point curve
Inf
  PointCompressed BaseField curve
x Bool
bigY ->
    let a :: BaseField curve
a = forall (curve :: k). StandardEllipticCurve curve => BaseField curve
forall {k} (curve :: k).
StandardEllipticCurve curve =>
BaseField curve
aParameter @curve
        b :: BaseField curve
b = forall (curve :: k). StandardEllipticCurve curve => BaseField curve
forall {k} (curve :: k).
StandardEllipticCurve curve =>
BaseField curve
bParameter @curve
        p :: Natural
p = forall a. Finite a => Natural
order @(BaseField curve)
        sqrt_ :: BaseField curve -> BaseField curve
sqrt_ BaseField curve
z = BaseField curve
z BaseField curve -> Natural -> BaseField curve
forall a b. Exponent a b => a -> b -> a
^ ((Natural
p Natural -> Natural -> Natural
forall a. AdditiveSemigroup a => a -> a -> a
+ Natural
1) Natural -> Natural -> Natural
forall a. Integral a => a -> a -> a
`P.div` Natural
2)
        y' :: BaseField curve
y' = BaseField curve -> BaseField curve
sqrt_ (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
a BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
b)
        y :: BaseField curve
y = (if Bool
bigY then [BaseField curve] -> BaseField curve
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
maximum else [BaseField curve] -> BaseField curve
forall a. Ord a => [a] -> a
forall (t :: Type -> Type) a. (Foldable t, Ord a) => t a -> a
minimum) [BaseField curve
y', BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y']
    in
        BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x BaseField curve
y