{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

module ZkFold.Base.Algebra.EllipticCurve.Class where

import           Data.Functor                    ((<&>))
import           Data.Kind                       (Type)
import           Numeric.Natural                 (Natural)
import           Prelude                         hiding (Num (..), sum, (/), (^))
import           Test.QuickCheck                 hiding (scale)

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Data.ByteString


data Point curve = Point { forall {k} (curve :: k). Point curve -> BaseField curve
_x :: (BaseField curve), forall {k} (curve :: k). Point curve -> BaseField curve
_y :: (BaseField curve) } | Inf

class EllipticCurve curve where

    type BaseField curve :: Type
    type ScalarField curve :: Type

    inf :: Point curve

    gen :: Point curve

    add :: Point curve -> Point curve -> Point curve

    mul :: ScalarField curve -> Point curve -> Point curve

instance (EllipticCurve curve, Show (BaseField curve)) => Show (Point curve) where
    show :: Point curve -> String
show Point curve
Inf         = String
"Inf"
    show (Point BaseField curve
x BaseField curve
y) = String
"(" String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
x String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
", " String -> ShowS
forall a. [a] -> [a] -> [a]
++ BaseField curve -> String
forall a. Show a => a -> String
show BaseField curve
y String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

instance (EllipticCurve curve, Eq (BaseField curve)) => Eq (Point curve) where
    Point curve
Inf         == :: Point curve -> Point curve -> Bool
== Point curve
Inf         = Bool
True
    Point curve
Inf         == Point curve
_           = Bool
False
    Point curve
_           == Point curve
Inf         = Bool
False
    Point BaseField curve
x1 BaseField curve
y1 == Point BaseField curve
x2 BaseField curve
y2 = BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2 Bool -> Bool -> Bool
&& BaseField curve
y1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
y2

instance EllipticCurve curve => AdditiveSemigroup (Point curve) where
    + :: Point curve -> Point curve -> Point curve
(+) = Point curve -> Point curve -> Point curve
forall k (curve :: k).
EllipticCurve curve =>
Point curve -> Point curve -> Point curve
add

instance EllipticCurve curve => Scale Natural (Point curve) where
    scale :: Natural -> Point curve -> Point curve
scale = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale

instance EllipticCurve curve => AdditiveMonoid (Point curve) where
    zero :: Point curve
zero = Point curve
forall {k} (curve :: k). Point curve
Inf

instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => Scale Integer (Point curve) where
    scale :: Integer -> Point curve -> Point curve
scale = Integer -> Point curve -> Point curve
forall a. AdditiveGroup a => Integer -> a -> a
intScale

instance (EllipticCurve curve, AdditiveGroup (BaseField curve)) => AdditiveGroup (Point curve) where
    negate :: Point curve -> Point curve
negate = Point curve -> Point curve
forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate

instance (EllipticCurve curve, Binary (BaseField curve)) => Binary (Point curve) where
    -- TODO: Point Compression
    -- When we know the equation of an elliptic curve, y^2 = x^3 + a * x + b
    -- then we only need to retain a flag sign byte,
    -- and the x-value to reconstruct the y-value of a point.
    put :: Point curve -> Put
put Point curve
Inf         = Word8 -> Put
putWord8 Word8
0
    put (Point BaseField curve
x BaseField curve
y) = Word8 -> Put
putWord8 Word8
1 Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> BaseField curve -> Put
forall t. Binary t => t -> Put
put BaseField curve
x Put -> Put -> Put
forall a. Semigroup a => a -> a -> a
<> BaseField curve -> Put
forall t. Binary t => t -> Put
put BaseField curve
y
    get :: Get (Point curve)
get = do
        Word8
flag <- Get Word8
getWord8
        if Word8
flag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0 then Point curve -> Get (Point curve)
forall a. a -> Get a
forall (m :: Type -> Type) a. Monad m => a -> m a
return Point curve
forall {k} (curve :: k). Point curve
Inf
        else if Word8
flag Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
1 then BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point (BaseField curve -> BaseField curve -> Point curve)
-> Get (BaseField curve) -> Get (BaseField curve -> Point curve)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Get (BaseField curve)
forall t. Binary t => Get t
get Get (BaseField curve -> Point curve)
-> Get (BaseField curve) -> Get (Point curve)
forall a b. Get (a -> b) -> Get a -> Get b
forall (f :: Type -> Type) a b.
Applicative f =>
f (a -> b) -> f a -> f b
<*> Get (BaseField curve)
forall t. Binary t => Get t
get
        else String -> Get (Point curve)
forall a. String -> Get a
forall (m :: Type -> Type) a. MonadFail m => String -> m a
fail (String
"Binary (Point curve): unexpected flag " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Word8 -> String
forall a. Show a => a -> String
show Word8
flag)

instance (EllipticCurve curve, Arbitrary (ScalarField curve)) => Arbitrary (Point curve) where
    arbitrary :: Gen (Point curve)
arbitrary = Gen (ScalarField curve)
forall a. Arbitrary a => Gen a
arbitrary Gen (ScalarField curve)
-> (ScalarField curve -> Point curve) -> Gen (Point curve)
forall (f :: Type -> Type) a b. Functor f => f a -> (a -> b) -> f b
<&> (ScalarField curve -> Point curve -> Point curve
forall {k} (curve :: k).
EllipticCurve curve =>
ScalarField curve -> Point curve -> Point curve
`mul` Point curve
forall k (curve :: k). EllipticCurve curve => Point curve
gen)

class (EllipticCurve curve1, EllipticCurve curve2, ScalarField curve1 ~ ScalarField curve2,
        Eq t, MultiplicativeGroup t, Exponent t (ScalarField curve1)) => Pairing curve1 curve2 t | curve1 curve2 -> t where
    pairing :: Point curve1 -> Point curve2 -> t

pointAdd
    :: Field (BaseField curve)
    => Eq (BaseField curve)
    => Point curve
    -> Point curve
    -> Point curve
pointAdd :: forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p   Point curve
Inf     = Point curve
p
pointAdd Point curve
Inf Point curve
q       = Point curve
q
pointAdd (Point BaseField curve
x1 BaseField curve
y1) (Point BaseField curve
x2 BaseField curve
y2)
  | BaseField curve
x1 BaseField curve -> BaseField curve -> Bool
forall a. Eq a => a -> a -> Bool
== BaseField curve
x2  = Point curve
forall {k} (curve :: k). Point curve
Inf
  | Bool
otherwise = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x3 BaseField curve
y3
  where
    slope :: BaseField curve
slope  = (BaseField curve
y1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y2) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2)
    x3 :: BaseField curve
x3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x2
    y3 :: BaseField curve
y3 = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x1 BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x3) BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y1

pointDouble
    :: Field (BaseField curve)
    => Point curve -> Point curve
pointDouble :: forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
Inf = Point curve
forall {k} (curve :: k). Point curve
Inf
pointDouble (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x' BaseField curve
y'
  where
    slope :: BaseField curve
slope = (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
x) BaseField curve -> BaseField curve -> BaseField curve
forall a. Field a => a -> a -> a
// (BaseField curve
y BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField curve
y)
    x' :: BaseField curve
x' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x
    y' :: BaseField curve
y' = BaseField curve
slope BaseField curve -> BaseField curve -> BaseField curve
forall a. MultiplicativeSemigroup a => a -> a -> a
* (BaseField curve
x BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
x') BaseField curve -> BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a -> a
- BaseField curve
y

addPoints
    :: EllipticCurve curve
    => Field (BaseField curve)
    => Eq (BaseField curve)
    => Point curve
    -> Point curve
    -> Point curve
addPoints :: forall {k} (curve :: k).
(EllipticCurve curve, Field (BaseField curve),
 Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
addPoints Point curve
p1 Point curve
p2
    | Point curve
p1 Point curve -> Point curve -> Bool
forall a. Eq a => a -> a -> Bool
== Point curve
p2  = Point curve -> Point curve
forall {k} (curve :: k).
Field (BaseField curve) =>
Point curve -> Point curve
pointDouble Point curve
p1
    | Bool
otherwise = Point curve -> Point curve -> Point curve
forall {k} (curve :: k).
(Field (BaseField curve), Eq (BaseField curve)) =>
Point curve -> Point curve -> Point curve
pointAdd Point curve
p1 Point curve
p2

pointNegate
    :: AdditiveGroup (BaseField curve)
    => Point curve -> Point curve
pointNegate :: forall {k} (curve :: k).
AdditiveGroup (BaseField curve) =>
Point curve -> Point curve
pointNegate Point curve
Inf         = Point curve
forall {k} (curve :: k). Point curve
Inf
pointNegate (Point BaseField curve
x BaseField curve
y) = BaseField curve -> BaseField curve -> Point curve
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField curve
x (BaseField curve -> BaseField curve
forall a. AdditiveGroup a => a -> a
negate BaseField curve
y)

pointMul
    :: forall curve
    .  EllipticCurve curve
    => BinaryExpansion (ScalarField curve)
    => Eq (ScalarField curve)
    => ScalarField curve
    -> Point curve
    -> Point curve
pointMul :: forall {k} (curve :: k).
(EllipticCurve curve, BinaryExpansion (ScalarField curve),
 Eq (ScalarField curve)) =>
ScalarField curve -> Point curve -> Point curve
pointMul = Natural -> Point curve -> Point curve
forall a. AdditiveMonoid a => Natural -> a -> a
natScale (Natural -> Point curve -> Point curve)
-> (ScalarField curve -> Natural)
-> ScalarField curve
-> Point curve
-> Point curve
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Natural] -> Natural
forall a. BinaryExpansion a => [a] -> a
fromBinary ([Natural] -> Natural)
-> (ScalarField curve -> [Natural]) -> ScalarField curve -> Natural
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [ScalarField curve] -> [Natural]
forall a b. (Semiring a, Eq a, Semiring b) => [a] -> [b]
castBits ([ScalarField curve] -> [Natural])
-> (ScalarField curve -> [ScalarField curve])
-> ScalarField curve
-> [Natural]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarField curve -> [ScalarField curve]
forall a. BinaryExpansion a => a -> [a]
binaryExpansion