{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module ZkFold.Symbolic.Data.Ed25519 () where

import           Control.DeepSeq                           (NFData, force)
import           Data.Functor.Rep                          (Representable)
import           Data.Traversable                          (Traversable)
import           Prelude                                   (type (~), ($))
import qualified Prelude                                   as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class
import           ZkFold.Base.Algebra.EllipticCurve.Ed25519
import qualified ZkFold.Base.Data.Vector                   as V
import qualified ZkFold.Symbolic.Class                     as S
import           ZkFold.Symbolic.Class                     (Symbolic)
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.ByteString
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Conditional
import           ZkFold.Symbolic.Data.Eq
import           ZkFold.Symbolic.Data.FFA
import           ZkFold.Symbolic.Data.FieldElement


instance
    ( Symbolic c
    , S.BaseField c ~ a
    ) => SymbolicData (Point (Ed25519 c)) where

    type Context (Point (Ed25519 c)) = c
    type Support (Point (Ed25519 c)) = Support (FFA Ed25519_Base c)
    type Layout (Point (Ed25519 c)) = Layout (FFA Ed25519_Base c, FFA Ed25519_Base c)

    -- (0, 0) is never on a Twisted Edwards curve for any curve parameters.
    -- We can encode the point at infinity as (0, 0), therefore.
    -- Moreover, (0, 0) acts as the point at infinity when used in the addition formula.
    -- We can restore Inf as (0, 0) since we can't arithmetise sum-types.
    -- It will need additional checks in pointDouble because of the denominator becoming zero, though.
    -- TODO: Think of a better solution
    --
    pieces :: Point (Ed25519 c)
-> Support (Point (Ed25519 c))
-> Context (Point (Ed25519 c)) (Layout (Point (Ed25519 c)))
pieces Point (Ed25519 c)
Inf         = (FFA Ed25519_Base c, FFA Ed25519_Base c)
-> Support (FFA Ed25519_Base c, FFA Ed25519_Base c)
-> Context
     (FFA Ed25519_Base c, FFA Ed25519_Base c)
     (Layout (FFA Ed25519_Base c, FFA Ed25519_Base c))
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces (FFA Ed25519_Base c
forall a. AdditiveMonoid a => a
zero :: FFA Ed25519_Base c, FFA Ed25519_Base c
forall a. AdditiveMonoid a => a
zero :: FFA Ed25519_Base c)
    pieces (Point BaseField (Ed25519 c)
x BaseField (Ed25519 c)
y) = (FFA Ed25519_Base c, FFA Ed25519_Base c)
-> Support (FFA Ed25519_Base c, FFA Ed25519_Base c)
-> Context
     (FFA Ed25519_Base c, FFA Ed25519_Base c)
     (Layout (FFA Ed25519_Base c, FFA Ed25519_Base c))
forall x. SymbolicData x => x -> Support x -> Context x (Layout x)
pieces (BaseField (Ed25519 c)
FFA Ed25519_Base c
x, BaseField (Ed25519 c)
FFA Ed25519_Base c
y)

    restore :: (Support (Point (Ed25519 c))
 -> Context (Point (Ed25519 c)) (Layout (Point (Ed25519 c))))
-> Point (Ed25519 c)
restore Support (Point (Ed25519 c))
-> Context (Point (Ed25519 c)) (Layout (Point (Ed25519 c)))
f = BaseField (Ed25519 c) -> BaseField (Ed25519 c) -> Point (Ed25519 c)
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField (Ed25519 c)
FFA Ed25519_Base c
x BaseField (Ed25519 c)
FFA Ed25519_Base c
y
        where
            (FFA Ed25519_Base c
x, FFA Ed25519_Base c
y) = (Support (FFA Ed25519_Base c, FFA Ed25519_Base c)
 -> Context
      (FFA Ed25519_Base c, FFA Ed25519_Base c)
      (Layout (FFA Ed25519_Base c, FFA Ed25519_Base c)))
-> (FFA Ed25519_Base c, FFA Ed25519_Base c)
forall x.
SymbolicData x =>
(Support x -> Context x (Layout x)) -> x
restore Support (FFA Ed25519_Base c, FFA Ed25519_Base c)
-> Context
     (FFA Ed25519_Base c, FFA Ed25519_Base c)
     (Layout (FFA Ed25519_Base c, FFA Ed25519_Base c))
Support (Point (Ed25519 c))
-> Context (Point (Ed25519 c)) (Layout (Point (Ed25519 c)))
f

instance (Symbolic c) => Eq (Bool c) (Point (Ed25519 c)) where
    Point (Ed25519 c)
Inf == :: Point (Ed25519 c) -> Point (Ed25519 c) -> Bool c
== Point (Ed25519 c)
Inf                     = Bool c
forall b. BoolType b => b
true
    Point (Ed25519 c)
Inf == Point (Ed25519 c)
_                       = Bool c
forall b. BoolType b => b
false
    Point (Ed25519 c)
_ == Point (Ed25519 c)
Inf                       = Bool c
forall b. BoolType b => b
false
    (Point BaseField (Ed25519 c)
x1 BaseField (Ed25519 c)
y1) == (Point BaseField (Ed25519 c)
x2 BaseField (Ed25519 c)
y2) = BaseField (Ed25519 c)
FFA Ed25519_Base c
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> Bool c
forall b a. Eq b a => a -> a -> b
== BaseField (Ed25519 c)
FFA Ed25519_Base c
x2 Bool c -> Bool c -> Bool c
forall b. BoolType b => b -> b -> b
&& BaseField (Ed25519 c)
FFA Ed25519_Base c
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> Bool c
forall b a. Eq b a => a -> a -> b
== BaseField (Ed25519 c)
FFA Ed25519_Base c
y2

    Point (Ed25519 c)
Inf /= :: Point (Ed25519 c) -> Point (Ed25519 c) -> Bool c
/= Point (Ed25519 c)
Inf                     = Bool c
forall b. BoolType b => b
false
    Point (Ed25519 c)
Inf /= Point (Ed25519 c)
_                       = Bool c
forall b. BoolType b => b
true
    Point (Ed25519 c)
_ /= Point (Ed25519 c)
Inf                       = Bool c
forall b. BoolType b => b
true
    (Point BaseField (Ed25519 c)
x1 BaseField (Ed25519 c)
y1) /= (Point BaseField (Ed25519 c)
x2 BaseField (Ed25519 c)
y2) = BaseField (Ed25519 c)
FFA Ed25519_Base c
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> Bool c
forall b a. Eq b a => a -> a -> b
/= BaseField (Ed25519 c)
FFA Ed25519_Base c
x2 Bool c -> Bool c -> Bool c
forall b. BoolType b => b -> b -> b
|| BaseField (Ed25519 c)
FFA Ed25519_Base c
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> Bool c
forall b a. Eq b a => a -> a -> b
/= BaseField (Ed25519 c)
FFA Ed25519_Base c
y2

-- | Ed25519 with @UInt 256 ArithmeticCircuit a@ as computational backend
--
instance
    ( Symbolic c
    , NFData (c (V.Vector Size))
    ) => EllipticCurve (Ed25519 c)  where

    type BaseField (Ed25519 c) = FFA Ed25519_Base c
    type ScalarField (Ed25519 c) = FieldElement c

    inf :: Point (Ed25519 c)
inf = Point (Ed25519 c)
forall {k} (curve :: k). Point curve
Inf

    gen :: Point (Ed25519 c)
gen = BaseField (Ed25519 c) -> BaseField (Ed25519 c) -> Point (Ed25519 c)
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point
            (Nat -> FFA Ed25519_Base c
forall a b. FromConstant a b => a -> b
fromConstant (Nat
15112221349535400772501151409588531511454012693041857206046113283949847762202 :: Natural))
            (Nat -> FFA Ed25519_Base c
forall a b. FromConstant a b => a -> b
fromConstant (Nat
46316835694926478169428394003475163141307993866256225615783033603165251855960 :: Natural))

    add :: Point (Ed25519 c) -> Point (Ed25519 c) -> Point (Ed25519 c)
add Point (Ed25519 c)
x Point (Ed25519 c)
y = forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool c) @(Point (Ed25519 c)) (Point (Ed25519 c) -> Point (Ed25519 c) -> Point (Ed25519 c)
forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (Ed25519 c) -> Point (Ed25519 c) -> Point (Ed25519 c)
acAdd25519 Point (Ed25519 c)
x Point (Ed25519 c)
y) (Point (Ed25519 c) -> Point (Ed25519 c)
forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (Ed25519 c) -> Point (Ed25519 c)
acDouble25519 Point (Ed25519 c)
x) (Point (Ed25519 c)
x Point (Ed25519 c) -> Point (Ed25519 c) -> Bool c
forall b a. Eq b a => a -> a -> b
== Point (Ed25519 c)
y)

    -- pointMul uses natScale which converts the scale to Natural.
    -- We can't convert arithmetic circuits to Natural, so we can't use pointMul either.
    --
    mul :: ScalarField (Ed25519 c) -> Point (Ed25519 c) -> Point (Ed25519 c)
mul = ScalarField (Ed25519 c) -> Point (Ed25519 c) -> Point (Ed25519 c)
FieldElement c -> Point (Ed25519 c) -> Point (Ed25519 c)
forall b a. Scale b a => b -> a -> a
scale

instance
    ( EllipticCurve c
    , SymbolicData (Point c)
    , l ~ Layout (Point c)
    , Representable l
    , Traversable l
    , ctx ~ Context (Point c)
    , Symbolic ctx
    , a ~ S.BaseField ctx
    , bits ~ NumberOfBits a
    ) => Scale (FieldElement ctx) (Point c) where

    scale :: FieldElement ctx -> Point c -> Point c
scale FieldElement ctx
sc Point c
x = [Point c] -> Point c
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([Point c] -> Point c) -> [Point c] -> Point c
forall a b. (a -> b) -> a -> b
$ (Nat -> Point c -> Point c) -> [Nat] -> [Point c] -> [Point c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith (\Nat
b Point c
p -> forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool ctx) Point c
forall a. AdditiveMonoid a => a
zero Point c
p (ByteString bits ctx -> Nat -> Bool ctx
forall (c :: (Type -> Type) -> Type) (n :: Nat).
Symbolic c =>
ByteString n c -> Nat -> Bool c
isSet ByteString bits ctx
bits Nat
b)) [Nat
upper, Nat
upper Nat -> Nat -> Nat
-! Nat
1 .. Nat
0] ((Point c -> Point c) -> Point c -> [Point c]
forall a. (a -> a) -> a -> [a]
P.iterate (\Point c
e -> Point c
e Point c -> Point c -> Point c
forall a. AdditiveSemigroup a => a -> a -> a
+ Point c
e) Point c
x)
        where
            bits :: ByteString bits ctx
            bits :: ByteString bits ctx
bits = ctx (Vector bits) -> ByteString bits ctx
forall (n :: Nat) (context :: (Type -> Type) -> Type).
context (Vector n) -> ByteString n context
ByteString (ctx (Vector bits) -> ByteString bits ctx)
-> ctx (Vector bits) -> ByteString bits ctx
forall a b. (a -> b) -> a -> b
$ FieldElement ctx -> Bits (FieldElement ctx)
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion FieldElement ctx
sc

            upper :: Natural
            upper :: Nat
upper = forall (n :: Nat). KnownNat n => Nat
value @bits Nat -> Nat -> Nat
-! Nat
1

a :: Symbolic ctx => FFA Ed25519_Base ctx
a :: forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a = forall a b. FromConstant a b => a -> b
fromConstant @P.Integer (-Integer
1)

d :: Symbolic ctx => FFA Ed25519_Base ctx
d :: forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
d = forall a b. FromConstant a b => a -> b
fromConstant @P.Integer (-Integer
121665) FFA Ed25519_Base ctx
-> FFA Ed25519_Base ctx -> FFA Ed25519_Base ctx
forall a. Field a => a -> a -> a
// forall a b. FromConstant a b => a -> b
fromConstant @Natural Nat
121666

acAdd25519
    :: forall c
    .  Symbolic c
    => NFData (c (V.Vector Size))
    => Point (Ed25519 c)
    -> Point (Ed25519 c)
    -> Point (Ed25519 c)
acAdd25519 :: forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (Ed25519 c) -> Point (Ed25519 c) -> Point (Ed25519 c)
acAdd25519 Point (Ed25519 c)
Inf Point (Ed25519 c)
q = Point (Ed25519 c)
q
acAdd25519 Point (Ed25519 c)
p Point (Ed25519 c)
Inf = Point (Ed25519 c)
p
acAdd25519 (Point BaseField (Ed25519 c)
x1 BaseField (Ed25519 c)
y1) (Point BaseField (Ed25519 c)
x2 BaseField (Ed25519 c)
y2) = BaseField (Ed25519 c) -> BaseField (Ed25519 c) -> Point (Ed25519 c)
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField (Ed25519 c)
FFA Ed25519_Base c
x3 BaseField (Ed25519 c)
FFA Ed25519_Base c
y3
    where
        prodx :: FFA Ed25519_Base c
prodx = BaseField (Ed25519 c)
FFA Ed25519_Base c
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
x2
        prody :: FFA Ed25519_Base c
prody = BaseField (Ed25519 c)
FFA Ed25519_Base c
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
y2
        prod4 :: FFA Ed25519_Base c
prod4 = FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
d FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
prodx FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
prody
        x3 :: FFA Ed25519_Base c
x3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (BaseField (Ed25519 c)
FFA Ed25519_Base c
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
y2 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ BaseField (Ed25519 c)
FFA Ed25519_Base c
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
x2) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
prod4)
        y3 :: FFA Ed25519_Base c
y3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
prody FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
prodx) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
prod4)

acDouble25519
    :: forall c
    .  Symbolic c
    => NFData (c (V.Vector Size))
    => Point (Ed25519 c)
    -> Point (Ed25519 c)
acDouble25519 :: forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (Ed25519 c) -> Point (Ed25519 c)
acDouble25519 Point (Ed25519 c)
Inf = Point (Ed25519 c)
forall {k} (curve :: k). Point curve
Inf
acDouble25519 (Point BaseField (Ed25519 c)
x1 BaseField (Ed25519 c)
y1) = BaseField (Ed25519 c) -> BaseField (Ed25519 c) -> Point (Ed25519 c)
forall {k} (curve :: k).
BaseField curve -> BaseField curve -> Point curve
Point BaseField (Ed25519 c)
FFA Ed25519_Base c
x3 BaseField (Ed25519 c)
FFA Ed25519_Base c
y3
    where
        xsq :: FFA Ed25519_Base c
xsq = BaseField (Ed25519 c)
FFA Ed25519_Base c
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
x1
        ysq :: FFA Ed25519_Base c
ysq = BaseField (Ed25519 c)
FFA Ed25519_Base c
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
y1
        xy :: FFA Ed25519_Base c
xy =  BaseField (Ed25519 c)
FFA Ed25519_Base c
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* BaseField (Ed25519 c)
FFA Ed25519_Base c
y1

        -- Note: due to our laws for finv, division below is going to work exactly as it should
        -- if the point is (0, 0)
        x3 :: FFA Ed25519_Base c
x3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
xy FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
xy) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
xsq FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
ysq)
        y3 :: FFA Ed25519_Base c
y3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
ysq FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
xsq) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
xsq  FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
ysq)