{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE TypeOperators       #-}
module ZkFold.Symbolic.Algorithms.ECDSA.ECDSA where
import           Data.Type.Equality
import           GHC.Base                                (($))
import           GHC.TypeLits                            (KnownNat, Log2)
import qualified GHC.TypeNats

import           ZkFold.Base.Algebra.Basic.Class
import           ZkFold.Base.Algebra.Basic.Number        (value)
import           ZkFold.Base.Algebra.EllipticCurve.Class (EllipticCurve (BaseField, gen), Point (Inf, Point))
import qualified ZkFold.Symbolic.Class                   as S
import           ZkFold.Symbolic.Data.Bool
import           ZkFold.Symbolic.Data.ByteString         (ByteString)
import           ZkFold.Symbolic.Data.Combinators        (Iso (..), NumberOfRegisters, RegisterSize (Auto))
import           ZkFold.Symbolic.Data.Eq
import           ZkFold.Symbolic.Data.FieldElement       (FieldElement)
import           ZkFold.Symbolic.Data.UInt               (UInt, eea)

ecdsaVerify :: forall curve n c . (
      S.Symbolic c
    , KnownNat n
    , EllipticCurve curve
    , BaseField curve ~ UInt 256 'Auto c
    , Scale (FieldElement c) (Point curve)
    , Log2 (Order (S.BaseField c) GHC.TypeNats.- 1) ~ 255
    , SemiEuclidean (UInt 256 'Auto c)
    , KnownNat (NumberOfRegisters (S.BaseField c) 256 'Auto)
    )
    => Point curve
    -> ByteString 256 c
    -> (UInt 256 'Auto c, UInt 256 'Auto c)
    -> Bool c
ecdsaVerify :: forall {k} (curve :: k) (n :: Nat) (c :: (Type -> Type) -> Type).
(Symbolic c, KnownNat n, EllipticCurve curve,
 BaseField curve ~ UInt 256 'Auto c,
 Scale (FieldElement c) (Point curve),
 Log2 (Order (BaseField c) - 1) ~ 255,
 SemiEuclidean (UInt 256 'Auto c),
 KnownNat (NumberOfRegisters (BaseField c) 256 'Auto)) =>
Point curve
-> ByteString 256 c
-> (UInt 256 'Auto c, UInt 256 'Auto c)
-> Bool c
ecdsaVerify Point curve
publicKey ByteString 256 c
message (UInt 256 'Auto c
r, UInt 256 'Auto c
s) = case Point curve
c of
                Point curve
Inf       -> Bool c
forall b. BoolType b => b
false
                Point BaseField curve
x BaseField curve
_ -> UInt 256 'Auto c
r UInt 256 'Auto c -> UInt 256 'Auto c -> Bool c
forall b a. Eq b a => a -> a -> b
== (BaseField curve
UInt 256 'Auto c
x UInt 256 'Auto c -> UInt 256 'Auto c -> UInt 256 'Auto c
forall a. SemiEuclidean a => a -> a -> a
`mod` UInt 256 'Auto c
n)
    where
        n :: UInt 256 'Auto c
n = Nat -> UInt 256 'Auto c
forall a b. FromConstant a b => a -> b
fromConstant (Nat -> UInt 256 'Auto c) -> Nat -> UInt 256 'Auto c
forall a b. (a -> b) -> a -> b
$ forall (n :: Nat). KnownNat n => Nat
value @n

        g :: Point curve
g = Point curve
forall {k} (curve :: k). EllipticCurve curve => Point curve
gen

        (UInt 256 'Auto c
sInv, UInt 256 'Auto c
_, UInt 256 'Auto c
_) = UInt 256 'Auto c
-> UInt 256 'Auto c
-> (UInt 256 'Auto c, UInt 256 'Auto c, UInt 256 'Auto c)
forall (n :: Nat) (c :: (Type -> Type) -> Type)
       (r :: RegisterSize).
(Symbolic c, SemiEuclidean (UInt n r c), KnownNat n,
 KnownNat (NumberOfRegisters (BaseField c) n r),
 AdditiveGroup (UInt n r c), Eq (Bool c) (UInt n r c)) =>
UInt n r c -> UInt n r c -> (UInt n r c, UInt n r c, UInt n r c)
eea UInt 256 'Auto c
s UInt 256 'Auto c
n

        u :: UInt 256 'Auto c
u = (ByteString 256 c -> UInt 256 'Auto c
forall a b. Iso a b => a -> b
from ByteString 256 c
message UInt 256 'Auto c -> UInt 256 'Auto c -> UInt 256 'Auto c
forall a. MultiplicativeSemigroup a => a -> a -> a
* UInt 256 'Auto c
sInv) UInt 256 'Auto c -> UInt 256 'Auto c -> UInt 256 'Auto c
forall a. SemiEuclidean a => a -> a -> a
`mod` UInt 256 'Auto c
n

        v :: UInt 256 'Auto c
v = UInt 256 'Auto c
r UInt 256 'Auto c -> UInt 256 'Auto c -> UInt 256 'Auto c
forall a. MultiplicativeSemigroup a => a -> a -> a
* UInt 256 'Auto c
sInv UInt 256 'Auto c -> UInt 256 'Auto c -> UInt 256 'Auto c
forall a. SemiEuclidean a => a -> a -> a
`mod` UInt 256 'Auto c
n

        c :: Point curve
c = (UInt 256 'Auto c -> FieldElement c
forall a b. Iso a b => a -> b
from UInt 256 'Auto c
u :: FieldElement c) FieldElement c -> Point curve -> Point curve
forall b a. Scale b a => b -> a -> a
`scale` Point curve
g Point curve -> Point curve -> Point curve
forall a. AdditiveSemigroup a => a -> a -> a
+ (UInt 256 'Auto c -> FieldElement c
forall a b. Iso a b => a -> b
from UInt 256 'Auto c
v :: FieldElement c) FieldElement c -> Point curve -> Point curve
forall b a. Scale b a => b -> a -> a
`scale` Point curve
publicKey