{-# LANGUAGE RebindableSyntax     #-}
{-# LANGUAGE TypeOperators        #-}
{-# LANGUAGE UndecidableInstances #-}

{-# OPTIONS_GHC -Wno-orphans #-}

module ZkFold.Symbolic.Data.Ed25519 (AcEd25519) where

import           Control.DeepSeq                           (NFData, force)
import           Prelude                                   (fromInteger, type (~), ($))
import qualified Prelude                                   as P

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number
import           ZkFold.Base.Algebra.EllipticCurve.Class
import           ZkFold.Base.Algebra.EllipticCurve.Ed25519
import qualified ZkFold.Base.Data.Vector                   as V
import qualified ZkFold.Symbolic.Class                     as S
import           ZkFold.Symbolic.Class                     (Symbolic)
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.ByteString
import           ZkFold.Symbolic.Data.Class
import           ZkFold.Symbolic.Data.Conditional
import           ZkFold.Symbolic.Data.Eq
import           ZkFold.Symbolic.Data.FFA
import           ZkFold.Symbolic.Data.FieldElement

data AcEd25519 c

-- | Ed25519 with @UInt 256 ArithmeticCircuit a@ as computational backend
--
instance
    ( Symbolic c
    , NFData (c (V.Vector Size))
    ) => EllipticCurve (AcEd25519 c)  where

    type BaseField (AcEd25519 c) = FFA Ed25519_Base c
    type ScalarField (AcEd25519 c) = FieldElement c
    type BooleanOf (AcEd25519 c) = Bool c

    gen :: Point (AcEd25519 c)
gen = FFA Ed25519_Base c -> FFA Ed25519_Base c -> Point (AcEd25519 c)
forall field plane. Planar field plane => field -> field -> plane
point
            (Nat -> FFA Ed25519_Base c
forall a b. FromConstant a b => a -> b
fromConstant (Nat
15112221349535400772501151409588531511454012693041857206046113283949847762202 :: Natural))
            (Nat -> FFA Ed25519_Base c
forall a b. FromConstant a b => a -> b
fromConstant (Nat
46316835694926478169428394003475163141307993866256225615783033603165251855960 :: Natural))

    add :: Point (AcEd25519 c) -> Point (AcEd25519 c) -> Point (AcEd25519 c)
add Point (AcEd25519 c)
x Point (AcEd25519 c)
y = if Point (AcEd25519 c)
x Point (AcEd25519 c) -> Point (AcEd25519 c) -> Bool c
forall b a. Eq b a => a -> a -> b
== Point (AcEd25519 c)
y then Point (AcEd25519 c) -> Point (AcEd25519 c)
forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (AcEd25519 c) -> Point (AcEd25519 c)
acDouble25519 Point (AcEd25519 c)
x else Point (AcEd25519 c) -> Point (AcEd25519 c) -> Point (AcEd25519 c)
forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (AcEd25519 c) -> Point (AcEd25519 c) -> Point (AcEd25519 c)
acAdd25519 Point (AcEd25519 c)
x Point (AcEd25519 c)
y

    -- pointMul uses natScale which converts the scale to Natural.
    -- We can't convert arithmetic circuits to Natural, so we can't use pointMul either.
    --
    mul :: ScalarField (AcEd25519 c)
-> Point (AcEd25519 c) -> Point (AcEd25519 c)
mul = FieldElement c -> Point (AcEd25519 c) -> Point (AcEd25519 c)
ScalarField (AcEd25519 c)
-> Point (AcEd25519 c) -> Point (AcEd25519 c)
forall b a. Scale b a => b -> a -> a
scale

instance
    ( EllipticCurve c
    , SymbolicData (Point c)
    , l ~ Layout (Point c)
    , ctx ~ Context (Point c)
    , Symbolic ctx
    , a ~ S.BaseField ctx
    , bits ~ NumberOfBits a
    , BooleanOf c ~ Bool ctx
    ) => Scale (FieldElement ctx) (Point c) where

    scale :: FieldElement ctx -> Point c -> Point c
scale FieldElement ctx
sc Point c
x = [Point c] -> Point c
forall (t :: Type -> Type) a.
(Foldable t, AdditiveMonoid a) =>
t a -> a
sum ([Point c] -> Point c) -> [Point c] -> Point c
forall a b. (a -> b) -> a -> b
$ (Nat -> Point c -> Point c) -> [Nat] -> [Point c] -> [Point c]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
P.zipWith (\Nat
b Point c
p -> forall b a. Conditional b a => a -> a -> b -> a
bool @(Bool ctx) Point c
forall a. AdditiveMonoid a => a
zero Point c
p (ByteString bits ctx -> Nat -> Bool ctx
forall (c :: (Type -> Type) -> Type) (n :: Nat).
Symbolic c =>
ByteString n c -> Nat -> Bool c
isSet ByteString bits ctx
bits Nat
b)) [Nat
upper, Nat
upper Nat -> Nat -> Nat
-! Nat
1 .. Nat
0] ((Point c -> Point c) -> Point c -> [Point c]
forall a. (a -> a) -> a -> [a]
P.iterate (\Point c
e -> Point c
e Point c -> Point c -> Point c
forall a. AdditiveSemigroup a => a -> a -> a
+ Point c
e) Point c
x)
        where
            bits :: ByteString bits ctx
            bits :: ByteString bits ctx
bits = ctx (Vector bits) -> ByteString bits ctx
forall (n :: Nat) (context :: (Type -> Type) -> Type).
context (Vector n) -> ByteString n context
ByteString (ctx (Vector bits) -> ByteString bits ctx)
-> ctx (Vector bits) -> ByteString bits ctx
forall a b. (a -> b) -> a -> b
$ FieldElement ctx -> Bits (FieldElement ctx)
forall a. BinaryExpansion a => a -> Bits a
binaryExpansion FieldElement ctx
sc

            upper :: Natural
            upper :: Nat
upper = forall (n :: Nat). KnownNat n => Nat
value @bits Nat -> Nat -> Nat
-! Nat
1

a :: Symbolic ctx => FFA Ed25519_Base ctx
a :: forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a = forall a b. FromConstant a b => a -> b
fromConstant @P.Integer (-Integer
1)

d :: Symbolic ctx => FFA Ed25519_Base ctx
d :: forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
d = forall a b. FromConstant a b => a -> b
fromConstant @P.Integer (-Integer
121665) FFA Ed25519_Base ctx
-> FFA Ed25519_Base ctx -> FFA Ed25519_Base ctx
forall a. Field a => a -> a -> a
// forall a b. FromConstant a b => a -> b
fromConstant @Natural Nat
121666

acAdd25519
    :: forall c
    .  Symbolic c
    => NFData (c (V.Vector Size))
    => Point (AcEd25519 c)
    -> Point (AcEd25519 c)
    -> Point (AcEd25519 c)
acAdd25519 :: forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (AcEd25519 c) -> Point (AcEd25519 c) -> Point (AcEd25519 c)
acAdd25519 p :: Point (AcEd25519 c)
p@(Point BaseField (AcEd25519 c)
x1 BaseField (AcEd25519 c)
y1 BooleanOf (AcEd25519 c)
isInf1) q :: Point (AcEd25519 c)
q@(Point BaseField (AcEd25519 c)
x2 BaseField (AcEd25519 c)
y2 BooleanOf (AcEd25519 c)
isInf2) =
    if Bool c
BooleanOf (AcEd25519 c)
isInf1 then Point (AcEd25519 c)
q
    else if Bool c
BooleanOf (AcEd25519 c)
isInf2 then Point (AcEd25519 c)
p
    else FFA Ed25519_Base c -> FFA Ed25519_Base c -> Point (AcEd25519 c)
forall field plane. Planar field plane => field -> field -> plane
point FFA Ed25519_Base c
x3 FFA Ed25519_Base c
y3
    where
        prodx :: FFA Ed25519_Base c
prodx = FFA Ed25519_Base c
BaseField (AcEd25519 c)
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
x2
        prody :: FFA Ed25519_Base c
prody = FFA Ed25519_Base c
BaseField (AcEd25519 c)
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
y2
        prod4 :: FFA Ed25519_Base c
prod4 = FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
d FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
prodx FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
prody
        x3 :: FFA Ed25519_Base c
x3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
BaseField (AcEd25519 c)
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
y2 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
BaseField (AcEd25519 c)
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
x2) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
prod4)
        y3 :: FFA Ed25519_Base c
y3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
prody FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
prodx) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
prod4)

acDouble25519
    :: forall c
    .  Symbolic c
    => NFData (c (V.Vector Size))
    => Point (AcEd25519 c)
    -> Point (AcEd25519 c)
acDouble25519 :: forall (c :: (Type -> Type) -> Type).
(Symbolic c, NFData (c (Vector Size))) =>
Point (AcEd25519 c) -> Point (AcEd25519 c)
acDouble25519 (Point BaseField (AcEd25519 c)
x1 BaseField (AcEd25519 c)
y1 BooleanOf (AcEd25519 c)
isInf) =
    if Bool c
BooleanOf (AcEd25519 c)
isInf then Point (AcEd25519 c)
forall plane. ProjectivePlanar plane => plane
inf else FFA Ed25519_Base c -> FFA Ed25519_Base c -> Point (AcEd25519 c)
forall field plane. Planar field plane => field -> field -> plane
point FFA Ed25519_Base c
x3 FFA Ed25519_Base c
y3
    where
        xsq :: FFA Ed25519_Base c
xsq = FFA Ed25519_Base c
BaseField (AcEd25519 c)
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
x1
        ysq :: FFA Ed25519_Base c
ysq = FFA Ed25519_Base c
BaseField (AcEd25519 c)
y1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
y1
        xy :: FFA Ed25519_Base c
xy =  FFA Ed25519_Base c
BaseField (AcEd25519 c)
x1 FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
BaseField (AcEd25519 c)
y1

        -- Note: due to our laws for finv, division below is going to work exactly as it should
        -- if the point is (0, 0)
        x3 :: FFA Ed25519_Base c
x3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
xy FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
xy) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
xsq FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
ysq)
        y3 :: FFA Ed25519_Base c
y3 = FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. NFData a => a -> a
force (FFA Ed25519_Base c -> FFA Ed25519_Base c)
-> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a b. (a -> b) -> a -> b
$ (FFA Ed25519_Base c
ysq FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
xsq) FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. Field a => a -> a -> a
// (FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveSemigroup a => a -> a -> a
+ FFA Ed25519_Base c
forall a. MultiplicativeMonoid a => a
one FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
forall (ctx :: (Type -> Type) -> Type).
Symbolic ctx =>
FFA Ed25519_Base ctx
a FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. MultiplicativeSemigroup a => a -> a -> a
* FFA Ed25519_Base c
xsq  FFA Ed25519_Base c -> FFA Ed25519_Base c -> FFA Ed25519_Base c
forall a. AdditiveGroup a => a -> a -> a
- FFA Ed25519_Base c
ysq)