{-# LANGUAGE DerivingVia #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module ZkFold.Base.Algebra.EllipticCurve.Pasta where

import           Prelude

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Field
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class
import           ZkFold.Base.Data.ByteString

-------------------------------- Introducing Fields ----------------------------------

type FpModulus = 0x40000000000000000000000000000000224698fc094cf91b992d30ed00000001
instance Prime FpModulus

type FqModulus = 0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000001
instance Prime FqModulus

type Fp = Zp FpModulus
type Fq = Zp FqModulus

------------------------------------ Pallas ------------------------------------

data Pallas

instance EllipticCurve Pallas where
    type ScalarField Pallas = Fq

    type BaseField Pallas = Fp

    pointGen :: Point Pallas
pointGen = Fp -> Fp -> Point Pallas
forall field plane. Planar field plane => field -> field -> plane
pointXY
        Fp
0x40000000000000000000000000000000224698fc094cf91b992d30ed00000000
        Fp
0x02

    add :: Point Pallas -> Point Pallas -> Point Pallas
add = Point Pallas -> Point Pallas -> Point Pallas
forall curve.
(EllipticCurve curve, Field (BaseField curve)) =>
Point curve -> Point curve -> Point curve
addPoints

    mul :: ScalarField Pallas -> Point Pallas -> Point Pallas
mul = Fq -> Point Pallas -> Point Pallas
ScalarField Pallas -> Point Pallas -> Point Pallas
forall curve s.
(EllipticCurve curve, BinaryExpansion s, Bits s ~ [s], Eq s) =>
s -> Point curve -> Point curve
pointMul

instance WeierstrassCurve Pallas where
    weierstrassA :: BaseField Pallas
weierstrassA = Fp
BaseField Pallas
forall a. AdditiveMonoid a => a
zero

    weierstrassB :: BaseField Pallas
weierstrassB = Nat -> Fp
forall a b. FromConstant a b => a -> b
fromConstant (Nat
5 :: Natural)

------------------------------------ Vesta ------------------------------------

data Vesta

instance EllipticCurve Vesta where

    type ScalarField Vesta = Fp

    type BaseField Vesta = Fq

    pointGen :: Point Vesta
pointGen = Fq -> Fq -> Point Vesta
forall field plane. Planar field plane => field -> field -> plane
pointXY
        Fq
0x40000000000000000000000000000000224698fc0994a8dd8c46eb2100000000
        Fq
0x02

    add :: Point Vesta -> Point Vesta -> Point Vesta
add = Point Vesta -> Point Vesta -> Point Vesta
forall curve.
(EllipticCurve curve, Field (BaseField curve)) =>
Point curve -> Point curve -> Point curve
addPoints

    mul :: ScalarField Vesta -> Point Vesta -> Point Vesta
mul = Fp -> Point Vesta -> Point Vesta
ScalarField Vesta -> Point Vesta -> Point Vesta
forall curve s.
(EllipticCurve curve, BinaryExpansion s, Bits s ~ [s], Eq s) =>
s -> Point curve -> Point curve
pointMul

instance WeierstrassCurve Vesta where
    weierstrassA :: BaseField Vesta
weierstrassA = Fq
BaseField Vesta
forall a. AdditiveMonoid a => a
zero

    weierstrassB :: BaseField Vesta
weierstrassB = Nat -> Fq
forall a b. FromConstant a b => a -> b
fromConstant (Nat
5 :: Natural)

------------------------------------ Encoding ------------------------------------

instance Binary (Point Pallas) where
  put :: Point Pallas -> Put
put (Point BaseField Pallas
xp BaseField Pallas
yp BooleanOf Pallas
isInf) =
    if Bool
BooleanOf Pallas
isInf then forall t. Binary t => t -> Put
put @(Point Pallas) (Fp -> Fp -> Point Pallas
forall field plane. Planar field plane => field -> field -> plane
pointXY Fp
forall a. AdditiveMonoid a => a
zero Fp
forall a. AdditiveMonoid a => a
zero) else Fp -> Put
forall t. Binary t => t -> Put
put Fp
BaseField Pallas
xp Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: Type -> Type) a b. Monad m => m a -> m b -> m b
>> Fp -> Put
forall t. Binary t => t -> Put
put Fp
BaseField Pallas
yp
  get :: Get (Point Pallas)
get = do
    Fp
xp <- Get Fp
forall t. Binary t => Get t
get
    Fp
yp <- Get Fp
forall t. Binary t => Get t
get
    Point Pallas -> Get (Point Pallas)
forall a. a -> Get a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Point Pallas -> Get (Point Pallas))
-> Point Pallas -> Get (Point Pallas)
forall a b. (a -> b) -> a -> b
$
      if Fp
xp Fp -> Fp -> Bool
forall a. Eq a => a -> a -> Bool
== Fp
forall a. AdditiveMonoid a => a
zero Bool -> Bool -> Bool
&& Fp
yp Fp -> Fp -> Bool
forall a. Eq a => a -> a -> Bool
== Fp
forall a. AdditiveMonoid a => a
zero
      then Point Pallas
forall plane. ProjectivePlanar plane => plane
pointInf
      else Fp -> Fp -> Point Pallas
forall field plane. Planar field plane => field -> field -> plane
pointXY Fp
xp Fp
yp

instance Binary (Point Vesta) where
  put :: Point Vesta -> Put
put (Point BaseField Vesta
xp BaseField Vesta
yp BooleanOf Vesta
isInf) =
    if Bool
BooleanOf Vesta
isInf then forall t. Binary t => t -> Put
put @(Point Vesta) (Fq -> Fq -> Point Vesta
forall field plane. Planar field plane => field -> field -> plane
pointXY Fq
forall a. AdditiveMonoid a => a
zero Fq
forall a. AdditiveMonoid a => a
zero) else Fq -> Put
forall t. Binary t => t -> Put
put Fq
BaseField Vesta
xp Put -> Put -> Put
forall a b. PutM a -> PutM b -> PutM b
forall (m :: Type -> Type) a b. Monad m => m a -> m b -> m b
>> Fq -> Put
forall t. Binary t => t -> Put
put Fq
BaseField Vesta
yp
  get :: Get (Point Vesta)
get = do
    Fq
xp <- Get Fq
forall t. Binary t => Get t
get
    Fq
yp <- Get Fq
forall t. Binary t => Get t
get
    Point Vesta -> Get (Point Vesta)
forall a. a -> Get a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Point Vesta -> Get (Point Vesta))
-> Point Vesta -> Get (Point Vesta)
forall a b. (a -> b) -> a -> b
$
      if Fq
xp Fq -> Fq -> Bool
forall a. Eq a => a -> a -> Bool
== Fq
forall a. AdditiveMonoid a => a
zero Bool -> Bool -> Bool
&& Fq
yp Fq -> Fq -> Bool
forall a. Eq a => a -> a -> Bool
== Fq
forall a. AdditiveMonoid a => a
zero
      then Point Vesta
forall plane. ProjectivePlanar plane => plane
pointInf
      else Fq -> Fq -> Point Vesta
forall field plane. Planar field plane => field -> field -> plane
pointXY Fq
xp Fq
yp